Skip to main content

metaxy_cli/codegen/
client.rs

1use super::common::{GENERATED_HEADER, is_void_input};
2use super::typescript::{emit_jsdoc, rust_type_to_ts};
3use crate::model::{Manifest, ProcedureKind};
4
5/// Standard RPC error class with status code and structured error data.
6const ERROR_CLASS: &str = r#"export class RpcError extends Error {
7  readonly status: number;
8  readonly data: unknown;
9
10  constructor(status: number, message: string, data?: unknown) {
11    super(message);
12    this.name = "RpcError";
13    this.status = status;
14    this.data = data;
15  }
16}"#;
17
18/// Context passed to the `onRequest` lifecycle hook.
19const REQUEST_CONTEXT_INTERFACE: &str = r#"export interface RequestContext {
20  procedure: string;
21  method: "GET" | "POST";
22  url: string;
23  headers: Record<string, string>;
24  input?: unknown;
25}"#;
26
27/// Context passed to the `onResponse` lifecycle hook.
28const RESPONSE_CONTEXT_INTERFACE: &str = r#"export interface ResponseContext {
29  procedure: string;
30  method: "GET" | "POST";
31  url: string;
32  response: Response;
33  data: unknown;
34  duration: number;
35}"#;
36
37/// Context passed to the `onError` lifecycle hook.
38const ERROR_CONTEXT_INTERFACE: &str = r#"export interface ErrorContext {
39  procedure: string;
40  method: "GET" | "POST";
41  url: string;
42  error: unknown;
43  attempt: number;
44  willRetry: boolean;
45}"#;
46
47/// Retry policy configuration.
48const RETRY_POLICY_INTERFACE: &str = r#"export interface RetryPolicy {
49  attempts: number;
50  delay: number | ((attempt: number) => number);
51  retryOn?: number[];
52}"#;
53
54/// Configuration interface for the RPC client.
55const CONFIG_INTERFACE: &str = r#"export interface RpcClientConfig {
56  baseUrl: string;
57  fetch?: typeof globalThis.fetch;
58  headers?:
59    | Record<string, string>
60    | (() => Record<string, string> | Promise<Record<string, string>>);
61  onRequest?: (ctx: RequestContext) => void | Promise<void>;
62  onResponse?: (ctx: ResponseContext) => void | Promise<void>;
63  onError?: (ctx: ErrorContext) => void | Promise<void>;
64  retry?: RetryPolicy;
65  timeout?: number;
66  serialize?: (input: unknown) => string;
67  deserialize?: (text: string) => unknown;
68  // AbortSignal for cancelling all requests made by this client.
69  signal?: AbortSignal;
70  dedupe?: boolean;
71}"#;
72
73/// Per-call options that override client-level defaults for a single request.
74const CALL_OPTIONS_INTERFACE: &str = r#"export interface CallOptions {
75  headers?: Record<string, string>;
76  timeout?: number;
77  signal?: AbortSignal;
78  dedupe?: boolean;
79}"#;
80
81/// Computes a dedup map key from procedure name and serialized input.
82const DEDUP_KEY_FN: &str = r#"function dedupKey(procedure: string, input: unknown, config: RpcClientConfig): string {
83  const serialized = input === undefined
84    ? ""
85    : config.serialize
86      ? config.serialize(input)
87      : JSON.stringify(input);
88  return procedure + ":" + serialized;
89}"#;
90
91/// Wraps a shared promise so that a per-caller AbortSignal can reject independently.
92const WRAP_WITH_SIGNAL_FN: &str = r#"function wrapWithSignal<T>(promise: Promise<T>, signal?: AbortSignal): Promise<T> {
93  if (!signal) return promise;
94  if (signal.aborted) return Promise.reject(signal.reason);
95  return new Promise<T>((resolve, reject) => {
96    const onAbort = () => reject(signal.reason);
97    signal.addEventListener("abort", onAbort, { once: true });
98    promise.then(
99      (value) => { signal.removeEventListener("abort", onAbort); resolve(value); },
100      (error) => { signal.removeEventListener("abort", onAbort); reject(error); },
101    );
102  });
103}"#;
104
105/// Internal fetch helper shared by query and mutate methods.
106const FETCH_HELPER: &str = r#"const DEFAULT_RETRY_ON = [408, 429, 500, 502, 503, 504];
107
108async function rpcFetch(
109  config: RpcClientConfig,
110  method: "GET" | "POST",
111  procedure: string,
112  input?: unknown,
113  callOptions?: CallOptions,
114): Promise<unknown> {
115  let url = `${config.baseUrl}/${procedure}`;
116  const customHeaders = typeof config.headers === "function"
117    ? await config.headers()
118    : config.headers;
119  const baseHeaders: Record<string, string> = { ...customHeaders, ...callOptions?.headers };
120
121  if (method === "GET" && input !== undefined) {
122    const serialized = config.serialize ? config.serialize(input) : JSON.stringify(input);
123    url += `?input=${encodeURIComponent(serialized)}`;
124  } else if (method === "POST" && input !== undefined) {
125    baseHeaders["Content-Type"] = "application/json";
126  }
127
128  const fetchFn = config.fetch ?? globalThis.fetch;
129  const maxAttempts = 1 + (config.retry?.attempts ?? 0);
130  const retryOn = config.retry?.retryOn ?? DEFAULT_RETRY_ON;
131  const effectiveTimeout = callOptions?.timeout ?? PROCEDURE_TIMEOUTS[procedure] ?? config.timeout;
132  const start = Date.now();
133
134  for (let attempt = 1; attempt <= maxAttempts; attempt++) {
135    const reqCtx: RequestContext = { procedure, method, url, headers: { ...baseHeaders }, input };
136    await config.onRequest?.(reqCtx);
137
138    const init: RequestInit = { method, headers: reqCtx.headers };
139    if (method === "POST" && input !== undefined) {
140      init.body = config.serialize ? config.serialize(input) : JSON.stringify(input);
141    }
142
143    let timeoutId: ReturnType<typeof setTimeout> | undefined;
144    const signals: AbortSignal[] = [];
145    if (config.signal) signals.push(config.signal);
146    if (callOptions?.signal) signals.push(callOptions.signal);
147    if (effectiveTimeout) {
148      const controller = new AbortController();
149      timeoutId = setTimeout(() => controller.abort(), effectiveTimeout);
150      signals.push(controller.signal);
151    }
152    if (signals.length > 0) {
153      init.signal = signals.length === 1 ? signals[0] : AbortSignal.any(signals);
154    }
155
156    const isRetryable = attempt < maxAttempts && (method === "GET" || IDEMPOTENT_MUTATIONS.has(procedure));
157
158    try {
159      const res = await fetchFn(url, init);
160
161      if (!res.ok) {
162        let data: unknown;
163        try {
164          data = await res.json();
165        } catch {
166          data = await res.text().catch(() => null);
167        }
168        const rpcError = new RpcError(
169          res.status,
170          `RPC error on "${procedure}": ${res.status} ${res.statusText}`,
171          data,
172        );
173        const canRetry = retryOn.includes(res.status) && isRetryable;
174        await config.onError?.({ procedure, method, url, error: rpcError, attempt, willRetry: canRetry });
175        if (!canRetry) throw rpcError;
176      } else {
177        const json = config.deserialize ? config.deserialize(await res.text()) : await res.json();
178        const result = json?.result?.data ?? json;
179        const duration = Date.now() - start;
180        await config.onResponse?.({ procedure, method, url, response: res, data: result, duration });
181        return result;
182      }
183    } catch (err) {
184      if (err instanceof RpcError) throw err;
185      await config.onError?.({ procedure, method, url, error: err, attempt, willRetry: isRetryable });
186      if (!isRetryable) throw err;
187    } finally {
188      if (timeoutId !== undefined) clearTimeout(timeoutId);
189    }
190
191    if (config.retry) {
192      const d = typeof config.retry.delay === "function"
193        ? config.retry.delay(attempt) : config.retry.delay;
194      await new Promise(r => setTimeout(r, d));
195    }
196  }
197}"#;
198
199/// Generates the complete `rpc-client.ts` file content from a manifest.
200///
201/// The output includes:
202/// 1. Auto-generation header
203/// 2. Re-export of `Procedures` type from the types file
204/// 3. `RpcError` class for structured error handling
205/// 4. Internal `rpcFetch` helper
206/// 5. `createRpcClient` factory function with fully typed `query` / `mutate` methods
207pub fn generate_client_file(
208    manifest: &Manifest,
209    types_import_path: &str,
210    preserve_docs: bool,
211) -> String {
212    let mut out = String::with_capacity(2048);
213
214    // Header
215    out.push_str(GENERATED_HEADER);
216    out.push('\n');
217
218    // Collect all user-defined type names (structs + enums) for import
219    let type_names: Vec<&str> = manifest
220        .structs
221        .iter()
222        .map(|s| s.name.as_str())
223        .chain(manifest.enums.iter().map(|e| e.name.as_str()))
224        .collect();
225
226    // Import Procedures type (and any referenced types) from the types file
227    if type_names.is_empty() {
228        emit!(
229            out,
230            "import type {{ Procedures }} from \"{types_import_path}\";\n"
231        );
232        emit!(out, "export type {{ Procedures }};\n");
233    } else {
234        let types_csv = type_names.join(", ");
235        emit!(
236            out,
237            "import type {{ Procedures, {types_csv} }} from \"{types_import_path}\";\n"
238        );
239        emit!(out, "export type {{ Procedures, {types_csv} }};\n");
240    }
241
242    // Error class
243    emit!(out, "{ERROR_CLASS}\n");
244
245    // Lifecycle hook context interfaces
246    emit!(out, "{REQUEST_CONTEXT_INTERFACE}\n");
247    emit!(out, "{RESPONSE_CONTEXT_INTERFACE}\n");
248    emit!(out, "{ERROR_CONTEXT_INTERFACE}\n");
249
250    // Retry policy interface
251    emit!(out, "{RETRY_POLICY_INTERFACE}\n");
252
253    // Client config interface
254    emit!(out, "{CONFIG_INTERFACE}\n");
255
256    // Per-call options interface
257    emit!(out, "{CALL_OPTIONS_INTERFACE}\n");
258
259    // Per-procedure timeout defaults (ms)
260    generate_procedure_timeouts(manifest, &mut out);
261
262    // Idempotent mutations set (for retry gating)
263    generate_idempotent_mutations(manifest, &mut out);
264
265    // Internal fetch helper
266    emit!(out, "{FETCH_HELPER}\n");
267
268    // Dedup helpers (only when the manifest has queries)
269    let has_queries = manifest
270        .procedures
271        .iter()
272        .any(|p| p.kind == ProcedureKind::Query);
273    if has_queries {
274        emit!(out, "{DEDUP_KEY_FN}\n");
275        emit!(out, "{WRAP_WITH_SIGNAL_FN}\n");
276    }
277
278    // Type helpers for ergonomic API
279    generate_type_helpers(&mut out);
280    out.push('\n');
281
282    // Client factory
283    generate_client_factory(manifest, preserve_docs, &mut out);
284
285    out
286}
287
288/// Emits the `PROCEDURE_TIMEOUTS` record mapping procedure names to their default timeout in ms.
289fn generate_procedure_timeouts(manifest: &Manifest, out: &mut String) {
290    let entries: Vec<_> = manifest
291        .procedures
292        .iter()
293        .filter_map(|p| p.timeout_ms.map(|ms| format!("  \"{}\": {}", p.name, ms)))
294        .collect();
295
296    if entries.is_empty() {
297        emit!(
298            out,
299            "const PROCEDURE_TIMEOUTS: Record<string, number> = {{}};\n"
300        );
301    } else {
302        emit!(out, "const PROCEDURE_TIMEOUTS: Record<string, number> = {{");
303        for entry in &entries {
304            emit!(out, "{entry},");
305        }
306        emit!(out, "}};\n");
307    }
308}
309
310/// Emits the `IDEMPOTENT_MUTATIONS` set listing mutations marked as safe to retry.
311fn generate_idempotent_mutations(manifest: &Manifest, out: &mut String) {
312    let names: Vec<_> = manifest
313        .procedures
314        .iter()
315        .filter(|p| p.idempotent)
316        .map(|p| format!("\"{}\"", p.name))
317        .collect();
318
319    if names.is_empty() {
320        emit!(
321            out,
322            "const IDEMPOTENT_MUTATIONS: Set<string> = new Set();\n"
323        );
324    } else {
325        emit!(
326            out,
327            "const IDEMPOTENT_MUTATIONS: Set<string> = new Set([{}]);\n",
328            names.join(", ")
329        );
330    }
331}
332
333/// Emits utility types that power the typed client API.
334fn generate_type_helpers(out: &mut String) {
335    emit!(out, "type QueryKey = keyof Procedures[\"queries\"];");
336    emit!(out, "type MutationKey = keyof Procedures[\"mutations\"];");
337    emit!(
338        out,
339        "type QueryInput<K extends QueryKey> = Procedures[\"queries\"][K][\"input\"];"
340    );
341    emit!(
342        out,
343        "type QueryOutput<K extends QueryKey> = Procedures[\"queries\"][K][\"output\"];"
344    );
345    emit!(
346        out,
347        "type MutationInput<K extends MutationKey> = Procedures[\"mutations\"][K][\"input\"];"
348    );
349    emit!(
350        out,
351        "type MutationOutput<K extends MutationKey> = Procedures[\"mutations\"][K][\"output\"];"
352    );
353}
354
355/// Generates the `createRpcClient` factory using an interface for typed overloads.
356fn generate_client_factory(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
357    let queries: Vec<_> = manifest
358        .procedures
359        .iter()
360        .filter(|p| p.kind == ProcedureKind::Query)
361        .collect();
362    let mutations: Vec<_> = manifest
363        .procedures
364        .iter()
365        .filter(|p| p.kind == ProcedureKind::Mutation)
366        .collect();
367    let has_queries = !queries.is_empty();
368    let has_mutations = !mutations.is_empty();
369
370    // Partition queries and mutations by void/non-void input
371    let void_queries: Vec<_> = queries.iter().filter(|p| is_void_input(p)).collect();
372    let non_void_queries: Vec<_> = queries.iter().filter(|p| !is_void_input(p)).collect();
373    let void_mutations: Vec<_> = mutations.iter().filter(|p| is_void_input(p)).collect();
374    let non_void_mutations: Vec<_> = mutations.iter().filter(|p| !is_void_input(p)).collect();
375
376    let query_mixed = !void_queries.is_empty() && !non_void_queries.is_empty();
377    let mutation_mixed = !void_mutations.is_empty() && !non_void_mutations.is_empty();
378
379    // Emit VOID_QUERIES/VOID_MUTATIONS sets when mixed void/non-void exists
380    if query_mixed {
381        let names: Vec<_> = void_queries
382            .iter()
383            .map(|p| format!("\"{}\"", p.name))
384            .collect();
385        emit!(
386            out,
387            "const VOID_QUERIES: Set<string> = new Set([{}]);",
388            names.join(", ")
389        );
390        out.push('\n');
391    }
392    if mutation_mixed {
393        let names: Vec<_> = void_mutations
394            .iter()
395            .map(|p| format!("\"{}\"", p.name))
396            .collect();
397        emit!(
398            out,
399            "const VOID_MUTATIONS: Set<string> = new Set([{}]);",
400            names.join(", ")
401        );
402        out.push('\n');
403    }
404
405    // Emit the RpcClient interface with overloaded method signatures
406    emit!(out, "export interface RpcClient {{");
407
408    if has_queries {
409        generate_query_overloads(manifest, preserve_docs, out);
410    }
411
412    if has_mutations {
413        if has_queries {
414            out.push('\n');
415        }
416        generate_mutation_overloads(manifest, preserve_docs, out);
417    }
418
419    emit!(out, "}}");
420    out.push('\n');
421
422    // Emit the factory function
423    emit!(
424        out,
425        "export function createRpcClient(config: RpcClientConfig): RpcClient {{"
426    );
427
428    if has_queries {
429        emit!(
430            out,
431            "  const inflight = new Map<string, Promise<unknown>>();\n"
432        );
433    }
434
435    emit!(out, "  return {{");
436
437    if has_queries {
438        emit!(
439            out,
440            "    query(key: QueryKey, ...args: unknown[]): Promise<unknown> {{"
441        );
442
443        // Extract input and callOptions into locals based on void/non-void branching
444        if query_mixed {
445            emit!(out, "      let input: unknown;");
446            emit!(out, "      let callOptions: CallOptions | undefined;");
447            emit!(out, "      if (VOID_QUERIES.has(key)) {{");
448            emit!(out, "        input = undefined;");
449            emit!(
450                out,
451                "        callOptions = args[0] as CallOptions | undefined;"
452            );
453            emit!(out, "      }} else {{");
454            emit!(out, "        input = args[0];");
455            emit!(
456                out,
457                "        callOptions = args[1] as CallOptions | undefined;"
458            );
459            emit!(out, "      }}");
460        } else if !void_queries.is_empty() {
461            emit!(out, "      const input = undefined;");
462            emit!(
463                out,
464                "      const callOptions = args[0] as CallOptions | undefined;"
465            );
466        } else {
467            emit!(out, "      const input = args[0];");
468            emit!(
469                out,
470                "      const callOptions = args[1] as CallOptions | undefined;"
471            );
472        }
473
474        // Dedup logic
475        emit!(
476            out,
477            "      const shouldDedupe = callOptions?.dedupe ?? config.dedupe ?? true;"
478        );
479        emit!(out, "      if (shouldDedupe) {{");
480        emit!(out, "        const k = dedupKey(key, input, config);");
481        emit!(out, "        const existing = inflight.get(k);");
482        emit!(
483            out,
484            "        if (existing) return wrapWithSignal(existing, callOptions?.signal);"
485        );
486        emit!(
487            out,
488            "        const promise = rpcFetch(config, \"GET\", key, input, callOptions)"
489        );
490        emit!(out, "          .finally(() => inflight.delete(k));");
491        emit!(out, "        inflight.set(k, promise);");
492        emit!(
493            out,
494            "        return wrapWithSignal(promise, callOptions?.signal);"
495        );
496        emit!(out, "      }}");
497        emit!(
498            out,
499            "      return rpcFetch(config, \"GET\", key, input, callOptions);"
500        );
501        emit!(out, "    }},");
502    }
503
504    if has_mutations {
505        emit!(
506            out,
507            "    mutate(key: MutationKey, ...args: unknown[]): Promise<unknown> {{"
508        );
509        if mutation_mixed {
510            // Mixed: use VOID_MUTATIONS set to branch at runtime
511            emit!(out, "      if (VOID_MUTATIONS.has(key)) {{");
512            emit!(
513                out,
514                "        return rpcFetch(config, \"POST\", key, undefined, args[0] as CallOptions | undefined);"
515            );
516            emit!(out, "      }}");
517            emit!(
518                out,
519                "      return rpcFetch(config, \"POST\", key, args[0], args[1] as CallOptions | undefined);"
520            );
521        } else if !void_mutations.is_empty() {
522            // All void: args[0] is always CallOptions
523            emit!(
524                out,
525                "      return rpcFetch(config, \"POST\", key, undefined, args[0] as CallOptions | undefined);"
526            );
527        } else {
528            // All non-void: args[0] is input, args[1] is CallOptions
529            emit!(
530                out,
531                "      return rpcFetch(config, \"POST\", key, args[0], args[1] as CallOptions | undefined);"
532            );
533        }
534        emit!(out, "    }},");
535    }
536
537    emit!(out, "  }} as RpcClient;");
538    emit!(out, "}}");
539}
540
541/// Generates query overload signatures for the RpcClient interface.
542fn generate_query_overloads(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
543    let (void_queries, non_void_queries): (Vec<_>, Vec<_>) = manifest
544        .procedures
545        .iter()
546        .filter(|p| p.kind == ProcedureKind::Query)
547        .partition(|p| is_void_input(p));
548
549    // Overload signatures for void-input queries (no input argument required)
550    for proc in &void_queries {
551        if preserve_docs && let Some(doc) = &proc.docs {
552            emit_jsdoc(doc, "  ", out);
553        }
554        let output_ts = proc
555            .output
556            .as_ref()
557            .map(rust_type_to_ts)
558            .unwrap_or_else(|| "void".to_string());
559        emit!(
560            out,
561            "  query(key: \"{}\"): Promise<{}>;",
562            proc.name,
563            output_ts,
564        );
565        emit!(
566            out,
567            "  query(key: \"{}\", options: CallOptions): Promise<{}>;",
568            proc.name,
569            output_ts,
570        );
571    }
572
573    // Overload signatures for non-void-input queries
574    for proc in &non_void_queries {
575        if preserve_docs && let Some(doc) = &proc.docs {
576            emit_jsdoc(doc, "  ", out);
577        }
578        let input_ts = proc
579            .input
580            .as_ref()
581            .map(rust_type_to_ts)
582            .unwrap_or_else(|| "void".to_string());
583        let output_ts = proc
584            .output
585            .as_ref()
586            .map(rust_type_to_ts)
587            .unwrap_or_else(|| "void".to_string());
588        emit!(
589            out,
590            "  query(key: \"{}\", input: {}): Promise<{}>;",
591            proc.name,
592            input_ts,
593            output_ts,
594        );
595        emit!(
596            out,
597            "  query(key: \"{}\", input: {}, options: CallOptions): Promise<{}>;",
598            proc.name,
599            input_ts,
600            output_ts,
601        );
602    }
603}
604
605/// Generates mutation overload signatures for the RpcClient interface.
606fn generate_mutation_overloads(manifest: &Manifest, preserve_docs: bool, out: &mut String) {
607    let (void_mutations, non_void_mutations): (Vec<_>, Vec<_>) = manifest
608        .procedures
609        .iter()
610        .filter(|p| p.kind == ProcedureKind::Mutation)
611        .partition(|p| is_void_input(p));
612
613    // Overload signatures for void-input mutations
614    for proc in &void_mutations {
615        if preserve_docs && let Some(doc) = &proc.docs {
616            emit_jsdoc(doc, "  ", out);
617        }
618        let output_ts = proc
619            .output
620            .as_ref()
621            .map(rust_type_to_ts)
622            .unwrap_or_else(|| "void".to_string());
623        emit!(
624            out,
625            "  mutate(key: \"{}\"): Promise<{}>;",
626            proc.name,
627            output_ts,
628        );
629        emit!(
630            out,
631            "  mutate(key: \"{}\", options: CallOptions): Promise<{}>;",
632            proc.name,
633            output_ts,
634        );
635    }
636
637    // Overload signatures for non-void-input mutations
638    for proc in &non_void_mutations {
639        if preserve_docs && let Some(doc) = &proc.docs {
640            emit_jsdoc(doc, "  ", out);
641        }
642        let input_ts = proc
643            .input
644            .as_ref()
645            .map(rust_type_to_ts)
646            .unwrap_or_else(|| "void".to_string());
647        let output_ts = proc
648            .output
649            .as_ref()
650            .map(rust_type_to_ts)
651            .unwrap_or_else(|| "void".to_string());
652        emit!(
653            out,
654            "  mutate(key: \"{}\", input: {}): Promise<{}>;",
655            proc.name,
656            input_ts,
657            output_ts,
658        );
659        emit!(
660            out,
661            "  mutate(key: \"{}\", input: {}, options: CallOptions): Promise<{}>;",
662            proc.name,
663            input_ts,
664            output_ts,
665        );
666    }
667}