Skip to main content

roam_codegen/targets/swift/
encode.rs

1//! Swift encoding expression generation.
2//!
3//! Generates Swift code that encodes Rust types into byte arrays.
4
5use facet_core::{ScalarType, Shape};
6use heck::ToLowerCamelCase;
7use roam_types::{
8    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
9};
10
11/// Generate a Swift encode expression for a given shape and value.
12pub fn generate_encode_expr(shape: &'static Shape, value: &str) -> String {
13    if is_bytes(shape) {
14        return format!("encodeBytes(Array({value}))");
15    }
16
17    match classify_shape(shape) {
18        ShapeKind::Scalar(scalar) => {
19            let encode_fn = swift_encode_fn(scalar);
20            format!("{encode_fn}({value})")
21        }
22        ShapeKind::List { element }
23        | ShapeKind::Slice { element }
24        | ShapeKind::Array { element, .. } => {
25            let inner_encode = generate_encode_closure(element);
26            format!("encodeVec({value}, encoder: {inner_encode})")
27        }
28        ShapeKind::Option { inner } => {
29            let inner_encode = generate_encode_closure(inner);
30            format!("encodeOption({value}, encoder: {inner_encode})")
31        }
32        ShapeKind::Tuple { elements } if elements.len() == 2 => {
33            let a_encode = generate_encode_closure(elements[0].shape);
34            let b_encode = generate_encode_closure(elements[1].shape);
35            format!("{a_encode}({value}.0) + {b_encode}({value}.1)")
36        }
37        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
38            let a_encode = generate_encode_closure(fields[0].shape());
39            let b_encode = generate_encode_closure(fields[1].shape());
40            format!("{a_encode}({value}.0) + {b_encode}({value}.1)")
41        }
42        ShapeKind::Struct(StructInfo { fields, .. }) => {
43            // Encode each field and concatenate
44            let field_encodes: Vec<String> = fields
45                .iter()
46                .map(|f| {
47                    let field_name = f.name.to_lower_camel_case();
48                    generate_encode_expr(f.shape(), &format!("{value}.{field_name}"))
49                })
50                .collect();
51            if field_encodes.is_empty() {
52                "[]".into()
53            } else {
54                field_encodes.join(" + ")
55            }
56        }
57        ShapeKind::Enum(EnumInfo { .. }) => {
58            let encode_closure = generate_encode_closure(shape);
59            format!("{encode_closure}({value})")
60        }
61        ShapeKind::Pointer { pointee } => generate_encode_expr(pointee, value),
62        ShapeKind::Result { ok, err } => {
63            // Encode Result<T, E> - discriminant 0 = Ok, 1 = Err
64            let ok_encode = generate_encode_closure(ok);
65            let err_encode = generate_encode_closure(err);
66            format!(
67                "{{ switch {value} {{ case .success(let v): return encodeVarint(UInt64(0)) + {ok_encode}(v); case .failure(let e): return encodeVarint(UInt64(1)) + {err_encode}(e) }} }}()"
68            )
69        }
70        _ => "[]".into(), // fallback
71    }
72}
73
74/// Generate a Swift encode closure for use with encodeVec, encodeOption, etc.
75pub fn generate_encode_closure(shape: &'static Shape) -> String {
76    if is_bytes(shape) {
77        return "{ encodeBytes(Array($0)) }".into();
78    }
79
80    match classify_shape(shape) {
81        ShapeKind::Scalar(scalar) => {
82            let encode_fn = swift_encode_fn(scalar);
83            format!("{{ {encode_fn}($0) }}")
84        }
85        ShapeKind::List { element } | ShapeKind::Slice { element } => {
86            let inner = generate_encode_closure(element);
87            format!("{{ encodeVec($0, encoder: {inner}) }}")
88        }
89        ShapeKind::Option { inner } => {
90            let inner_closure = generate_encode_closure(inner);
91            format!("{{ encodeOption($0, encoder: {inner_closure}) }}")
92        }
93        ShapeKind::Tuple { elements } if elements.len() == 2 => {
94            let a_encode = generate_encode_closure(elements[0].shape);
95            let b_encode = generate_encode_closure(elements[1].shape);
96            format!("{{ {a_encode}($0.0) + {b_encode}($0.1) }}")
97        }
98        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
99            let a_encode = generate_encode_closure(fields[0].shape());
100            let b_encode = generate_encode_closure(fields[1].shape());
101            format!("{{ {a_encode}($0.0) + {b_encode}($0.1) }}")
102        }
103        ShapeKind::Struct(StructInfo { fields, .. }) => {
104            // Generate inline struct encode closure
105            let field_encodes: Vec<String> = fields
106                .iter()
107                .map(|f| {
108                    let field_name = f.name.to_lower_camel_case();
109                    generate_encode_expr(f.shape(), &format!("$0.{field_name}"))
110                })
111                .collect();
112            if field_encodes.is_empty() {
113                "{ _ in [] }".into()
114            } else {
115                format!("{{ {} }}", field_encodes.join(" + "))
116            }
117        }
118        ShapeKind::Enum(EnumInfo {
119            name: Some(_name),
120            variants,
121            ..
122        }) => {
123            // Generate inline enum encode closure with switch
124            let mut code = "{ v in\n    switch v {\n".to_string();
125            for (i, v) in variants.iter().enumerate() {
126                let variant_name = v.name.to_lower_camel_case();
127                match classify_variant(v) {
128                    VariantKind::Unit => {
129                        code.push_str(&format!(
130                            "    case .{variant_name}:\n        return encodeVarint(UInt64({i}))\n"
131                        ));
132                    }
133                    VariantKind::Newtype { inner } => {
134                        let inner_encode = generate_encode_expr(inner, "val");
135                        code.push_str(&format!(
136                            "    case .{variant_name}(let val):\n        return encodeVarint(UInt64({i})) + {inner_encode}\n"
137                        ));
138                    }
139                    VariantKind::Tuple { fields } => {
140                        let bindings: Vec<String> =
141                            (0..fields.len()).map(|j| format!("f{j}")).collect();
142                        let field_encodes: Vec<String> = fields
143                            .iter()
144                            .enumerate()
145                            .map(|(j, f)| generate_encode_expr(f.shape(), &format!("f{j}")))
146                            .collect();
147                        code.push_str(&format!(
148                            "    case .{variant_name}({}):\n        return encodeVarint(UInt64({i})) + {}\n",
149                            bindings
150                                .iter()
151                                .map(|b| format!("let {b}"))
152                                .collect::<Vec<_>>()
153                                .join(", "),
154                            field_encodes.join(" + ")
155                        ));
156                    }
157                    VariantKind::Struct { fields } => {
158                        let bindings: Vec<String> = fields
159                            .iter()
160                            .map(|f| f.name.to_lower_camel_case())
161                            .collect();
162                        let field_encodes: Vec<String> = fields
163                            .iter()
164                            .map(|f| {
165                                let field_name = f.name.to_lower_camel_case();
166                                generate_encode_expr(f.shape(), &field_name)
167                            })
168                            .collect();
169                        code.push_str(&format!(
170                            "    case .{variant_name}({}):\n        return encodeVarint(UInt64({i})) + {}\n",
171                            bindings
172                                .iter()
173                                .map(|b| format!("let {b}"))
174                                .collect::<Vec<_>>()
175                                .join(", "),
176                            field_encodes.join(" + ")
177                        ));
178                    }
179                }
180            }
181            code.push_str("    }\n}");
182            code
183        }
184        ShapeKind::Pointer { pointee } => generate_encode_closure(pointee),
185        ShapeKind::Result { ok, err } => {
186            let ok_encode = generate_encode_closure(ok);
187            let err_encode = generate_encode_closure(err);
188            format!(
189                "{{ switch $0 {{ case .success(let v): return encodeVarint(UInt64(0)) + {ok_encode}(v); case .failure(let e): return encodeVarint(UInt64(1)) + {err_encode}(e) }} }}"
190            )
191        }
192        _ => "{ _ in [] }".into(), // fallback
193    }
194}
195
196/// Get the Swift encode function name for a scalar type.
197pub fn swift_encode_fn(scalar: ScalarType) -> &'static str {
198    match scalar {
199        ScalarType::Bool => "encodeBool",
200        ScalarType::U8 => "encodeU8",
201        ScalarType::I8 => "encodeI8",
202        ScalarType::U16 => "encodeU16",
203        ScalarType::I16 => "encodeI16",
204        ScalarType::U32 => "encodeU32",
205        ScalarType::I32 => "encodeI32",
206        ScalarType::U64 | ScalarType::USize => "encodeVarint",
207        ScalarType::I64 | ScalarType::ISize => "encodeI64",
208        ScalarType::F32 => "encodeF32",
209        ScalarType::F64 => "encodeF64",
210        ScalarType::Char | ScalarType::Str | ScalarType::CowStr | ScalarType::String => {
211            "encodeString"
212        }
213        ScalarType::Unit => "{ _ in [] }",
214        _ => "encodeBytes", // fallback
215    }
216}