roam_codegen/targets/swift/
client.rs1use heck::{ToLowerCamelCase, ToUpperCamelCase};
6use roam_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
7
8use super::decode::generate_decode_stmt_from_with_cursor;
9use super::encode::generate_encode_expr;
10use super::types::{format_doc, is_channel, swift_type_client_arg, swift_type_client_return};
11use crate::code_writer::CodeWriter;
12use crate::cw_writeln;
13use crate::render::hex_u64;
14
15pub fn generate_client(service: &ServiceDescriptor) -> String {
17 let mut out = String::new();
18 out.push_str(&generate_caller_protocol(service));
19 out.push_str(&generate_client_impl(service));
20 out
21}
22
23fn generate_caller_protocol(service: &ServiceDescriptor) -> String {
25 let mut out = String::new();
26 let service_name = service.service_name.to_upper_camel_case();
27
28 if let Some(doc) = &service.doc {
29 out.push_str(&format_doc(doc, ""));
30 }
31 out.push_str(&format!("public protocol {service_name}Caller {{\n"));
32
33 for method in service.methods {
34 let method_name = method.method_name.to_lower_camel_case();
35
36 if let Some(doc) = &method.doc {
37 out.push_str(&format_doc(doc, " "));
38 }
39
40 let args: Vec<String> = method
41 .args
42 .iter()
43 .map(|a| {
44 format!(
45 "{}: {}",
46 a.name.to_lower_camel_case(),
47 swift_type_client_arg(a.shape)
48 )
49 })
50 .collect();
51
52 let ret_type = swift_type_client_return(method.return_shape);
53
54 if ret_type == "Void" {
55 out.push_str(&format!(
56 " func {method_name}({}) async throws\n",
57 args.join(", ")
58 ));
59 } else {
60 out.push_str(&format!(
61 " func {method_name}({}) async throws -> {ret_type}\n",
62 args.join(", ")
63 ));
64 }
65 }
66
67 out.push_str("}\n\n");
68 out
69}
70
71fn generate_client_impl(service: &ServiceDescriptor) -> String {
73 let mut out = String::new();
74 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
75 let service_name = service.service_name.to_upper_camel_case();
76
77 w.writeln(&format!(
78 "public final class {service_name}Client: {service_name}Caller, Sendable {{"
79 ))
80 .unwrap();
81 {
82 let _indent = w.indent();
83 w.writeln("private let connection: RoamConnection").unwrap();
84 w.writeln("private let timeout: TimeInterval?").unwrap();
85 w.blank_line().unwrap();
86 w.writeln("public init(connection: RoamConnection, timeout: TimeInterval? = 30.0) {")
87 .unwrap();
88 {
89 let _indent = w.indent();
90 w.writeln("self.connection = connection").unwrap();
91 w.writeln("self.timeout = timeout").unwrap();
92 }
93 w.writeln("}").unwrap();
94
95 for method in service.methods {
96 w.blank_line().unwrap();
97 generate_client_method(&mut w, method, &service_name);
98 }
99 }
100 w.writeln("}").unwrap();
101 w.blank_line().unwrap();
102
103 out
104}
105
106fn generate_client_method(
108 w: &mut CodeWriter<&mut String>,
109 method: &MethodDescriptor,
110 service_name: &str,
111) {
112 let method_name = method.method_name.to_lower_camel_case();
113 let method_id_name = method.method_name.to_lower_camel_case();
114
115 let args: Vec<String> = method
116 .args
117 .iter()
118 .map(|a| {
119 format!(
120 "{}: {}",
121 a.name.to_lower_camel_case(),
122 swift_type_client_arg(a.shape)
123 )
124 })
125 .collect();
126
127 let ret_type = swift_type_client_return(method.return_shape);
128 let has_streaming = method.args.iter().any(|a| is_channel(a.shape));
129
130 if ret_type == "Void" {
132 cw_writeln!(
133 w,
134 "public func {method_name}({}) async throws {{",
135 args.join(", ")
136 )
137 .unwrap();
138 } else {
139 cw_writeln!(
140 w,
141 "public func {method_name}({}) async throws -> {ret_type} {{",
142 args.join(", ")
143 )
144 .unwrap();
145 }
146
147 {
148 let _indent = w.indent();
149 let cursor_var = unique_decode_cursor_name(method.args);
150
151 if has_streaming {
152 generate_streaming_client_body(w, method, service_name, &method_id_name, &cursor_var);
153 } else {
154 generate_encode_args(w, method.args);
156
157 let method_id = crate::method_id(method);
159 cw_writeln!(
160 w,
161 "let response = try await connection.call(methodId: {}, payload: payload, timeout: timeout)",
162 hex_u64(method_id)
163 )
164 .unwrap();
165 generate_response_decode(w, method, &cursor_var, "response");
166 }
167 }
168 w.writeln("}").unwrap();
169}
170
171fn generate_encode_args(w: &mut CodeWriter<&mut String>, args: &[roam_types::ArgDescriptor]) {
173 if args.is_empty() {
174 w.writeln("let payload = Data()").unwrap();
175 return;
176 }
177
178 w.writeln("var payloadBytes: [UInt8] = []").unwrap();
179 for arg in args {
180 let arg_name = arg.name.to_lower_camel_case();
181 let encode_expr = generate_encode_expr(arg.shape, &arg_name);
182 cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
183 }
184 w.writeln("let payload = Data(payloadBytes)").unwrap();
185}
186
187fn generate_streaming_client_body(
189 w: &mut CodeWriter<&mut String>,
190 method: &MethodDescriptor,
191 service_name: &str,
192 method_id_name: &str,
193 cursor_var: &str,
194) {
195 let service_name_lower = service_name.to_lower_camel_case();
196
197 let arg_names: Vec<String> = method
199 .args
200 .iter()
201 .map(|a| a.name.to_lower_camel_case())
202 .collect();
203
204 w.writeln("// Bind channels using schema").unwrap();
205 w.writeln("await bindChannels(").unwrap();
206 {
207 let _indent = w.indent();
208 cw_writeln!(
209 w,
210 "schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args,"
211 )
212 .unwrap();
213 cw_writeln!(w, "args: [{}],", arg_names.join(", ")).unwrap();
214 w.writeln("allocator: connection.channelAllocator,")
215 .unwrap();
216 w.writeln("incomingRegistry: connection.incomingChannelRegistry,")
217 .unwrap();
218 w.writeln("taskSender: connection.taskSender,").unwrap();
219 cw_writeln!(w, "serializers: {service_name}Serializers()").unwrap();
220 }
221 w.writeln(")").unwrap();
222 w.blank_line().unwrap();
223
224 w.writeln("// Encode payload with channel IDs").unwrap();
228 w.writeln("var payloadBytes: [UInt8] = []").unwrap();
229 for arg in method.args {
230 let arg_name = arg.name.to_lower_camel_case();
231 if is_tx(arg.shape) || is_rx(arg.shape) {
232 cw_writeln!(w, "payloadBytes += encodeVarint({arg_name}.channelId)").unwrap();
233 } else {
234 let encode_expr = generate_encode_expr(arg.shape, &arg_name);
235 cw_writeln!(w, "payloadBytes += {encode_expr}").unwrap();
236 }
237 }
238 w.writeln("let payload = Data(payloadBytes)").unwrap();
239 cw_writeln!(
240 w,
241 "let channels = collectChannelIds(schemas: {service_name_lower}_schemas[\"{method_id_name}\"]!.args, args: [{}])",
242 arg_names.join(", ")
243 )
244 .unwrap();
245 w.blank_line().unwrap();
246
247 let ret_type = swift_type_client_return(method.return_shape);
249 let method_id = crate::method_id(method);
250 let _ = ret_type;
251 cw_writeln!(
252 w,
253 "let response = try await connection.call(methodId: {}, payload: payload, channels: channels, timeout: timeout)",
254 hex_u64(method_id)
255 )
256 .unwrap();
257 generate_response_decode(w, method, cursor_var, "response");
258}
259
260fn unique_decode_cursor_name(args: &[roam_types::ArgDescriptor]) -> String {
261 let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
262 let mut candidate = String::from("cursor");
263 while arg_names.iter().any(|name| name == &candidate) {
264 candidate.push('_');
265 }
266 candidate
267}
268
269fn generate_response_decode(
272 w: &mut CodeWriter<&mut String>,
273 method: &MethodDescriptor,
274 cursor_var: &str,
275 response_var: &str,
276) {
277 let ret_type = swift_type_client_return(method.return_shape);
278 let result_disc_var = format!("_{cursor_var}_resultDisc");
279 let error_code_var = format!("_{cursor_var}_errorCode");
280 let is_fallible = matches!(
281 classify_shape(method.return_shape),
282 ShapeKind::Result { .. }
283 );
284
285 cw_writeln!(w, "var {cursor_var} = 0").unwrap();
286 cw_writeln!(
287 w,
288 "let {result_disc_var} = try decodeVarint(from: {response_var}, offset: &{cursor_var})"
289 )
290 .unwrap();
291 cw_writeln!(w, "switch {result_disc_var} {{").unwrap();
292
293 w.writeln("case 0:").unwrap();
294 {
295 let _indent = w.indent();
296 if is_fallible {
297 let ShapeKind::Result { ok, .. } = classify_shape(method.return_shape) else {
298 unreachable!()
299 };
300 let decode_ok =
301 generate_decode_stmt_from_with_cursor(ok, "value", "", response_var, cursor_var);
302 for line in decode_ok.lines() {
303 w.writeln(line).unwrap();
304 }
305 w.writeln("return .success(value)").unwrap();
306 } else if ret_type == "Void" {
307 w.writeln("return").unwrap();
308 } else {
309 let decode_stmt = generate_decode_stmt_from_with_cursor(
310 method.return_shape,
311 "result",
312 "",
313 response_var,
314 cursor_var,
315 );
316 for line in decode_stmt.lines() {
317 w.writeln(line).unwrap();
318 }
319 w.writeln("return result").unwrap();
320 }
321 }
322
323 w.writeln("case 1:").unwrap();
324 {
325 let _indent = w.indent();
326 cw_writeln!(
327 w,
328 "let {error_code_var} = try decodeU8(from: {response_var}, offset: &{cursor_var})"
329 )
330 .unwrap();
331 cw_writeln!(w, "switch {error_code_var} {{").unwrap();
332
333 w.writeln("case 0:").unwrap();
334 {
335 let _indent = w.indent();
336 if is_fallible {
337 let ShapeKind::Result { err, .. } = classify_shape(method.return_shape) else {
338 unreachable!()
339 };
340 let decode_err = generate_decode_stmt_from_with_cursor(
341 err,
342 "userError",
343 "",
344 response_var,
345 cursor_var,
346 );
347 for line in decode_err.lines() {
348 w.writeln(line).unwrap();
349 }
350 w.writeln("return .failure(userError)").unwrap();
351 } else {
352 w.writeln(
353 "throw RoamError.decodeError(\"unexpected user error for infallible method\")",
354 )
355 .unwrap();
356 }
357 }
358 w.writeln("case 1:").unwrap();
359 w.writeln(" throw RoamError.unknownMethod").unwrap();
360 w.writeln("case 2:").unwrap();
361 w.writeln(" throw RoamError.decodeError(\"invalid payload\")")
362 .unwrap();
363 w.writeln("case 3:").unwrap();
364 w.writeln(" throw RoamError.cancelled").unwrap();
365 w.writeln("default:").unwrap();
366 cw_writeln!(
367 w,
368 " throw RoamError.decodeError(\"invalid RoamError discriminant: \\({error_code_var})\")"
369 )
370 .unwrap();
371 w.writeln("}").unwrap();
372 }
373
374 w.writeln("default:").unwrap();
375 cw_writeln!(
376 w,
377 " throw RoamError.decodeError(\"invalid Result discriminant: \\({result_disc_var})\")"
378 )
379 .unwrap();
380 w.writeln("}").unwrap();
381}