roam_codegen/targets/swift/
server.rs1use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::{generate_decode_stmt_with_cursor, generate_inline_decode};
10use super::encode::generate_encode_closure;
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn dispatch_helper_name(method_name: &str) -> String {
17 format!("dispatch_{method_name}")
18}
19
20pub fn generate_server(service: &ServiceDescriptor) -> String {
22 let mut out = String::new();
23 out.push_str(&generate_handler_protocol(service));
24 out.push_str(&generate_channeling_dispatcher(service));
26 out
27}
28
29fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
31 let mut out = String::new();
32 let service_name = service.service_name.to_upper_camel_case();
33
34 if let Some(doc) = &service.doc {
35 out.push_str(&format_doc(doc, ""));
36 }
37 out.push_str(&format!("public protocol {service_name}Handler {{\n"));
38
39 for method in service.methods {
40 let method_name = method.method_name.to_lower_camel_case();
41
42 if let Some(doc) = &method.doc {
43 out.push_str(&format_doc(doc, " "));
44 }
45
46 let args: Vec<String> = method
48 .args
49 .iter()
50 .map(|a| {
51 format!(
52 "{}: {}",
53 a.name.to_lower_camel_case(),
54 swift_type_server_arg(a.shape)
55 )
56 })
57 .collect();
58
59 let ret_type = swift_type_server_return(method.return_shape);
60
61 if ret_type == "Void" {
62 out.push_str(&format!(
63 " func {method_name}({}) async throws\n",
64 args.join(", ")
65 ));
66 } else {
67 out.push_str(&format!(
68 " func {method_name}({}) async throws -> {ret_type}\n",
69 args.join(", ")
70 ));
71 }
72 }
73
74 out.push_str("}\n\n");
75 out
76}
77
78fn generate_channeling_dispatcher(service: &ServiceDescriptor) -> String {
80 let mut out = String::new();
81 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
82 let service_name = service.service_name.to_upper_camel_case();
83
84 cw_writeln!(
85 w,
86 "public final class {service_name}ChannelingDispatcher {{"
87 )
88 .unwrap();
89 {
90 let _indent = w.indent();
91 cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
92 w.writeln("private let registry: IncomingChannelRegistry")
93 .unwrap();
94 w.writeln("private let taskSender: TaskSender").unwrap();
95 w.blank_line().unwrap();
96
97 cw_writeln!(
98 w,
99 "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) {{"
100 )
101 .unwrap();
102 {
103 let _indent = w.indent();
104 w.writeln("self.handler = handler").unwrap();
105 w.writeln("self.registry = registry").unwrap();
106 w.writeln("self.taskSender = taskSender").unwrap();
107 }
108 w.writeln("}").unwrap();
109 w.blank_line().unwrap();
110
111 w.writeln(
113 "public func dispatch(methodId: UInt64, requestId: UInt64, channels: [UInt64], payload: Data) async {",
114 )
115 .unwrap();
116 {
117 let _indent = w.indent();
118 w.writeln("switch methodId {").unwrap();
119 for method in service.methods {
120 let method_name = method.method_name.to_lower_camel_case();
121 let method_id = crate::method_id(method);
122 let dispatch_name = dispatch_helper_name(&method_name);
123 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
124 cw_writeln!(
125 w,
126 " await {dispatch_name}(requestId: requestId, channels: channels, payload: payload)"
127 )
128 .unwrap();
129 }
130 w.writeln("default:").unwrap();
131 w.writeln(
132 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
133 )
134 .unwrap();
135 w.writeln("}").unwrap();
136 }
137 w.writeln("}").unwrap();
138 w.blank_line().unwrap();
139
140 generate_preregister_channels(&mut w, service);
142 w.blank_line().unwrap();
143
144 for method in service.methods {
146 generate_channeling_dispatch_method(&mut w, method);
147 w.blank_line().unwrap();
148 }
149 }
150 w.writeln("}").unwrap();
151 w.blank_line().unwrap();
152
153 out
154}
155
156fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
158 w.writeln("/// Pre-register Rx channel IDs from request channels.")
159 .unwrap();
160 w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
161 .unwrap();
162 w.writeln("/// race conditions where Data arrives before channels are registered.")
163 .unwrap();
164 w.writeln("public static func preregisterChannels(methodId: UInt64, channels: [UInt64], registry: ChannelRegistry) async {")
165 .unwrap();
166 {
167 let _indent = w.indent();
168 w.writeln("switch methodId {").unwrap();
169
170 for method in service.methods {
171 let method_id = crate::method_id(method);
172 let has_rx_args = method.args.iter().any(|a| is_rx(a.shape));
173
174 if has_rx_args {
175 let channel_arg_count = method
176 .args
177 .iter()
178 .filter(|a| is_rx(a.shape) || is_tx(a.shape))
179 .count();
180 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
181 cw_writeln!(w, " guard channels.count >= {channel_arg_count} else {{").unwrap();
182 w.writeln(" return").unwrap();
183 w.writeln(" }").unwrap();
184 w.writeln(" var channelCursor = 0").unwrap();
185
186 for arg in method.args {
188 let arg_name = arg.name.to_lower_camel_case();
189 if is_rx(arg.shape) {
190 cw_writeln!(w, " let {arg_name}ChannelId = channels[channelCursor]")
192 .unwrap();
193 w.writeln(" channelCursor += 1").unwrap();
194 cw_writeln!(w, " await registry.markKnown({arg_name}ChannelId)")
195 .unwrap();
196 } else if is_tx(arg.shape) {
197 cw_writeln!(w, " _ = channels[channelCursor] // {arg_name}").unwrap();
198 w.writeln(" channelCursor += 1").unwrap();
199 }
200 }
201 }
202 }
203
204 w.writeln("default:").unwrap();
205 w.writeln(" break").unwrap();
206 w.writeln("}").unwrap();
207 }
208 w.writeln("}").unwrap();
209}
210
211fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
213 let method_name = method.method_name.to_lower_camel_case();
214 let dispatch_name = dispatch_helper_name(&method_name);
215 let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
216
217 cw_writeln!(
218 w,
219 "private func {dispatch_name}(requestId: UInt64, channels: [UInt64], payload: Data) async {{"
220 )
221 .unwrap();
222 {
223 let _indent = w.indent();
224 w.writeln("do {").unwrap();
225 {
226 let _indent = w.indent();
227 let has_payload_args = method
228 .args
229 .iter()
230 .any(|a| !is_rx(a.shape) && !is_tx(a.shape));
231 let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
232 let cursor_var = if has_payload_args {
233 let name = unique_decode_cursor_name(method.args);
234 cw_writeln!(w, "var {name} = 0").unwrap();
235 Some(name)
236 } else {
237 None
238 };
239 if has_channel_args {
240 w.writeln("var channelCursor = 0").unwrap();
241 }
242
243 for arg in method.args {
245 let arg_name = arg.name.to_lower_camel_case();
246 generate_channeling_decode_arg(
247 w,
248 &arg_name,
249 arg.shape,
250 cursor_var.as_deref(),
251 "channels",
252 Some("channelCursor"),
253 );
254 }
255
256 let arg_names: Vec<String> = method
258 .args
259 .iter()
260 .map(|a| {
261 let name = a.name.to_lower_camel_case();
262 format!("{name}: {name}")
263 })
264 .collect();
265
266 let ret_type = swift_type_server_return(method.return_shape);
267
268 if has_channeling {
269 if ret_type == "Void" {
271 cw_writeln!(
272 w,
273 "try await handler.{method_name}({})",
274 arg_names.join(", ")
275 )
276 .unwrap();
277 } else {
278 cw_writeln!(
279 w,
280 "let result = try await handler.{method_name}({})",
281 arg_names.join(", ")
282 )
283 .unwrap();
284 }
285
286 for arg in method.args {
288 if is_tx(arg.shape) {
289 let arg_name = arg.name.to_lower_camel_case();
290 cw_writeln!(w, "{arg_name}.close()").unwrap();
291 }
292 }
293
294 if ret_type == "Void" {
296 w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
297 } else {
298 let encode_closure = generate_encode_closure(method.return_shape);
299 cw_writeln!(
300 w,
301 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
302 )
303 .unwrap();
304 }
305 } else {
306 if ret_type == "Void" {
308 cw_writeln!(
309 w,
310 "try await handler.{method_name}({})",
311 arg_names.join(", ")
312 )
313 .unwrap();
314 w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
315 } else {
316 cw_writeln!(
317 w,
318 "let result = try await handler.{method_name}({})",
319 arg_names.join(", ")
320 )
321 .unwrap();
322 if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
324 let ok_encode = generate_encode_closure(ok);
325 let err_encode = generate_encode_closure(err);
326 cw_writeln!(
328 w,
329 "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}()))"
330 )
331 .unwrap();
332 } else {
333 let encode_closure = generate_encode_closure(method.return_shape);
334 cw_writeln!(
335 w,
336 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
337 )
338 .unwrap();
339 }
340 }
341 }
342 }
343 w.writeln("} catch {").unwrap();
344 {
345 let _indent = w.indent();
346 w.writeln(
347 "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError()))",
348 )
349 .unwrap();
350 }
351 w.writeln("}").unwrap();
352 }
353 w.writeln("}").unwrap();
354}
355
356fn generate_channeling_decode_arg(
358 w: &mut CodeWriter<&mut String>,
359 name: &str,
360 shape: &'static Shape,
361 cursor_var: Option<&str>,
362 channels_var: &str,
363 channel_cursor_var: Option<&str>,
364) {
365 match classify_shape(shape) {
366 ShapeKind::Rx { inner } => {
367 let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
370 let channel_cursor_var =
371 channel_cursor_var.expect("channel cursor required for channeling args");
372 cw_writeln!(
373 w,
374 "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
375 )
376 .unwrap();
377 cw_writeln!(
378 w,
379 "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
380 )
381 .unwrap();
382 cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
383 cw_writeln!(
384 w,
385 "let {name}Receiver = await registry.register({name}ChannelId)"
386 )
387 .unwrap();
388 cw_writeln!(
389 w,
390 "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
391 )
392 .unwrap();
393 cw_writeln!(w, " var off = 0").unwrap();
394 cw_writeln!(w, " return try {inline_decode}").unwrap();
395 w.writeln("})").unwrap();
396 }
397 ShapeKind::Tx { inner } => {
398 let encode_closure = generate_encode_closure(inner);
401 let channel_cursor_var =
402 channel_cursor_var.expect("channel cursor required for channeling args");
403 cw_writeln!(
404 w,
405 "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
406 )
407 .unwrap();
408 cw_writeln!(
409 w,
410 "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
411 )
412 .unwrap();
413 cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
414 cw_writeln!(
415 w,
416 "let {name} = createServerTx(channelId: {name}ChannelId, taskSender: taskSender, serialize: ({encode_closure}))"
417 )
418 .unwrap();
419 }
420 _ => {
421 let cursor_var = cursor_var.expect("payload cursor required for non-channel args");
423 let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", cursor_var);
424 for line in decode_stmt.lines() {
425 w.writeln(line).unwrap();
426 }
427 }
428 }
429}
430
431fn unique_decode_cursor_name(args: &[roam_types::ArgDescriptor]) -> String {
432 let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
433 let mut candidate = String::from("cursor");
434 while arg_names.iter().any(|name| name == &candidate) {
435 candidate.push('_');
436 }
437 candidate
438}