roam_codegen/targets/typescript/
server.rs

1//! TypeScript server/handler generation.
2//!
3//! Generates server handler interface and method dispatch logic.
4
5use heck::{ToLowerCamelCase, ToUpperCamelCase};
6use roam_schema::{ServiceDetail, ShapeKind, classify_shape, is_rx, is_tx};
7
8use super::decode::{generate_decode_stmt_server, generate_decode_stmt_server_streaming};
9use super::encode::generate_encode_expr;
10use super::types::{is_fully_supported, ts_type_server_arg, ts_type_server_return};
11
12/// Generate handler interface (for handling incoming calls).
13///
14/// r[impl channeling.caller-pov] - Handler uses Rx for args, Tx for returns.
15pub fn generate_handler_interface(service: &ServiceDetail) -> String {
16    let mut out = String::new();
17    let service_name = service.name.to_upper_camel_case();
18
19    out.push_str(&format!("// Handler interface for {service_name}\n"));
20    out.push_str(&format!("export interface {service_name}Handler {{\n"));
21
22    for method in &service.methods {
23        let method_name = method.method_name.to_lower_camel_case();
24        // Handler args: Tx becomes Rx (receives), Rx becomes Tx (sends)
25        let args = method
26            .args
27            .iter()
28            .map(|a| {
29                format!(
30                    "{}: {}",
31                    a.name.to_lower_camel_case(),
32                    ts_type_server_arg(a.ty)
33                )
34            })
35            .collect::<Vec<_>>()
36            .join(", ");
37        // Handler returns
38        let ret_ty = ts_type_server_return(method.return_type);
39
40        out.push_str(&format!(
41            "  {method_name}({args}): Promise<{ret_ty}> | {ret_ty};\n"
42        ));
43    }
44
45    out.push_str("}\n\n");
46    out
47}
48
49/// Generate RPC method handlers map.
50pub fn generate_method_handlers(service: &ServiceDetail) -> String {
51    use crate::render::hex_u64;
52
53    let mut out = String::new();
54    let service_name = service.name.to_upper_camel_case();
55    let service_name_lower = service.name.to_lower_camel_case();
56
57    out.push_str(&format!("// Method handlers for {service_name}\n"));
58    out.push_str(&format!("export const {}_methodHandlers = new Map<bigint, MethodHandler<{service_name}Handler>>([\n", service_name_lower));
59
60    for method in &service.methods {
61        let method_name = method.method_name.to_lower_camel_case();
62        let id = crate::method_id(method);
63
64        // Check if this method uses streaming
65        let method_has_streaming = method.args.iter().any(|a| is_tx(a.ty) || is_rx(a.ty))
66            || is_tx(method.return_type)
67            || is_rx(method.return_type);
68
69        out.push_str(&format!(
70            "  [{}n, async (handler, payload) => {{\n",
71            hex_u64(id)
72        ));
73        out.push_str("    try {\n");
74
75        // Check if we can fully implement this method
76        let can_decode_args = method.args.iter().all(|a| is_fully_supported(a.ty));
77        let can_encode_return = is_fully_supported(method.return_type);
78
79        if can_decode_args && can_encode_return && !method_has_streaming {
80            // Non-streaming method - decode and call directly
81            out.push_str("      const buf = payload;\n");
82            out.push_str("      let offset = 0;\n");
83            for arg in &method.args {
84                let arg_name = arg.name.to_lower_camel_case();
85                let decode_stmt = generate_decode_stmt_server(arg.ty, &arg_name, "offset");
86                out.push_str(&format!("      {decode_stmt}\n"));
87            }
88            out.push_str(
89                "      if (offset !== buf.length) throw new Error(\"args: trailing bytes\");\n",
90            );
91
92            let arg_names = method
93                .args
94                .iter()
95                .map(|a| a.name.to_lower_camel_case())
96                .collect::<Vec<_>>()
97                .join(", ");
98            out.push_str(&format!(
99                "      const result = await handler.{method_name}({arg_names});\n"
100            ));
101
102            let encode_expr = generate_encode_expr(method.return_type, "result");
103            out.push_str(&format!("      return encodeResultOk({encode_expr});\n"));
104        } else {
105            // Streaming method - must use streaming dispatcher
106            out.push_str(
107                "      // Channeling method - use streamingDispatch() instead of simple RPC dispatch\n",
108            );
109            out.push_str("      return encodeResultErr(encodeInvalidPayload());\n");
110        }
111
112        out.push_str("    } catch (e) {\n");
113        out.push_str("      return encodeResultErr(encodeInvalidPayload());\n");
114        out.push_str("    }\n");
115        out.push_str("  }],\n");
116    }
117
118    out.push_str("]);\n\n");
119    out
120}
121
122/// Generate streaming method handlers.
123///
124/// These handlers receive the registry and taskSender to properly bind streams.
125pub fn generate_streaming_handlers(service: &ServiceDetail) -> String {
126    use crate::render::hex_u64;
127
128    let mut out = String::new();
129    let service_name = service.name.to_upper_camel_case();
130    let service_name_lower = service.name.to_lower_camel_case();
131
132    // Type for streaming method handler
133    out.push_str(&format!(
134        "// Streaming method handler type for {service_name}\n"
135    ));
136    out.push_str("export type ChannelingMethodHandler<H> = (\n  handler: H,\n  payload: Uint8Array,\n  requestId: bigint,\n  registry: ChannelRegistry,\n  taskSender: TaskSender,\n) => Promise<void>;\n\n");
137
138    // Generate streaming handlers map
139    out.push_str(&format!(
140        "// Streaming method handlers for {service_name}\n"
141    ));
142    out.push_str(&format!(
143        "export const {service_name_lower}_streamingHandlers = new Map<bigint, ChannelingMethodHandler<{service_name}Handler>>([\n"
144    ));
145
146    for method in &service.methods {
147        let method_name = method.method_name.to_lower_camel_case();
148        let id = crate::method_id(method);
149
150        out.push_str(&format!(
151            "  [{}n, async (handler, payload, requestId, registry, taskSender) => {{\n",
152            hex_u64(id)
153        ));
154        out.push_str("    try {\n");
155        out.push_str("      const buf = payload;\n");
156        out.push_str("      let offset = 0;\n");
157
158        // Decode all arguments with proper stream binding
159        for arg in &method.args {
160            let arg_name = arg.name.to_lower_camel_case();
161            let decode_stmt = generate_decode_stmt_server_streaming(
162                arg.ty,
163                &arg_name,
164                "offset",
165                "registry",
166                "taskSender",
167            );
168            out.push_str(&format!("      {decode_stmt}\n"));
169        }
170        out.push_str(
171            "      if (offset !== buf.length) throw new Error(\"args: trailing bytes\");\n",
172        );
173
174        // Call handler
175        let arg_names = method
176            .args
177            .iter()
178            .map(|a| a.name.to_lower_camel_case())
179            .collect::<Vec<_>>()
180            .join(", ");
181        out.push_str(&format!(
182            "      const result = await handler.{method_name}({arg_names});\n"
183        ));
184
185        // Close any Tx streams that were passed as arguments
186        for arg in &method.args {
187            if is_tx(arg.ty) {
188                let arg_name = arg.name.to_lower_camel_case();
189                out.push_str(&format!("      {arg_name}.close();\n"));
190            }
191        }
192
193        // Encode and send response via taskSender
194        // Check if return type is Result<T, E> - if so, encode as Result<T, RoamError<User(E)>>
195        if let ShapeKind::Result { ok, err } = classify_shape(method.return_type) {
196            // Handler returns { ok: true; value: T } | { ok: false; error: E }
197            // Wire format: [0] + T for success, [1, 0] + E for User error
198            let ok_encode = generate_encode_expr(ok, "result.value");
199            let err_encode = generate_encode_expr(err, "result.error");
200            out.push_str("      if (result.ok) {\n");
201            out.push_str(&format!(
202                "        taskSender({{ kind: 'response', requestId, payload: concat(encodeU8(0), {ok_encode}) }});\n"
203            ));
204            out.push_str("      } else {\n");
205            out.push_str(&format!(
206                "        taskSender({{ kind: 'response', requestId, payload: concat(encodeU8(1), encodeU8(0), {err_encode}) }});\n"
207            ));
208            out.push_str("      }\n");
209        } else {
210            let encode_expr = generate_encode_expr(method.return_type, "result");
211            out.push_str(&format!(
212                "      taskSender({{ kind: 'response', requestId, payload: encodeResultOk({encode_expr}) }});\n"
213            ));
214        }
215
216        out.push_str("    } catch (e) {\n");
217        out.push_str(
218            "      taskSender({ kind: 'response', requestId, payload: encodeResultErr(encodeInvalidPayload()) });\n",
219        );
220        out.push_str("    }\n");
221        out.push_str("  }],\n");
222    }
223
224    out.push_str("]);\n\n");
225    out
226}
227
228/// Generate complete server code (interface + handlers).
229pub fn generate_server(service: &ServiceDetail) -> String {
230    let mut out = String::new();
231
232    // Generate handler interface
233    out.push_str(&generate_handler_interface(service));
234
235    // Generate RPC method handlers
236    out.push_str(&generate_method_handlers(service));
237
238    // Check if any method uses streaming
239    let has_streaming = service.methods.iter().any(|m| {
240        m.args.iter().any(|a| is_tx(a.ty) || is_rx(a.ty))
241            || is_tx(m.return_type)
242            || is_rx(m.return_type)
243    });
244
245    // Generate streaming handlers if needed
246    if has_streaming {
247        out.push_str(&generate_streaming_handlers(service));
248    }
249
250    out
251}