use std::collections::BTreeSet;
use std::fmt::Write;
use heck::ToLowerCamelCase;
use crate::config::Config;
use crate::model::*;
pub fn render(model: &ServiceModel, config: &Config) -> String {
render_file(&[model], config)
}
pub fn render_file(services: &[&ServiceModel], config: &Config) -> String {
assert!(!services.is_empty(), "render_file requires ≥1 service");
let source_file = &services[0].source_file;
let names = TypeNames::compute(services);
let mut out = String::new();
writeln!(
out,
"// Code generated by protoc-gen-ts-temporal. DO NOT EDIT."
)
.unwrap();
writeln!(out, "// source: {source_file}").unwrap();
writeln!(out).unwrap();
writeln!(
out,
"import type {{ Client, WorkflowHandle, WorkflowStartOptions }} from \"@temporalio/client\";"
)
.unwrap();
writeln!(
out,
"import {{ defineSignal, defineQuery, defineUpdate }} from \"@temporalio/workflow\";"
)
.unwrap();
let basename = proto_basename(source_file);
let type_imports = collect_type_imports_multi(services);
if !type_imports.is_empty() {
writeln!(out, "import {{").unwrap();
for ty in &type_imports {
writeln!(out, " type {},", local_type_name(ty)).unwrap();
}
writeln!(out, "}} from \"./{}{}.ts\";", basename, config.pb_suffix).unwrap();
}
writeln!(out).unwrap();
for model in services {
render_service_body(&mut out, model, &names);
}
out
}
struct TypeNames {
colliding: std::collections::HashSet<String>,
}
impl TypeNames {
fn compute(services: &[&ServiceModel]) -> Self {
let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
for s in services {
for w in &s.workflows {
*counts.entry(w.rpc_method.as_str()).or_default() += 1;
}
}
let colliding = counts
.into_iter()
.filter(|(_, c)| *c > 1)
.map(|(k, _)| k.to_string())
.collect();
Self { colliding }
}
fn workflow_prefix(&self, svc: &ServiceModel, wf: &WorkflowModel) -> String {
if self.colliding.contains(&wf.rpc_method) {
svc.service.clone()
} else {
String::new()
}
}
}
fn render_service_body(out: &mut String, model: &ServiceModel, names: &TypeNames) {
render_workflow_name_constants(out, model);
render_signal_defs(out, model);
render_query_defs(out, model);
render_update_defs(out, model);
for wf in &model.workflows {
render_start_opts_type(out, names, model, wf);
}
writeln!(out, "export class {}Client {{", model.service).unwrap();
writeln!(out, " constructor(private readonly client: Client) {{}}").unwrap();
writeln!(out).unwrap();
for wf in &model.workflows {
render_start_method(out, names, model, wf);
render_handle_factory(out, names, model, wf);
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
for wf in &model.workflows {
render_run_handle_class(out, names, model, wf);
}
for wf in &model.workflows {
render_signal_with_start_helpers(out, names, model, wf);
render_update_with_start_helpers(out, names, model, wf);
}
}
pub fn render_register(model: &ServiceModel, config: &Config) -> String {
render_register_file(&[model], config)
}
pub fn render_register_file(services: &[&ServiceModel], config: &Config) -> String {
assert!(
!services.is_empty(),
"render_register_file requires ≥1 service"
);
let source_file = &services[0].source_file;
let basename = proto_basename(source_file);
let imports = collect_type_imports_multi(services);
let mut out = String::new();
writeln!(
out,
"// Code generated by protoc-gen-ts-temporal. DO NOT EDIT."
)
.unwrap();
writeln!(out, "// source: {source_file}").unwrap();
writeln!(out).unwrap();
writeln!(
out,
"import type {{ DescMessage }} from \"@bufbuild/protobuf\";"
)
.unwrap();
if !imports.is_empty() {
writeln!(out, "import {{").unwrap();
for ty in &imports {
writeln!(out, " {}Schema,", local_type_name(ty)).unwrap();
}
writeln!(out, "}} from \"./{}{}.ts\";", basename, config.pb_suffix).unwrap();
}
writeln!(out).unwrap();
writeln!(out, "export const schemas: readonly DescMessage[] = [").unwrap();
for ty in &imports {
writeln!(out, " {}Schema,", local_type_name(ty)).unwrap();
}
writeln!(out, "];").unwrap();
out
}
fn proto_basename(source_file: &str) -> String {
let last = source_file.rsplit('/').next().unwrap_or(source_file);
last.trim_end_matches(".proto").to_string()
}
fn local_type_name(full: &str) -> String {
full.rsplit('.').next().unwrap_or(full).to_string()
}
fn collect_type_imports_multi(services: &[&ServiceModel]) -> Vec<String> {
let mut seen: BTreeSet<String> = BTreeSet::new();
let mut push = |t: &ProtoType| {
if !t.is_empty {
seen.insert(t.full_name.clone());
}
};
for model in services {
for wf in &model.workflows {
push(&wf.input_type);
push(&wf.output_type);
for r in &wf.attached_signals {
if let Some(sig) = model.signals.iter().find(|s| s.rpc_method == r.method) {
push(&sig.input_type);
}
}
for r in &wf.attached_queries {
if let Some(q) = model.queries.iter().find(|q| q.rpc_method == r.method) {
push(&q.input_type);
push(&q.output_type);
}
}
for r in &wf.attached_updates {
if let Some(u) = model.updates.iter().find(|u| u.rpc_method == r.method) {
push(&u.input_type);
push(&u.output_type);
}
}
}
}
seen.into_iter().collect()
}
fn render_workflow_name_constants(out: &mut String, model: &ServiceModel) {
if model.workflows.is_empty() {
return;
}
writeln!(out, "export const {}Workflows = {{", model.service).unwrap();
for wf in &model.workflows {
let registered = wf
.registered_name
.clone()
.unwrap_or_else(|| format!("{}.{}/{}", model.package, model.service, wf.rpc_method));
writeln!(out, " {}: \"{}\",", wf.rpc_method, registered).unwrap();
}
writeln!(out, "}} as const;").unwrap();
writeln!(out).unwrap();
let any_aliases = model.workflows.iter().any(|w| !w.aliases.is_empty());
if any_aliases {
writeln!(
out,
"/** Extra worker-side registration names per workflow. The client uses the canonical name in `{0}Workflows`; the worker should register under canonical + all aliases. */\nexport const {0}WorkflowAliases = {{",
model.service
)
.unwrap();
for wf in &model.workflows {
if wf.aliases.is_empty() {
continue;
}
let list = wf
.aliases
.iter()
.map(|a| format!("\"{a}\""))
.collect::<Vec<_>>()
.join(", ");
writeln!(out, " {}: [{list}] as const,", wf.rpc_method).unwrap();
}
writeln!(out, "}} as const;").unwrap();
writeln!(out).unwrap();
}
}
fn render_signal_defs(out: &mut String, model: &ServiceModel) {
if model.signals.is_empty() {
return;
}
for sig in &model.signals {
let const_name = format!("{}Signal", sig.rpc_method.to_lower_camel_case());
if sig.input_type.is_empty {
writeln!(
out,
"export const {const_name} = defineSignal<[]>(\"{}\");",
sig.registered_name
)
.unwrap();
} else {
let in_ty = local_type_name(&sig.input_type.full_name);
writeln!(
out,
"export const {const_name} = defineSignal<[{in_ty}]>(\"{}\");",
sig.registered_name
)
.unwrap();
}
}
writeln!(out).unwrap();
}
fn render_query_defs(out: &mut String, model: &ServiceModel) {
if model.queries.is_empty() {
return;
}
for q in &model.queries {
let const_name = format!("{}Query", q.rpc_method.to_lower_camel_case());
let out_ty = local_type_name(&q.output_type.full_name);
if q.input_type.is_empty {
writeln!(
out,
"export const {const_name} = defineQuery<{out_ty}, []>(\"{}\");",
q.registered_name
)
.unwrap();
} else {
let in_ty = local_type_name(&q.input_type.full_name);
writeln!(
out,
"export const {const_name} = defineQuery<{out_ty}, [{in_ty}]>(\"{}\");",
q.registered_name
)
.unwrap();
}
}
writeln!(out).unwrap();
}
fn render_update_defs(out: &mut String, model: &ServiceModel) {
if model.updates.is_empty() {
return;
}
for u in &model.updates {
let const_name = format!("{}Update", u.rpc_method.to_lower_camel_case());
let out_ty = local_type_name(&u.output_type.full_name);
if u.input_type.is_empty {
writeln!(
out,
"export const {const_name} = defineUpdate<{out_ty}, []>(\"{}\");",
u.registered_name
)
.unwrap();
} else {
let in_ty = local_type_name(&u.input_type.full_name);
writeln!(
out,
"export const {const_name} = defineUpdate<{out_ty}, [{in_ty}]>(\"{}\");",
u.registered_name
)
.unwrap();
}
}
writeln!(out).unwrap();
}
fn workflow_id(names: &TypeNames, svc: &ServiceModel, wf: &WorkflowModel) -> String {
format!("{}{}", names.workflow_prefix(svc, wf), wf.rpc_method)
}
fn render_start_opts_type(
out: &mut String,
names: &TypeNames,
svc: &ServiceModel,
wf: &WorkflowModel,
) {
let wid = workflow_id(names, svc, wf);
writeln!(
out,
"type {wid}StartOpts = Omit<\n WorkflowStartOptions,\n \"args\" | \"taskQueue\" | \"workflowId\" | \"workflowType\"\n> & {{ workflowId?: string }};"
)
.unwrap();
writeln!(out).unwrap();
}
fn render_start_method(
out: &mut String,
names: &TypeNames,
model: &ServiceModel,
wf: &WorkflowModel,
) {
let method = wf.rpc_method.to_lower_camel_case();
let in_ty = local_type_name(&wf.input_type.full_name);
let out_ty = local_type_name(&wf.output_type.full_name);
let wid = workflow_id(names, model, wf);
writeln!(out, " async {method}(").unwrap();
writeln!(out, " input: {in_ty},").unwrap();
writeln!(out, " opts: {wid}StartOpts = {{}}").unwrap();
writeln!(out, " ): Promise<{wid}Run> {{").unwrap();
writeln!(out, " const {{ workflowId, ...rest }} = opts;").unwrap();
writeln!(out, " const handle = await this.client.workflow.start<").unwrap();
writeln!(out, " (input: {in_ty}) => Promise<{out_ty}>").unwrap();
writeln!(
out,
" >({}Workflows.{}, {{",
model.service, wf.rpc_method
)
.unwrap();
writeln!(out, " args: [input],").unwrap();
writeln!(
out,
" taskQueue: \"{}\",",
model.resolved_task_queue(wf)
)
.unwrap();
writeln!(
out,
" workflowId: workflowId ?? `${{crypto.randomUUID()}}`,"
)
.unwrap();
writeln!(out, " ...rest,").unwrap();
writeln!(out, " }});").unwrap();
writeln!(out, " return new {wid}Run(handle);").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out).unwrap();
}
fn render_handle_factory(
out: &mut String,
names: &TypeNames,
svc: &ServiceModel,
wf: &WorkflowModel,
) {
let method = format!("{}Handle", wf.rpc_method.to_lower_camel_case());
let in_ty = local_type_name(&wf.input_type.full_name);
let out_ty = local_type_name(&wf.output_type.full_name);
let wid = workflow_id(names, svc, wf);
writeln!(out, " {method}(workflowId: string): {wid}Run {{").unwrap();
writeln!(out, " return new {wid}Run(").unwrap();
writeln!(out, " this.client.workflow.getHandle<").unwrap();
writeln!(out, " (input: {in_ty}) => Promise<{out_ty}>").unwrap();
writeln!(out, " >(workflowId),").unwrap();
writeln!(out, " );").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out).unwrap();
}
fn render_run_handle_class(
out: &mut String,
names: &TypeNames,
model: &ServiceModel,
wf: &WorkflowModel,
) {
let in_ty = local_type_name(&wf.input_type.full_name);
let out_ty = local_type_name(&wf.output_type.full_name);
let wid = workflow_id(names, model, wf);
writeln!(out, "export class {wid}Run {{").unwrap();
writeln!(out, " constructor(").unwrap();
writeln!(out, " private readonly handle: WorkflowHandle<").unwrap();
writeln!(out, " (input: {in_ty}) => Promise<{out_ty}>").unwrap();
writeln!(out, " >,").unwrap();
writeln!(out, " ) {{}}").unwrap();
writeln!(
out,
" get workflowId(): string {{ return this.handle.workflowId; }}"
)
.unwrap();
writeln!(
out,
" async result(): Promise<{out_ty}> {{ return this.handle.result(); }}"
)
.unwrap();
for r in &wf.attached_signals {
let sig = model
.signals
.iter()
.find(|s| s.rpc_method == r.method)
.expect("validate.rs guarantees the ref resolves");
let method_name = sig.rpc_method.to_lower_camel_case();
let const_name = format!("{}Signal", method_name);
if sig.input_type.is_empty {
writeln!(out, " async {method_name}(): Promise<void> {{").unwrap();
writeln!(out, " return this.handle.signal({const_name});").unwrap();
writeln!(out, " }}").unwrap();
} else {
let in_ty = local_type_name(&sig.input_type.full_name);
writeln!(
out,
" async {method_name}(input: {in_ty}): Promise<void> {{"
)
.unwrap();
writeln!(out, " return this.handle.signal({const_name}, input);").unwrap();
writeln!(out, " }}").unwrap();
}
}
for r in &wf.attached_queries {
let q = model
.queries
.iter()
.find(|q| q.rpc_method == r.method)
.expect("validate.rs guarantees the ref resolves");
let method_name = q.rpc_method.to_lower_camel_case();
let const_name = format!("{}Query", method_name);
let out_ty = local_type_name(&q.output_type.full_name);
if q.input_type.is_empty {
writeln!(out, " async {method_name}(): Promise<{out_ty}> {{").unwrap();
writeln!(out, " return this.handle.query({const_name});").unwrap();
writeln!(out, " }}").unwrap();
} else {
let in_ty = local_type_name(&q.input_type.full_name);
writeln!(
out,
" async {method_name}(input: {in_ty}): Promise<{out_ty}> {{"
)
.unwrap();
writeln!(out, " return this.handle.query({const_name}, input);").unwrap();
writeln!(out, " }}").unwrap();
}
}
for r in &wf.attached_updates {
let u = model
.updates
.iter()
.find(|u| u.rpc_method == r.method)
.expect("validate.rs guarantees the ref resolves");
let method_name = u.rpc_method.to_lower_camel_case();
let const_name = format!("{}Update", method_name);
let out_ty = local_type_name(&u.output_type.full_name);
if u.input_type.is_empty {
writeln!(out, " async {method_name}(): Promise<{out_ty}> {{").unwrap();
writeln!(out, " return this.handle.executeUpdate({const_name});").unwrap();
writeln!(out, " }}").unwrap();
} else {
let in_ty = local_type_name(&u.input_type.full_name);
writeln!(
out,
" async {method_name}(input: {in_ty}): Promise<{out_ty}> {{"
)
.unwrap();
writeln!(
out,
" return this.handle.executeUpdate({const_name}, {{ args: [input] }});"
)
.unwrap();
writeln!(out, " }}").unwrap();
}
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
fn render_signal_with_start_helpers(
out: &mut String,
names: &TypeNames,
model: &ServiceModel,
wf: &WorkflowModel,
) {
let wf_in_ty = local_type_name(&wf.input_type.full_name);
let wid = workflow_id(names, model, wf);
for r in &wf.attached_signals {
if !r.signal_with_start {
continue;
}
let sig = model
.signals
.iter()
.find(|s| s.rpc_method == r.method)
.expect("validate.rs guarantees the ref resolves");
let fn_name = format!("{}With{}Start", sig.rpc_method.to_lower_camel_case(), wid);
let const_name = format!("{}Signal", sig.rpc_method.to_lower_camel_case());
let sig_args = if sig.input_type.is_empty {
String::from("[]")
} else {
let in_ty = local_type_name(&sig.input_type.full_name);
format!("[signal: {in_ty}]")
};
let sig_param = if sig.input_type.is_empty {
String::new()
} else {
let in_ty = local_type_name(&sig.input_type.full_name);
format!(" signal: {in_ty},\n")
};
let sig_pass = if sig.input_type.is_empty {
" signalArgs: [] as [],".to_string()
} else {
" signalArgs: [signal] as const,".to_string()
};
writeln!(out, "export async function {fn_name}(").unwrap();
writeln!(out, " client: Client,").unwrap();
writeln!(out, " input: {wf_in_ty},").unwrap();
write!(out, "{sig_param}").unwrap();
writeln!(out, " opts: {wid}StartOpts = {{}},").unwrap();
writeln!(out, "): Promise<{wid}Run> {{").unwrap();
writeln!(out, " const {{ workflowId, ...rest }} = opts;").unwrap();
writeln!(
out,
" const handle = await client.workflow.signalWithStart<"
)
.unwrap();
let wf_out_ty = local_type_name(&wf.output_type.full_name);
writeln!(
out,
" (input: {wf_in_ty}) => Promise<{wf_out_ty}>, {sig_args}"
)
.unwrap();
writeln!(out, " >({}Workflows.{}, {{", model.service, wf.rpc_method).unwrap();
writeln!(out, " args: [input],").unwrap();
writeln!(out, " taskQueue: \"{}\",", model.resolved_task_queue(wf)).unwrap();
writeln!(
out,
" workflowId: workflowId ?? `${{crypto.randomUUID()}}`,"
)
.unwrap();
writeln!(out, " signal: {const_name},").unwrap();
writeln!(out, "{sig_pass}").unwrap();
writeln!(out, " ...rest,").unwrap();
writeln!(out, " }});").unwrap();
writeln!(out, " return new {wid}Run(handle);").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
}
fn render_update_with_start_helpers(
out: &mut String,
names: &TypeNames,
model: &ServiceModel,
wf: &WorkflowModel,
) {
let wf_in_ty = local_type_name(&wf.input_type.full_name);
let wid = workflow_id(names, model, wf);
for r in &wf.attached_updates {
if !r.update_with_start {
continue;
}
let u = model
.updates
.iter()
.find(|u| u.rpc_method == r.method)
.expect("validate.rs guarantees the ref resolves");
let fn_name = format!("{}With{}Start", u.rpc_method.to_lower_camel_case(), wid);
let const_name = format!("{}Update", u.rpc_method.to_lower_camel_case());
let upd_out_ty = local_type_name(&u.output_type.full_name);
let upd_param = if u.input_type.is_empty {
String::new()
} else {
let in_ty = local_type_name(&u.input_type.full_name);
format!(" update: {in_ty},\n")
};
let upd_args = if u.input_type.is_empty {
"[]".to_string()
} else {
"[update]".to_string()
};
writeln!(out, "export async function {fn_name}(").unwrap();
writeln!(out, " client: Client,").unwrap();
writeln!(out, " input: {wf_in_ty},").unwrap();
write!(out, "{upd_param}").unwrap();
writeln!(out, " opts: {wid}StartOpts = {{}},").unwrap();
writeln!(
out,
"): Promise<{{ run: {wid}Run; updateResult: Promise<{upd_out_ty}> }}> {{"
)
.unwrap();
writeln!(out, " const {{ workflowId, ...rest }} = opts;").unwrap();
writeln!(
out,
" const wfId = workflowId ?? `${{crypto.randomUUID()}}`;"
)
.unwrap();
writeln!(
out,
" const updateHandle = await client.workflow.startUpdateWithStart("
)
.unwrap();
writeln!(out, " {const_name},").unwrap();
writeln!(out, " {{").unwrap();
writeln!(out, " args: {upd_args},").unwrap();
writeln!(out, " startWorkflowOperation: {{").unwrap();
writeln!(
out,
" workflowType: {}Workflows.{},",
model.service, wf.rpc_method
)
.unwrap();
writeln!(out, " args: [input],").unwrap();
writeln!(
out,
" taskQueue: \"{}\",",
model.resolved_task_queue(wf)
)
.unwrap();
writeln!(out, " workflowId: wfId,").unwrap();
writeln!(out, " ...rest,").unwrap();
writeln!(out, " }},").unwrap();
writeln!(out, " }},").unwrap();
writeln!(out, " );").unwrap();
let wf_out_ty = local_type_name(&wf.output_type.full_name);
writeln!(out, " const runHandle = client.workflow.getHandle<").unwrap();
writeln!(out, " (input: {wf_in_ty}) => Promise<{wf_out_ty}>").unwrap();
writeln!(out, " >(wfId);").unwrap();
writeln!(out, " return {{").unwrap();
writeln!(out, " run: new {wid}Run(runHandle),").unwrap();
writeln!(out, " updateResult: updateHandle.result(),").unwrap();
writeln!(out, " }};").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
}