Skip to main content

roam_codegen/targets/swift/
server.rs

1//! Swift server/handler generation.
2//!
3//! Generates handler protocol and dispatcher for routing incoming calls.
4
5use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::{generate_decode_stmt_with_cursor, generate_inline_decode};
10use super::encode::generate_encode_closure;
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn dispatch_helper_name(method_name: &str) -> String {
17    format!("dispatch_{method_name}")
18}
19
20fn extract_initial_credit(shape: &'static Shape) -> u32 {
21    shape
22        .const_params
23        .iter()
24        .find(|cp| cp.name == "N")
25        .map(|cp| cp.value as u32)
26        .unwrap_or(16)
27}
28
29/// Generate complete server code (handler protocol + dispatchers).
30pub fn generate_server(service: &ServiceDescriptor) -> String {
31    let mut out = String::new();
32    out.push_str(&generate_handler_protocol(service));
33    // Emit only the channel-capable dispatcher.
34    out.push_str(&generate_channeling_dispatcher(service));
35    out
36}
37
38/// Generate handler protocol (for handling incoming calls).
39fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40    let mut out = String::new();
41    let service_name = service.service_name.to_upper_camel_case();
42
43    if let Some(doc) = &service.doc {
44        out.push_str(&format_doc(doc, ""));
45    }
46    out.push_str(&format!("public protocol {service_name}Handler {{\n"));
47
48    for method in service.methods {
49        let method_name = method.method_name.to_lower_camel_case();
50
51        if let Some(doc) = &method.doc {
52            out.push_str(&format_doc(doc, "    "));
53        }
54
55        // Server perspective
56        let args: Vec<String> = method
57            .args
58            .iter()
59            .map(|a| {
60                format!(
61                    "{}: {}",
62                    a.name.to_lower_camel_case(),
63                    swift_type_server_arg(a.shape)
64                )
65            })
66            .collect();
67
68        let ret_type = swift_type_server_return(method.return_shape);
69
70        if ret_type == "Void" {
71            out.push_str(&format!(
72                "    func {method_name}({}) async throws\n",
73                args.join(", ")
74            ));
75        } else {
76            out.push_str(&format!(
77                "    func {method_name}({}) async throws -> {ret_type}\n",
78                args.join(", ")
79            ));
80        }
81    }
82
83    out.push_str("}\n\n");
84    out
85}
86
87/// Generate channeling dispatcher for handling incoming calls with channel support.
88fn generate_channeling_dispatcher(service: &ServiceDescriptor) -> String {
89    let mut out = String::new();
90    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91    let service_name = service.service_name.to_upper_camel_case();
92
93    cw_writeln!(
94        w,
95        "public final class {service_name}ChannelingDispatcher {{"
96    )
97    .unwrap();
98    {
99        let _indent = w.indent();
100        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
101        w.writeln("private let registry: IncomingChannelRegistry")
102            .unwrap();
103        w.writeln("private let taskSender: TaskSender").unwrap();
104        w.blank_line().unwrap();
105
106        cw_writeln!(
107            w,
108            "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) {{"
109        )
110        .unwrap();
111        {
112            let _indent = w.indent();
113            w.writeln("self.handler = handler").unwrap();
114            w.writeln("self.registry = registry").unwrap();
115            w.writeln("self.taskSender = taskSender").unwrap();
116        }
117        w.writeln("}").unwrap();
118        w.blank_line().unwrap();
119
120        // Main dispatch method
121        w.writeln(
122            "public func dispatch(methodId: UInt64, requestId: UInt64, channels: [UInt64], payload: Data) async {",
123        )
124        .unwrap();
125        {
126            let _indent = w.indent();
127            w.writeln("switch methodId {").unwrap();
128            for method in service.methods {
129                let method_name = method.method_name.to_lower_camel_case();
130                let method_id = crate::method_id(method);
131                let dispatch_name = dispatch_helper_name(&method_name);
132                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
133                cw_writeln!(
134                    w,
135                    "    await {dispatch_name}(requestId: requestId, channels: channels, payload: payload)"
136                )
137                .unwrap();
138            }
139            w.writeln("default:").unwrap();
140            w.writeln(
141                "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
142            )
143            .unwrap();
144            w.writeln("}").unwrap();
145        }
146        w.writeln("}").unwrap();
147        w.blank_line().unwrap();
148
149        // Generate preregisterChannels method
150        generate_preregister_channels(&mut w, service);
151        w.blank_line().unwrap();
152
153        // Individual dispatch methods
154        for method in service.methods {
155            generate_channeling_dispatch_method(&mut w, method);
156            w.blank_line().unwrap();
157        }
158    }
159    w.writeln("}").unwrap();
160    w.blank_line().unwrap();
161
162    out
163}
164
165/// Generate preregisterChannels method.
166fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
167    w.writeln("/// Pre-register Rx channel IDs from request channels.")
168        .unwrap();
169    w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
170        .unwrap();
171    w.writeln("/// race conditions where Data arrives before channels are registered.")
172        .unwrap();
173    w.writeln("public static func preregisterChannels(methodId: UInt64, channels: [UInt64], registry: ChannelRegistry) async {")
174        .unwrap();
175    {
176        let _indent = w.indent();
177        w.writeln("switch methodId {").unwrap();
178
179        for method in service.methods {
180            let method_id = crate::method_id(method);
181            let has_rx_args = method.args.iter().any(|a| is_rx(a.shape));
182
183            if has_rx_args {
184                let channel_arg_count = method
185                    .args
186                    .iter()
187                    .filter(|a| is_rx(a.shape) || is_tx(a.shape))
188                    .count();
189                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
190                cw_writeln!(w, "    guard channels.count >= {channel_arg_count} else {{").unwrap();
191                w.writeln("        return").unwrap();
192                w.writeln("    }").unwrap();
193                w.writeln("    var channelCursor = 0").unwrap();
194
195                // Channel IDs are provided in declaration order.
196                for arg in method.args {
197                    let arg_name = arg.name.to_lower_camel_case();
198                    if is_rx(arg.shape) {
199                        // Schema Rx = client sends, server receives → need to preregister
200                        cw_writeln!(w, "    let {arg_name}ChannelId = channels[channelCursor]")
201                            .unwrap();
202                        w.writeln("    channelCursor += 1").unwrap();
203                        cw_writeln!(w, "    await registry.markKnown({arg_name}ChannelId)")
204                            .unwrap();
205                    } else if is_tx(arg.shape) {
206                        cw_writeln!(w, "    _ = channels[channelCursor] // {arg_name}").unwrap();
207                        w.writeln("    channelCursor += 1").unwrap();
208                    }
209                }
210            }
211        }
212
213        w.writeln("default:").unwrap();
214        w.writeln("    break").unwrap();
215        w.writeln("}").unwrap();
216    }
217    w.writeln("}").unwrap();
218}
219
220/// Generate a single channeling dispatch method.
221fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
222    let method_name = method.method_name.to_lower_camel_case();
223    let dispatch_name = dispatch_helper_name(&method_name);
224    let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
225
226    cw_writeln!(
227        w,
228        "private func {dispatch_name}(requestId: UInt64, channels: [UInt64], payload: Data) async {{"
229    )
230    .unwrap();
231    {
232        let _indent = w.indent();
233        w.writeln("do {").unwrap();
234        {
235            let _indent = w.indent();
236            let has_payload_args = method
237                .args
238                .iter()
239                .any(|a| !is_rx(a.shape) && !is_tx(a.shape));
240            let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
241            let cursor_var = if has_payload_args {
242                let name = unique_decode_cursor_name(method.args);
243                cw_writeln!(w, "var {name} = 0").unwrap();
244                Some(name)
245            } else {
246                None
247            };
248            if has_channel_args {
249                w.writeln("var channelCursor = 0").unwrap();
250            }
251
252            // Decode arguments - channel IDs come from Request.channels.
253            for arg in method.args {
254                let arg_name = arg.name.to_lower_camel_case();
255                generate_channeling_decode_arg(
256                    w,
257                    &arg_name,
258                    arg.shape,
259                    cursor_var.as_deref(),
260                    "channels",
261                    Some("channelCursor"),
262                );
263            }
264
265            // Call handler
266            let arg_names: Vec<String> = method
267                .args
268                .iter()
269                .map(|a| {
270                    let name = a.name.to_lower_camel_case();
271                    format!("{name}: {name}")
272                })
273                .collect();
274
275            let ret_type = swift_type_server_return(method.return_shape);
276
277            if has_channeling {
278                // For channeling methods, close any Tx channels after handler completes
279                if ret_type == "Void" {
280                    cw_writeln!(
281                        w,
282                        "try await handler.{method_name}({})",
283                        arg_names.join(", ")
284                    )
285                    .unwrap();
286                } else {
287                    cw_writeln!(
288                        w,
289                        "let result = try await handler.{method_name}({})",
290                        arg_names.join(", ")
291                    )
292                    .unwrap();
293                }
294
295                // Close any Tx channels
296                for arg in method.args {
297                    if is_tx(arg.shape) {
298                        let arg_name = arg.name.to_lower_camel_case();
299                        cw_writeln!(w, "{arg_name}.close()").unwrap();
300                    }
301                }
302
303                // Send response
304                if ret_type == "Void" {
305                    w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
306                } else {
307                    let encode_closure = generate_encode_closure(method.return_shape);
308                    cw_writeln!(
309                        w,
310                        "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
311                    )
312                    .unwrap();
313                }
314            } else {
315                // Non-channeling method
316                if ret_type == "Void" {
317                    cw_writeln!(
318                        w,
319                        "try await handler.{method_name}({})",
320                        arg_names.join(", ")
321                    )
322                    .unwrap();
323                    w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
324                } else {
325                    cw_writeln!(
326                        w,
327                        "let result = try await handler.{method_name}({})",
328                        arg_names.join(", ")
329                    )
330                    .unwrap();
331                    // Check if return type is Result<T, E> - if so, encode as Result<T, RoamError<User(E)>>
332                    if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
333                        let ok_encode = generate_encode_closure(ok);
334                        let err_encode = generate_encode_closure(err);
335                        // Wire format: [0] + T for success, [1, 0] + E for User error
336                        cw_writeln!(
337                            w,
338                            "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}()))"
339                        )
340                        .unwrap();
341                    } else {
342                        let encode_closure = generate_encode_closure(method.return_shape);
343                        cw_writeln!(
344                            w,
345                            "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
346                        )
347                        .unwrap();
348                    }
349                }
350            }
351        }
352        w.writeln("} catch {").unwrap();
353        {
354            let _indent = w.indent();
355            w.writeln(
356                "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError()))",
357            )
358            .unwrap();
359        }
360        w.writeln("}").unwrap();
361    }
362    w.writeln("}").unwrap();
363}
364
365/// Generate code to decode a single argument for channeling dispatch.
366fn generate_channeling_decode_arg(
367    w: &mut CodeWriter<&mut String>,
368    name: &str,
369    shape: &'static Shape,
370    cursor_var: Option<&str>,
371    channels_var: &str,
372    channel_cursor_var: Option<&str>,
373) {
374    match classify_shape(shape) {
375        ShapeKind::Rx { inner } => {
376            // Schema Rx = client passes Rx to method, sends via paired Tx
377            // Server needs to receive → create server Rx
378            let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
379            let channel_cursor_var =
380                channel_cursor_var.expect("channel cursor required for channeling args");
381            cw_writeln!(
382                w,
383                "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
384            )
385            .unwrap();
386            cw_writeln!(
387                w,
388                "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
389            )
390            .unwrap();
391            cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
392            cw_writeln!(
393                w,
394                "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: {}, onConsumed: {{ [taskSender = self.taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})",
395                extract_initial_credit(shape)
396            )
397            .unwrap();
398            cw_writeln!(
399                w,
400                "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
401            )
402            .unwrap();
403            cw_writeln!(w, "    var off = 0").unwrap();
404            cw_writeln!(w, "    return try {inline_decode}").unwrap();
405            w.writeln("})").unwrap();
406        }
407        ShapeKind::Tx { inner } => {
408            // Schema Tx = client passes Tx to method, receives via paired Rx
409            // Server needs to send → create server Tx
410            let encode_closure = generate_encode_closure(inner);
411            let channel_cursor_var =
412                channel_cursor_var.expect("channel cursor required for channeling args");
413            cw_writeln!(
414                w,
415                "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
416            )
417            .unwrap();
418            cw_writeln!(
419                w,
420                "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
421            )
422            .unwrap();
423            cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
424            cw_writeln!(
425                w,
426                "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: {}, serialize: ({encode_closure}))",
427                extract_initial_credit(shape)
428            )
429            .unwrap();
430        }
431        _ => {
432            // Non-channeling argument - use standard decode
433            let cursor_var = cursor_var.expect("payload cursor required for non-channel args");
434            let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", cursor_var);
435            for line in decode_stmt.lines() {
436                w.writeln(line).unwrap();
437            }
438        }
439    }
440}
441
442fn unique_decode_cursor_name(args: &[roam_types::ArgDescriptor]) -> String {
443    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
444    let mut candidate = String::from("cursor");
445    while arg_names.iter().any(|name| name == &candidate) {
446        candidate.push('_');
447    }
448    candidate
449}