roam_codegen/targets/typescript/
client.rs

1//! TypeScript client generation.
2//!
3//! Generates client interface and implementation for making RPC calls.
4
5use heck::{ToLowerCamelCase, ToUpperCamelCase};
6use roam_schema::{ServiceDetail, ShapeKind, classify_shape, is_rx, is_tx};
7
8use super::types::{ts_type_client_arg, ts_type_client_return};
9
10/// Generate caller interface (for making calls to the service).
11///
12/// r[impl channeling.caller-pov] - Caller uses Tx for args, Rx for returns.
13pub fn generate_caller_interface(service: &ServiceDetail) -> String {
14    let mut out = String::new();
15    let service_name = service.name.to_upper_camel_case();
16
17    out.push_str(&format!("// Caller interface for {service_name}\n"));
18    out.push_str(&format!("export interface {service_name}Caller {{\n"));
19
20    for method in &service.methods {
21        let method_name = method.method_name.to_lower_camel_case();
22        // Caller args: Tx stays Tx, Rx stays Rx
23        let args = method
24            .args
25            .iter()
26            .map(|a| {
27                format!(
28                    "{}: {}",
29                    a.name.to_lower_camel_case(),
30                    ts_type_client_arg(a.ty)
31                )
32            })
33            .collect::<Vec<_>>()
34            .join(", ");
35        // Caller returns
36        let ret_ty = ts_type_client_return(method.return_type);
37
38        if let Some(doc) = &method.doc {
39            out.push_str(&format!("  /** {} */\n", doc));
40        }
41        out.push_str(&format!("  {method_name}({args}): Promise<{ret_ty}>;\n"));
42    }
43
44    out.push_str("}\n\n");
45    out
46}
47
48/// Generate client implementation (for making calls to the service).
49///
50/// Uses schema-driven encoding/decoding via `encodeWithSchema`/`decodeWithSchema`.
51pub fn generate_client_impl(service: &ServiceDetail) -> String {
52    use crate::render::hex_u64;
53
54    let mut out = String::new();
55    let service_name = service.name.to_upper_camel_case();
56    let service_name_lower = service.name.to_lower_camel_case();
57
58    out.push_str(&format!("// Client implementation for {service_name}\n"));
59    out.push_str(&format!(
60        "export class {service_name}Client<T extends MessageTransport = MessageTransport> implements {service_name}Caller {{\n"
61    ));
62    out.push_str("  private conn: Connection<T>;\n\n");
63    out.push_str("  constructor(conn: Connection<T>) {\n");
64    out.push_str("    this.conn = conn;\n");
65    out.push_str("  }\n\n");
66
67    for method in &service.methods {
68        let method_name = method.method_name.to_lower_camel_case();
69        let id = crate::method_id(method);
70
71        // Check if this method has streaming args (Tx or Rx)
72        let has_streaming_args = method.args.iter().any(|a| is_tx(a.ty) || is_rx(a.ty));
73
74        // Build args list
75        let args = method
76            .args
77            .iter()
78            .map(|a| {
79                format!(
80                    "{}: {}",
81                    a.name.to_lower_camel_case(),
82                    ts_type_client_arg(a.ty)
83                )
84            })
85            .collect::<Vec<_>>()
86            .join(", ");
87
88        // Return type
89        let ret_ty = ts_type_client_return(method.return_type);
90
91        if let Some(doc) = &method.doc {
92            out.push_str(&format!("  /** {} */\n", doc));
93        }
94        out.push_str(&format!(
95            "  async {method_name}({args}): Promise<{ret_ty}> {{\n"
96        ));
97
98        // Get schema reference
99        out.push_str(&format!(
100            "    const schema = {service_name_lower}_schemas.{method_name};\n"
101        ));
102
103        // If method has streaming args, bind channels first
104        if has_streaming_args {
105            let arg_names: Vec<_> = method
106                .args
107                .iter()
108                .map(|a| a.name.to_lower_camel_case())
109                .collect();
110            out.push_str("    // Bind any Tx/Rx channels in arguments and collect channel IDs\n");
111            out.push_str(&format!(
112                "    const channels = bindChannels(\n      schema.args,\n      [{}],\n      this.conn.getChannelAllocator(),\n      this.conn.getChannelRegistry(),\n      {service_name_lower}_serializers,\n    );\n",
113                arg_names.join(", ")
114            ));
115        }
116
117        // Encode payload using schema
118        if method.args.is_empty() {
119            out.push_str("    const payload = new Uint8Array(0);\n");
120        } else if method.args.len() == 1 {
121            let arg_name = method.args[0].name.to_lower_camel_case();
122            out.push_str(&format!(
123                "    const payload = encodeWithSchema({arg_name}, schema.args[0]);\n"
124            ));
125        } else {
126            // Multiple args - encode as tuple
127            let arg_names: Vec<_> = method
128                .args
129                .iter()
130                .map(|a| a.name.to_lower_camel_case())
131                .collect();
132            out.push_str(&format!(
133                "    const payload = encodeWithSchema([{}], {{ kind: 'tuple', elements: schema.args }});\n",
134                arg_names.join(", ")
135            ));
136        }
137
138        // Call the server
139        if has_streaming_args {
140            out.push_str(&format!(
141                "    const response = await this.conn.call({}n, payload, 30000, channels);\n",
142                hex_u64(id)
143            ));
144        } else {
145            out.push_str(&format!(
146                "    const response = await this.conn.call({}n, payload);\n",
147                hex_u64(id)
148            ));
149        }
150
151        // Check if this method returns Result<T, E>
152        let is_fallible = matches!(classify_shape(method.return_type), ShapeKind::Result { .. });
153
154        if is_fallible {
155            // Fallible method: handle both success and user error
156            out.push_str("    try {\n");
157            out.push_str("      const offset = decodeRpcResult(response, 0);\n");
158            out.push_str(
159                "      const value = decodeWithSchema(response, offset, schema.returns).value;\n",
160            );
161            out.push_str(&format!(
162                "      return {{ ok: true, value }} as {ret_ty};\n"
163            ));
164            out.push_str("    } catch (e) {\n");
165            out.push_str("      if (e instanceof RpcError && e.isUserError() && e.payload && schema.error) {\n");
166            out.push_str(
167                "        const error = decodeWithSchema(e.payload, 0, schema.error).value;\n",
168            );
169            out.push_str(&format!(
170                "        return {{ ok: false, error }} as {ret_ty};\n"
171            ));
172            out.push_str("      }\n");
173            out.push_str("      throw e;\n");
174            out.push_str("    }\n");
175        } else {
176            // Infallible method: just decode success
177            out.push_str("    const offset = decodeRpcResult(response, 0);\n");
178            out.push_str(
179                "    const result = decodeWithSchema(response, offset, schema.returns).value;\n",
180            );
181            out.push_str(&format!("    return result as {ret_ty};\n"));
182        }
183
184        out.push_str("  }\n\n");
185    }
186
187    out.push_str("}\n\n");
188    out
189}
190
191/// Generate a connect() helper function for WebSocket connections.
192pub fn generate_connect_function(service: &ServiceDetail) -> String {
193    use heck::ToUpperCamelCase;
194
195    let service_name = service.name.to_upper_camel_case();
196
197    let mut out = String::new();
198    out.push_str(&format!(
199        "/**\n * Connect to a {service_name} server over WebSocket.\n"
200    ));
201    out.push_str(" * @param url - WebSocket URL (e.g., \"ws://localhost:9000\")\n");
202    out.push_str(&format!(
203        " * @returns A connected {service_name}Client instance\n"
204    ));
205    out.push_str(" */\n");
206    out.push_str(&format!(
207        "export async function connect{service_name}(url: string): Promise<{service_name}Client<WsTransport>> {{\n"
208    ));
209    out.push_str("  const transport = await connectWs(url);\n");
210    out.push_str("  const connection = await helloExchangeInitiator(transport, defaultHello());\n");
211    out.push_str(&format!("  return new {service_name}Client(connection);\n"));
212    out.push_str("}\n\n");
213    out
214}
215
216/// Generate complete client code (interface + implementation + connect helper).
217pub fn generate_client(service: &ServiceDetail) -> String {
218    let mut out = String::new();
219    out.push_str(&generate_caller_interface(service));
220    out.push_str(&generate_client_impl(service));
221    out.push_str(&generate_connect_function(service));
222    out
223}