Skip to main content

roam_codegen/targets/swift/
client.rs

1//! Swift client generation.
2//!
3//! Generates caller protocol and client implementation for making RPC calls.
4
5use heck::{ToLowerCamelCase, ToUpperCamelCase};
6use roam_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
7
8use super::decode::generate_decode_stmt_from_with_cursor;
9use super::encode::generate_encode_expr;
10use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
11use crate::code_writer::CodeWriter;
12use crate::cw_writeln;
13use crate::render::hex_u64;
14
15/// Generate complete client code (caller protocol + client implementation).
16pub fn generate_client(service: &ServiceDescriptor) -> String {
17    let mut out = String::new();
18    out.push_str(&generate_caller_protocol(service));
19    out.push_str(&generate_client_impl(service));
20    out
21}
22
23/// Generate caller protocol (for making calls to the service).
24fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
25    let mut out = String::new();
26    let service_name = service.service_name.to_upper_camel_case();
27
28    if let Some(doc) = &service.doc {
29        out.push_str(&format_doc(doc, ""));
30    }
31    out.push_str(&format!("public protocol {service_name}Caller {{\n"));
32
33    for method in service.methods {
34        let method_name = method.method_name.to_lower_camel_case();
35
36        if let Some(doc) = &method.doc {
37            out.push_str(&format_doc(doc, "    "));
38        }
39
40        let args: Vec<String> = method
41            .args
42            .iter()
43            .map(|a| {
44                format!(
45                    "{}: {}",
46                    a.name.to_lower_camel_case(),
47                    swift_type_client_arg(a.shape)
48                )
49            })
50            .collect();
51
52        let ret_type = swift_type_client_return(method.return_shape);
53
54        if ret_type == "Void" {
55            out.push_str(&format!(
56                "    func {method_name}({}) async throws\n",
57                args.join(", ")
58            ));
59        } else {
60            out.push_str(&format!(
61                "    func {method_name}({}) async throws -> {ret_type}\n",
62                args.join(", ")
63            ));
64        }
65    }
66
67    out.push_str("}\n\n");
68    out
69}
70
71/// Generate client implementation (for making calls to the service).
72fn generate_client_impl(service: &ServiceDescriptor) -> String {
73    let mut out = String::new();
74    let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
75    let service_name = service.service_name.to_upper_camel_case();
76
77    w.writeln(&format!(
78        "public final class {service_name}Client: {service_name}Caller, Sendable {{"
79    ))
80    .unwrap();
81    {
82        let _indent = w.indent();
83        w.writeln("private let connection: RoamConnection").unwrap();
84        w.writeln("private let timeout: TimeInterval?").unwrap();
85        w.blank_line().unwrap();
86        w.writeln("public init(connection: RoamConnection, timeout: TimeInterval? = 30.0) {")
87            .unwrap();
88        {
89            let _indent = w.indent();
90            w.writeln("self.connection = connection").unwrap();
91            w.writeln("self.timeout = timeout").unwrap();
92        }
93        w.writeln("}").unwrap();
94
95        for method in service.methods {
96            w.blank_line().unwrap();
97            generate_client_method(&mut w, method, &service_name);
98        }
99    }
100    w.writeln("}").unwrap();
101    w.blank_line().unwrap();
102
103    out
104}
105
106/// Generate a single client method implementation.
107fn generate_client_method(
108    w: &mut CodeWriter<&mut String>,
109    method: &MethodDescriptor,
110    service_name: &str,
111) {
112    let method_name = method.method_name.to_lower_camel_case();
113    let method_id_name = method.method_name.to_lower_camel_case();
114
115    let args: Vec<String> = method
116        .args
117        .iter()
118        .map(|a| {
119            format!(
120                "{}: {}",
121                a.name.to_lower_camel_case(),
122                swift_type_client_arg(a.shape)
123            )
124        })
125        .collect();
126
127    let ret_type = swift_type_client_return(method.return_shape);
128    let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
129
130    // Method signature
131    if ret_type == "Void" {
132        cw_writeln!(
133            w,
134            "public func {method_name}({}) async throws {{",
135            args.join(", ")
136        )
137        .unwrap();
138    } else {
139        cw_writeln!(
140            w,
141            "public func {method_name}({}) async throws -> {ret_type} {{",
142            args.join(", ")
143        )
144        .unwrap();
145    }
146
147    {
148        let _indent = w.indent();
149        let cursor_var = unique_decode_cursor_name(method.args);
150
151        if has_streaming {
152            generate_streaming_client_body(w, method, service_name, &method_id_name, &cursor_var);
153        } else {
154            // Encode arguments
155            generate_encode_args(w, method.args);
156
157            // Make call
158            let method_id = crate::method_id(method);
159            cw_writeln!(
160                w,
161                "let response = try await connection.call(methodId: {}, payload: payload, timeout: timeout)",
162                hex_u64(method_id)
163            )
164            .unwrap();
165            generate_response_decode(w, method, &cursor_var, "response");
166        }
167    }
168    w.writeln("}").unwrap();
169}
170
171/// Generate code to encode method arguments (for client).
172fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[roam_types::ArgDescriptor]) {
173    if args.is_empty() {
174        w.writeln("let payload = Data()").unwrap();
175        return;
176    }
177
178    w.writeln("var payloadBytes: [UInt8] = []").unwrap();
179    for arg in args {
180        let arg_name = arg.name.to_lower_camel_case();
181        let encode_expr = generate_encode_expr(arg.shape, &arg_name);
182        cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
183    }
184    w.writeln("let payload = Data(payloadBytes)").unwrap();
185}
186
187/// Generate client body for channeled methods.
188fn generate_streaming_client_body(
189    w: &mut CodeWriter<&mut String>,
190    method: &MethodDescriptor,
191    service_name: &str,
192    method_id_name: &str,
193    cursor_var: &str,
194) {
195    let service_name_lower = service_name.to_lower_camel_case();
196
197    // Bind channels
198    let arg_names: Vec<String> = method
199        .args
200        .iter()
201        .map(|a| a.name.to_lower_camel_case())
202        .collect();
203
204    w.writeln("// Bind channels using schema").unwrap();
205    w.writeln("await bindChannels(").unwrap();
206    {
207        let _indent = w.indent();
208        cw_writeln!(
209            w,
210            "schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args,"
211        )
212        .unwrap();
213        cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
214        w.writeln("allocator: connection.channelAllocator,")
215            .unwrap();
216        w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
217            .unwrap();
218        w.writeln("taskSender: connection.taskSender,").unwrap();
219        cw_writeln!(w, "serializers: {service_name}Serializers()").unwrap();
220    }
221    w.writeln(")").unwrap();
222    w.blank_line().unwrap();
223
224    // Encode payload as the full argument tuple.
225    // Channel IDs are still included here for payload shape fidelity, and also sent
226    // in Request.channels for schema-driven discovery.
227    w.writeln("// Encode payload with channel IDs").unwrap();
228    w.writeln("var payloadBytes: [UInt8] = []").unwrap();
229    for arg in method.args {
230        let arg_name = arg.name.to_lower_camel_case();
231        if is_tx(arg.shape) || is_rx(arg.shape) {
232            cw_writeln!(w, "payloadBytes += encodeVarint({arg_name}.channelId)").unwrap();
233        } else {
234            let encode_expr = generate_encode_expr(arg.shape, &arg_name);
235            cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
236        }
237    }
238    w.writeln("let payload = Data(payloadBytes)").unwrap();
239    cw_writeln!(
240        w,
241        "let channels = collectChannelIds(schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args, args: [{}])",
242        arg_names.join(", ")
243    )
244    .unwrap();
245    w.blank_line().unwrap();
246
247    // Make the call
248    let ret_type = swift_type_client_return(method.return_shape);
249    let method_id = crate::method_id(method);
250    let _ = ret_type;
251    cw_writeln!(
252        w,
253        "let response = try await connection.call(methodId: {}, payload: payload, channels: channels, timeout: timeout)",
254        hex_u64(method_id)
255    )
256    .unwrap();
257    generate_response_decode(w, method, cursor_var, "response");
258}
259
260fn unique_decode_cursor_name(args: &[roam_types::ArgDescriptor]) -> String {
261    let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
262    let mut candidate = String::from("cursor");
263    while arg_names.iter().any(|name| name == &candidate) {
264        candidate.push('_');
265    }
266    candidate
267}
268
269/// Generate code to decode the full wire response payload:
270/// `Result<T, RoamError<E>>`.
271fn generate_response_decode(
272    w: &mut CodeWriter<&mut String>,
273    method: &MethodDescriptor,
274    cursor_var: &str,
275    response_var: &str,
276) {
277    let ret_type = swift_type_client_return(method.return_shape);
278    let result_disc_var = format!("_{cursor_var}_resultDisc");
279    let error_code_var = format!("_{cursor_var}_errorCode");
280    let is_fallible = matches!(
281        classify_shape(method.return_shape),
282        ShapeKind::Result { .. }
283    );
284
285    cw_writeln!(w, "var {cursor_var} = 0").unwrap();
286    cw_writeln!(
287        w,
288        "let {result_disc_var} = try decodeVarint(from: {response_var}, offset: &{cursor_var})"
289    )
290    .unwrap();
291    cw_writeln!(w, "switch {result_disc_var} {{").unwrap();
292
293    w.writeln("case 0:").unwrap();
294    {
295        let _indent = w.indent();
296        if is_fallible {
297            let ShapeKind::Result { ok, .. } = classify_shape(method.return_shape) else {
298                unreachable!()
299            };
300            let decode_ok =
301                generate_decode_stmt_from_with_cursor(ok, "value", "", response_var, cursor_var);
302            for line in decode_ok.lines() {
303                w.writeln(line).unwrap();
304            }
305            w.writeln("return .success(value)").unwrap();
306        } else if ret_type == "Void" {
307            w.writeln("return").unwrap();
308        } else {
309            let decode_stmt = generate_decode_stmt_from_with_cursor(
310                method.return_shape,
311                "result",
312                "",
313                response_var,
314                cursor_var,
315            );
316            for line in decode_stmt.lines() {
317                w.writeln(line).unwrap();
318            }
319            w.writeln("return result").unwrap();
320        }
321    }
322
323    w.writeln("case 1:").unwrap();
324    {
325        let _indent = w.indent();
326        cw_writeln!(
327            w,
328            "let {error_code_var} = try decodeU8(from: {response_var}, offset: &{cursor_var})"
329        )
330        .unwrap();
331        cw_writeln!(w, "switch {error_code_var} {{").unwrap();
332
333        w.writeln("case 0:").unwrap();
334        {
335            let _indent = w.indent();
336            if is_fallible {
337                let ShapeKind::Result { err, .. } = classify_shape(method.return_shape) else {
338                    unreachable!()
339                };
340                let decode_err = generate_decode_stmt_from_with_cursor(
341                    err,
342                    "userError",
343                    "",
344                    response_var,
345                    cursor_var,
346                );
347                for line in decode_err.lines() {
348                    w.writeln(line).unwrap();
349                }
350                w.writeln("return .failure(userError)").unwrap();
351            } else {
352                w.writeln(
353                    "throw RoamError.decodeError(\"unexpected user error for infallible method\")",
354                )
355                .unwrap();
356            }
357        }
358        w.writeln("case 1:").unwrap();
359        w.writeln("    throw RoamError.unknownMethod").unwrap();
360        w.writeln("case 2:").unwrap();
361        w.writeln("    throw RoamError.decodeError(\"invalid payload\")")
362            .unwrap();
363        w.writeln("case 3:").unwrap();
364        w.writeln("    throw RoamError.cancelled").unwrap();
365        w.writeln("default:").unwrap();
366        cw_writeln!(
367            w,
368            "    throw RoamError.decodeError(\"invalid RoamError discriminant: \\({error_code_var})\")"
369        )
370        .unwrap();
371        w.writeln("}").unwrap();
372    }
373
374    w.writeln("default:").unwrap();
375    cw_writeln!(
376        w,
377        "    throw RoamError.decodeError(\"invalid Result discriminant: \\({result_disc_var})\")"
378    )
379    .unwrap();
380    w.writeln("}").unwrap();
381}