Skip to main content

roam_codegen/targets/swift/
decode.rs

1//! Swift decoding statement generation.
2//!
3//! Generates Swift code that decodes byte arrays into Rust types.
4
5use facet_core::{ScalarType, Shape};
6use heck::ToLowerCamelCase;
7use roam_types::{
8    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
9};
10
11/// Generate a Swift decode statement for a given shape.
12/// Returns code that decodes from `payload` at `cursor` into a variable named `var_name`.
13pub fn generate_decode_stmt(shape: &'static Shape, var_name: &str, indent: &str) -> String {
14    generate_decode_stmt_with_cursor(shape, var_name, indent, "cursor")
15}
16
17/// Generate a Swift decode statement for a given shape using a custom cursor variable name.
18pub fn generate_decode_stmt_with_cursor(
19    shape: &'static Shape,
20    var_name: &str,
21    indent: &str,
22    cursor_var: &str,
23) -> String {
24    generate_decode_stmt_from_with_cursor(shape, var_name, indent, "payload", cursor_var)
25}
26
27/// Generate a Swift decode statement for a given shape from a specific data variable.
28/// Returns code that decodes from `data_var` at `cursor` into a variable named `var_name`.
29pub fn generate_decode_stmt_from(
30    shape: &'static Shape,
31    var_name: &str,
32    indent: &str,
33    data_var: &str,
34) -> String {
35    generate_decode_stmt_from_with_cursor(shape, var_name, indent, data_var, "cursor")
36}
37
38/// Generate a Swift decode statement for a given shape from a specific data variable
39/// and using a custom cursor variable name.
40pub fn generate_decode_stmt_from_with_cursor(
41    shape: &'static Shape,
42    var_name: &str,
43    indent: &str,
44    data_var: &str,
45    cursor_var: &str,
46) -> String {
47    // Check for bytes first
48    if is_bytes(shape) {
49        return format!(
50            "{indent}let {var_name} = try decodeBytes(from: {data_var}, offset: &{cursor_var})\n"
51        );
52    }
53
54    match classify_shape(shape) {
55        ShapeKind::Scalar(scalar) => {
56            let decode_fn = swift_decode_fn(scalar);
57            format!(
58                "{indent}let {var_name} = try {decode_fn}(from: {data_var}, offset: &{cursor_var})\n"
59            )
60        }
61        ShapeKind::List { element }
62        | ShapeKind::Slice { element }
63        | ShapeKind::Array { element, .. } => {
64            let inner_decode = generate_decode_closure(element);
65            format!(
66                "{indent}let {var_name} = try decodeVec(from: {data_var}, offset: &{cursor_var}, decoder: {inner_decode})\n"
67            )
68        }
69        ShapeKind::Option { inner } => {
70            let inner_decode = generate_decode_closure(inner);
71            format!(
72                "{indent}let {var_name} = try decodeOption(from: {data_var}, offset: &{cursor_var}, decoder: {inner_decode})\n"
73            )
74        }
75        ShapeKind::Tuple { elements } if elements.len() == 2 => {
76            let a_decode = generate_decode_closure(elements[0].shape);
77            let b_decode = generate_decode_closure(elements[1].shape);
78            format!(
79                "{indent}let {var_name} = try decodeTuple2(from: {data_var}, offset: &{cursor_var}, decoderA: {a_decode}, decoderB: {b_decode})\n"
80            )
81        }
82        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
83            let a_decode = generate_decode_closure(fields[0].shape());
84            let b_decode = generate_decode_closure(fields[1].shape());
85            format!(
86                "{indent}let {var_name} = try decodeTuple2(from: {data_var}, offset: &{cursor_var}, decoderA: {a_decode}, decoderB: {b_decode})\n"
87            )
88        }
89        ShapeKind::Struct(StructInfo {
90            name: Some(name),
91            fields,
92            ..
93        }) => {
94            // Named struct - decode fields inline and construct
95            let mut out = String::new();
96            for f in fields.iter() {
97                let field_name = f.name.to_lower_camel_case();
98                out.push_str(&generate_decode_stmt_from_with_cursor(
99                    f.shape(),
100                    &format!("_{var_name}_{field_name}"),
101                    indent,
102                    data_var,
103                    cursor_var,
104                ));
105            }
106            let field_inits: Vec<String> = fields
107                .iter()
108                .map(|f| {
109                    let field_name = f.name.to_lower_camel_case();
110                    format!("{field_name}: _{var_name}_{field_name}")
111                })
112                .collect();
113            out.push_str(&format!(
114                "{indent}let {var_name} = {name}({})\n",
115                field_inits.join(", ")
116            ));
117            out
118        }
119        ShapeKind::Enum(EnumInfo {
120            name: Some(name),
121            variants,
122            ..
123        }) => {
124            // Named enum - decode discriminant then decode variant
125            let mut out = String::new();
126            out.push_str(&format!(
127                "{indent}let _{var_name}_disc = try decodeVarint(from: {data_var}, offset: &{cursor_var})\n"
128            ));
129            out.push_str(&format!("{indent}let {var_name}: {name}\n"));
130            out.push_str(&format!("{indent}switch _{var_name}_disc {{\n"));
131            for (i, v) in variants.iter().enumerate() {
132                out.push_str(&format!("{indent}case {i}:\n"));
133                match classify_variant(v) {
134                    VariantKind::Unit => {
135                        out.push_str(&format!(
136                            "{indent}    {var_name} = .{}\n",
137                            v.name.to_lower_camel_case()
138                        ));
139                    }
140                    VariantKind::Newtype { inner } => {
141                        out.push_str(&generate_decode_stmt_from_with_cursor(
142                            inner,
143                            &format!("_{var_name}_val"),
144                            &format!("{indent}    "),
145                            data_var,
146                            cursor_var,
147                        ));
148                        out.push_str(&format!(
149                            "{indent}    {var_name} = .{}(_{var_name}_val)\n",
150                            v.name.to_lower_camel_case()
151                        ));
152                    }
153                    VariantKind::Tuple { fields } => {
154                        for (j, f) in fields.iter().enumerate() {
155                            out.push_str(&generate_decode_stmt_from_with_cursor(
156                                f.shape(),
157                                &format!("_{var_name}_f{j}"),
158                                &format!("{indent}    "),
159                                data_var,
160                                cursor_var,
161                            ));
162                        }
163                        let args: Vec<String> = (0..fields.len())
164                            .map(|j| format!("_{var_name}_f{j}"))
165                            .collect();
166                        out.push_str(&format!(
167                            "{indent}    {var_name} = .{}({})\n",
168                            v.name.to_lower_camel_case(),
169                            args.join(", ")
170                        ));
171                    }
172                    VariantKind::Struct { fields } => {
173                        for f in fields.iter() {
174                            let field_name = f.name.to_lower_camel_case();
175                            out.push_str(&generate_decode_stmt_from_with_cursor(
176                                f.shape(),
177                                &format!("_{var_name}_{field_name}"),
178                                &format!("{indent}    "),
179                                data_var,
180                                cursor_var,
181                            ));
182                        }
183                        let args: Vec<String> = fields
184                            .iter()
185                            .map(|f| {
186                                let field_name = f.name.to_lower_camel_case();
187                                format!("{field_name}: _{var_name}_{field_name}")
188                            })
189                            .collect();
190                        out.push_str(&format!(
191                            "{indent}    {var_name} = .{}({})\n",
192                            v.name.to_lower_camel_case(),
193                            args.join(", ")
194                        ));
195                    }
196                }
197            }
198            out.push_str(&format!("{indent}default:\n"));
199            out.push_str(&format!(
200                "{indent}    throw RoamError.decodeError(\"unknown enum variant\")\n"
201            ));
202            out.push_str(&format!("{indent}}}\n"));
203            out
204        }
205        ShapeKind::Pointer { pointee } => {
206            generate_decode_stmt_with_cursor(pointee, var_name, indent, cursor_var)
207        }
208        ShapeKind::Result { ok, err } => {
209            // Decode Result<T, E> - discriminant 0 = Ok, 1 = Err
210            let ok_type = super::types::swift_type_base(ok);
211            let err_type = super::types::swift_type_base(err);
212            let mut out = String::new();
213            out.push_str(&format!(
214                "{indent}let _{var_name}_disc = try decodeVarint(from: {data_var}, offset: &{cursor_var})\n"
215            ));
216            out.push_str(&format!(
217                "{indent}let {var_name}: Result<{ok_type}, {err_type}>\n"
218            ));
219            out.push_str(&format!("{indent}switch _{var_name}_disc {{\n"));
220            out.push_str(&format!("{indent}case 0:\n"));
221            out.push_str(&generate_decode_stmt_from_with_cursor(
222                ok,
223                &format!("_{var_name}_ok"),
224                &format!("{indent}    "),
225                data_var,
226                cursor_var,
227            ));
228            out.push_str(&format!(
229                "{indent}    {var_name} = .success(_{var_name}_ok)\n"
230            ));
231            out.push_str(&format!("{indent}case 1:\n"));
232            out.push_str(&generate_decode_stmt_from_with_cursor(
233                err,
234                &format!("_{var_name}_err"),
235                &format!("{indent}    "),
236                data_var,
237                cursor_var,
238            ));
239            out.push_str(&format!(
240                "{indent}    {var_name} = .failure(_{var_name}_err)\n"
241            ));
242            out.push_str(&format!("{indent}default:\n"));
243            out.push_str(&format!(
244                "{indent}    throw RoamError.decodeError(\"invalid Result discriminant\")\n"
245            ));
246            out.push_str(&format!("{indent}}}\n"));
247            out
248        }
249        _ => {
250            // Fallback for unsupported types
251            format!("{indent}let {var_name}: Any = () // unsupported type\n")
252        }
253    }
254}
255
256/// Generate a Swift decode closure for use with decodeVec, decodeOption, etc.
257pub fn generate_decode_closure(shape: &'static Shape) -> String {
258    if is_bytes(shape) {
259        return "{ data, off in try decodeBytes(from: data, offset: &off) }".into();
260    }
261
262    match classify_shape(shape) {
263        ShapeKind::Scalar(scalar) => {
264            let decode_fn = swift_decode_fn(scalar);
265            format!("{{ data, off in try {decode_fn}(from: data, offset: &off) }}")
266        }
267        ShapeKind::List { element } | ShapeKind::Slice { element } => {
268            let inner = generate_decode_closure(element);
269            format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
270        }
271        ShapeKind::Option { inner } => {
272            let inner_closure = generate_decode_closure(inner);
273            format!(
274                "{{ data, off in try decodeOption(from: data, offset: &off, decoder: {inner_closure}) }}"
275            )
276        }
277        ShapeKind::Tuple { elements } if elements.len() == 2 => {
278            let a_decode = generate_decode_closure(elements[0].shape);
279            let b_decode = generate_decode_closure(elements[1].shape);
280            format!(
281                "{{ data, off in try decodeTuple2(from: data, offset: &off, decoderA: {a_decode}, decoderB: {b_decode}) }}"
282            )
283        }
284        ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
285            let a_decode = generate_decode_closure(fields[0].shape());
286            let b_decode = generate_decode_closure(fields[1].shape());
287            format!(
288                "{{ data, off in try decodeTuple2(from: data, offset: &off, decoderA: {a_decode}, decoderB: {b_decode}) }}"
289            )
290        }
291        ShapeKind::Struct(StructInfo {
292            name: Some(name),
293            fields,
294            ..
295        }) => {
296            // Generate inline struct decode closure
297            let mut code = "{ data, off in\n".to_string();
298            for f in fields.iter() {
299                let field_name = f.name.to_lower_camel_case();
300                let decode_call = generate_inline_decode(f.shape(), "data", "off");
301                code.push_str(&format!("    let _{field_name} = try {decode_call}\n"));
302            }
303            let field_inits: Vec<String> = fields
304                .iter()
305                .map(|f| {
306                    let field_name = f.name.to_lower_camel_case();
307                    format!("{field_name}: _{field_name}")
308                })
309                .collect();
310            code.push_str(&format!(
311                "    return {name}({})\n}}",
312                field_inits.join(", ")
313            ));
314            code
315        }
316        ShapeKind::Enum(EnumInfo {
317            name: Some(name),
318            variants,
319            ..
320        }) => {
321            // Generate inline enum decode closure
322            let mut code = format!(
323                "{{ data, off in\n    let disc = try decodeVarint(from: data, offset: &off)\n    let result: {name}\n    switch disc {{\n"
324            );
325            for (i, v) in variants.iter().enumerate() {
326                code.push_str(&format!("    case {i}:\n"));
327                match classify_variant(v) {
328                    VariantKind::Unit => {
329                        code.push_str(&format!(
330                            "        result = .{}\n",
331                            v.name.to_lower_camel_case()
332                        ));
333                    }
334                    VariantKind::Newtype { inner } => {
335                        let inner_decode = generate_inline_decode(inner, "data", "off");
336                        code.push_str(&format!(
337                            "        let val = try {inner_decode}\n        result = .{}(val)\n",
338                            v.name.to_lower_camel_case()
339                        ));
340                    }
341                    VariantKind::Tuple { fields } => {
342                        for (j, f) in fields.iter().enumerate() {
343                            let inner_decode = generate_inline_decode(f.shape(), "data", "off");
344                            code.push_str(&format!("        let f{j} = try {inner_decode}\n"));
345                        }
346                        let args: Vec<String> =
347                            (0..fields.len()).map(|j| format!("f{j}")).collect();
348                        code.push_str(&format!(
349                            "        result = .{}({})\n",
350                            v.name.to_lower_camel_case(),
351                            args.join(", ")
352                        ));
353                    }
354                    VariantKind::Struct { fields } => {
355                        for f in fields.iter() {
356                            let field_name = f.name.to_lower_camel_case();
357                            let inner_decode = generate_inline_decode(f.shape(), "data", "off");
358                            code.push_str(&format!(
359                                "        let _{field_name} = try {inner_decode}\n"
360                            ));
361                        }
362                        let args: Vec<String> = fields
363                            .iter()
364                            .map(|f| {
365                                let field_name = f.name.to_lower_camel_case();
366                                format!("{field_name}: _{field_name}")
367                            })
368                            .collect();
369                        code.push_str(&format!(
370                            "        result = .{}({})\n",
371                            v.name.to_lower_camel_case(),
372                            args.join(", ")
373                        ));
374                    }
375                }
376            }
377            code.push_str("    default:\n        throw RoamError.decodeError(\"unknown enum variant\")\n    }\n    return result\n}");
378            code
379        }
380        ShapeKind::Pointer { pointee } => generate_decode_closure(pointee),
381        _ => "{ _, _ in throw RoamError.decodeError(\"unsupported type\") }".into(),
382    }
383}
384
385/// Generate inline decode expression (for use in closures).
386pub fn generate_inline_decode(shape: &'static Shape, data_var: &str, offset_var: &str) -> String {
387    if is_bytes(shape) {
388        return format!("decodeBytes(from: {data_var}, offset: &{offset_var})");
389    }
390
391    match classify_shape(shape) {
392        ShapeKind::Scalar(scalar) => {
393            let decode_fn = swift_decode_fn(scalar);
394            format!("{decode_fn}(from: {data_var}, offset: &{offset_var})")
395        }
396        ShapeKind::List { element } | ShapeKind::Slice { element } => {
397            let inner = generate_decode_closure(element);
398            format!("decodeVec(from: {data_var}, offset: &{offset_var}, decoder: {inner})")
399        }
400        ShapeKind::Option { inner } => {
401            let inner_closure = generate_decode_closure(inner);
402            format!(
403                "decodeOption(from: {data_var}, offset: &{offset_var}, decoder: {inner_closure})"
404            )
405        }
406        ShapeKind::Pointer { pointee } => generate_inline_decode(pointee, data_var, offset_var),
407        _ => {
408            let decode_closure = generate_decode_closure(shape);
409            format!("({decode_closure})({data_var}, &{offset_var})")
410        }
411    }
412}
413
414/// Get the Swift decode function name for a scalar type.
415pub fn swift_decode_fn(scalar: ScalarType) -> &'static str {
416    match scalar {
417        ScalarType::Bool => "decodeBool",
418        ScalarType::U8 => "decodeU8",
419        ScalarType::I8 => "decodeI8",
420        ScalarType::U16 => "decodeU16",
421        ScalarType::I16 => "decodeI16",
422        ScalarType::U32 => "decodeU32",
423        ScalarType::I32 => "decodeI32",
424        ScalarType::U64 | ScalarType::USize => "decodeVarint",
425        ScalarType::I64 | ScalarType::ISize => "decodeI64",
426        ScalarType::F32 => "decodeF32",
427        ScalarType::F64 => "decodeF64",
428        ScalarType::Char | ScalarType::Str | ScalarType::CowStr | ScalarType::String => {
429            "decodeString"
430        }
431        ScalarType::Unit => "{ _, _ in () }",
432        _ => "decodeBytes", // fallback
433    }
434}