roam_codegen/targets/swift/
server.rs1use facet_core::Shape;
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_types::{MethodDescriptor, ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
8
9use super::decode::{generate_decode_stmt_with_cursor, generate_inline_decode};
10use super::encode::generate_encode_closure;
11use super::types::{format_doc, is_channel, swift_type_server_arg, swift_type_server_return};
12use crate::code_writer::CodeWriter;
13use crate::cw_writeln;
14use crate::render::hex_u64;
15
16fn dispatch_helper_name(method_name: &str) -> String {
17 format!("dispatch_{method_name}")
18}
19
20fn extract_initial_credit(shape: &'static Shape) -> u32 {
21 shape
22 .const_params
23 .iter()
24 .find(|cp| cp.name == "N")
25 .map(|cp| cp.value as u32)
26 .unwrap_or(16)
27}
28
29pub fn generate_server(service: &ServiceDescriptor) -> String {
31 let mut out = String::new();
32 out.push_str(&generate_handler_protocol(service));
33 out.push_str(&generate_channeling_dispatcher(service));
35 out
36}
37
38fn generate_handler_protocol(service: &ServiceDescriptor) -> String {
40 let mut out = String::new();
41 let service_name = service.service_name.to_upper_camel_case();
42
43 if let Some(doc) = &service.doc {
44 out.push_str(&format_doc(doc, ""));
45 }
46 out.push_str(&format!("public protocol {service_name}Handler {{\n"));
47
48 for method in service.methods {
49 let method_name = method.method_name.to_lower_camel_case();
50
51 if let Some(doc) = &method.doc {
52 out.push_str(&format_doc(doc, " "));
53 }
54
55 let args: Vec<String> = method
57 .args
58 .iter()
59 .map(|a| {
60 format!(
61 "{}: {}",
62 a.name.to_lower_camel_case(),
63 swift_type_server_arg(a.shape)
64 )
65 })
66 .collect();
67
68 let ret_type = swift_type_server_return(method.return_shape);
69
70 if ret_type == "Void" {
71 out.push_str(&format!(
72 " func {method_name}({}) async throws\n",
73 args.join(", ")
74 ));
75 } else {
76 out.push_str(&format!(
77 " func {method_name}({}) async throws -> {ret_type}\n",
78 args.join(", ")
79 ));
80 }
81 }
82
83 out.push_str("}\n\n");
84 out
85}
86
87fn generate_channeling_dispatcher(service: &ServiceDescriptor) -> String {
89 let mut out = String::new();
90 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
91 let service_name = service.service_name.to_upper_camel_case();
92
93 cw_writeln!(
94 w,
95 "public final class {service_name}ChannelingDispatcher {{"
96 )
97 .unwrap();
98 {
99 let _indent = w.indent();
100 cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
101 w.writeln("private let registry: IncomingChannelRegistry")
102 .unwrap();
103 w.writeln("private let taskSender: TaskSender").unwrap();
104 w.blank_line().unwrap();
105
106 cw_writeln!(
107 w,
108 "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) {{"
109 )
110 .unwrap();
111 {
112 let _indent = w.indent();
113 w.writeln("self.handler = handler").unwrap();
114 w.writeln("self.registry = registry").unwrap();
115 w.writeln("self.taskSender = taskSender").unwrap();
116 }
117 w.writeln("}").unwrap();
118 w.blank_line().unwrap();
119
120 w.writeln(
122 "public func dispatch(methodId: UInt64, requestId: UInt64, channels: [UInt64], payload: Data) async {",
123 )
124 .unwrap();
125 {
126 let _indent = w.indent();
127 w.writeln("switch methodId {").unwrap();
128 for method in service.methods {
129 let method_name = method.method_name.to_lower_camel_case();
130 let method_id = crate::method_id(method);
131 let dispatch_name = dispatch_helper_name(&method_name);
132 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
133 cw_writeln!(
134 w,
135 " await {dispatch_name}(requestId: requestId, channels: channels, payload: payload)"
136 )
137 .unwrap();
138 }
139 w.writeln("default:").unwrap();
140 w.writeln(
141 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
142 )
143 .unwrap();
144 w.writeln("}").unwrap();
145 }
146 w.writeln("}").unwrap();
147 w.blank_line().unwrap();
148
149 generate_preregister_channels(&mut w, service);
151 w.blank_line().unwrap();
152
153 for method in service.methods {
155 generate_channeling_dispatch_method(&mut w, method);
156 w.blank_line().unwrap();
157 }
158 }
159 w.writeln("}").unwrap();
160 w.blank_line().unwrap();
161
162 out
163}
164
165fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDescriptor) {
167 w.writeln("/// Pre-register Rx channel IDs from request channels.")
168 .unwrap();
169 w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
170 .unwrap();
171 w.writeln("/// race conditions where Data arrives before channels are registered.")
172 .unwrap();
173 w.writeln("public static func preregisterChannels(methodId: UInt64, channels: [UInt64], registry: ChannelRegistry) async {")
174 .unwrap();
175 {
176 let _indent = w.indent();
177 w.writeln("switch methodId {").unwrap();
178
179 for method in service.methods {
180 let method_id = crate::method_id(method);
181 let has_rx_args = method.args.iter().any(|a| is_rx(a.shape));
182
183 if has_rx_args {
184 let channel_arg_count = method
185 .args
186 .iter()
187 .filter(|a| is_rx(a.shape) || is_tx(a.shape))
188 .count();
189 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
190 cw_writeln!(w, " guard channels.count >= {channel_arg_count} else {{").unwrap();
191 w.writeln(" return").unwrap();
192 w.writeln(" }").unwrap();
193 w.writeln(" var channelCursor = 0").unwrap();
194
195 for arg in method.args {
197 let arg_name = arg.name.to_lower_camel_case();
198 if is_rx(arg.shape) {
199 cw_writeln!(w, " let {arg_name}ChannelId = channels[channelCursor]")
201 .unwrap();
202 w.writeln(" channelCursor += 1").unwrap();
203 cw_writeln!(w, " await registry.markKnown({arg_name}ChannelId)")
204 .unwrap();
205 } else if is_tx(arg.shape) {
206 cw_writeln!(w, " _ = channels[channelCursor] // {arg_name}").unwrap();
207 w.writeln(" channelCursor += 1").unwrap();
208 }
209 }
210 }
211 }
212
213 w.writeln("default:").unwrap();
214 w.writeln(" break").unwrap();
215 w.writeln("}").unwrap();
216 }
217 w.writeln("}").unwrap();
218}
219
220fn generate_channeling_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDescriptor) {
222 let method_name = method.method_name.to_lower_camel_case();
223 let dispatch_name = dispatch_helper_name(&method_name);
224 let has_channeling = method.args.iter().any(|a| is_channel(a.shape));
225
226 cw_writeln!(
227 w,
228 "private func {dispatch_name}(requestId: UInt64, channels: [UInt64], payload: Data) async {{"
229 )
230 .unwrap();
231 {
232 let _indent = w.indent();
233 w.writeln("do {").unwrap();
234 {
235 let _indent = w.indent();
236 let has_payload_args = method
237 .args
238 .iter()
239 .any(|a| !is_rx(a.shape) && !is_tx(a.shape));
240 let has_channel_args = method.args.iter().any(|a| is_rx(a.shape) || is_tx(a.shape));
241 let cursor_var = if has_payload_args {
242 let name = unique_decode_cursor_name(method.args);
243 cw_writeln!(w, "var {name} = 0").unwrap();
244 Some(name)
245 } else {
246 None
247 };
248 if has_channel_args {
249 w.writeln("var channelCursor = 0").unwrap();
250 }
251
252 for arg in method.args {
254 let arg_name = arg.name.to_lower_camel_case();
255 generate_channeling_decode_arg(
256 w,
257 &arg_name,
258 arg.shape,
259 cursor_var.as_deref(),
260 "channels",
261 Some("channelCursor"),
262 );
263 }
264
265 let arg_names: Vec<String> = method
267 .args
268 .iter()
269 .map(|a| {
270 let name = a.name.to_lower_camel_case();
271 format!("{name}: {name}")
272 })
273 .collect();
274
275 let ret_type = swift_type_server_return(method.return_shape);
276
277 if has_channeling {
278 if ret_type == "Void" {
280 cw_writeln!(
281 w,
282 "try await handler.{method_name}({})",
283 arg_names.join(", ")
284 )
285 .unwrap();
286 } else {
287 cw_writeln!(
288 w,
289 "let result = try await handler.{method_name}({})",
290 arg_names.join(", ")
291 )
292 .unwrap();
293 }
294
295 for arg in method.args {
297 if is_tx(arg.shape) {
298 let arg_name = arg.name.to_lower_camel_case();
299 cw_writeln!(w, "{arg_name}.close()").unwrap();
300 }
301 }
302
303 if ret_type == "Void" {
305 w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
306 } else {
307 let encode_closure = generate_encode_closure(method.return_shape);
308 cw_writeln!(
309 w,
310 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
311 )
312 .unwrap();
313 }
314 } else {
315 if ret_type == "Void" {
317 cw_writeln!(
318 w,
319 "try await handler.{method_name}({})",
320 arg_names.join(", ")
321 )
322 .unwrap();
323 w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
324 } else {
325 cw_writeln!(
326 w,
327 "let result = try await handler.{method_name}({})",
328 arg_names.join(", ")
329 )
330 .unwrap();
331 if let ShapeKind::Result { ok, err } = classify_shape(method.return_shape) {
333 let ok_encode = generate_encode_closure(ok);
334 let err_encode = generate_encode_closure(err);
335 cw_writeln!(
337 w,
338 "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}()))"
339 )
340 .unwrap();
341 } else {
342 let encode_closure = generate_encode_closure(method.return_shape);
343 cw_writeln!(
344 w,
345 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
346 )
347 .unwrap();
348 }
349 }
350 }
351 }
352 w.writeln("} catch {").unwrap();
353 {
354 let _indent = w.indent();
355 w.writeln(
356 "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError()))",
357 )
358 .unwrap();
359 }
360 w.writeln("}").unwrap();
361 }
362 w.writeln("}").unwrap();
363}
364
365fn generate_channeling_decode_arg(
367 w: &mut CodeWriter<&mut String>,
368 name: &str,
369 shape: &'static Shape,
370 cursor_var: Option<&str>,
371 channels_var: &str,
372 channel_cursor_var: Option<&str>,
373) {
374 match classify_shape(shape) {
375 ShapeKind::Rx { inner } => {
376 let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
379 let channel_cursor_var =
380 channel_cursor_var.expect("channel cursor required for channeling args");
381 cw_writeln!(
382 w,
383 "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
384 )
385 .unwrap();
386 cw_writeln!(
387 w,
388 "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
389 )
390 .unwrap();
391 cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
392 cw_writeln!(
393 w,
394 "let {name}Receiver = await registry.register({name}ChannelId, initialCredit: {}, onConsumed: {{ [taskSender = self.taskSender] additional in taskSender(.grantCredit(channelId: {name}ChannelId, bytes: additional)) }})",
395 extract_initial_credit(shape)
396 )
397 .unwrap();
398 cw_writeln!(
399 w,
400 "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
401 )
402 .unwrap();
403 cw_writeln!(w, " var off = 0").unwrap();
404 cw_writeln!(w, " return try {inline_decode}").unwrap();
405 w.writeln("})").unwrap();
406 }
407 ShapeKind::Tx { inner } => {
408 let encode_closure = generate_encode_closure(inner);
411 let channel_cursor_var =
412 channel_cursor_var.expect("channel cursor required for channeling args");
413 cw_writeln!(
414 w,
415 "guard {channel_cursor_var} < {channels_var}.count else {{ throw RoamError.decodeError(\"missing channel id for {name}\") }}"
416 )
417 .unwrap();
418 cw_writeln!(
419 w,
420 "let {name}ChannelId = {channels_var}[{channel_cursor_var}]"
421 )
422 .unwrap();
423 cw_writeln!(w, "{channel_cursor_var} += 1").unwrap();
424 cw_writeln!(
425 w,
426 "let {name} = await createServerTx(channelId: {name}ChannelId, taskSender: taskSender, registry: registry, initialCredit: {}, serialize: ({encode_closure}))",
427 extract_initial_credit(shape)
428 )
429 .unwrap();
430 }
431 _ => {
432 let cursor_var = cursor_var.expect("payload cursor required for non-channel args");
434 let decode_stmt = generate_decode_stmt_with_cursor(shape, name, "", cursor_var);
435 for line in decode_stmt.lines() {
436 w.writeln(line).unwrap();
437 }
438 }
439 }
440}
441
442fn unique_decode_cursor_name(args: &[roam_types::ArgDescriptor]) -> String {
443 let arg_names: Vec<String> = args.iter().map(|a| a.name.to_lower_camel_case()).collect();
444 let mut candidate = String::from("cursor");
445 while arg_names.iter().any(|name| name == &candidate) {
446 candidate.push('_');
447 }
448 candidate
449}