Skip to main content

roam_codegen/targets/swift/
types.rs

1//! Swift type generation and collection.
2//!
3//! This module handles:
4//! - Collecting named types (structs and enums) from service definitions
5//! - Generating Swift type definitions (structs, enums)
6//! - Converting Rust types to Swift type strings
7
8use std::collections::HashSet;
9
10use facet_core::{ScalarType, Shape};
11use heck::ToLowerCamelCase;
12use roam_types::{
13    EnumInfo, RpcPlan, ServiceDescriptor, ShapeKind, StructInfo, VariantKind, classify_shape,
14    classify_variant, is_bytes, is_rx, is_tx,
15};
16
17/// Collect all named types (structs and enums with a name) from a service.
18/// Returns a vector of (name, Shape) pairs in dependency order.
19pub fn collect_named_types(service: &ServiceDescriptor) -> Vec<(String, &'static Shape)> {
20    let mut seen: HashSet<String> = HashSet::new();
21    let mut types = Vec::new();
22
23    fn visit(
24        shape: &'static Shape,
25        seen: &mut HashSet<String>,
26        types: &mut Vec<(String, &'static Shape)>,
27    ) {
28        match classify_shape(shape) {
29            ShapeKind::Struct(StructInfo {
30                name: Some(name),
31                fields,
32                ..
33            }) => {
34                if !seen.contains(name) {
35                    seen.insert(name.to_string());
36                    // Visit nested types first (dependencies before dependents)
37                    for field in fields {
38                        visit(field.shape(), seen, types);
39                    }
40                    types.push((name.to_string(), shape));
41                }
42            }
43            ShapeKind::Enum(EnumInfo {
44                name: Some(name),
45                variants,
46            }) => {
47                if !seen.contains(name) {
48                    seen.insert(name.to_string());
49                    // Visit nested types in variants
50                    for variant in variants {
51                        match classify_variant(variant) {
52                            VariantKind::Newtype { inner } => visit(inner, seen, types),
53                            VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
54                                for field in fields {
55                                    visit(field.shape(), seen, types);
56                                }
57                            }
58                            VariantKind::Unit => {}
59                        }
60                    }
61                    types.push((name.to_string(), shape));
62                }
63            }
64            ShapeKind::List { element }
65            | ShapeKind::Slice { element }
66            | ShapeKind::Option { inner: element }
67            | ShapeKind::Array { element, .. }
68            | ShapeKind::Set { element } => visit(element, seen, types),
69            ShapeKind::Map { key, value } => {
70                visit(key, seen, types);
71                visit(value, seen, types);
72            }
73            ShapeKind::Tuple { elements } => {
74                for param in elements {
75                    visit(param.shape, seen, types);
76                }
77            }
78            ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => visit(inner, seen, types),
79            ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
80            ShapeKind::Result { ok, err } => {
81                visit(ok, seen, types);
82                visit(err, seen, types);
83            }
84            _ => {}
85        }
86    }
87
88    for method in service.methods {
89        for arg in method.args {
90            visit(arg.shape, &mut seen, &mut types);
91        }
92        visit(method.return_shape, &mut seen, &mut types);
93    }
94
95    types
96}
97
98/// Generate Swift type definitions for all named types.
99pub fn generate_named_types(named_types: &[(String, &'static Shape)]) -> String {
100    let mut out = String::new();
101
102    for (name, shape) in named_types {
103        match classify_shape(shape) {
104            ShapeKind::Struct(StructInfo { fields, .. }) => {
105                out.push_str(&format!("public struct {name}: Codable, Sendable {{\n"));
106                for field in fields {
107                    let field_name = field.name.to_lower_camel_case();
108                    let field_type = swift_type_base(field.shape());
109                    out.push_str(&format!("    public var {field_name}: {field_type}\n"));
110                }
111                out.push('\n');
112                // Generate initializer
113                out.push_str("    public init(");
114                for (i, field) in fields.iter().enumerate() {
115                    if i > 0 {
116                        out.push_str(", ");
117                    }
118                    let field_name = field.name.to_lower_camel_case();
119                    let field_type = swift_type_base(field.shape());
120                    out.push_str(&format!("{field_name}: {field_type}"));
121                }
122                out.push_str(") {\n");
123                for field in fields {
124                    let field_name = field.name.to_lower_camel_case();
125                    out.push_str(&format!("        self.{field_name} = {field_name}\n"));
126                }
127                out.push_str("    }\n");
128                out.push_str("}\n\n");
129            }
130            ShapeKind::Enum(EnumInfo { variants, .. }) => {
131                // Add Error conformance if the enum name ends with "Error"
132                let protocols = if name.ends_with("Error") {
133                    "Codable, Sendable, Error"
134                } else {
135                    "Codable, Sendable"
136                };
137                out.push_str(&format!("public enum {name}: {protocols} {{\n"));
138                for variant in variants {
139                    let variant_name = variant.name.to_lower_camel_case();
140                    match classify_variant(variant) {
141                        VariantKind::Unit => {
142                            out.push_str(&format!("    case {variant_name}\n"));
143                        }
144                        VariantKind::Newtype { inner } => {
145                            let inner_type = swift_type_base(inner);
146                            out.push_str(&format!("    case {variant_name}({inner_type})\n"));
147                        }
148                        VariantKind::Tuple { fields } => {
149                            let field_types: Vec<_> =
150                                fields.iter().map(|f| swift_type_base(f.shape())).collect();
151                            out.push_str(&format!(
152                                "    case {variant_name}({})\n",
153                                field_types.join(", ")
154                            ));
155                        }
156                        VariantKind::Struct { fields } => {
157                            let field_decls: Vec<_> = fields
158                                .iter()
159                                .map(|f| {
160                                    format!(
161                                        "{}: {}",
162                                        f.name.to_lower_camel_case(),
163                                        swift_type_base(f.shape())
164                                    )
165                                })
166                                .collect();
167                            out.push_str(&format!(
168                                "    case {variant_name}({})\n",
169                                field_decls.join(", ")
170                            ));
171                        }
172                    }
173                }
174                out.push_str("}\n\n");
175            }
176            _ => {}
177        }
178    }
179
180    out
181}
182
183/// Convert ScalarType to Swift type string.
184pub fn swift_scalar_type(scalar: ScalarType) -> String {
185    match scalar {
186        ScalarType::Bool => "Bool".into(),
187        ScalarType::U8 => "UInt8".into(),
188        ScalarType::U16 => "UInt16".into(),
189        ScalarType::U32 => "UInt32".into(),
190        ScalarType::U64 => "UInt64".into(),
191        ScalarType::U128 => "UInt128".into(),
192        ScalarType::USize => "UInt".into(),
193        ScalarType::I8 => "Int8".into(),
194        ScalarType::I16 => "Int16".into(),
195        ScalarType::I32 => "Int32".into(),
196        ScalarType::I64 => "Int64".into(),
197        ScalarType::I128 => "Int128".into(),
198        ScalarType::ISize => "Int".into(),
199        ScalarType::F32 => "Float".into(),
200        ScalarType::F64 => "Double".into(),
201        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
202            "String".into()
203        }
204        ScalarType::Unit => "Void".into(),
205        _ => "Data".into(),
206    }
207}
208
209/// Convert Shape to Swift type string.
210pub fn swift_type_base(shape: &'static Shape) -> String {
211    // Check for bytes first
212    if is_bytes(shape) {
213        return "Data".into();
214    }
215
216    match classify_shape(shape) {
217        ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
218        ShapeKind::List { element } => format!("[{}]", swift_type_base(element)),
219        ShapeKind::Slice { element } => format!("[{}]", swift_type_base(element)),
220        ShapeKind::Option { inner } => format!("{}?", swift_type_base(inner)),
221        ShapeKind::Array { element, .. } => format!("[{}]", swift_type_base(element)),
222        ShapeKind::Map { key, value } => {
223            format!("[{}: {}]", swift_type_base(key), swift_type_base(value))
224        }
225        ShapeKind::Set { element } => format!("Set<{}>", swift_type_base(element)),
226        ShapeKind::Tuple { elements } => {
227            if elements.is_empty() {
228                "Void".into()
229            } else {
230                let types: Vec<_> = elements.iter().map(|p| swift_type_base(p.shape)).collect();
231                format!("({})", types.join(", "))
232            }
233        }
234        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
235        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
236        ShapeKind::Struct(StructInfo {
237            name: Some(name), ..
238        }) => name.to_string(),
239        ShapeKind::Enum(EnumInfo {
240            name: Some(name), ..
241        }) => name.to_string(),
242        ShapeKind::Struct(StructInfo {
243            name: None, fields, ..
244        }) => {
245            // Anonymous struct - use tuple-like representation
246            let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
247            format!("({})", types.join(", "))
248        }
249        ShapeKind::Enum(EnumInfo {
250            name: None,
251            variants,
252        }) => {
253            // Anonymous enum - not well supported in Swift, use Any
254            let _ = variants; // suppress warning
255            "Any".into()
256        }
257        ShapeKind::Pointer { pointee } => swift_type_base(pointee),
258        ShapeKind::Result { ok, err } => {
259            format!("Result<{}, {}>", swift_type_base(ok), swift_type_base(err))
260        }
261        ShapeKind::TupleStruct { fields } => {
262            let types: Vec<_> = fields.iter().map(|f| swift_type_base(f.shape())).collect();
263            format!("({})", types.join(", "))
264        }
265        ShapeKind::Opaque => "Data".into(),
266    }
267}
268
269/// Convert Shape to Swift type string for client arguments.
270pub fn swift_type_client_arg(shape: &'static Shape) -> String {
271    match classify_shape(shape) {
272        ShapeKind::Tx { inner } => format!("UnboundTx<{}>", swift_type_base(inner)),
273        ShapeKind::Rx { inner } => format!("UnboundRx<{}>", swift_type_base(inner)),
274        _ => swift_type_base(shape),
275    }
276}
277
278/// Convert Shape to Swift type string for client returns.
279pub fn swift_type_client_return(shape: &'static Shape) -> String {
280    assert_no_channels_in_return_shape(shape);
281    match classify_shape(shape) {
282        ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
283        ShapeKind::Tuple { elements: [] } => "Void".into(),
284        _ => swift_type_base(shape),
285    }
286}
287
288/// Convert Shape to Swift type string for server/handler arguments.
289pub fn swift_type_server_arg(shape: &'static Shape) -> String {
290    match classify_shape(shape) {
291        ShapeKind::Tx { inner } => format!("Tx<{}>", swift_type_base(inner)),
292        ShapeKind::Rx { inner } => format!("Rx<{}>", swift_type_base(inner)),
293        _ => swift_type_base(shape),
294    }
295}
296
297/// Convert Shape to Swift type string for server returns.
298pub fn swift_type_server_return(shape: &'static Shape) -> String {
299    assert_no_channels_in_return_shape(shape);
300    match classify_shape(shape) {
301        ShapeKind::Scalar(ScalarType::Unit) => "Void".into(),
302        ShapeKind::Tuple { elements: [] } => "Void".into(),
303        _ => swift_type_base(shape),
304    }
305}
306
307/// Check if a shape represents a channel type (Tx or Rx).
308pub fn is_channel(shape: &'static Shape) -> bool {
309    is_tx(shape) || is_rx(shape)
310}
311
312/// Format documentation comments for Swift.
313pub fn format_doc(doc: &str, indent: &str) -> String {
314    doc.lines()
315        .map(|line| format!("{indent}/// {line}\n"))
316        .collect()
317}
318
319fn assert_no_channels_in_return_shape(shape: &'static Shape) {
320    assert!(
321        RpcPlan::for_shape(shape).channel_locations.is_empty(),
322        "channels are not allowed in return types"
323    );
324}