roam_codegen/targets/swift/
types.rs

1//! Swift type generation and collection.
2//!
3//! This module handles:
4//! - Collecting named types (structs and enums) from service definitions
5//! - Generating Swift type definitions (structs, enums)
6//! - Converting Rust types to Swift type strings
7
8use std::collections::HashSet;
9
10use facet_core::{ScalarType, Shape};
11use heck::ToLowerCamelCase;
12use roam_schema::{
13    EnumInfo, ServiceDetail, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant,
14    is_bytes, is_rx, is_tx,
15};
16
17/// Collect all named types (structs and enums with a name) from a service.
18/// Returns a vector of (name, Shape) pairs in dependency order.
19pub fn collect_named_types(service: &ServiceDetail) -> Vec<(String, &'static Shape)> {
20    let mut seen: HashSet<String> = HashSet::new();
21    let mut types = Vec::new();
22
23    fn visit(
24        shape: &'static Shape,
25        seen: &mut HashSet<String>,
26        types: &mut Vec<(String, &'static Shape)>,
27    ) {
28        match classify_shape(shape) {
29            ShapeKind::Struct(StructInfo {
30                name: Some(name),
31                fields,
32                ..
33            }) => {
34                if !seen.contains(name) {
35                    seen.insert(name.to_string());
36                    // Visit nested types first (dependencies before dependents)
37                    for field in fields {
38                        visit(field.shape(), seen, types);
39                    }
40                    types.push((name.to_string(), shape));
41                }
42            }
43            ShapeKind::Enum(EnumInfo {
44                name: Some(name),
45                variants,
46            }) => {
47                if !seen.contains(name) {
48                    seen.insert(name.to_string());
49                    // Visit nested types in variants
50                    for variant in variants {
51                        match classify_variant(variant) {
52                            VariantKind::Newtype { inner } => visit(inner, seen, types),
53                            VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
54                                for field in fields {
55                                    visit(field.shape(), seen, types);
56                                }
57                            }
58                            VariantKind::Unit => {}
59                        }
60                    }
61                    types.push((name.to_string(), shape));
62                }
63            }
64            ShapeKind::List { element }
65            | ShapeKind::Slice { element }
66            | ShapeKind::Option { inner: element }
67            | ShapeKind::Array { element, .. }
68            | ShapeKind::Set { element } => visit(element, seen, types),
69            ShapeKind::Map { key, value } => {
70                visit(key, seen, types);
71                visit(value, seen, types);
72            }
73            ShapeKind::Tuple { elements } => {
74                for param in elements {
75                    visit(param.shape, seen, types);
76                }
77            }
78            ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => visit(inner, seen, types),
79            ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
80            ShapeKind::Result { ok, err } => {
81                visit(ok, seen, types);
82                visit(err, seen, types);
83            }
84            _ => {}
85        }
86    }
87
88    for method in &service.methods {
89        for arg in &method.args {
90            visit(arg.ty, &mut seen, &mut types);
91        }
92        visit(method.return_type, &mut seen, &mut types);
93    }
94
95    types
96}
97
98/// Generate Swift type definitions for all named types.
99pub fn generate_named_types(named_types: &[(String, &'static Shape)]) -> String {
100    let mut out = String::new();
101
102    for (name, shape) in named_types {
103        match classify_shape(shape) {
104            ShapeKind::Struct(StructInfo { fields, .. }) => {
105                out.push_str(&format!("public struct {name}: Codable, Sendable {{\n"));
106                for field in fields {
107                    let field_name = field.name.to_lower_camel_case();
108                    let field_type = swift_type_base(field.shape());
109                    out.push_str(&format!("    public var {field_name}: {field_type}\n"));
110                }
111                out.push('\n');
112                // Generate initializer
113                out.push_str("    public init(");
114                for (i, field) in fields.iter().enumerate() {
115                    if i > 0 {
116                        out.push_str(", ");
117                    }
118                    let field_name = field.name.to_lower_camel_case();
119                    let field_type = swift_type_base(field.shape());
120                    out.push_str(&format!("{field_name}: {field_type}"));
121                }
122                out.push_str(") {\n");
123                for field in fields {
124                    let field_name = field.name.to_lower_camel_case();
125                    out.push_str(&format!("        self.{field_name} = {field_name}\n"));
126                }
127                out.push_str("    }\n");
128                out.push_str("}\n\n");
129            }
130            ShapeKind::Enum(EnumInfo { variants, .. }) => {
131                // Add Error conformance if the enum name ends with "Error"
132                let protocols = if name.ends_with("Error") {
133                    "Codable, Sendable, Error"
134                } else {
135                    "Codable, Sendable"
136                };
137                out.push_str(&format!("public enum {name}: {protocols} {{\n"));
138                for variant in variants {
139                    let variant_name = variant.name.to_lower_camel_case();
140                    match classify_variant(variant) {
141                        VariantKind::Unit => {
142                            out.push_str(&format!("    case {variant_name}\n"));
143                        }
144                        VariantKind::Newtype { inner } => {
145                            let inner_type = swift_type_base(inner);
146                            out.push_str(&format!("    case {variant_name}({inner_type})\n"));
147                        }
148                        VariantKind::Tuple { fields } => {
149                            let field_types: Vec<_> =
150                                fields.iter().map(|f| swift_type_base(f.shape())).collect();
151                            out.push_str(&format!(
152                                "    case {variant_name}({})\n",
153                                field_types.join(", ")
154                            ));
155                        }
156                        VariantKind::Struct { fields } => {
157                            let field_decls: Vec<_> = fields
158                                .iter()
159                                .map(|f| {
160                                    format!(
161                                        "{}: {}",
162                                        f.name.to_lower_camel_case(),
163                                        swift_type_base(f.shape())
164                                    )
165                                })
166                                .collect();
167                            out.push_str(&format!(
168                                "    case {variant_name}({})\n",
169                                field_decls.join(", ")
170                            ));
171                        }
172                    }
173                }
174                out.push_str("}\n\n");
175            }
176            _ => {}
177        }
178    }
179
180    out
181}
182
183/// Convert ScalarType to Swift type string.
184pub fn swift_scalar_type(scalar: ScalarType) -> String {
185    match scalar {
186        ScalarType::Bool => "Bool".into(),
187        ScalarType::U8 => "UInt8".into(),
188        ScalarType::U16 => "UInt16".into(),
189        ScalarType::U32 => "UInt32".into(),
190        ScalarType::U64 => "UInt64".into(),
191        ScalarType::U128 => "UInt128".into(),
192        ScalarType::USize => "UInt".into(),
193        ScalarType::I8 => "Int8".into(),
194        ScalarType::I16 => "Int16".into(),
195        ScalarType::I32 => "Int32".into(),
196        ScalarType::I64 => "Int64".into(),
197        ScalarType::I128 => "Int128".into(),
198        ScalarType::ISize => "Int".into(),
199        ScalarType::F32 => "Float".into(),
200        ScalarType::F64 => "Double".into(),
201        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
202            "String".into()
203        }
204        ScalarType::Unit => "Void".into(),
205        _ => "Data".into(),
206    }
207}
208
209/// Convert Shape to Swift type string.
210pub fn swift_type_base(shape: &'static Shape) -> String {
211    // Check for bytes first
212    if is_bytes(shape) {
213        return "Data".into();
214    }
215
216    match classify_shape(shape) {
217        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
218        ShapeKind::List { element } => format!("[{}]", swift_type_base(element)),
219        ShapeKind::Slice { element } => format!("[{}]", swift_type_base(element)),
220        ShapeKind::Option { inner } => format!("{}?", swift_type_base(inner)),
221        ShapeKind::Array { element, .. } => format!("[{}]", swift_type_base(element)),
222        ShapeKind::Map { key, value } => {
223            format!("[{}: {}]", swift_type_base(key), swift_type_base(value))
224        }
225        ShapeKind::Set { element } => format!("Set<{}>", swift_type_base(element)),
226        ShapeKind::Tuple { elements } => {
227            if elements.is_empty() {
228                "Void".into()
229            } else {
230                let types: Vec<_> = elements.iter().map(|p| swift_type_base(p.shape)).collect();
231                format!("({})", types.join(", "))
232            }
233        }
234        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
235        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
236        ShapeKind::Struct(StructInfo {
237            name: Some(name), ..
238        }) => name.to_string(),
239        ShapeKind::Enum(EnumInfo {
240            name: Some(name), ..
241        }) => name.to_string(),
242        ShapeKind::Struct(StructInfo {
243            name: None, fields, ..
244        }) => {
245            // Anonymous struct - use tuple-like representation
246            let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
247            format!("({})", types.join(", "))
248        }
249        ShapeKind::Enum(EnumInfo {
250            name: None,
251            variants,
252        }) => {
253            // Anonymous enum - not well supported in Swift, use Any
254            let _ = variants; // suppress warning
255            "Any".into()
256        }
257        ShapeKind::Pointer { pointee } => swift_type_base(pointee),
258        ShapeKind::Result { ok, err } => {
259            format!("Result<{}, {}>", swift_type_base(ok), swift_type_base(err))
260        }
261        ShapeKind::TupleStruct { fields } => {
262            let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
263            format!("({})", types.join(", "))
264        }
265        ShapeKind::Opaque => "Data".into(),
266    }
267}
268
269/// Convert Shape to Swift type string for client arguments.
270pub fn swift_type_client_arg(shape: &'static Shape) -> String {
271    match classify_shape(shape) {
272        ShapeKind::Tx { inner } => format!("UnboundTx<{}>", swift_type_base(inner)),
273        ShapeKind::Rx { inner } => format!("UnboundRx<{}>", swift_type_base(inner)),
274        _ => swift_type_base(shape),
275    }
276}
277
278/// Convert Shape to Swift type string for client returns.
279pub fn swift_type_client_return(shape: &'static Shape) -> String {
280    match classify_shape(shape) {
281        ShapeKind::Tx { inner } => format!("UnboundTx<{}>", swift_type_base(inner)),
282        ShapeKind::Rx { inner } => format!("UnboundRx<{}>", swift_type_base(inner)),
283        ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
284        ShapeKind::Tuple { elements: [] } => "Void".into(),
285        _ => swift_type_base(shape),
286    }
287}
288
289/// Convert Shape to Swift type string for server/handler arguments.
290pub fn swift_type_server_arg(shape: &'static Shape) -> String {
291    match classify_shape(shape) {
292        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
293        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
294        _ => swift_type_base(shape),
295    }
296}
297
298/// Convert Shape to Swift type string for server returns.
299pub fn swift_type_server_return(shape: &'static Shape) -> String {
300    match classify_shape(shape) {
301        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
302        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
303        ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
304        ShapeKind::Tuple { elements: [] } => "Void".into(),
305        _ => swift_type_base(shape),
306    }
307}
308
309/// Check if a shape uses streaming (Tx or Rx).
310pub fn is_stream(shape: &'static Shape) -> bool {
311    is_tx(shape) || is_rx(shape)
312}
313
314/// Format documentation comments for Swift.
315pub fn format_doc(doc: &str, indent: &str) -> String {
316    doc.lines()
317        .map(|line| format!("{indent}/// {line}\n"))
318        .collect()
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use facet::Facet;
325
326    #[test]
327    fn test_swift_type_base_primitives() {
328        assert_eq!(swift_type_base(<bool as Facet>::SHAPE), "Bool");
329        assert_eq!(swift_type_base(<u32 as Facet>::SHAPE), "UInt32");
330        assert_eq!(swift_type_base(<i64 as Facet>::SHAPE), "Int64");
331        assert_eq!(swift_type_base(<f32 as Facet>::SHAPE), "Float");
332        assert_eq!(swift_type_base(<f64 as Facet>::SHAPE), "Double");
333        assert_eq!(swift_type_base(<String as Facet>::SHAPE), "String");
334        assert_eq!(swift_type_base(<Vec<u8> as Facet>::SHAPE), "Data");
335        assert_eq!(swift_type_base(<() as Facet>::SHAPE), "Void");
336    }
337
338    #[test]
339    fn test_swift_type_base_containers() {
340        assert_eq!(swift_type_base(<Vec<i32> as Facet>::SHAPE), "[Int32]");
341        assert_eq!(swift_type_base(<Option<String> as Facet>::SHAPE), "String?");
342    }
343}