Skip to main content

roam_codegen/targets/swift/
schema.rs

1//! Swift schema generation for runtime channel binding.
2//!
3//! Generates runtime schema information for channel discovery.
4
5use facet_core::{ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_types::{
8    EnumInfo, ServiceDescriptor, ShapeKind, StructInfo, VariantKind, classify_shape,
9    classify_variant, is_bytes,
10};
11
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14
15/// Generate complete schema code (method schemas + serializers).
16pub fn generate_schemas(service: &ServiceDescriptor) -> String {
17    let mut out = String::new();
18    out.push_str(&generate_method_schemas(service));
19    out.push_str(&generate_serializers(service));
20    out
21}
22
23fn extract_initial_credit(shape: &'static Shape) -> u32 {
24    shape
25        .const_params
26        .iter()
27        .find(|cp| cp.name == "N")
28        .map(|cp| cp.value as u32)
29        .unwrap_or(16)
30}
31
32/// Generate method schemas for runtime channel binding.
33fn generate_method_schemas(service: &ServiceDescriptor) -> String {
34    let mut out = String::new();
35    let service_name = service.service_name.to_lower_camel_case();
36
37    out.push_str(&format!(
38        "public let {service_name}_schemas: [String: MethodSchema] = [\n"
39    ));
40
41    for method in service.methods {
42        let method_name = method.method_name.to_lower_camel_case();
43        out.push_str(&format!("    \"{method_name}\": MethodSchema(args: ["));
44
45        let schemas: Vec<String> = method
46            .args
47            .iter()
48            .map(|a| shape_to_schema(a.shape))
49            .collect();
50        out.push_str(&schemas.join(", "));
51
52        out.push_str("]),\n");
53    }
54
55    out.push_str("]\n\n");
56    out
57}
58
59/// Convert a Shape to its Swift Schema representation.
60fn shape_to_schema(shape: &'static Shape) -> String {
61    if is_bytes(shape) {
62        return ".bytes".into();
63    }
64
65    match classify_shape(shape) {
66        ShapeKind::Scalar(scalar) => match scalar {
67            ScalarType::Bool => ".bool".into(),
68            ScalarType::U8 => ".u8".into(),
69            ScalarType::U16 => ".u16".into(),
70            ScalarType::U32 => ".u32".into(),
71            ScalarType::U64 => ".u64".into(),
72            ScalarType::I8 => ".i8".into(),
73            ScalarType::I16 => ".i16".into(),
74            ScalarType::I32 => ".i32".into(),
75            ScalarType::I64 => ".i64".into(),
76            ScalarType::F32 => ".f32".into(),
77            ScalarType::F64 => ".f64".into(),
78            ScalarType::Str | ScalarType::CowStr | ScalarType::String => ".string".into(),
79            ScalarType::Unit => ".tuple(elements: [])".into(),
80            _ => ".bytes".into(), // fallback
81        },
82        ShapeKind::List { element } | ShapeKind::Slice { element } => {
83            format!(".vec(element: {})", shape_to_schema(element))
84        }
85        ShapeKind::Option { inner } => {
86            format!(".option(inner: {})", shape_to_schema(inner))
87        }
88        ShapeKind::Map { key, value } => {
89            format!(
90                ".map(key: {}, value: {})",
91                shape_to_schema(key),
92                shape_to_schema(value)
93            )
94        }
95        ShapeKind::Tx { inner } => format!(
96            ".tx(initialCredit: {}, element: {})",
97            extract_initial_credit(shape),
98            shape_to_schema(inner)
99        ),
100        ShapeKind::Rx { inner } => format!(
101            ".rx(initialCredit: {}, element: {})",
102            extract_initial_credit(shape),
103            shape_to_schema(inner)
104        ),
105        ShapeKind::Tuple { elements } => {
106            let inner: Vec<String> = elements.iter().map(|p| shape_to_schema(p.shape)).collect();
107            format!(".tuple(elements: [{}])", inner.join(", "))
108        }
109        ShapeKind::Struct(StructInfo { fields, .. }) => {
110            let field_strs: Vec<String> = fields
111                .iter()
112                .map(|f| format!("(\"{}\", {})", f.name, shape_to_schema(f.shape())))
113                .collect();
114            format!(".struct(fields: [{}])", field_strs.join(", "))
115        }
116        ShapeKind::Enum(EnumInfo { variants, .. }) => {
117            let variant_strs: Vec<String> = variants
118                .iter()
119                .map(|v| {
120                    let fields: Vec<String> = match classify_variant(v) {
121                        VariantKind::Unit => vec![],
122                        VariantKind::Newtype { inner } => vec![shape_to_schema(inner)],
123                        VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
124                            fields.iter().map(|f| shape_to_schema(f.shape())).collect()
125                        }
126                    };
127                    format!("(\"{}\", [{}])", v.name, fields.join(", "))
128                })
129                .collect();
130            format!(".enum(variants: [{}])", variant_strs.join(", "))
131        }
132        _ => ".bytes".into(), // fallback for unknown types
133    }
134}
135
136/// Generate serializers for runtime channel binding.
137fn generate_serializers(service: &ServiceDescriptor) -> String {
138    let mut out = String::new();
139    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
140    let service_name_upper = service.service_name.to_upper_camel_case();
141
142    cw_writeln!(
143        w,
144        "public struct {service_name_upper}Serializers: BindingSerializers {{"
145    )
146    .unwrap();
147    {
148        let _indent = w.indent();
149        w.writeln("public init() {}").unwrap();
150        w.blank_line().unwrap();
151
152        // txSerializer
153        w.writeln("public func txSerializer(for schema: Schema) -> @Sendable (Any) -> [UInt8] {")
154            .unwrap();
155        {
156            let _indent = w.indent();
157            w.writeln("switch schema {").unwrap();
158            w.writeln("case .bool: return { encodeBool($0 as! Bool) }")
159                .unwrap();
160            w.writeln("case .u8: return { encodeU8($0 as! UInt8) }")
161                .unwrap();
162            w.writeln("case .i8: return { encodeI8($0 as! Int8) }")
163                .unwrap();
164            w.writeln("case .u16: return { encodeU16($0 as! UInt16) }")
165                .unwrap();
166            w.writeln("case .i16: return { encodeI16($0 as! Int16) }")
167                .unwrap();
168            w.writeln("case .u32: return { encodeU32($0 as! UInt32) }")
169                .unwrap();
170            w.writeln("case .i32: return { encodeI32($0 as! Int32) }")
171                .unwrap();
172            w.writeln("case .u64: return { encodeVarint($0 as! UInt64) }")
173                .unwrap();
174            w.writeln("case .i64: return { encodeI64($0 as! Int64) }")
175                .unwrap();
176            w.writeln("case .f32: return { encodeF32($0 as! Float) }")
177                .unwrap();
178            w.writeln("case .f64: return { encodeF64($0 as! Double) }")
179                .unwrap();
180            w.writeln("case .string: return { encodeString($0 as! String) }")
181                .unwrap();
182            w.writeln("case .bytes: return { [UInt8]($0 as! Data) }")
183                .unwrap();
184            w.writeln(
185                "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not serialized directly\")",
186            )
187            .unwrap();
188            w.writeln(
189                "default: fatalError(\"Unsupported schema for Tx serialization: \\(schema)\")",
190            )
191            .unwrap();
192            w.writeln("}").unwrap();
193        }
194        w.writeln("}").unwrap();
195        w.blank_line().unwrap();
196
197        // rxDeserializer
198        w.writeln(
199            "public func rxDeserializer(for schema: Schema) -> @Sendable ([UInt8]) throws -> Any {",
200        )
201        .unwrap();
202        {
203            let _indent = w.indent();
204            w.writeln("switch schema {").unwrap();
205            w.writeln("case .bool: return { var o = 0; return try decodeBool(from: Data($0), offset: &o) }").unwrap();
206            w.writeln(
207                "case .u8: return { var o = 0; return try decodeU8(from: Data($0), offset: &o) }",
208            )
209            .unwrap();
210            w.writeln(
211                "case .i8: return { var o = 0; return try decodeI8(from: Data($0), offset: &o) }",
212            )
213            .unwrap();
214            w.writeln(
215                "case .u16: return { var o = 0; return try decodeU16(from: Data($0), offset: &o) }",
216            )
217            .unwrap();
218            w.writeln(
219                "case .i16: return { var o = 0; return try decodeI16(from: Data($0), offset: &o) }",
220            )
221            .unwrap();
222            w.writeln(
223                "case .u32: return { var o = 0; return try decodeU32(from: Data($0), offset: &o) }",
224            )
225            .unwrap();
226            w.writeln(
227                "case .i32: return { var o = 0; return try decodeI32(from: Data($0), offset: &o) }",
228            )
229            .unwrap();
230            w.writeln("case .u64: return { var o = 0; return try decodeVarint(from: Data($0), offset: &o) }").unwrap();
231            w.writeln(
232                "case .i64: return { var o = 0; return try decodeI64(from: Data($0), offset: &o) }",
233            )
234            .unwrap();
235            w.writeln(
236                "case .f32: return { var o = 0; return try decodeF32(from: Data($0), offset: &o) }",
237            )
238            .unwrap();
239            w.writeln(
240                "case .f64: return { var o = 0; return try decodeF64(from: Data($0), offset: &o) }",
241            )
242            .unwrap();
243            w.writeln("case .string: return { var o = 0; return try decodeString(from: Data($0), offset: &o) }").unwrap();
244            w.writeln("case .bytes: return { Data($0) }").unwrap();
245            w.writeln(
246                "case .tx(_, _), .rx(_, _): fatalError(\"Channel schemas are not deserialized directly\")",
247            )
248            .unwrap();
249            w.writeln(
250                "default: fatalError(\"Unsupported schema for Rx deserialization: \\(schema)\")",
251            )
252            .unwrap();
253            w.writeln("}").unwrap();
254        }
255        w.writeln("}").unwrap();
256    }
257    w.writeln("}").unwrap();
258    w.blank_line().unwrap();
259
260    out
261}