1use facet_core::{ScalarType, Shape};
6use heck::{ToLowerCamelCase, ToUpperCamelCase};
7use roam_schema::{
8 MethodDetail, ServiceDetail, ShapeKind, StructInfo, classify_shape, is_rx, is_tx,
9};
10
11use super::decode::{generate_decode_stmt, generate_inline_decode};
12use super::encode::generate_encode_closure;
13use super::types::{format_doc, is_stream, swift_type_server_arg, swift_type_server_return};
14use crate::code_writer::CodeWriter;
15use crate::cw_writeln;
16use crate::render::hex_u64;
17
18pub fn generate_server(service: &ServiceDetail) -> String {
20 let mut out = String::new();
21 out.push_str(&generate_handler_protocol(service));
22 out.push_str(&generate_dispatcher(service));
23 out.push_str(&generate_streaming_dispatcher(service));
24 out
25}
26
27fn generate_handler_protocol(service: &ServiceDetail) -> String {
29 let mut out = String::new();
30 let service_name = service.name.to_upper_camel_case();
31
32 if let Some(doc) = &service.doc {
33 out.push_str(&format_doc(doc, ""));
34 }
35 out.push_str(&format!("public protocol {service_name}Handler {{\n"));
36
37 for method in &service.methods {
38 let method_name = method.method_name.to_lower_camel_case();
39
40 if let Some(doc) = &method.doc {
41 out.push_str(&format_doc(doc, " "));
42 }
43
44 let args: Vec<String> = method
46 .args
47 .iter()
48 .map(|a| {
49 format!(
50 "{}: {}",
51 a.name.to_lower_camel_case(),
52 swift_type_server_arg(a.ty)
53 )
54 })
55 .collect();
56
57 let ret_type = swift_type_server_return(method.return_type);
58
59 if ret_type == "Void" {
60 out.push_str(&format!(
61 " func {method_name}({}) async throws\n",
62 args.join(", ")
63 ));
64 } else {
65 out.push_str(&format!(
66 " func {method_name}({}) async throws -> {ret_type}\n",
67 args.join(", ")
68 ));
69 }
70 }
71
72 out.push_str("}\n\n");
73 out
74}
75
76fn generate_dispatcher(service: &ServiceDetail) -> String {
78 let mut out = String::new();
79 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
80 let service_name = service.name.to_upper_camel_case();
81
82 cw_writeln!(w, "public final class {service_name}Dispatcher {{").unwrap();
83 {
84 let _indent = w.indent();
85 cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
86 w.blank_line().unwrap();
87 cw_writeln!(w, "public init(handler: {service_name}Handler) {{").unwrap();
88 {
89 let _indent = w.indent();
90 w.writeln("self.handler = handler").unwrap();
91 }
92 w.writeln("}").unwrap();
93 w.blank_line().unwrap();
94
95 w.writeln("public func dispatch(methodId: UInt64, payload: Data) async throws -> Data {")
97 .unwrap();
98 {
99 let _indent = w.indent();
100 w.writeln("switch methodId {").unwrap();
101 for method in &service.methods {
102 let method_name = method.method_name.to_lower_camel_case();
103 let method_id = crate::method_id(method);
104 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
105 cw_writeln!(
106 w,
107 " return try await dispatch{method_name}(payload: payload)"
108 )
109 .unwrap();
110 }
111 w.writeln("default:").unwrap();
112 w.writeln(" throw RoamError.unknownMethod").unwrap();
113 w.writeln("}").unwrap();
114 }
115 w.writeln("}").unwrap();
116
117 for method in &service.methods {
119 w.blank_line().unwrap();
120 generate_dispatch_method(&mut w, method);
121 }
122 }
123 w.writeln("}").unwrap();
124 w.blank_line().unwrap();
125
126 out
127}
128
129fn generate_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDetail) {
131 let method_name = method.method_name.to_lower_camel_case();
132 let has_streaming =
133 method.args.iter().any(|a| is_stream(a.ty)) || is_stream(method.return_type);
134
135 cw_writeln!(
136 w,
137 "private func dispatch{method_name}(payload: Data) async throws -> Data {{"
138 )
139 .unwrap();
140 {
141 let _indent = w.indent();
142
143 if has_streaming {
144 w.writeln("// TODO: Implement streaming dispatch").unwrap();
145 w.writeln("throw RoamError.notImplemented").unwrap();
146 } else {
147 generate_decode_args(w, &method.args);
149
150 let arg_names: Vec<String> = method
152 .args
153 .iter()
154 .map(|a| {
155 let name = a.name.to_lower_camel_case();
156 format!("{name}: {name}")
157 })
158 .collect();
159
160 let ret_type = swift_type_server_return(method.return_type);
161
162 if ret_type == "Void" {
163 cw_writeln!(
164 w,
165 "try await handler.{method_name}({})",
166 arg_names.join(", ")
167 )
168 .unwrap();
169 w.writeln("return Data()").unwrap();
170 } else {
171 cw_writeln!(
172 w,
173 "let result = try await handler.{method_name}({})",
174 arg_names.join(", ")
175 )
176 .unwrap();
177 let encode_closure = generate_encode_closure(method.return_type);
178 cw_writeln!(
179 w,
180 "return Data(encodeResultOk(result, encoder: {encode_closure}))"
181 )
182 .unwrap();
183 }
184 }
185 }
186 w.writeln("}").unwrap();
187}
188
189fn generate_decode_args(w: &mut CodeWriter<&mut String>, args: &[roam_schema::ArgDetail]) {
191 if args.is_empty() {
192 w.writeln("// No arguments to decode").unwrap();
193 return;
194 }
195
196 w.writeln("var offset = 0").unwrap();
197 for arg in args {
198 let arg_name = arg.name.to_lower_camel_case();
199 let decode_stmt = generate_decode_stmt(arg.ty, &arg_name, "");
200 for line in decode_stmt.lines() {
201 w.writeln(line).unwrap();
202 }
203 }
204}
205
206fn generate_streaming_dispatcher(service: &ServiceDetail) -> String {
208 let mut out = String::new();
209 let mut w = CodeWriter::with_indent_spaces(&mut out, 4);
210 let service_name = service.name.to_upper_camel_case();
211
212 cw_writeln!(w, "public final class {service_name}StreamingDispatcher {{").unwrap();
213 {
214 let _indent = w.indent();
215 cw_writeln!(w, "private let handler: {service_name}Handler").unwrap();
216 w.writeln("private let registry: IncomingChannelRegistry")
217 .unwrap();
218 w.writeln("private let taskSender: TaskSender").unwrap();
219 w.blank_line().unwrap();
220
221 cw_writeln!(
222 w,
223 "public init(handler: {service_name}Handler, registry: IncomingChannelRegistry, taskSender: @escaping TaskSender) {{"
224 )
225 .unwrap();
226 {
227 let _indent = w.indent();
228 w.writeln("self.handler = handler").unwrap();
229 w.writeln("self.registry = registry").unwrap();
230 w.writeln("self.taskSender = taskSender").unwrap();
231 }
232 w.writeln("}").unwrap();
233 w.blank_line().unwrap();
234
235 w.writeln(
237 "public func dispatch(methodId: UInt64, requestId: UInt64, payload: Data) async {",
238 )
239 .unwrap();
240 {
241 let _indent = w.indent();
242 w.writeln("switch methodId {").unwrap();
243 for method in &service.methods {
244 let method_name = method.method_name.to_lower_camel_case();
245 let method_id = crate::method_id(method);
246 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
247 cw_writeln!(
248 w,
249 " await dispatch{method_name}(requestId: requestId, payload: payload)"
250 )
251 .unwrap();
252 }
253 w.writeln("default:").unwrap();
254 w.writeln(
255 " taskSender(.response(requestId: requestId, payload: encodeUnknownMethodError()))",
256 )
257 .unwrap();
258 w.writeln("}").unwrap();
259 }
260 w.writeln("}").unwrap();
261 w.blank_line().unwrap();
262
263 generate_preregister_channels(&mut w, service);
265 w.blank_line().unwrap();
266
267 for method in &service.methods {
269 generate_streaming_dispatch_method(&mut w, method);
270 w.blank_line().unwrap();
271 }
272 }
273 w.writeln("}").unwrap();
274 w.blank_line().unwrap();
275
276 out
277}
278
279fn generate_preregister_channels(w: &mut CodeWriter<&mut String>, service: &ServiceDetail) {
281 w.writeln("/// Pre-register channel IDs from a request payload.")
282 .unwrap();
283 w.writeln("/// Call this synchronously before spawning the dispatch task to avoid")
284 .unwrap();
285 w.writeln("/// race conditions where Data arrives before channels are registered.")
286 .unwrap();
287 w.writeln("public static func preregisterChannels(methodId: UInt64, payload: Data, registry: ChannelRegistry) async {")
288 .unwrap();
289 {
290 let _indent = w.indent();
291 w.writeln("switch methodId {").unwrap();
292
293 for method in &service.methods {
294 let method_id = crate::method_id(method);
295 let has_rx_args = method.args.iter().any(|a| is_rx(a.ty));
296
297 if has_rx_args {
298 cw_writeln!(w, "case {}:", hex_u64(method_id)).unwrap();
299 w.writeln(" do {").unwrap();
300 w.writeln(" var offset = 0").unwrap();
301
302 for arg in &method.args {
304 let arg_name = arg.name.to_lower_camel_case();
305 if is_rx(arg.ty) {
306 cw_writeln!(
308 w,
309 " let {arg_name}ChannelId = try decodeVarint(from: payload, offset: &offset)"
310 )
311 .unwrap();
312 cw_writeln!(w, " await registry.markKnown({arg_name}ChannelId)")
313 .unwrap();
314 } else if is_tx(arg.ty) {
315 cw_writeln!(
317 w,
318 " _ = try decodeVarint(from: payload, offset: &offset) // {arg_name}"
319 )
320 .unwrap();
321 } else {
322 generate_skip_arg(w, &arg_name, arg.ty, " ");
324 }
325 }
326
327 w.writeln(" } catch {").unwrap();
328 w.writeln(" // Ignore parse errors - dispatch will handle them")
329 .unwrap();
330 w.writeln(" }").unwrap();
331 }
332 }
333
334 w.writeln("default:").unwrap();
335 w.writeln(" break").unwrap();
336 w.writeln("}").unwrap();
337 }
338 w.writeln("}").unwrap();
339}
340
341fn generate_skip_arg(
343 w: &mut CodeWriter<&mut String>,
344 name: &str,
345 shape: &'static Shape,
346 indent: &str,
347) {
348 use roam_schema::is_bytes;
349
350 if is_bytes(shape) {
351 cw_writeln!(
352 w,
353 "{indent}_ = try decodeBytes(from: payload, offset: &offset) // {name}"
354 )
355 .unwrap();
356 return;
357 }
358
359 match classify_shape(shape) {
360 ShapeKind::Scalar(scalar) => {
361 let skip_code = match scalar {
362 ScalarType::Bool | ScalarType::U8 | ScalarType::I8 => "offset += 1",
363 ScalarType::U16 | ScalarType::I16 => "offset += 2",
364 ScalarType::U32 | ScalarType::I32 | ScalarType::U64 | ScalarType::I64 => {
365 "_ = try decodeVarint(from: payload, offset: &offset)"
366 }
367 ScalarType::F32 => "offset += 4",
368 ScalarType::F64 => "offset += 8",
369 ScalarType::Unit => "",
370 ScalarType::Char => "_ = try decodeVarint(from: payload, offset: &offset)",
371 _ => "// unknown scalar type",
372 };
373 if !skip_code.is_empty() {
374 cw_writeln!(w, "{indent}{skip_code} // {name}").unwrap();
375 }
376 }
377 ShapeKind::List { .. } | ShapeKind::Slice { .. } | ShapeKind::Array { .. } => {
378 cw_writeln!(
379 w,
380 "{indent}_ = try decodeBytes(from: payload, offset: &offset) // {name} (skipped)"
381 )
382 .unwrap();
383 }
384 ShapeKind::Option { .. } => {
385 cw_writeln!(w, "{indent}// TODO: skip option {name}").unwrap();
386 }
387 ShapeKind::Struct(StructInfo { fields, .. }) => {
388 for field in fields {
390 let field_name = format!("{}.{}", name, field.name);
391 generate_skip_arg(w, &field_name, field.shape(), indent);
392 }
393 }
394 _ => {
395 cw_writeln!(w, "{indent}// TODO: skip {name}").unwrap();
396 }
397 }
398}
399
400fn generate_streaming_dispatch_method(w: &mut CodeWriter<&mut String>, method: &MethodDetail) {
402 let method_name = method.method_name.to_lower_camel_case();
403 let has_streaming =
404 method.args.iter().any(|a| is_stream(a.ty)) || is_stream(method.return_type);
405
406 cw_writeln!(
407 w,
408 "private func dispatch{method_name}(requestId: UInt64, payload: Data) async {{"
409 )
410 .unwrap();
411 {
412 let _indent = w.indent();
413 w.writeln("do {").unwrap();
414 {
415 let _indent = w.indent();
416 w.writeln("var offset = 0").unwrap();
417
418 for arg in &method.args {
420 let arg_name = arg.name.to_lower_camel_case();
421 generate_streaming_decode_arg(w, &arg_name, arg.ty);
422 }
423
424 let arg_names: Vec<String> = method
426 .args
427 .iter()
428 .map(|a| {
429 let name = a.name.to_lower_camel_case();
430 format!("{name}: {name}")
431 })
432 .collect();
433
434 let ret_type = swift_type_server_return(method.return_type);
435
436 if has_streaming {
437 if ret_type == "Void" {
439 cw_writeln!(
440 w,
441 "try await handler.{method_name}({})",
442 arg_names.join(", ")
443 )
444 .unwrap();
445 } else {
446 cw_writeln!(
447 w,
448 "let result = try await handler.{method_name}({})",
449 arg_names.join(", ")
450 )
451 .unwrap();
452 }
453
454 for arg in &method.args {
456 if is_tx(arg.ty) {
457 let arg_name = arg.name.to_lower_camel_case();
458 cw_writeln!(w, "{arg_name}.close()").unwrap();
459 }
460 }
461
462 if ret_type == "Void" {
464 w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
465 } else {
466 let encode_closure = generate_encode_closure(method.return_type);
467 cw_writeln!(
468 w,
469 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
470 )
471 .unwrap();
472 }
473 } else {
474 if ret_type == "Void" {
476 cw_writeln!(
477 w,
478 "try await handler.{method_name}({})",
479 arg_names.join(", ")
480 )
481 .unwrap();
482 w.writeln("taskSender(.response(requestId: requestId, payload: encodeResultOk((), encoder: { _ in [] })))").unwrap();
483 } else {
484 cw_writeln!(
485 w,
486 "let result = try await handler.{method_name}({})",
487 arg_names.join(", ")
488 )
489 .unwrap();
490 if let ShapeKind::Result { ok, err } = classify_shape(method.return_type) {
492 let ok_encode = generate_encode_closure(ok);
493 let err_encode = generate_encode_closure(err);
494 cw_writeln!(
496 w,
497 "taskSender(.response(requestId: requestId, payload: {{ switch result {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1), UInt8(0)] + {err_encode}(e) }} }}()))"
498 )
499 .unwrap();
500 } else {
501 let encode_closure = generate_encode_closure(method.return_type);
502 cw_writeln!(
503 w,
504 "taskSender(.response(requestId: requestId, payload: encodeResultOk(result, encoder: {encode_closure})))"
505 )
506 .unwrap();
507 }
508 }
509 }
510 }
511 w.writeln("} catch {").unwrap();
512 {
513 let _indent = w.indent();
514 w.writeln(
515 "taskSender(.response(requestId: requestId, payload: encodeInvalidPayloadError()))",
516 )
517 .unwrap();
518 }
519 w.writeln("}").unwrap();
520 }
521 w.writeln("}").unwrap();
522}
523
524fn generate_streaming_decode_arg(
526 w: &mut CodeWriter<&mut String>,
527 name: &str,
528 shape: &'static Shape,
529) {
530 match classify_shape(shape) {
531 ShapeKind::Rx { inner } => {
532 let inline_decode = generate_inline_decode(inner, "Data(bytes)", "off");
535 cw_writeln!(
536 w,
537 "let {name}ChannelId = try decodeVarint(from: payload, offset: &offset)"
538 )
539 .unwrap();
540 cw_writeln!(
541 w,
542 "let {name}Receiver = await registry.register({name}ChannelId)"
543 )
544 .unwrap();
545 cw_writeln!(
546 w,
547 "let {name} = createServerRx(channelId: {name}ChannelId, receiver: {name}Receiver, deserialize: {{ bytes in"
548 )
549 .unwrap();
550 cw_writeln!(w, " var off = 0").unwrap();
551 cw_writeln!(w, " return try {inline_decode}").unwrap();
552 w.writeln("})").unwrap();
553 }
554 ShapeKind::Tx { inner } => {
555 let encode_closure = generate_encode_closure(inner);
558 cw_writeln!(
559 w,
560 "let {name}ChannelId = try decodeVarint(from: payload, offset: &offset)"
561 )
562 .unwrap();
563 cw_writeln!(
564 w,
565 "let {name} = createServerTx(channelId: {name}ChannelId, taskSender: taskSender, serialize: ({encode_closure}))"
566 )
567 .unwrap();
568 }
569 _ => {
570 let decode_stmt = generate_decode_stmt(shape, name, "");
572 for line in decode_stmt.lines() {
573 w.writeln(line).unwrap();
574 }
575 }
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use facet::Facet;
583 use roam_schema::{ArgDetail, MethodDetail, ServiceDetail};
584 use std::borrow::Cow;
585
586 fn sample_service() -> ServiceDetail {
587 ServiceDetail {
588 name: Cow::Borrowed("Echo"),
589 doc: Some(Cow::Borrowed("Simple echo service")),
590 methods: vec![MethodDetail {
591 service_name: Cow::Borrowed("Echo"),
592 method_name: Cow::Borrowed("echo"),
593 args: vec![ArgDetail {
594 name: Cow::Borrowed("message"),
595 ty: <String as Facet>::SHAPE,
596 }],
597 return_type: <String as Facet>::SHAPE,
598 doc: Some(Cow::Borrowed("Echo back the message")),
599 }],
600 }
601 }
602
603 #[test]
604 fn test_generate_handler_protocol() {
605 let service = sample_service();
606 let code = generate_handler_protocol(&service);
607
608 assert!(code.contains("protocol EchoHandler"));
609 assert!(code.contains("func echo(message: String)"));
610 assert!(code.contains("async throws -> String"));
611 }
612
613 #[test]
614 fn test_generate_dispatcher() {
615 let service = sample_service();
616 let code = generate_dispatcher(&service);
617
618 assert!(code.contains("class EchoDispatcher"));
619 assert!(code.contains("EchoHandler"));
620 assert!(code.contains("dispatch(methodId:"));
621 assert!(code.contains("dispatchecho"));
622 }
623
624 #[test]
625 fn test_generate_streaming_dispatcher() {
626 let service = sample_service();
627 let code = generate_streaming_dispatcher(&service);
628
629 assert!(code.contains("class EchoStreamingDispatcher"));
630 assert!(code.contains("preregisterChannels"));
631 assert!(code.contains("IncomingChannelRegistry"));
632 }
633}