roam_codegen/targets/swift/
encode.rs

1//! Swift encoding expression generation.
2//!
3//! Generates Swift code that encodes Rust types into byte arrays.
4
5use facet_core::{ScalarType, Shape};
6use heck::ToLowerCamelCase;
7use roam_schema::{
8    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
9};
10
11/// Generate a Swift encode expression for a given shape and value.
12pub fn generate_encode_expr(shape: &'static Shape, value: &str) -> String {
13    if is_bytes(shape) {
14        return format!("encodeBytes(Array({value}))");
15    }
16
17    match classify_shape(shape) {
18        ShapeKind::Scalar(scalar) => {
19            let encode_fn = swift_encode_fn(scalar);
20            format!("{encode_fn}({value})")
21        }
22        ShapeKind::List { element }
23        | ShapeKind::Slice { element }
24        | ShapeKind::Array { element, .. } => {
25            let inner_encode = generate_encode_closure(element);
26            format!("encodeVec({value}, encoder: {inner_encode})")
27        }
28        ShapeKind::Option { inner } => {
29            let inner_encode = generate_encode_closure(inner);
30            format!("encodeOption({value}, encoder: {inner_encode})")
31        }
32        ShapeKind::Tuple { elements } if elements.len() == 2 => {
33            let a_encode = generate_encode_closure(elements[0].shape);
34            let b_encode = generate_encode_closure(elements[1].shape);
35            format!("{a_encode}({value}.0) + {b_encode}({value}.1)")
36        }
37        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
38            let a_encode = generate_encode_closure(fields[0].shape());
39            let b_encode = generate_encode_closure(fields[1].shape());
40            format!("{a_encode}({value}.0) + {b_encode}({value}.1)")
41        }
42        ShapeKind::Struct(StructInfo { fields, .. }) => {
43            // Encode each field and concatenate
44            let field_encodes: Vec<String> = fields
45                .iter()
46                .map(|f| {
47                    let field_name = f.name.to_lower_camel_case();
48                    generate_encode_expr(f.shape(), &format!("{value}.{field_name}"))
49                })
50                .collect();
51            if field_encodes.is_empty() {
52                "[]".into()
53            } else {
54                field_encodes.join(" + ")
55            }
56        }
57        ShapeKind::Pointer { pointee } => generate_encode_expr(pointee, value),
58        ShapeKind::Result { ok, err } => {
59            // Encode Result<T, E> - discriminant 0 = Ok, 1 = Err
60            let ok_encode = generate_encode_closure(ok);
61            let err_encode = generate_encode_closure(err);
62            format!(
63                "{{ switch {value} {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1)] + {err_encode}(e) }} }}()"
64            )
65        }
66        _ => "[]".into(), // fallback
67    }
68}
69
70/// Generate a Swift encode closure for use with encodeVec, encodeOption, etc.
71pub fn generate_encode_closure(shape: &'static Shape) -> String {
72    if is_bytes(shape) {
73        return "{ encodeBytes(Array($0)) }".into();
74    }
75
76    match classify_shape(shape) {
77        ShapeKind::Scalar(scalar) => {
78            let encode_fn = swift_encode_fn(scalar);
79            format!("{{ {encode_fn}($0) }}")
80        }
81        ShapeKind::List { element } | ShapeKind::Slice { element } => {
82            let inner = generate_encode_closure(element);
83            format!("{{ encodeVec($0, encoder: {inner}) }}")
84        }
85        ShapeKind::Option { inner } => {
86            let inner_closure = generate_encode_closure(inner);
87            format!("{{ encodeOption($0, encoder: {inner_closure}) }}")
88        }
89        ShapeKind::Tuple { elements } if elements.len() == 2 => {
90            let a_encode = generate_encode_closure(elements[0].shape);
91            let b_encode = generate_encode_closure(elements[1].shape);
92            format!("{{ {a_encode}($0.0) + {b_encode}($0.1) }}")
93        }
94        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
95            let a_encode = generate_encode_closure(fields[0].shape());
96            let b_encode = generate_encode_closure(fields[1].shape());
97            format!("{{ {a_encode}($0.0) + {b_encode}($0.1) }}")
98        }
99        ShapeKind::Struct(StructInfo { fields, .. }) => {
100            // Generate inline struct encode closure
101            let field_encodes: Vec<String> = fields
102                .iter()
103                .map(|f| {
104                    let field_name = f.name.to_lower_camel_case();
105                    generate_encode_expr(f.shape(), &format!("$0.{field_name}"))
106                })
107                .collect();
108            if field_encodes.is_empty() {
109                "{ _ in [] }".into()
110            } else {
111                format!("{{ {} }}", field_encodes.join(" + "))
112            }
113        }
114        ShapeKind::Enum(EnumInfo {
115            name: Some(_name),
116            variants,
117            ..
118        }) => {
119            // Generate inline enum encode closure with switch
120            let mut code = "{ v in\n    switch v {\n".to_string();
121            for (i, v) in variants.iter().enumerate() {
122                let variant_name = v.name.to_lower_camel_case();
123                match classify_variant(v) {
124                    VariantKind::Unit => {
125                        code.push_str(&format!(
126                            "    case .{variant_name}:\n        return [UInt8({i})]\n"
127                        ));
128                    }
129                    VariantKind::Newtype { inner } => {
130                        let inner_encode = generate_encode_expr(inner, "val");
131                        code.push_str(&format!(
132                            "    case .{variant_name}(let val):\n        return [UInt8({i})] + {inner_encode}\n"
133                        ));
134                    }
135                    VariantKind::Tuple { fields } => {
136                        let bindings: Vec<String> =
137                            (0..fields.len()).map(|j| format!("f{j}")).collect();
138                        let field_encodes: Vec<String> = fields
139                            .iter()
140                            .enumerate()
141                            .map(|(j, f)| generate_encode_expr(f.shape(), &format!("f{j}")))
142                            .collect();
143                        code.push_str(&format!(
144                            "    case .{variant_name}({}):\n        return [UInt8({i})] + {}\n",
145                            bindings
146                                .iter()
147                                .map(|b| format!("let {b}"))
148                                .collect::<Vec<_>>()
149                                .join(", "),
150                            field_encodes.join(" + ")
151                        ));
152                    }
153                    VariantKind::Struct { fields } => {
154                        let bindings: Vec<String> = fields
155                            .iter()
156                            .map(|f| f.name.to_lower_camel_case())
157                            .collect();
158                        let field_encodes: Vec<String> = fields
159                            .iter()
160                            .map(|f| {
161                                let field_name = f.name.to_lower_camel_case();
162                                generate_encode_expr(f.shape(), &field_name)
163                            })
164                            .collect();
165                        code.push_str(&format!(
166                            "    case .{variant_name}({}):\n        return [UInt8({i})] + {}\n",
167                            bindings
168                                .iter()
169                                .map(|b| format!("let {b}"))
170                                .collect::<Vec<_>>()
171                                .join(", "),
172                            field_encodes.join(" + ")
173                        ));
174                    }
175                }
176            }
177            code.push_str("    }\n}");
178            code
179        }
180        ShapeKind::Pointer { pointee } => generate_encode_closure(pointee),
181        ShapeKind::Result { ok, err } => {
182            let ok_encode = generate_encode_closure(ok);
183            let err_encode = generate_encode_closure(err);
184            format!(
185                "{{ switch $0 {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1)] + {err_encode}(e) }} }}"
186            )
187        }
188        _ => "{ _ in [] }".into(), // fallback
189    }
190}
191
192/// Get the Swift encode function name for a scalar type.
193pub fn swift_encode_fn(scalar: ScalarType) -> &'static str {
194    match scalar {
195        ScalarType::Bool => "encodeBool",
196        ScalarType::U8 => "encodeU8",
197        ScalarType::I8 => "encodeI8",
198        ScalarType::U16 => "encodeU16",
199        ScalarType::I16 => "encodeI16",
200        ScalarType::U32 => "encodeU32",
201        ScalarType::I32 => "encodeI32",
202        ScalarType::U64 | ScalarType::USize => "encodeVarint",
203        ScalarType::I64 | ScalarType::ISize => "encodeI64",
204        ScalarType::F32 => "encodeF32",
205        ScalarType::F64 => "encodeF64",
206        ScalarType::Char | ScalarType::Str | ScalarType::CowStr | ScalarType::String => {
207            "encodeString"
208        }
209        ScalarType::Unit => "{ _ in [] }",
210        _ => "encodeBytes", // fallback
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use facet::Facet;
218
219    #[test]
220    fn test_encode_primitives() {
221        assert_eq!(
222            generate_encode_expr(<bool as Facet>::SHAPE, "x"),
223            "encodeBool(x)"
224        );
225        assert_eq!(
226            generate_encode_expr(<u32 as Facet>::SHAPE, "x"),
227            "encodeU32(x)"
228        );
229        assert_eq!(
230            generate_encode_expr(<String as Facet>::SHAPE, "x"),
231            "encodeString(x)"
232        );
233    }
234
235    #[test]
236    fn test_encode_vec() {
237        let result = generate_encode_expr(<Vec<i32> as Facet>::SHAPE, "items");
238        assert!(result.contains("encodeVec"));
239        assert!(result.contains("encodeI32"));
240    }
241
242    #[test]
243    fn test_encode_option() {
244        let result = generate_encode_expr(<Option<String> as Facet>::SHAPE, "val");
245        assert!(result.contains("encodeOption"));
246        assert!(result.contains("encodeString"));
247    }
248
249    #[test]
250    fn test_encode_bytes() {
251        let result = generate_encode_expr(<Vec<u8> as Facet>::SHAPE, "data");
252        assert_eq!(result, "encodeBytes(Array(data))");
253    }
254}