Skip to main content

roam_codegen/targets/swift/
wire.rs

1//! Swift wire type code generation.
2//!
3//! Generates a complete Swift source file containing all wire protocol types
4//! (MessageV7, MessagePayloadV7, etc.) with encode/decode methods. The generated
5//! code is driven by facet `Shape`s from `roam_types::message`.
6//!
7//! The only special-cased type is `Payload` (`ShapeKind::Opaque`), which maps to
8//! `OpaquePayloadV7` with both length-prefixed and trailing byte handling.
9//! Everything else is normal struct/enum codegen.
10
11use facet_core::{Field, ScalarType, Shape};
12use heck::ToLowerCamelCase;
13use roam_types::{
14    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
15};
16
17/// A wire type to generate Swift code for.
18pub struct WireType {
19    /// The Swift name for this type (e.g. "MessageV7", "HelloV7")
20    pub swift_name: String,
21    /// The facet Shape for this type
22    pub shape: &'static Shape,
23}
24
25/// Generate a complete Swift wire types file.
26pub fn generate_wire_types(types: &[WireType]) -> String {
27    let mut out = String::new();
28    out.push_str("// @generated by roam-codegen\n");
29    out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
30    out.push_str("import Foundation\n\n");
31
32    // Preamble: error type + helpers + OpaquePayloadV7
33    out.push_str(&generate_preamble());
34
35    // Metadata typealias (Metadata is Vec<MetadataEntry> in Rust, we alias for convenience)
36    out.push_str("public typealias MetadataV7 = [MetadataEntryV7]\n\n");
37
38    // Generate each type
39    for wt in types {
40        out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
41        out.push('\n');
42    }
43
44    // Generate factory methods extension on MessageV7
45    out.push_str(&generate_factory_methods(types));
46
47    out
48}
49
50/// Generate the static preamble: WireV7Error, helpers, OpaquePayloadV7.
51fn generate_preamble() -> String {
52    let mut out = String::new();
53
54    // Error type
55    out.push_str("public enum WireV7Error: Error, Equatable {\n");
56    out.push_str("    case truncated\n");
57    out.push_str("    case unknownVariant(UInt64)\n");
58    out.push_str("    case overflow\n");
59    out.push_str("    case invalidUtf8\n");
60    out.push_str("    case trailingBytes\n");
61    out.push_str("}\n\n");
62
63    // OpaquePayloadV7
64    out.push_str("public struct OpaquePayloadV7: Sendable, Equatable {\n");
65    out.push_str("    public var bytes: [UInt8]\n\n");
66    out.push_str("    public init(_ bytes: [UInt8]) {\n");
67    out.push_str("        self.bytes = bytes\n");
68    out.push_str("    }\n\n");
69    out.push_str("    func encode() -> [UInt8] {\n");
70    out.push_str("        encodeBytes(bytes)\n");
71    out.push_str("    }\n\n");
72    out.push_str("    static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
73    out.push_str("        .init(Array(try decodeBytesV7(from: data, offset: &offset)))\n");
74    out.push_str("    }\n\n");
75    out.push_str("    /// Encode without a length prefix — for trailing fields only.\n");
76    out.push_str("    func encodeTrailing() -> [UInt8] {\n");
77    out.push_str("        bytes\n");
78    out.push_str("    }\n\n");
79    out.push_str("    /// Decode by consuming all remaining bytes — for trailing fields only.\n");
80    out.push_str("    static func decodeTrailing(from data: Data, offset: inout Int) -> Self {\n");
81    out.push_str("        let start = data.startIndex + offset\n");
82    out.push_str("        let remaining = Array(data[start...])\n");
83    out.push_str("        offset = data.count\n");
84    out.push_str("        return .init(remaining)\n");
85    out.push_str("    }\n");
86    out.push_str("}\n\n");
87
88    // Helper functions
89    out.push_str("@inline(__always)\n");
90    out.push_str(
91        "private func decodeVarintU32V7(from data: Data, offset: inout Int) throws -> UInt32 {\n",
92    );
93    out.push_str("    let value = try decodeVarint(from: data, offset: &offset)\n");
94    out.push_str("    guard value <= UInt64(UInt32.max) else {\n");
95    out.push_str("        throw WireV7Error.overflow\n");
96    out.push_str("    }\n");
97    out.push_str("    return UInt32(value)\n");
98    out.push_str("}\n\n");
99
100    out.push_str("@inline(__always)\n");
101    out.push_str(
102        "private func decodeStringV7(from data: Data, offset: inout Int) throws -> String {\n",
103    );
104    out.push_str("    do {\n");
105    out.push_str("        return try decodeString(from: data, offset: &offset)\n");
106    out.push_str("    } catch PostcardError.invalidUtf8 {\n");
107    out.push_str("        throw WireV7Error.invalidUtf8\n");
108    out.push_str("    } catch PostcardError.truncated {\n");
109    out.push_str("        throw WireV7Error.truncated\n");
110    out.push_str("    } catch {\n");
111    out.push_str("        throw error\n");
112    out.push_str("    }\n");
113    out.push_str("}\n\n");
114
115    out.push_str("@inline(__always)\n");
116    out.push_str(
117        "private func decodeBytesV7(from data: Data, offset: inout Int) throws -> Data {\n",
118    );
119    out.push_str("    do {\n");
120    out.push_str("        return try decodeBytes(from: data, offset: &offset)\n");
121    out.push_str("    } catch PostcardError.truncated {\n");
122    out.push_str("        throw WireV7Error.truncated\n");
123    out.push_str("    } catch {\n");
124    out.push_str("        throw error\n");
125    out.push_str("    }\n");
126    out.push_str("}\n\n");
127
128    out
129}
130
131/// Get the wire Swift type for a field's shape, taking the `WireType` list into account.
132fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
133    if is_bytes(shape) {
134        return "[UInt8]".into();
135    }
136
137    match classify_shape(shape) {
138        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
139        ShapeKind::List { element } | ShapeKind::Slice { element } => {
140            format!("[{}]", swift_wire_type(element, None, types))
141        }
142        ShapeKind::Option { inner } => {
143            format!("{}?", swift_wire_type(inner, None, types))
144        }
145        ShapeKind::Array { element, .. } => {
146            format!("[{}]", swift_wire_type(element, None, types))
147        }
148        ShapeKind::Struct(StructInfo {
149            name: Some(name), ..
150        }) => lookup_wire_name(name, types),
151        ShapeKind::Enum(EnumInfo {
152            name: Some(name), ..
153        }) => lookup_wire_name(name, types),
154        ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
155        ShapeKind::Opaque => "OpaquePayloadV7".into(),
156        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
157            swift_wire_type(fields[0].shape(), None, types)
158        }
159        _ => "Any /* unsupported */".into(),
160    }
161}
162
163fn swift_scalar_type(scalar: ScalarType) -> String {
164    match scalar {
165        ScalarType::Bool => "Bool".into(),
166        ScalarType::U8 => "UInt8".into(),
167        ScalarType::U16 => "UInt16".into(),
168        ScalarType::U32 => "UInt32".into(),
169        ScalarType::U64 | ScalarType::USize => "UInt64".into(),
170        ScalarType::I8 => "Int8".into(),
171        ScalarType::I16 => "Int16".into(),
172        ScalarType::I32 => "Int32".into(),
173        ScalarType::I64 | ScalarType::ISize => "Int64".into(),
174        ScalarType::F32 => "Float".into(),
175        ScalarType::F64 => "Double".into(),
176        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
177            "String".into()
178        }
179        ScalarType::Unit => "Void".into(),
180        _ => "Data".into(),
181    }
182}
183
184/// Look up the wire name for a Rust type name. If a matching WireType exists, use its
185/// swift_name; otherwise fall back to the Rust name.
186fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
187    for wt in types {
188        // Match by checking if the swift_name is derived from this rust_name.
189        // The xtask creates WireTypes with swift_name = rust_name + "V7".
190        // We can also check the Shape's type_identifier.
191        if wt.shape.type_identifier.ends_with(rust_name) {
192            return wt.swift_name.clone();
193        }
194    }
195    rust_name.to_string()
196}
197
198/// Generate a single wire type (struct or enum).
199fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
200    match classify_shape(shape) {
201        ShapeKind::Struct(StructInfo { fields, .. }) => {
202            generate_struct(swift_name, fields, types, swift_name == "MessageV7")
203        }
204        ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
205        _ => format!("// Unsupported shape for {swift_name}\n"),
206    }
207}
208
209/// Generate a Swift struct with encode/decode methods.
210fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
211    let mut out = String::new();
212
213    // Struct definition
214    out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
215    for f in fields {
216        let field_name = f.name.to_lower_camel_case();
217        let field_type = swift_wire_type(f.shape(), Some(f), types);
218        out.push_str(&format!("    public var {field_name}: {field_type}\n"));
219    }
220
221    // Initializer
222    if !fields.is_empty() {
223        out.push_str("\n    public init(");
224        for (i, f) in fields.iter().enumerate() {
225            if i > 0 {
226                out.push_str(", ");
227            }
228            let field_name = f.name.to_lower_camel_case();
229            let field_type = swift_wire_type(f.shape(), Some(f), types);
230            out.push_str(&format!("{field_name}: {field_type}"));
231        }
232        out.push_str(") {\n");
233        for f in fields {
234            let field_name = f.name.to_lower_camel_case();
235            out.push_str(&format!("        self.{field_name} = {field_name}\n"));
236        }
237        out.push_str("    }\n");
238    }
239
240    // Encode method
241    let vis = if is_top_level { "public " } else { "" };
242    out.push_str(&format!("\n    {vis}func encode() -> [UInt8] {{\n"));
243    if fields.is_empty() {
244        out.push_str("        []\n");
245    } else if fields.len() == 1 {
246        let f = &fields[0];
247        let expr = encode_field_expr(f, types);
248        out.push_str(&format!("        {expr}\n"));
249    } else {
250        out.push_str("        var out: [UInt8] = []\n");
251        for f in fields {
252            let expr = encode_field_expr(f, types);
253            out.push_str(&format!("        out += {expr}\n"));
254        }
255        out.push_str("        return out\n");
256    }
257    out.push_str("    }\n");
258
259    // Decode method (with offset)
260    out.push_str(&format!(
261        "\n    {vis}static func decode(from data: Data, offset: inout Int) throws -> Self {{\n"
262    ));
263    for f in fields {
264        let field_name = f.name.to_lower_camel_case();
265        let decode_expr = decode_field_expr(f, types);
266        out.push_str(&format!("        let {field_name} = {decode_expr}\n"));
267    }
268    let field_names: Vec<String> = fields
269        .iter()
270        .map(|f| {
271            let n = f.name.to_lower_camel_case();
272            format!("{n}: {n}")
273        })
274        .collect();
275    out.push_str(&format!(
276        "        return .init({})\n",
277        field_names.join(", ")
278    ));
279    out.push_str("    }\n");
280
281    // Top-level MessageV7 gets an extra decode(from:) without offset that checks trailing bytes
282    if is_top_level {
283        out.push_str("\n    public static func decode(from data: Data) throws -> Self {\n");
284        out.push_str("        var offset = 0\n");
285        out.push_str("        let result = try decode(from: data, offset: &offset)\n");
286        out.push_str("        guard offset == data.count else {\n");
287        out.push_str("            throw WireV7Error.trailingBytes\n");
288        out.push_str("        }\n");
289        out.push_str("        return result\n");
290        out.push_str("    }\n");
291    }
292
293    out.push_str("}\n");
294    out
295}
296
297/// Generate a Swift enum with varint-discriminanted encode/decode methods.
298fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
299    let mut out = String::new();
300
301    // Enum definition
302    out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
303    for v in variants {
304        let variant_name = v.name.to_lower_camel_case();
305        match classify_variant(v) {
306            VariantKind::Unit => {
307                out.push_str(&format!("    case {variant_name}\n"));
308            }
309            VariantKind::Newtype { inner } => {
310                let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
311                out.push_str(&format!("    case {variant_name}({inner_type})\n"));
312            }
313            VariantKind::Tuple { fields } => {
314                let field_types: Vec<String> = fields
315                    .iter()
316                    .map(|f| swift_wire_type(f.shape(), Some(f), types))
317                    .collect();
318                out.push_str(&format!(
319                    "    case {variant_name}({})\n",
320                    field_types.join(", ")
321                ));
322            }
323            VariantKind::Struct { fields } => {
324                let field_decls: Vec<String> = fields
325                    .iter()
326                    .map(|f| {
327                        format!(
328                            "{}: {}",
329                            f.name.to_lower_camel_case(),
330                            swift_wire_type(f.shape(), Some(f), types)
331                        )
332                    })
333                    .collect();
334                out.push_str(&format!(
335                    "    case {variant_name}({})\n",
336                    field_decls.join(", ")
337                ));
338            }
339        }
340    }
341
342    // Encode
343    out.push_str("\n    func encode() -> [UInt8] {\n");
344    out.push_str("        switch self {\n");
345    for (i, v) in variants.iter().enumerate() {
346        let variant_name = v.name.to_lower_camel_case();
347        match classify_variant(v) {
348            VariantKind::Unit => {
349                out.push_str(&format!(
350                    "        case .{variant_name}:\n            return encodeVarint(UInt64({i}))\n"
351                ));
352            }
353            VariantKind::Newtype { inner } => {
354                let encode = encode_shape_expr(inner, "val", v.data.fields.first(), types);
355                out.push_str(&format!(
356                    "        case .{variant_name}(let val):\n            return encodeVarint(UInt64({i})) + {encode}\n"
357                ));
358            }
359            VariantKind::Tuple { fields } => {
360                let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
361                let binding_str = bindings
362                    .iter()
363                    .map(|b| format!("let {b}"))
364                    .collect::<Vec<_>>()
365                    .join(", ");
366                let encodes: Vec<String> = fields
367                    .iter()
368                    .enumerate()
369                    .map(|(j, f)| encode_shape_expr(f.shape(), &format!("f{j}"), Some(f), types))
370                    .collect();
371                out.push_str(&format!(
372                    "        case .{variant_name}({binding_str}):\n            return encodeVarint(UInt64({i})) + {}\n",
373                    encodes.join(" + ")
374                ));
375            }
376            VariantKind::Struct { fields } => {
377                let bindings: Vec<String> = fields
378                    .iter()
379                    .map(|f| f.name.to_lower_camel_case())
380                    .collect();
381                let binding_str = bindings
382                    .iter()
383                    .map(|b| format!("let {b}"))
384                    .collect::<Vec<_>>()
385                    .join(", ");
386                let encodes: Vec<String> = fields
387                    .iter()
388                    .map(|f| {
389                        encode_shape_expr(f.shape(), &f.name.to_lower_camel_case(), Some(f), types)
390                    })
391                    .collect();
392                out.push_str(&format!(
393                    "        case .{variant_name}({binding_str}):\n            return encodeVarint(UInt64({i})) + {}\n",
394                    encodes.join(" + ")
395                ));
396            }
397        }
398    }
399    out.push_str("        }\n");
400    out.push_str("    }\n");
401
402    // Decode
403    out.push_str("\n    static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
404    out.push_str("        let disc = try decodeVarint(from: data, offset: &offset)\n");
405    out.push_str("        switch disc {\n");
406    for (i, v) in variants.iter().enumerate() {
407        let variant_name = v.name.to_lower_camel_case();
408        out.push_str(&format!("        case {i}:\n"));
409        match classify_variant(v) {
410            VariantKind::Unit => {
411                out.push_str(&format!("            return .{variant_name}\n"));
412            }
413            VariantKind::Newtype { inner } => {
414                let decode = decode_shape_expr(inner, v.data.fields.first(), types);
415                out.push_str(&format!("            return .{variant_name}({decode})\n"));
416            }
417            VariantKind::Tuple { fields } => {
418                for (j, f) in fields.iter().enumerate() {
419                    let decode = decode_shape_expr(f.shape(), Some(f), types);
420                    out.push_str(&format!("            let f{j} = {decode}\n"));
421                }
422                let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
423                out.push_str(&format!(
424                    "            return .{variant_name}({})\n",
425                    args.join(", ")
426                ));
427            }
428            VariantKind::Struct { fields } => {
429                for f in fields {
430                    let field_name = f.name.to_lower_camel_case();
431                    let decode = decode_shape_expr(f.shape(), Some(f), types);
432                    out.push_str(&format!("            let {field_name} = {decode}\n"));
433                }
434                let args: Vec<String> = fields
435                    .iter()
436                    .map(|f| {
437                        let n = f.name.to_lower_camel_case();
438                        format!("{n}: {n}")
439                    })
440                    .collect();
441                out.push_str(&format!(
442                    "            return .{variant_name}({})\n",
443                    args.join(", ")
444                ));
445            }
446        }
447    }
448    out.push_str("        default:\n");
449    out.push_str("            throw WireV7Error.unknownVariant(disc)\n");
450    out.push_str("        }\n");
451    out.push_str("    }\n");
452
453    out.push_str("}\n");
454    out
455}
456
457/// Generate an encode expression for a struct field.
458fn encode_field_expr(field: &Field, types: &[WireType]) -> String {
459    let field_name = field.name.to_lower_camel_case();
460    encode_shape_expr(field.shape(), &field_name, Some(field), types)
461}
462
463/// Generate an encode expression for a shape with a given value expression.
464fn encode_shape_expr(
465    shape: &'static Shape,
466    value: &str,
467    field: Option<&Field>,
468    types: &[WireType],
469) -> String {
470    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
471
472    // Opaque type → OpaquePayloadV7
473    if matches!(classify_shape(shape), ShapeKind::Opaque) {
474        return if is_trailing {
475            format!("{value}.encodeTrailing()")
476        } else {
477            format!("{value}.encode()")
478        };
479    }
480
481    if is_bytes(shape) {
482        return format!("encodeBytes({value})");
483    }
484
485    match classify_shape(shape) {
486        ShapeKind::Scalar(scalar) => encode_scalar(scalar, value),
487        ShapeKind::List { element } | ShapeKind::Slice { element } => {
488            let inner = encode_element_closure(element, types);
489            format!("encodeVec({value}, encoder: {inner})")
490        }
491        ShapeKind::Option { inner } => {
492            let inner_closure = encode_element_closure(inner, types);
493            format!("encodeOption({value}, encoder: {inner_closure})")
494        }
495        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
496            format!("{value}.encode()")
497        }
498        ShapeKind::Pointer { pointee } => encode_shape_expr(pointee, value, field, types),
499        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
500            encode_shape_expr(fields[0].shape(), value, field, types)
501        }
502        _ => format!("[] /* unsupported encode for {value} */"),
503    }
504}
505
506fn encode_scalar(scalar: ScalarType, value: &str) -> String {
507    match scalar {
508        ScalarType::Bool => format!("encodeBool({value})"),
509        ScalarType::U8 => format!("encodeU8({value})"),
510        ScalarType::I8 => format!("encodeI8({value})"),
511        ScalarType::U16 => format!("encodeU16({value})"),
512        ScalarType::I16 => format!("encodeI16({value})"),
513        ScalarType::U32 => format!("encodeVarint(UInt64({value}))"),
514        ScalarType::I32 => format!("encodeI32({value})"),
515        ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value})"),
516        ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value})"),
517        ScalarType::F32 => format!("encodeF32({value})"),
518        ScalarType::F64 => format!("encodeF64({value})"),
519        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
520            format!("encodeString({value})")
521        }
522        _ => "[] /* unsupported scalar */".to_string(),
523    }
524}
525
526/// Generate an encode closure for use with encodeVec.
527fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
528    if is_bytes(shape) {
529        return "{ encodeBytes($0) }".into();
530    }
531
532    match classify_shape(shape) {
533        ShapeKind::Scalar(scalar) => {
534            let expr = encode_scalar(scalar, "$0");
535            format!("{{ {expr} }}")
536        }
537        ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
538            "{ $0.encode() }".into()
539        }
540        ShapeKind::List { element } | ShapeKind::Slice { element } => {
541            let inner = encode_element_closure(element, _types);
542            format!("{{ encodeVec($0, encoder: {inner}) }}")
543        }
544        ShapeKind::Opaque => "{ $0.encode() }".into(),
545        ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
546        _ => "{ _ in [] }".into(),
547    }
548}
549
550/// Generate a decode expression for a struct field.
551fn decode_field_expr(field: &Field, types: &[WireType]) -> String {
552    decode_shape_expr(field.shape(), Some(field), types)
553}
554
555/// Generate a decode expression for a shape.
556fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
557    let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
558
559    // Opaque type → OpaquePayloadV7
560    if matches!(classify_shape(shape), ShapeKind::Opaque) {
561        return if is_trailing {
562            "OpaquePayloadV7.decodeTrailing(from: data, offset: &offset)".into()
563        } else {
564            "try OpaquePayloadV7.decode(from: data, offset: &offset)".into()
565        };
566    }
567
568    if is_bytes(shape) {
569        return "Array(try decodeBytesV7(from: data, offset: &offset))".into();
570    }
571
572    match classify_shape(shape) {
573        ShapeKind::Scalar(scalar) => decode_scalar(scalar),
574        ShapeKind::List { element } | ShapeKind::Slice { element } => {
575            let inner = decode_element_closure(element, types);
576            format!("try decodeVec(from: data, offset: &offset, decoder: {inner})")
577        }
578        ShapeKind::Option { inner } => {
579            let inner_closure = decode_element_closure(inner, types);
580            format!("try decodeOption(from: data, offset: &offset, decoder: {inner_closure})")
581        }
582        ShapeKind::Struct(StructInfo {
583            name: Some(name), ..
584        }) => {
585            let swift_name = lookup_wire_name(name, types);
586            format!("try {swift_name}.decode(from: data, offset: &offset)")
587        }
588        ShapeKind::Enum(EnumInfo {
589            name: Some(name), ..
590        }) => {
591            let swift_name = lookup_wire_name(name, types);
592            format!("try {swift_name}.decode(from: data, offset: &offset)")
593        }
594        ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
595        ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
596            decode_shape_expr(fields[0].shape(), field, types)
597        }
598        _ => "nil /* unsupported decode */".into(),
599    }
600}
601
602fn decode_scalar(scalar: ScalarType) -> String {
603    match scalar {
604        ScalarType::Bool => "try decodeBool(from: data, offset: &offset)".into(),
605        ScalarType::U8 => "try decodeU8(from: data, offset: &offset)".into(),
606        ScalarType::I8 => "try decodeI8(from: data, offset: &offset)".into(),
607        ScalarType::U16 => "try decodeU16(from: data, offset: &offset)".into(),
608        ScalarType::I16 => "try decodeI16(from: data, offset: &offset)".into(),
609        ScalarType::U32 => "try decodeVarintU32V7(from: data, offset: &offset)".into(),
610        ScalarType::I32 => "try decodeI32(from: data, offset: &offset)".into(),
611        ScalarType::U64 | ScalarType::USize => {
612            "try decodeVarint(from: data, offset: &offset)".into()
613        }
614        ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: data, offset: &offset)".into(),
615        ScalarType::F32 => "try decodeF32(from: data, offset: &offset)".into(),
616        ScalarType::F64 => "try decodeF64(from: data, offset: &offset)".into(),
617        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
618            "try decodeStringV7(from: data, offset: &offset)".into()
619        }
620        _ => "nil /* unsupported scalar decode */".into(),
621    }
622}
623
624/// Generate a decode closure for use with decodeVec.
625fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
626    if is_bytes(shape) {
627        return "{ data, off in Array(try decodeBytesV7(from: data, offset: &off)) }".into();
628    }
629
630    match classify_shape(shape) {
631        ShapeKind::Scalar(scalar) => {
632            let expr = decode_scalar_with_params(scalar, "data", "off");
633            format!("{{ data, off in {expr} }}")
634        }
635        ShapeKind::Struct(StructInfo {
636            name: Some(name), ..
637        }) => {
638            let swift_name = lookup_wire_name(name, types);
639            format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
640        }
641        ShapeKind::Enum(EnumInfo {
642            name: Some(name), ..
643        }) => {
644            let swift_name = lookup_wire_name(name, types);
645            format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
646        }
647        ShapeKind::List { element } | ShapeKind::Slice { element } => {
648            let inner = decode_element_closure(element, types);
649            format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
650        }
651        ShapeKind::Opaque => {
652            "{ data, off in try OpaquePayloadV7.decode(from: data, offset: &off) }".into()
653        }
654        ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
655        _ => "{ _, _ in throw WireV7Error.truncated }".into(),
656    }
657}
658
659fn decode_scalar_with_params(scalar: ScalarType, data: &str, offset: &str) -> String {
660    match scalar {
661        ScalarType::Bool => format!("try decodeBool(from: {data}, offset: &{offset})"),
662        ScalarType::U8 => format!("try decodeU8(from: {data}, offset: &{offset})"),
663        ScalarType::I8 => format!("try decodeI8(from: {data}, offset: &{offset})"),
664        ScalarType::U16 => format!("try decodeU16(from: {data}, offset: &{offset})"),
665        ScalarType::I16 => format!("try decodeI16(from: {data}, offset: &{offset})"),
666        ScalarType::U32 => format!("try decodeVarintU32V7(from: {data}, offset: &{offset})"),
667        ScalarType::I32 => format!("try decodeI32(from: {data}, offset: &{offset})"),
668        ScalarType::U64 | ScalarType::USize => {
669            format!("try decodeVarint(from: {data}, offset: &{offset})")
670        }
671        ScalarType::I64 | ScalarType::ISize => {
672            format!("try decodeI64(from: {data}, offset: &{offset})")
673        }
674        ScalarType::F32 => format!("try decodeF32(from: {data}, offset: &{offset})"),
675        ScalarType::F64 => format!("try decodeF64(from: {data}, offset: &{offset})"),
676        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
677            format!("try decodeStringV7(from: {data}, offset: &{offset})")
678        }
679        _ => "nil /* unsupported scalar */".to_string(),
680    }
681}
682
683/// Generate convenience factory methods on MessageV7.
684fn generate_factory_methods(types: &[WireType]) -> String {
685    // Find the MessagePayload shape to inspect variants
686    let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayloadV7");
687    let payload_wt = match payload_wt {
688        Some(wt) => wt,
689        None => return String::new(),
690    };
691
692    let variants = match classify_shape(payload_wt.shape) {
693        ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
694        _ => return String::new(),
695    };
696
697    let mut out = String::new();
698    out.push_str("public extension MessageV7 {\n");
699
700    for v in variants {
701        let variant_name = v.name.to_lower_camel_case();
702        if let VariantKind::Newtype { inner } = classify_variant(v) {
703            let inner_swift = swift_wire_type(inner, None, types);
704            // Control messages use connId=0.
705            let is_control = matches!(
706                v.name,
707                "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
708            );
709
710            if is_control {
711                out.push_str(&format!(
712                    "    static func {variant_name}(_ value: {inner_swift}) -> MessageV7 {{\n"
713                ));
714                out.push_str(&format!(
715                    "        MessageV7(connectionId: 0, payload: .{variant_name}(value))\n"
716                ));
717                out.push_str("    }\n\n");
718            } else {
719                out.push_str(&format!(
720                    "    static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> MessageV7 {{\n"
721                ));
722                out.push_str(&format!(
723                    "        MessageV7(connectionId: connId, payload: .{variant_name}(value))\n"
724                ));
725                out.push_str("    }\n\n");
726            }
727        }
728    }
729
730    // Additional ergonomic factory methods that flatten nested structs.
731    // These match the existing hand-coded API.
732    out.push_str("    static func protocolError(description: String) -> MessageV7 {\n");
733    out.push_str("        MessageV7(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
734    out.push_str("    }\n\n");
735
736    out.push_str("    static func connectionOpen(connId: UInt64, settings: ConnectionSettingsV7, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
737    out.push_str("        MessageV7(connectionId: connId, payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
738    out.push_str("    }\n\n");
739
740    out.push_str("    static func connectionAccept(connId: UInt64, settings: ConnectionSettingsV7, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
741    out.push_str("        MessageV7(connectionId: connId, payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
742    out.push_str("    }\n\n");
743
744    out.push_str("    static func connectionReject(connId: UInt64, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
745    out.push_str("        MessageV7(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
746    out.push_str("    }\n\n");
747
748    out.push_str("    static func connectionClose(connId: UInt64, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
749    out.push_str("        MessageV7(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
750    out.push_str("    }\n\n");
751
752    out.push_str("    static func request(\n");
753    out.push_str("        connId: UInt64,\n");
754    out.push_str("        requestId: UInt64,\n");
755    out.push_str("        methodId: UInt64,\n");
756    out.push_str("        metadata: [MetadataEntryV7],\n");
757    out.push_str("        channels: [UInt64],\n");
758    out.push_str("        payload: [UInt8]\n");
759    out.push_str("    ) -> MessageV7 {\n");
760    out.push_str("        MessageV7(\n");
761    out.push_str("            connectionId: connId,\n");
762    out.push_str("            payload: .requestMessage(\n");
763    out.push_str("                .init(\n");
764    out.push_str("                    id: requestId,\n");
765    out.push_str("                    body: .call(.init(methodId: methodId, channels: channels, metadata: metadata, args: .init(payload)))\n");
766    out.push_str("                ))\n");
767    out.push_str("        )\n");
768    out.push_str("    }\n\n");
769
770    out.push_str("    static func response(\n");
771    out.push_str("        connId: UInt64,\n");
772    out.push_str("        requestId: UInt64,\n");
773    out.push_str("        metadata: [MetadataEntryV7],\n");
774    out.push_str("        channels: [UInt64],\n");
775    out.push_str("        payload: [UInt8]\n");
776    out.push_str("    ) -> MessageV7 {\n");
777    out.push_str("        MessageV7(\n");
778    out.push_str("            connectionId: connId,\n");
779    out.push_str("            payload: .requestMessage(\n");
780    out.push_str("                .init(\n");
781    out.push_str("                    id: requestId,\n");
782    out.push_str("                    body: .response(.init(channels: channels, metadata: metadata, ret: .init(payload)))\n");
783    out.push_str("                ))\n");
784    out.push_str("        )\n");
785    out.push_str("    }\n\n");
786
787    out.push_str("    static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntryV7] = []) -> MessageV7 {\n");
788    out.push_str("        MessageV7(\n");
789    out.push_str("            connectionId: connId,\n");
790    out.push_str("            payload: .requestMessage(\n");
791    out.push_str("                .init(\n");
792    out.push_str("                    id: requestId,\n");
793    out.push_str("                    body: .cancel(.init(metadata: metadata))\n");
794    out.push_str("                ))\n");
795    out.push_str("        )\n");
796    out.push_str("    }\n\n");
797
798    out.push_str("    static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> MessageV7 {\n");
799    out.push_str("        MessageV7(\n");
800    out.push_str("            connectionId: connId,\n");
801    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
802    out.push_str("        )\n");
803    out.push_str("    }\n\n");
804
805    out.push_str("    static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntryV7] = []) -> MessageV7 {\n");
806    out.push_str("        MessageV7(\n");
807    out.push_str("            connectionId: connId,\n");
808    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
809    out.push_str("        )\n");
810    out.push_str("    }\n\n");
811
812    out.push_str("    static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntryV7] = []) -> MessageV7 {\n");
813    out.push_str("        MessageV7(\n");
814    out.push_str("            connectionId: connId,\n");
815    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
816    out.push_str("        )\n");
817    out.push_str("    }\n\n");
818
819    out.push_str(
820        "    static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> MessageV7 {\n",
821    );
822    out.push_str("        MessageV7(\n");
823    out.push_str("            connectionId: connId,\n");
824    out.push_str("            payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
825    out.push_str("        )\n");
826    out.push_str("    }\n");
827
828    out.push_str("}\n");
829    out
830}