roam_codegen/targets/swift/
decode.rs

1//! Swift decoding statement generation.
2//!
3//! Generates Swift code that decodes byte arrays into Rust types.
4
5use facet_core::{ScalarType, Shape};
6use heck::ToLowerCamelCase;
7use roam_schema::{
8    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
9};
10
11/// Generate a Swift decode statement for a given shape.
12/// Returns code that decodes from `payload` at `offset` into a variable named `var_name`.
13pub fn generate_decode_stmt(shape: &'static Shape, var_name: &str, indent: &str) -> String {
14    generate_decode_stmt_from(shape, var_name, indent, "payload")
15}
16
17/// Generate a Swift decode statement for a given shape from a specific data variable.
18/// Returns code that decodes from `data_var` at `offset` into a variable named `var_name`.
19pub fn generate_decode_stmt_from(
20    shape: &'static Shape,
21    var_name: &str,
22    indent: &str,
23    data_var: &str,
24) -> String {
25    // Check for bytes first
26    if is_bytes(shape) {
27        return format!(
28            "{indent}let {var_name} = try decodeBytes(from: {data_var}, offset: &offset)\n"
29        );
30    }
31
32    match classify_shape(shape) {
33        ShapeKind::Scalar(scalar) => {
34            let decode_fn = swift_decode_fn(scalar);
35            format!("{indent}let {var_name} = try {decode_fn}(from: {data_var}, offset: &offset)\n")
36        }
37        ShapeKind::List { element }
38        | ShapeKind::Slice { element }
39        | ShapeKind::Array { element, .. } => {
40            let inner_decode = generate_decode_closure(element);
41            format!(
42                "{indent}let {var_name} = try decodeVec(from: {data_var}, offset: &offset, decoder: {inner_decode})\n"
43            )
44        }
45        ShapeKind::Option { inner } => {
46            let inner_decode = generate_decode_closure(inner);
47            format!(
48                "{indent}let {var_name} = try decodeOption(from: {data_var}, offset: &offset, decoder: {inner_decode})\n"
49            )
50        }
51        ShapeKind::Tuple { elements } if elements.len() == 2 => {
52            let a_decode = generate_decode_closure(elements[0].shape);
53            let b_decode = generate_decode_closure(elements[1].shape);
54            format!(
55                "{indent}let {var_name} = try decodeTuple2(from: {data_var}, offset: &offset, decoderA: {a_decode}, decoderB: {b_decode})\n"
56            )
57        }
58        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
59            let a_decode = generate_decode_closure(fields[0].shape());
60            let b_decode = generate_decode_closure(fields[1].shape());
61            format!(
62                "{indent}let {var_name} = try decodeTuple2(from: {data_var}, offset: &offset, decoderA: {a_decode}, decoderB: {b_decode})\n"
63            )
64        }
65        ShapeKind::Struct(StructInfo {
66            name: Some(name),
67            fields,
68            ..
69        }) => {
70            // Named struct - decode fields inline and construct
71            let mut out = String::new();
72            for f in fields.iter() {
73                let field_name = f.name.to_lower_camel_case();
74                out.push_str(&generate_decode_stmt_from(
75                    f.shape(),
76                    &format!("_{var_name}_{field_name}"),
77                    indent,
78                    data_var,
79                ));
80            }
81            let field_inits: Vec<String> = fields
82                .iter()
83                .map(|f| {
84                    let field_name = f.name.to_lower_camel_case();
85                    format!("{field_name}: _{var_name}_{field_name}")
86                })
87                .collect();
88            out.push_str(&format!(
89                "{indent}let {var_name} = {name}({})\n",
90                field_inits.join(", ")
91            ));
92            out
93        }
94        ShapeKind::Enum(EnumInfo {
95            name: Some(name),
96            variants,
97            ..
98        }) => {
99            // Named enum - decode discriminant then decode variant
100            let mut out = String::new();
101            out.push_str(&format!(
102                "{indent}let _{var_name}_disc = try decodeU8(from: {data_var}, offset: &offset)\n"
103            ));
104            out.push_str(&format!("{indent}let {var_name}: {name}\n"));
105            out.push_str(&format!("{indent}switch _{var_name}_disc {{\n"));
106            for (i, v) in variants.iter().enumerate() {
107                out.push_str(&format!("{indent}case {i}:\n"));
108                match classify_variant(v) {
109                    VariantKind::Unit => {
110                        out.push_str(&format!(
111                            "{indent}    {var_name} = .{}\n",
112                            v.name.to_lower_camel_case()
113                        ));
114                    }
115                    VariantKind::Newtype { inner } => {
116                        out.push_str(&generate_decode_stmt_from(
117                            inner,
118                            &format!("_{var_name}_val"),
119                            &format!("{indent}    "),
120                            data_var,
121                        ));
122                        out.push_str(&format!(
123                            "{indent}    {var_name} = .{}(_{var_name}_val)\n",
124                            v.name.to_lower_camel_case()
125                        ));
126                    }
127                    VariantKind::Tuple { fields } => {
128                        for (j, f) in fields.iter().enumerate() {
129                            out.push_str(&generate_decode_stmt_from(
130                                f.shape(),
131                                &format!("_{var_name}_f{j}"),
132                                &format!("{indent}    "),
133                                data_var,
134                            ));
135                        }
136                        let args: Vec<String> = (0..fields.len())
137                            .map(|j| format!("_{var_name}_f{j}"))
138                            .collect();
139                        out.push_str(&format!(
140                            "{indent}    {var_name} = .{}({})\n",
141                            v.name.to_lower_camel_case(),
142                            args.join(", ")
143                        ));
144                    }
145                    VariantKind::Struct { fields } => {
146                        for f in fields.iter() {
147                            let field_name = f.name.to_lower_camel_case();
148                            out.push_str(&generate_decode_stmt_from(
149                                f.shape(),
150                                &format!("_{var_name}_{field_name}"),
151                                &format!("{indent}    "),
152                                data_var,
153                            ));
154                        }
155                        let args: Vec<String> = fields
156                            .iter()
157                            .map(|f| {
158                                let field_name = f.name.to_lower_camel_case();
159                                format!("{field_name}: _{var_name}_{field_name}")
160                            })
161                            .collect();
162                        out.push_str(&format!(
163                            "{indent}    {var_name} = .{}({})\n",
164                            v.name.to_lower_camel_case(),
165                            args.join(", ")
166                        ));
167                    }
168                }
169            }
170            out.push_str(&format!("{indent}default:\n"));
171            out.push_str(&format!(
172                "{indent}    throw RoamError.decodeError(\"unknown enum variant\")\n"
173            ));
174            out.push_str(&format!("{indent}}}\n"));
175            out
176        }
177        ShapeKind::Pointer { pointee } => generate_decode_stmt(pointee, var_name, indent),
178        ShapeKind::Result { ok, err } => {
179            // Decode Result<T, E> - discriminant 0 = Ok, 1 = Err
180            let ok_type = super::types::swift_type_base(ok);
181            let err_type = super::types::swift_type_base(err);
182            let mut out = String::new();
183            out.push_str(&format!(
184                "{indent}let _{var_name}_disc = try decodeU8(from: {data_var}, offset: &offset)\n"
185            ));
186            out.push_str(&format!(
187                "{indent}let {var_name}: Result<{ok_type}, {err_type}>\n"
188            ));
189            out.push_str(&format!("{indent}switch _{var_name}_disc {{\n"));
190            out.push_str(&format!("{indent}case 0:\n"));
191            out.push_str(&generate_decode_stmt_from(
192                ok,
193                &format!("_{var_name}_ok"),
194                &format!("{indent}    "),
195                data_var,
196            ));
197            out.push_str(&format!(
198                "{indent}    {var_name} = .success(_{var_name}_ok)\n"
199            ));
200            out.push_str(&format!("{indent}case 1:\n"));
201            out.push_str(&generate_decode_stmt_from(
202                err,
203                &format!("_{var_name}_err"),
204                &format!("{indent}    "),
205                data_var,
206            ));
207            out.push_str(&format!(
208                "{indent}    {var_name} = .failure(_{var_name}_err)\n"
209            ));
210            out.push_str(&format!("{indent}default:\n"));
211            out.push_str(&format!(
212                "{indent}    throw RoamError.decodeError(\"invalid Result discriminant\")\n"
213            ));
214            out.push_str(&format!("{indent}}}\n"));
215            out
216        }
217        _ => {
218            // Fallback for unsupported types
219            format!("{indent}let {var_name}: Any = () // unsupported type\n")
220        }
221    }
222}
223
224/// Generate a Swift decode closure for use with decodeVec, decodeOption, etc.
225pub fn generate_decode_closure(shape: &'static Shape) -> String {
226    if is_bytes(shape) {
227        return "{ data, off in try decodeBytes(from: data, offset: &off) }".into();
228    }
229
230    match classify_shape(shape) {
231        ShapeKind::Scalar(scalar) => {
232            let decode_fn = swift_decode_fn(scalar);
233            format!("{{ data, off in try {decode_fn}(from: data, offset: &off) }}")
234        }
235        ShapeKind::List { element } | ShapeKind::Slice { element } => {
236            let inner = generate_decode_closure(element);
237            format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
238        }
239        ShapeKind::Option { inner } => {
240            let inner_closure = generate_decode_closure(inner);
241            format!(
242                "{{ data, off in try decodeOption(from: data, offset: &off, decoder: {inner_closure}) }}"
243            )
244        }
245        ShapeKind::Tuple { elements } if elements.len() == 2 => {
246            let a_decode = generate_decode_closure(elements[0].shape);
247            let b_decode = generate_decode_closure(elements[1].shape);
248            format!(
249                "{{ data, off in try decodeTuple2(from: data, offset: &off, decoderA: {a_decode}, decoderB: {b_decode}) }}"
250            )
251        }
252        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
253            let a_decode = generate_decode_closure(fields[0].shape());
254            let b_decode = generate_decode_closure(fields[1].shape());
255            format!(
256                "{{ data, off in try decodeTuple2(from: data, offset: &off, decoderA: {a_decode}, decoderB: {b_decode}) }}"
257            )
258        }
259        ShapeKind::Struct(StructInfo {
260            name: Some(name),
261            fields,
262            ..
263        }) => {
264            // Generate inline struct decode closure
265            let mut code = "{ data, off in\n".to_string();
266            for f in fields.iter() {
267                let field_name = f.name.to_lower_camel_case();
268                let decode_call = generate_inline_decode(f.shape(), "data", "off");
269                code.push_str(&format!("    let _{field_name} = try {decode_call}\n"));
270            }
271            let field_inits: Vec<String> = fields
272                .iter()
273                .map(|f| {
274                    let field_name = f.name.to_lower_camel_case();
275                    format!("{field_name}: _{field_name}")
276                })
277                .collect();
278            code.push_str(&format!(
279                "    return {name}({})\n}}",
280                field_inits.join(", ")
281            ));
282            code
283        }
284        ShapeKind::Enum(EnumInfo {
285            name: Some(name),
286            variants,
287            ..
288        }) => {
289            // Generate inline enum decode closure
290            let mut code = format!(
291                "{{ data, off in\n    let disc = try decodeU8(from: data, offset: &off)\n    let result: {name}\n    switch disc {{\n"
292            );
293            for (i, v) in variants.iter().enumerate() {
294                code.push_str(&format!("    case {i}:\n"));
295                match classify_variant(v) {
296                    VariantKind::Unit => {
297                        code.push_str(&format!(
298                            "        result = .{}\n",
299                            v.name.to_lower_camel_case()
300                        ));
301                    }
302                    VariantKind::Newtype { inner } => {
303                        let inner_decode = generate_inline_decode(inner, "data", "off");
304                        code.push_str(&format!(
305                            "        let val = try {inner_decode}\n        result = .{}(val)\n",
306                            v.name.to_lower_camel_case()
307                        ));
308                    }
309                    VariantKind::Tuple { fields } => {
310                        for (j, f) in fields.iter().enumerate() {
311                            let inner_decode = generate_inline_decode(f.shape(), "data", "off");
312                            code.push_str(&format!("        let f{j} = try {inner_decode}\n"));
313                        }
314                        let args: Vec<String> =
315                            (0..fields.len()).map(|j| format!("f{j}")).collect();
316                        code.push_str(&format!(
317                            "        result = .{}({})\n",
318                            v.name.to_lower_camel_case(),
319                            args.join(", ")
320                        ));
321                    }
322                    VariantKind::Struct { fields } => {
323                        for f in fields.iter() {
324                            let field_name = f.name.to_lower_camel_case();
325                            let inner_decode = generate_inline_decode(f.shape(), "data", "off");
326                            code.push_str(&format!(
327                                "        let _{field_name} = try {inner_decode}\n"
328                            ));
329                        }
330                        let args: Vec<String> = fields
331                            .iter()
332                            .map(|f| {
333                                let field_name = f.name.to_lower_camel_case();
334                                format!("{field_name}: _{field_name}")
335                            })
336                            .collect();
337                        code.push_str(&format!(
338                            "        result = .{}({})\n",
339                            v.name.to_lower_camel_case(),
340                            args.join(", ")
341                        ));
342                    }
343                }
344            }
345            code.push_str("    default:\n        throw RoamError.decodeError(\"unknown enum variant\")\n    }\n    return result\n}");
346            code
347        }
348        ShapeKind::Pointer { pointee } => generate_decode_closure(pointee),
349        _ => "{ _, _ in throw RoamError.decodeError(\"unsupported type\") }".into(),
350    }
351}
352
353/// Generate inline decode expression (for use in closures).
354pub fn generate_inline_decode(shape: &'static Shape, data_var: &str, offset_var: &str) -> String {
355    if is_bytes(shape) {
356        return format!("decodeBytes(from: {data_var}, offset: &{offset_var})");
357    }
358
359    match classify_shape(shape) {
360        ShapeKind::Scalar(scalar) => {
361            let decode_fn = swift_decode_fn(scalar);
362            format!("{decode_fn}(from: {data_var}, offset: &{offset_var})")
363        }
364        ShapeKind::List { element } | ShapeKind::Slice { element } => {
365            let inner = generate_decode_closure(element);
366            format!("decodeVec(from: {data_var}, offset: &{offset_var}, decoder: {inner})")
367        }
368        ShapeKind::Option { inner } => {
369            let inner_closure = generate_decode_closure(inner);
370            format!(
371                "decodeOption(from: {data_var}, offset: &{offset_var}, decoder: {inner_closure})"
372            )
373        }
374        ShapeKind::Pointer { pointee } => generate_inline_decode(pointee, data_var, offset_var),
375        _ => "{ throw RoamError.decodeError(\"unsupported\") }()".to_string(),
376    }
377}
378
379/// Get the Swift decode function name for a scalar type.
380pub fn swift_decode_fn(scalar: ScalarType) -> &'static str {
381    match scalar {
382        ScalarType::Bool => "decodeBool",
383        ScalarType::U8 => "decodeU8",
384        ScalarType::I8 => "decodeI8",
385        ScalarType::U16 => "decodeU16",
386        ScalarType::I16 => "decodeI16",
387        ScalarType::U32 => "decodeU32",
388        ScalarType::I32 => "decodeI32",
389        ScalarType::U64 | ScalarType::USize => "decodeVarint",
390        ScalarType::I64 | ScalarType::ISize => "decodeI64",
391        ScalarType::F32 => "decodeF32",
392        ScalarType::F64 => "decodeF64",
393        ScalarType::Char | ScalarType::Str | ScalarType::CowStr | ScalarType::String => {
394            "decodeString"
395        }
396        ScalarType::Unit => "{ _, _ in () }",
397        _ => "decodeBytes", // fallback
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use facet::Facet;
405
406    #[test]
407    fn test_decode_primitives() {
408        let result = generate_decode_stmt(<bool as Facet>::SHAPE, "x", "    ");
409        assert!(result.contains("decodeBool"));
410        assert!(result.contains("let x"));
411
412        let result = generate_decode_stmt(<String as Facet>::SHAPE, "msg", "    ");
413        assert!(result.contains("decodeString"));
414        assert!(result.contains("let msg"));
415    }
416
417    #[test]
418    fn test_decode_vec() {
419        let result = generate_decode_stmt(<Vec<i32> as Facet>::SHAPE, "items", "    ");
420        assert!(result.contains("decodeVec"));
421        assert!(result.contains("decodeI32"));
422    }
423
424    #[test]
425    fn test_decode_option() {
426        let result = generate_decode_stmt(<Option<String> as Facet>::SHAPE, "val", "    ");
427        assert!(result.contains("decodeOption"));
428        assert!(result.contains("decodeString"));
429    }
430
431    #[test]
432    fn test_decode_bytes() {
433        let result = generate_decode_stmt(<Vec<u8> as Facet>::SHAPE, "data", "    ");
434        assert!(result.contains("decodeBytes"));
435    }
436
437    #[test]
438    fn test_inline_decode() {
439        let result = generate_inline_decode(<u32 as Facet>::SHAPE, "buf", "pos");
440        assert_eq!(result, "decodeU32(from: buf, offset: &pos)");
441    }
442}