Skip to main content

roam_codegen/targets/typescript/
client.rs

1//! TypeScript client generation.
2//!
3//! Generates client interface and implementation for making RPC calls.
4//! The client uses the service descriptor for schema-driven encode/decode —
5//! no serialization code is generated here.
6
7use heck::{ToLowerCamelCase, ToUpperCamelCase};
8use roam_types::{ServiceDescriptor, ShapeKind, classify_shape, is_rx, is_tx};
9
10use super::types::{ts_type_client_arg, ts_type_client_return};
11
12/// Format a doc comment for TypeScript/JSDoc.
13fn format_doc_comment(doc: &str, indent: &str) -> String {
14    let lines: Vec<&str> = doc.lines().collect();
15
16    if lines.is_empty() {
17        return String::new();
18    }
19
20    if lines.len() == 1 {
21        format!("{}/** {} */\n", indent, lines[0].trim())
22    } else {
23        let mut out = format!("{}/**\n", indent);
24        for line in lines {
25            let trimmed = line.trim();
26            if trimmed.is_empty() {
27                out.push_str(&format!("{} *\n", indent));
28            } else {
29                out.push_str(&format!("{} * {}\n", indent, trimmed));
30            }
31        }
32        out.push_str(&format!("{} */\n", indent));
33        out
34    }
35}
36
37/// Generate caller interface (for making calls to the service).
38///
39/// r[impl rpc.channel.binding] - Caller binds channels in args.
40pub fn generate_caller_interface(service: &ServiceDescriptor) -> String {
41    let mut out = String::new();
42    let service_name = service.service_name.to_upper_camel_case();
43
44    out.push_str(&format!("// Caller interface for {service_name}\n"));
45    out.push_str(&format!("export interface {service_name}Caller {{\n"));
46
47    for method in service.methods {
48        let method_name = method.method_name.to_lower_camel_case();
49        let args = method
50            .args
51            .iter()
52            .map(|a| {
53                format!(
54                    "{}: {}",
55                    a.name.to_lower_camel_case(),
56                    ts_type_client_arg(a.shape)
57                )
58            })
59            .collect::<Vec<_>>()
60            .join(", ");
61        let ret_ty = ts_type_client_return(method.return_shape);
62
63        if let Some(doc) = &method.doc {
64            out.push_str(&format_doc_comment(doc, "  "));
65        }
66        out.push_str(&format!(
67            "  {method_name}({args}): CallBuilder<{ret_ty}>;\n"
68        ));
69    }
70
71    out.push_str("}\n\n");
72    out
73}
74
75/// Generate client implementation.
76///
77/// Each method:
78/// 1. Looks up its `MethodDescriptor` from the service descriptor by index
79/// 2. Binds any channel args (via `bindChannels` if streaming)
80/// 3. Calls `caller.call({ method, args, descriptor, ... })`
81/// 4. The runtime encodes/decodes using the descriptor's schemas
82pub fn generate_client_impl(service: &ServiceDescriptor) -> String {
83    let mut out = String::new();
84    let service_name = service.service_name.to_upper_camel_case();
85    let service_name_lower = service.service_name.to_lower_camel_case();
86
87    out.push_str(&format!("// Client implementation for {service_name}\n"));
88    out.push_str(&format!(
89        "export class {service_name}Client implements {service_name}Caller {{\n"
90    ));
91    out.push_str("  private caller: Caller;\n\n");
92    out.push_str("  constructor(caller: Caller) {\n");
93    out.push_str("    this.caller = caller;\n");
94    out.push_str("  }\n\n");
95
96    for (method_idx, method) in service.methods.iter().enumerate() {
97        let method_name = method.method_name.to_lower_camel_case();
98
99        let has_streaming_args = method.args.iter().any(|a| is_tx(a.shape) || is_rx(a.shape));
100
101        let args = method
102            .args
103            .iter()
104            .map(|a| {
105                format!(
106                    "{}: {}",
107                    a.name.to_lower_camel_case(),
108                    ts_type_client_arg(a.shape)
109                )
110            })
111            .collect::<Vec<_>>()
112            .join(", ");
113
114        let ret_ty = ts_type_client_return(method.return_shape);
115
116        let args_record = if method.args.is_empty() {
117            "{}".to_string()
118        } else {
119            let fields: Vec<_> = method
120                .args
121                .iter()
122                .map(|a| a.name.to_lower_camel_case())
123                .collect();
124            format!("{{ {} }}", fields.join(", "))
125        };
126
127        if let Some(doc) = &method.doc {
128            out.push_str(&format_doc_comment(doc, "  "));
129        }
130        out.push_str(&format!(
131            "  {method_name}({args}): CallBuilder<{ret_ty}> {{\n"
132        ));
133
134        // Get the method descriptor by index (known at codegen time)
135        out.push_str(&format!(
136            "    const descriptor = {service_name_lower}_descriptor.methods[{method_idx}];\n"
137        ));
138
139        // Bind channel args if streaming
140        if has_streaming_args {
141            let arg_names: Vec<_> = method
142                .args
143                .iter()
144                .map(|a| a.name.to_lower_camel_case())
145                .collect();
146            out.push_str("    // Bind any Tx/Rx channels in arguments and collect channel IDs\n");
147            out.push_str(&format!(
148                "    const channels = bindChannels(\n      descriptor.args.elements,\n      [{}],\n      this.caller.getChannelAllocator(),\n      this.caller.getChannelRegistry(),\n      {service_name_lower}_descriptor.schema_registry,\n    );\n",
149                arg_names.join(", ")
150            ));
151        }
152
153        out.push_str("    return new CallBuilder(async (metadata) => {\n");
154
155        let is_fallible = matches!(
156            classify_shape(method.return_shape),
157            ShapeKind::Result { .. }
158        );
159
160        if is_fallible {
161            out.push_str("      try {\n");
162            out.push_str("        const value = await this.caller.call({\n");
163            out.push_str(&format!(
164                "          method: \"{}.{}\",\n",
165                service_name, method_name
166            ));
167            out.push_str(&format!("          args: {},\n", args_record));
168            out.push_str("          descriptor,\n");
169            out.push_str(&format!(
170                "          schemaRegistry: {service_name_lower}_descriptor.schema_registry,\n"
171            ));
172            if has_streaming_args {
173                out.push_str("          channels,\n");
174            }
175            out.push_str("          metadata,\n");
176            out.push_str("        });\n");
177            out.push_str(&format!(
178                "        return {{ ok: true, value }} as {ret_ty};\n"
179            ));
180            out.push_str("      } catch (e) {\n");
181            out.push_str("        if (e instanceof RpcError && e.isUserError()) {\n");
182            out.push_str(&format!(
183                "          return {{ ok: false, error: e.userError }} as {ret_ty};\n"
184            ));
185            out.push_str("        }\n");
186            out.push_str("        throw e;\n");
187            out.push_str("      }\n");
188            out.push_str("    });\n");
189        } else {
190            out.push_str("      const value = await this.caller.call({\n");
191            out.push_str(&format!(
192                "        method: \"{}.{}\",\n",
193                service_name, method_name
194            ));
195            out.push_str(&format!("        args: {},\n", args_record));
196            out.push_str("        descriptor,\n");
197            out.push_str(&format!(
198                "        schemaRegistry: {service_name_lower}_descriptor.schema_registry,\n"
199            ));
200            if has_streaming_args {
201                out.push_str("        channels,\n");
202            }
203            out.push_str("        metadata,\n");
204            out.push_str("      });\n");
205            out.push_str(&format!("      return value as {ret_ty};\n"));
206            out.push_str("    });\n");
207        }
208
209        out.push_str("  }\n\n");
210    }
211
212    out.push_str("}\n\n");
213    out
214}
215
216/// Generate a connect() helper function for WebSocket connections.
217pub fn generate_connect_function(service: &ServiceDescriptor) -> String {
218    let service_name = service.service_name.to_upper_camel_case();
219
220    let mut out = String::new();
221    out.push_str(&format!(
222        "/**\n * Connect to a {service_name} server over WebSocket.\n"
223    ));
224    out.push_str(" * @param url - WebSocket URL (e.g., \"ws://localhost:9000\")\n");
225    out.push_str(&format!(
226        " * @returns A connected {service_name}Client instance\n"
227    ));
228    out.push_str(" */\n");
229    out.push_str(&format!(
230        "export async function connect{service_name}(url: string): Promise<{service_name}Client> {{\n"
231    ));
232    out.push_str("  const transport = await connectWs(url);\n");
233    out.push_str("  const connection = await helloExchangeInitiator(transport, defaultHello());\n");
234    out.push_str(&format!(
235        "  return new {service_name}Client(connection.asCaller());\n"
236    ));
237    out.push_str("}\n\n");
238    out
239}
240
241/// Generate complete client code (interface + implementation + connect helper).
242pub fn generate_client(service: &ServiceDescriptor) -> String {
243    let mut out = String::new();
244    out.push_str(&generate_caller_interface(service));
245    out.push_str(&generate_client_impl(service));
246    out.push_str(&generate_connect_function(service));
247    out
248}