roam_codegen/targets/swift/
server.rs

1//! Swift server/handler generation.
2//!
3//! Generates handler protocol and dispatcher for routing incoming calls.
4
5use facet_core::{ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_schema::{
8    MethodDetail, ServiceDetail, ShapeKind, StructInfo, classify_shape, is_rx, is_tx,
9};
10
11use super::decode::{generate_decode_stmt, generate_inline_decode};
12use super::encode::generate_encode_closure;
13use super::types::{format_doc, is_stream, swift_type_server_arg, swift_type_server_return};
14use crate::code_writer::CodeWriter;
15use crate::cw_writeln;
16use crate::render::hex_u64;
17
18/// Generate complete server code (handler protocol + dispatchers).
19pub fn generate_server(service: &ServiceDetail) -> String {
20    let mut out = String::new();
21    out.push_str(&generate_handler_protocol(service));
22    out.push_str(&generate_dispatcher(service));
23    out.push_str(&generate_streaming_dispatcher(service));
24    out
25}
26
27/// Generate handler protocol (for handling incoming calls).
28fn generate_handler_protocol(service: &ServiceDetail) -> String {
29    let mut out = String::new();
30    let service_name = service.name.to_upper_camel_case();
31
32    if let Some(doc) = &service.doc {
33        out.push_str(&format_doc(doc, ""));
34    }
35    out.push_str(&format!("public protocol {service_name}Handler {{\n"));
36
37    for method in &service.methods {
38        let method_name = method.method_name.to_lower_camel_case();
39
40        if let Some(doc) = &method.doc {
41            out.push_str(&format_doc(doc, "    "));
42        }
43
44        // Server perspective
45        let args: Vec<String> = method
46            .args
47            .iter()
48            .map(|a| {
49                format!(
50                    "{}: {}",
51                    a.name.to_lower_camel_case(),
52                    swift_type_server_arg(a.ty)
53                )
54            })
55            .collect();
56
57        let ret_type = swift_type_server_return(method.return_type);
58
59        if ret_type == "Void" {
60            out.push_str(&format!(
61                "    func {method_name}({}) async throws\n",
62                args.join(", ")
63            ));
64        } else {
65            out.push_str(&format!(
66                "    func {method_name}({}) async throws -> {ret_type}\n",
67                args.join(", ")
68            ));
69        }
70    }
71
72    out.push_str("}\n\n");
73    out
74}
75
76/// Generate dispatcher for handling incoming calls.
77fn generate_dispatcher(service: &ServiceDetail) -> String {
78    let mut out = String::new();
79    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
80    let service_name = service.name.to_upper_camel_case();
81
82    cw_writeln!(w, "public final class {service_name}Dispatcher {{").unwrap();
83    {
84        let _indent = w.indent();
85        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
86        w.blank_line().unwrap();
87        cw_writeln!(w, "public init(handler: {service_name}Handler) {{").unwrap();
88        {
89            let _indent = w.indent();
90            w.writeln("self.handler = handler").unwrap();
91        }
92        w.writeln("}").unwrap();
93        w.blank_line().unwrap();
94
95        // Main dispatch method
96        w.writeln("public func dispatch(methodId: UInt64, payload: Data) async throws -> Data {")
97            .unwrap();
98        {
99            let _indent = w.indent();
100            w.writeln("switch methodId {").unwrap();
101            for method in &service.methods {
102                let method_name = method.method_name.to_lower_camel_case();
103                let method_id = crate::method_id(method);
104                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
105                cw_writeln!(
106                    w,
107                    "    return try await dispatch{method_name}(payload: payload)"
108                )
109                .unwrap();
110            }
111            w.writeln("default:").unwrap();
112            w.writeln("    throw RoamError.unknownMethod").unwrap();
113            w.writeln("}").unwrap();
114        }
115        w.writeln("}").unwrap();
116
117        // Individual dispatch methods
118        for method in &service.methods {
119            w.blank_line().unwrap();
120            generate_dispatch_method(&mut w, method);
121        }
122    }
123    w.writeln("}").unwrap();
124    w.blank_line().unwrap();
125
126    out
127}
128
129/// Generate a single dispatch method for non-streaming dispatcher.
130fn generate_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDetail) {
131    let method_name = method.method_name.to_lower_camel_case();
132    let has_streaming =
133        method.args.iter().any(|a| is_stream(a.ty)) || is_stream(method.return_type);
134
135    cw_writeln!(
136        w,
137        "private func dispatch{method_name}(payload: Data) async throws -> Data {{"
138    )
139    .unwrap();
140    {
141        let _indent = w.indent();
142
143        if has_streaming {
144            w.writeln("// TODO: Implement streaming dispatch").unwrap();
145            w.writeln("throw RoamError.notImplemented").unwrap();
146        } else {
147            // Decode arguments
148            generate_decode_args(w, &method.args);
149
150            // Call handler
151            let arg_names: Vec<String> = method
152                .args
153                .iter()
154                .map(|a| {
155                    let name = a.name.to_lower_camel_case();
156                    format!("{name}: {name}")
157                })
158                .collect();
159
160            let ret_type = swift_type_server_return(method.return_type);
161
162            if ret_type == "Void" {
163                cw_writeln!(
164                    w,
165                    "try await handler.{method_name}({})",
166                    arg_names.join(", ")
167                )
168                .unwrap();
169                w.writeln("return Data()").unwrap();
170            } else {
171                cw_writeln!(
172                    w,
173                    "let result = try await handler.{method_name}({})",
174                    arg_names.join(", ")
175                )
176                .unwrap();
177                let encode_closure = generate_encode_closure(method.return_type);
178                cw_writeln!(
179                    w,
180                    "return Data(encodeResultOk(result, encoder: {encode_closure}))"
181                )
182                .unwrap();
183            }
184        }
185    }
186    w.writeln("}").unwrap();
187}
188
189/// Generate code to decode method arguments (for dispatcher).
190fn generate_decode_args(w: &mut CodeWriter<&mut String>, args: &[roam_schema::ArgDetail]) {
191    if args.is_empty() {
192        w.writeln("// No arguments to decode").unwrap();
193        return;
194    }
195
196    w.writeln("var offset = 0").unwrap();
197    for arg in args {
198        let arg_name = arg.name.to_lower_camel_case();
199        let decode_stmt = generate_decode_stmt(arg.ty, &arg_name, "");
200        for line in decode_stmt.lines() {
201            w.writeln(line).unwrap();
202        }
203    }
204}
205
206/// Generate streaming dispatcher for handling incoming calls with channel support.
207fn generate_streaming_dispatcher(service: &ServiceDetail) -> String {
208    let mut out = String::new();
209    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
210    let service_name = service.name.to_upper_camel_case();
211
212    cw_writeln!(w, "public final class {service_name}StreamingDispatcher {{").unwrap();
213    {
214        let _indent = w.indent();
215        cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
216        w.writeln("private let registry: IncomingChannelRegistry")
217            .unwrap();
218        w.writeln("private let taskSender: TaskSender").unwrap();
219        w.blank_line().unwrap();
220
221        cw_writeln!(
222            w,
223            "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) {{"
224        )
225        .unwrap();
226        {
227            let _indent = w.indent();
228            w.writeln("self.handler = handler").unwrap();
229            w.writeln("self.registry = registry").unwrap();
230            w.writeln("self.taskSender = taskSender").unwrap();
231        }
232        w.writeln("}").unwrap();
233        w.blank_line().unwrap();
234
235        // Main dispatch method
236        w.writeln(
237            "public func dispatch(methodId: UInt64, requestId: UInt64, payload: Data) async {",
238        )
239        .unwrap();
240        {
241            let _indent = w.indent();
242            w.writeln("switch methodId {").unwrap();
243            for method in &service.methods {
244                let method_name = method.method_name.to_lower_camel_case();
245                let method_id = crate::method_id(method);
246                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
247                cw_writeln!(
248                    w,
249                    "    await dispatch{method_name}(requestId: requestId, payload: payload)"
250                )
251                .unwrap();
252            }
253            w.writeln("default:").unwrap();
254            w.writeln(
255                "    taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
256            )
257            .unwrap();
258            w.writeln("}").unwrap();
259        }
260        w.writeln("}").unwrap();
261        w.blank_line().unwrap();
262
263        // Generate preregisterChannels method
264        generate_preregister_channels(&mut w, service);
265        w.blank_line().unwrap();
266
267        // Individual dispatch methods
268        for method in &service.methods {
269            generate_streaming_dispatch_method(&mut w, method);
270            w.blank_line().unwrap();
271        }
272    }
273    w.writeln("}").unwrap();
274    w.blank_line().unwrap();
275
276    out
277}
278
279/// Generate preregisterChannels method.
280fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDetail) {
281    w.writeln("/// Pre-register channel IDs from a request payload.")
282        .unwrap();
283    w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
284        .unwrap();
285    w.writeln("/// race conditions where Data arrives before channels are registered.")
286        .unwrap();
287    w.writeln("public static func preregisterChannels(methodId: UInt64, payload: Data, registry: ChannelRegistry) async {")
288        .unwrap();
289    {
290        let _indent = w.indent();
291        w.writeln("switch methodId {").unwrap();
292
293        for method in &service.methods {
294            let method_id = crate::method_id(method);
295            let has_rx_args = method.args.iter().any(|a| is_rx(a.ty));
296
297            if has_rx_args {
298                cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
299                w.writeln("    do {").unwrap();
300                w.writeln("        var offset = 0").unwrap();
301
302                // Parse channel IDs and mark them as known
303                for arg in &method.args {
304                    let arg_name = arg.name.to_lower_camel_case();
305                    if is_rx(arg.ty) {
306                        // Schema Rx = client sends, server receives → need to preregister
307                        cw_writeln!(
308                            w,
309                            "        let {arg_name}ChannelId = try decodeVarint(from: payload, offset: &offset)"
310                        )
311                        .unwrap();
312                        cw_writeln!(w, "        await registry.markKnown({arg_name}ChannelId)")
313                            .unwrap();
314                    } else if is_tx(arg.ty) {
315                        // Schema Tx = server sends → just skip the varint
316                        cw_writeln!(
317                            w,
318                            "        _ = try decodeVarint(from: payload, offset: &offset) // {arg_name}"
319                        )
320                        .unwrap();
321                    } else {
322                        // Non-streaming arg - skip it based on type
323                        generate_skip_arg(w, &arg_name, arg.ty, "        ");
324                    }
325                }
326
327                w.writeln("    } catch {").unwrap();
328                w.writeln("        // Ignore parse errors - dispatch will handle them")
329                    .unwrap();
330                w.writeln("    }").unwrap();
331            }
332        }
333
334        w.writeln("default:").unwrap();
335        w.writeln("    break").unwrap();
336        w.writeln("}").unwrap();
337    }
338    w.writeln("}").unwrap();
339}
340
341/// Generate code to skip over an argument during preregistration.
342fn generate_skip_arg(
343    w: &mut CodeWriter<&mut String>,
344    name: &str,
345    shape: &'static Shape,
346    indent: &str,
347) {
348    use roam_schema::is_bytes;
349
350    if is_bytes(shape) {
351        cw_writeln!(
352            w,
353            "{indent}_ = try decodeBytes(from: payload, offset: &offset) // {name}"
354        )
355        .unwrap();
356        return;
357    }
358
359    match classify_shape(shape) {
360        ShapeKind::Scalar(scalar) => {
361            let skip_code = match scalar {
362                ScalarType::Bool | ScalarType::U8 | ScalarType::I8 => "offset += 1",
363                ScalarType::U16 | ScalarType::I16 => "offset += 2",
364                ScalarType::U32 | ScalarType::I32 | ScalarType::U64 | ScalarType::I64 => {
365                    "_ = try decodeVarint(from: payload, offset: &offset)"
366                }
367                ScalarType::F32 => "offset += 4",
368                ScalarType::F64 => "offset += 8",
369                ScalarType::Unit => "",
370                ScalarType::Char => "_ = try decodeVarint(from: payload, offset: &offset)",
371                _ => "// unknown scalar type",
372            };
373            if !skip_code.is_empty() {
374                cw_writeln!(w, "{indent}{skip_code} // {name}").unwrap();
375            }
376        }
377        ShapeKind::List { .. } | ShapeKind::Slice { .. } | ShapeKind::Array { .. } => {
378            cw_writeln!(
379                w,
380                "{indent}_ = try decodeBytes(from: payload, offset: &offset) // {name} (skipped)"
381            )
382            .unwrap();
383        }
384        ShapeKind::Option { .. } => {
385            cw_writeln!(w, "{indent}// TODO: skip option {name}").unwrap();
386        }
387        ShapeKind::Struct(StructInfo { fields, .. }) => {
388            // For structs, recursively skip each field
389            for field in fields {
390                let field_name = format!("{}.{}", name, field.name);
391                generate_skip_arg(w, &field_name, field.shape(), indent);
392            }
393        }
394        _ => {
395            cw_writeln!(w, "{indent}// TODO: skip {name}").unwrap();
396        }
397    }
398}
399
400/// Generate a single streaming dispatch method.
401fn generate_streaming_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDetail) {
402    let method_name = method.method_name.to_lower_camel_case();
403    let has_streaming =
404        method.args.iter().any(|a| is_stream(a.ty)) || is_stream(method.return_type);
405
406    cw_writeln!(
407        w,
408        "private func dispatch{method_name}(requestId: UInt64, payload: Data) async {{"
409    )
410    .unwrap();
411    {
412        let _indent = w.indent();
413        w.writeln("do {").unwrap();
414        {
415            let _indent = w.indent();
416            w.writeln("var offset = 0").unwrap();
417
418            // Decode arguments - for streaming, decode channel IDs and create Tx/Rx
419            for arg in &method.args {
420                let arg_name = arg.name.to_lower_camel_case();
421                generate_streaming_decode_arg(w, &arg_name, arg.ty);
422            }
423
424            // Call handler
425            let arg_names: Vec<String> = method
426                .args
427                .iter()
428                .map(|a| {
429                    let name = a.name.to_lower_camel_case();
430                    format!("{name}: {name}")
431                })
432                .collect();
433
434            let ret_type = swift_type_server_return(method.return_type);
435
436            if has_streaming {
437                // For streaming methods, close any Tx channels after handler completes
438                if ret_type == "Void" {
439                    cw_writeln!(
440                        w,
441                        "try await handler.{method_name}({})",
442                        arg_names.join(", ")
443                    )
444                    .unwrap();
445                } else {
446                    cw_writeln!(
447                        w,
448                        "let result = try await handler.{method_name}({})",
449                        arg_names.join(", ")
450                    )
451                    .unwrap();
452                }
453
454                // Close any Tx channels
455                for arg in &method.args {
456                    if is_tx(arg.ty) {
457                        let arg_name = arg.name.to_lower_camel_case();
458                        cw_writeln!(w, "{arg_name}.close()").unwrap();
459                    }
460                }
461
462                // Send response
463                if ret_type == "Void" {
464                    w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
465                } else {
466                    let encode_closure = generate_encode_closure(method.return_type);
467                    cw_writeln!(
468                        w,
469                        "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
470                    )
471                    .unwrap();
472                }
473            } else {
474                // Non-streaming method
475                if ret_type == "Void" {
476                    cw_writeln!(
477                        w,
478                        "try await handler.{method_name}({})",
479                        arg_names.join(", ")
480                    )
481                    .unwrap();
482                    w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
483                } else {
484                    cw_writeln!(
485                        w,
486                        "let result = try await handler.{method_name}({})",
487                        arg_names.join(", ")
488                    )
489                    .unwrap();
490                    // Check if return type is Result<T, E> - if so, encode as Result<T, RoamError<User(E)>>
491                    if let ShapeKind::Result { ok, err } = classify_shape(method.return_type) {
492                        let ok_encode = generate_encode_closure(ok);
493                        let err_encode = generate_encode_closure(err);
494                        // Wire format: [0] + T for success, [1, 0] + E for User error
495                        cw_writeln!(
496                            w,
497                            "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}()))"
498                        )
499                        .unwrap();
500                    } else {
501                        let encode_closure = generate_encode_closure(method.return_type);
502                        cw_writeln!(
503                            w,
504                            "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
505                        )
506                        .unwrap();
507                    }
508                }
509            }
510        }
511        w.writeln("} catch {").unwrap();
512        {
513            let _indent = w.indent();
514            w.writeln(
515                "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError()))",
516            )
517            .unwrap();
518        }
519        w.writeln("}").unwrap();
520    }
521    w.writeln("}").unwrap();
522}
523
524/// Generate code to decode a single argument for streaming dispatch.
525fn generate_streaming_decode_arg(
526    w: &mut CodeWriter<&mut String>,
527    name: &str,
528    shape: &'static Shape,
529) {
530    match classify_shape(shape) {
531        ShapeKind::Rx { inner } => {
532            // Schema Rx = client passes Rx to method, sends via paired Tx
533            // Server needs to receive → create server Rx
534            let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
535            cw_writeln!(
536                w,
537                "let {name}ChannelId = try decodeVarint(from: payload, offset: &offset)"
538            )
539            .unwrap();
540            cw_writeln!(
541                w,
542                "let {name}Receiver = await registry.register({name}ChannelId)"
543            )
544            .unwrap();
545            cw_writeln!(
546                w,
547                "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
548            )
549            .unwrap();
550            cw_writeln!(w, "    var off = 0").unwrap();
551            cw_writeln!(w, "    return try {inline_decode}").unwrap();
552            w.writeln("})").unwrap();
553        }
554        ShapeKind::Tx { inner } => {
555            // Schema Tx = client passes Tx to method, receives via paired Rx
556            // Server needs to send → create server Tx
557            let encode_closure = generate_encode_closure(inner);
558            cw_writeln!(
559                w,
560                "let {name}ChannelId = try decodeVarint(from: payload, offset: &offset)"
561            )
562            .unwrap();
563            cw_writeln!(
564                w,
565                "let {name} = createServerTx(channelId: {name}ChannelId, taskSender: taskSender, serialize: ({encode_closure}))"
566            )
567            .unwrap();
568        }
569        _ => {
570            // Non-streaming argument - use standard decode
571            let decode_stmt = generate_decode_stmt(shape, name, "");
572            for line in decode_stmt.lines() {
573                w.writeln(line).unwrap();
574            }
575        }
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582    use facet::Facet;
583    use roam_schema::{ArgDetail, MethodDetail, ServiceDetail};
584    use std::borrow::Cow;
585
586    fn sample_service() -> ServiceDetail {
587        ServiceDetail {
588            name: Cow::Borrowed("Echo"),
589            doc: Some(Cow::Borrowed("Simple echo service")),
590            methods: vec![MethodDetail {
591                service_name: Cow::Borrowed("Echo"),
592                method_name: Cow::Borrowed("echo"),
593                args: vec![ArgDetail {
594                    name: Cow::Borrowed("message"),
595                    ty: <String as Facet>::SHAPE,
596                }],
597                return_type: <String as Facet>::SHAPE,
598                doc: Some(Cow::Borrowed("Echo back the message")),
599            }],
600        }
601    }
602
603    #[test]
604    fn test_generate_handler_protocol() {
605        let service = sample_service();
606        let code = generate_handler_protocol(&service);
607
608        assert!(code.contains("protocol EchoHandler"));
609        assert!(code.contains("func echo(message: String)"));
610        assert!(code.contains("async throws -> String"));
611    }
612
613    #[test]
614    fn test_generate_dispatcher() {
615        let service = sample_service();
616        let code = generate_dispatcher(&service);
617
618        assert!(code.contains("class EchoDispatcher"));
619        assert!(code.contains("EchoHandler"));
620        assert!(code.contains("dispatch(methodId:"));
621        assert!(code.contains("dispatchecho"));
622    }
623
624    #[test]
625    fn test_generate_streaming_dispatcher() {
626        let service = sample_service();
627        let code = generate_streaming_dispatcher(&service);
628
629        assert!(code.contains("class EchoStreamingDispatcher"));
630        assert!(code.contains("preregisterChannels"));
631        assert!(code.contains("IncomingChannelRegistry"));
632    }
633}