Skip to main content

roam_codegen/targets/swift/
server.rs

1//! Swift server/handler generation.
2//!
3//! Generates handler protocol and dispatcher for routing incoming calls.
4
5use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::{generate_decode_stmt_with_cursor, generate_inline_decode};
10use super::encode::generate_encode_closure;
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn dispatch_helper_name(method_name: &str) -> String {
17    format!("dispatch_{method_name}")
18}
19
20/// Generate complete server code (handler protocol + dispatchers).
21pub fn generate_server(service: &ServiceDescriptor) -> String {
22    let mut out = String::new();
23    out.push_str(&generate_handler_protocol(service));
24    // Emit only the channel-capable dispatcher.
25    out.push_str(&generate_channeling_dispatcher(service));
26    out
27}
28
29/// Generate handler protocol (for handling incoming calls).
30fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
31    let mut out = String::new();
32    let service_name = service.service_name.to_upper_camel_case();
33
34    if let Some(doc) = &service.doc {
35        out.push_str(&format_doc(doc, ""));
36    }
37    out.push_str(&format!("public protocol {service_name}Handler {{\n"));
38
39    for method in service.methods {
40        let method_name = method.method_name.to_lower_camel_case();
41
42        if let Some(doc) = &method.doc {
43            out.push_str(&format_doc(doc, "    "));
44        }
45
46        // Server perspective
47        let args: Vec<String> = method
48            .args
49            .iter()
50            .map(|a| {
51                format!(
52                    "{}: {}",
53                    a.name.to_lower_camel_case(),
54                    swift_type_server_arg(a.shape)
55                )
56            })
57            .collect();
58
59        let ret_type = swift_type_server_return(method.return_shape);
60
61        if ret_type == "Void" {
62            out.push_str(&format!(
63                "    func {method_name}({}) async throws\n",
64                args.join(", ")
65            ));
66        } else {
67            out.push_str(&format!(
68                "    func {method_name}({}) async throws -> {ret_type}\n",
69                args.join(", ")
70            ));
71        }
72    }
73
74    out.push_str("}\n\n");
75    out
76}
77
78/// Generate channeling dispatcher for handling incoming calls with channel support.
79fn generate_channeling_dispatcher(service: &ServiceDescriptor) -> String {
80    let mut out = String::new();
81    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
82    let service_name = service.service_name.to_upper_camel_case();
83
84    cw_writeln!(
85        w,
86        "public final class {service_name}ChannelingDispatcher {{"
87    )
88    .unwrap();
89    {
90        let _indent = w.indent();
91        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
92        w.writeln("private let registry: IncomingChannelRegistry")
93            .unwrap();
94        w.writeln("private let taskSender: TaskSender").unwrap();
95        w.blank_line().unwrap();
96
97        cw_writeln!(
98            w,
99            "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) {{"
100        )
101        .unwrap();
102        {
103            let _indent = w.indent();
104            w.writeln("self.handler = handler").unwrap();
105            w.writeln("self.registry = registry").unwrap();
106            w.writeln("self.taskSender = taskSender").unwrap();
107        }
108        w.writeln("}").unwrap();
109        w.blank_line().unwrap();
110
111        // Main dispatch method
112        w.writeln(
113            "public func dispatch(methodId: UInt64, requestId: UInt64, channels: [UInt64], payload: Data) async {",
114        )
115        .unwrap();
116        {
117            let _indent = w.indent();
118            w.writeln("switch methodId {").unwrap();
119            for method in service.methods {
120                let method_name = method.method_name.to_lower_camel_case();
121                let method_id = crate::method_id(method);
122                let dispatch_name = dispatch_helper_name(&method_name);
123                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
124                cw_writeln!(
125                    w,
126                    "    await {dispatch_name}(requestId: requestId, channels: channels, payload: payload)"
127                )
128                .unwrap();
129            }
130            w.writeln("default:").unwrap();
131            w.writeln(
132                "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
133            )
134            .unwrap();
135            w.writeln("}").unwrap();
136        }
137        w.writeln("}").unwrap();
138        w.blank_line().unwrap();
139
140        // Generate preregisterChannels method
141        generate_preregister_channels(&mut w, service);
142        w.blank_line().unwrap();
143
144        // Individual dispatch methods
145        for method in service.methods {
146            generate_channeling_dispatch_method(&mut w, method);
147            w.blank_line().unwrap();
148        }
149    }
150    w.writeln("}").unwrap();
151    w.blank_line().unwrap();
152
153    out
154}
155
156/// Generate preregisterChannels method.
157fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
158    w.writeln("/// Pre-register Rx channel IDs from request channels.")
159        .unwrap();
160    w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
161        .unwrap();
162    w.writeln("/// race conditions where Data arrives before channels are registered.")
163        .unwrap();
164    w.writeln("public static func preregisterChannels(methodId: UInt64, channels: [UInt64], registry: ChannelRegistry) async {")
165        .unwrap();
166    {
167        let _indent = w.indent();
168        w.writeln("switch methodId {").unwrap();
169
170        for method in service.methods {
171            let method_id = crate::method_id(method);
172            let has_rx_args = method.args.iter().any(|a| is_rx(a.shape));
173
174            if has_rx_args {
175                let channel_arg_count = method
176                    .args
177                    .iter()
178                    .filter(|a| is_rx(a.shape) || is_tx(a.shape))
179                    .count();
180                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
181                cw_writeln!(w, "    guard channels.count >= {channel_arg_count} else {{").unwrap();
182                w.writeln("        return").unwrap();
183                w.writeln("    }").unwrap();
184                w.writeln("    var channelCursor = 0").unwrap();
185
186                // Channel IDs are provided in declaration order.
187                for arg in method.args {
188                    let arg_name = arg.name.to_lower_camel_case();
189                    if is_rx(arg.shape) {
190                        // Schema Rx = client sends, server receives → need to preregister
191                        cw_writeln!(w, "    let {arg_name}ChannelId = channels[channelCursor]")
192                            .unwrap();
193                        w.writeln("    channelCursor += 1").unwrap();
194                        cw_writeln!(w, "    await registry.markKnown({arg_name}ChannelId)")
195                            .unwrap();
196                    } else if is_tx(arg.shape) {
197                        cw_writeln!(w, "    _ = channels[channelCursor] // {arg_name}").unwrap();
198                        w.writeln("    channelCursor += 1").unwrap();
199                    }
200                }
201            }
202        }
203
204        w.writeln("default:").unwrap();
205        w.writeln("    break").unwrap();
206        w.writeln("}").unwrap();
207    }
208    w.writeln("}").unwrap();
209}
210
211/// Generate a single channeling dispatch method.
212fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
213    let method_name = method.method_name.to_lower_camel_case();
214    let dispatch_name = dispatch_helper_name(&method_name);
215    let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
216
217    cw_writeln!(
218        w,
219        "private func {dispatch_name}(requestId: UInt64, channels: [UInt64], payload: Data) async {{"
220    )
221    .unwrap();
222    {
223        let _indent = w.indent();
224        w.writeln("do {").unwrap();
225        {
226            let _indent = w.indent();
227            let has_payload_args = method
228                .args
229                .iter()
230                .any(|a| !is_rx(a.shape) && !is_tx(a.shape));
231            let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
232            let cursor_var = if has_payload_args {
233                let name = unique_decode_cursor_name(method.args);
234                cw_writeln!(w, "var {name} = 0").unwrap();
235                Some(name)
236            } else {
237                None
238            };
239            if has_channel_args {
240                w.writeln("var channelCursor = 0").unwrap();
241            }
242
243            // Decode arguments - channel IDs come from Request.channels.
244            for arg in method.args {
245                let arg_name = arg.name.to_lower_camel_case();
246                generate_channeling_decode_arg(
247                    w,
248                    &arg_name,
249                    arg.shape,
250                    cursor_var.as_deref(),
251                    "channels",
252                    Some("channelCursor"),
253                );
254            }
255
256            // Call handler
257            let arg_names: Vec<String> = method
258                .args
259                .iter()
260                .map(|a| {
261                    let name = a.name.to_lower_camel_case();
262                    format!("{name}: {name}")
263                })
264                .collect();
265
266            let ret_type = swift_type_server_return(method.return_shape);
267
268            if has_channeling {
269                // For channeling methods, close any Tx channels after handler completes
270                if ret_type == "Void" {
271                    cw_writeln!(
272                        w,
273                        "try await handler.{method_name}({})",
274                        arg_names.join(", ")
275                    )
276                    .unwrap();
277                } else {
278                    cw_writeln!(
279                        w,
280                        "let result = try await handler.{method_name}({})",
281                        arg_names.join(", ")
282                    )
283                    .unwrap();
284                }
285
286                // Close any Tx channels
287                for arg in method.args {
288                    if is_tx(arg.shape) {
289                        let arg_name = arg.name.to_lower_camel_case();
290                        cw_writeln!(w, "{arg_name}.close()").unwrap();
291                    }
292                }
293
294                // Send response
295                if ret_type == "Void" {
296                    w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
297                } else {
298                    let encode_closure = generate_encode_closure(method.return_shape);
299                    cw_writeln!(
300                        w,
301                        "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
302                    )
303                    .unwrap();
304                }
305            } else {
306                // Non-channeling method
307                if ret_type == "Void" {
308                    cw_writeln!(
309                        w,
310                        "try await handler.{method_name}({})",
311                        arg_names.join(", ")
312                    )
313                    .unwrap();
314                    w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
315                } else {
316                    cw_writeln!(
317                        w,
318                        "let result = try await handler.{method_name}({})",
319                        arg_names.join(", ")
320                    )
321                    .unwrap();
322                    // Check if return type is Result<T, E> - if so, encode as Result<T, RoamError<User(E)>>
323                    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
324                        let ok_encode = generate_encode_closure(ok);
325                        let err_encode = generate_encode_closure(err);
326                        // Wire format: [0] + T for success, [1, 0] + E for User error
327                        cw_writeln!(
328                            w,
329                            "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}()))"
330                        )
331                        .unwrap();
332                    } else {
333                        let encode_closure = generate_encode_closure(method.return_shape);
334                        cw_writeln!(
335                            w,
336                            "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
337                        )
338                        .unwrap();
339                    }
340                }
341            }
342        }
343        w.writeln("} catch {").unwrap();
344        {
345            let _indent = w.indent();
346            w.writeln(
347                "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError()))",
348            )
349            .unwrap();
350        }
351        w.writeln("}").unwrap();
352    }
353    w.writeln("}").unwrap();
354}
355
356/// Generate code to decode a single argument for channeling dispatch.
357fn generate_channeling_decode_arg(
358    w: &mut CodeWriter<&mut String>,
359    name: &str,
360    shape: &'static Shape,
361    cursor_var: Option<&str>,
362    channels_var: &str,
363    channel_cursor_var: Option<&str>,
364) {
365    match classify_shape(shape) {
366        ShapeKind::Rx { inner } => {
367            // Schema Rx = client passes Rx to method, sends via paired Tx
368            // Server needs to receive → create server Rx
369            let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
370            let channel_cursor_var =
371                channel_cursor_var.expect("channel cursor required for channeling args");
372            cw_writeln!(
373                w,
374                "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
375            )
376            .unwrap();
377            cw_writeln!(
378                w,
379                "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
380            )
381            .unwrap();
382            cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
383            cw_writeln!(
384                w,
385                "let {name}Receiver = await registry.register({name}ChannelId)"
386            )
387            .unwrap();
388            cw_writeln!(
389                w,
390                "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
391            )
392            .unwrap();
393            cw_writeln!(w, "    var off = 0").unwrap();
394            cw_writeln!(w, "    return try {inline_decode}").unwrap();
395            w.writeln("})").unwrap();
396        }
397        ShapeKind::Tx { inner } => {
398            // Schema Tx = client passes Tx to method, receives via paired Rx
399            // Server needs to send → create server Tx
400            let encode_closure = generate_encode_closure(inner);
401            let channel_cursor_var =
402                channel_cursor_var.expect("channel cursor required for channeling args");
403            cw_writeln!(
404                w,
405                "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
406            )
407            .unwrap();
408            cw_writeln!(
409                w,
410                "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
411            )
412            .unwrap();
413            cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
414            cw_writeln!(
415                w,
416                "let {name} = createServerTx(channelId: {name}ChannelId, taskSender: taskSender, serialize: ({encode_closure}))"
417            )
418            .unwrap();
419        }
420        _ => {
421            // Non-channeling argument - use standard decode
422            let cursor_var = cursor_var.expect("payload cursor required for non-channel args");
423            let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", cursor_var);
424            for line in decode_stmt.lines() {
425                w.writeln(line).unwrap();
426            }
427        }
428    }
429}
430
431fn unique_decode_cursor_name(args: &[roam_types::ArgDescriptor]) -> String {
432    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
433    let mut candidate = String::from("cursor");
434    while arg_names.iter().any(|name| name == &candidate) {
435        candidate.push('_');
436    }
437    candidate
438}