Skip to main content

protoc_gen_rust_temporal/
parse.rs

1//! `DescriptorPool` -> `Vec<ServiceModel>` extraction.
2//!
3//! The descriptor pool is built in `main.rs` via `decode_file_descriptor_set`
4//! so that `temporal.v1.*` extensions on `MethodOptions` / `ServiceOptions`
5//! survive — prost-types would otherwise drop them silently.
6//!
7//! Parsing strategy: re-encode each extension `Value` back to bytes through
8//! `prost-reflect`'s `DynamicMessage` and decode into the strongly-typed
9//! prost message via `prost::Message::decode`. This avoids hand-walking
10//! `Value` trees.
11
12use std::collections::HashSet;
13use std::time::Duration;
14
15use anyhow::{Context, Result, anyhow};
16use prost::Message;
17use prost_reflect::{
18    DescriptorPool, DynamicMessage, ExtensionDescriptor, MethodDescriptor, ServiceDescriptor, Value,
19};
20
21use crate::model::{
22    ActivityModel, IdReusePolicy, IdTemplateSegment, ProtoType, QueryModel, QueryRef, ServiceModel,
23    SignalModel, SignalRef, UpdateModel, UpdateRef, WorkflowModel,
24};
25use crate::temporal::v1::{
26    ActivityOptions, IdReusePolicy as ProtoPolicy, QueryOptions, ServiceOptions, SignalOptions,
27    UpdateOptions, WorkflowOptions,
28};
29use heck::ToSnakeCase;
30
31const SERVICE_EXT: &str = "temporal.v1.service";
32const WORKFLOW_EXT: &str = "temporal.v1.workflow";
33const ACTIVITY_EXT: &str = "temporal.v1.activity";
34const SIGNAL_EXT: &str = "temporal.v1.signal";
35const QUERY_EXT: &str = "temporal.v1.query";
36const UPDATE_EXT: &str = "temporal.v1.update";
37
38struct ExtensionSet {
39    service: ExtensionDescriptor,
40    workflow: ExtensionDescriptor,
41    activity: ExtensionDescriptor,
42    signal: ExtensionDescriptor,
43    query: ExtensionDescriptor,
44    update: ExtensionDescriptor,
45}
46
47impl ExtensionSet {
48    fn load(pool: &DescriptorPool) -> Result<Self> {
49        Ok(Self {
50            service: get_ext(pool, SERVICE_EXT)?,
51            workflow: get_ext(pool, WORKFLOW_EXT)?,
52            activity: get_ext(pool, ACTIVITY_EXT)?,
53            signal: get_ext(pool, SIGNAL_EXT)?,
54            query: get_ext(pool, QUERY_EXT)?,
55            update: get_ext(pool, UPDATE_EXT)?,
56        })
57    }
58}
59
60fn get_ext(pool: &DescriptorPool, name: &str) -> Result<ExtensionDescriptor> {
61    pool.get_extension_by_name(name)
62        .ok_or_else(|| anyhow!("missing extension definition: {name}"))
63}
64
65pub fn parse(
66    pool: &DescriptorPool,
67    files_to_generate: &HashSet<String>,
68) -> Result<Vec<ServiceModel>> {
69    // Early-exit when none of the targets carry any services. This matters
70    // for buf v2 modules that include the vendored `temporal/v1/temporal.proto`
71    // alongside consumer protos: buf sends one CodeGeneratorRequest per
72    // target file, so the plugin gets invoked with the annotation schema
73    // itself as the target. That file declares the `temporal.v1.*`
74    // extensions but uses none of them, so loading `ExtensionSet` would
75    // fail when the request only contains the annotation schema and not
76    // a file that uses the extensions. Skipping the lookup keeps the
77    // plugin a no-op in that case, which is the right answer — there's
78    // nothing to render.
79    let has_any_services = pool
80        .files()
81        .filter(|f| files_to_generate.contains(f.name()))
82        .any(|f| f.services().next().is_some());
83    if !has_any_services {
84        return Ok(Vec::new());
85    }
86
87    let ext = ExtensionSet::load(pool)?;
88
89    let mut out = Vec::new();
90    for file in pool.files() {
91        if !files_to_generate.contains(file.name()) {
92            continue;
93        }
94        for service in file.services() {
95            if let Some(model) = parse_service(&file, &service, &ext)? {
96                out.push(model);
97            }
98        }
99    }
100    Ok(out)
101}
102
103fn parse_service(
104    file: &prost_reflect::FileDescriptor,
105    service: &ServiceDescriptor,
106    ext: &ExtensionSet,
107) -> Result<Option<ServiceModel>> {
108    let package = file.package_name().to_string();
109    let service_name = service.name().to_string();
110
111    let default_task_queue = service_default_task_queue(service, &ext.service)?;
112
113    let mut workflows = Vec::new();
114    let mut signals = Vec::new();
115    let mut queries = Vec::new();
116    let mut updates = Vec::new();
117    let mut activities = Vec::new();
118
119    for method in service.methods() {
120        match method_kind(&method, ext)? {
121            MethodKind::Workflow(opts) => {
122                workflows.push(workflow_from(&method, *opts, &package, &service_name)?);
123            }
124            MethodKind::Signal(opts) => {
125                signals.push(signal_from(&method, opts));
126            }
127            MethodKind::Query(opts) => {
128                queries.push(query_from(&method, opts));
129            }
130            MethodKind::Update(opts) => {
131                updates.push(update_from(&method, opts));
132            }
133            MethodKind::Activity(opts) => {
134                activities.push(activity_from(&method, *opts));
135            }
136            MethodKind::None => continue,
137        }
138    }
139
140    if workflows.is_empty()
141        && signals.is_empty()
142        && queries.is_empty()
143        && updates.is_empty()
144        && activities.is_empty()
145    {
146        return Ok(None);
147    }
148
149    Ok(Some(ServiceModel {
150        package,
151        service: service_name,
152        source_file: file.name().to_string(),
153        default_task_queue,
154        workflows,
155        signals,
156        queries,
157        updates,
158        activities,
159    }))
160}
161
162fn service_default_task_queue(
163    service: &ServiceDescriptor,
164    service_ext: &ExtensionDescriptor,
165) -> Result<Option<String>> {
166    let opts: DynamicMessage = service.options();
167    if !opts.has_extension(service_ext) {
168        return Ok(None);
169    }
170    let value = opts.get_extension(service_ext);
171    let bytes = encode_message_value(&value)?;
172    let parsed = ServiceOptions::decode(bytes.as_slice())?;
173    Ok((!parsed.task_queue.is_empty()).then_some(parsed.task_queue))
174}
175
176enum MethodKind {
177    // WorkflowOptions is ~700 bytes — boxed so MethodKind stays small.
178    Workflow(Box<WorkflowOptions>),
179    Activity(Box<ActivityOptions>),
180    Signal(SignalOptions),
181    Query(QueryOptions),
182    Update(UpdateOptions),
183    None,
184}
185
186fn method_kind(method: &MethodDescriptor, ext: &ExtensionSet) -> Result<MethodKind> {
187    let opts: DynamicMessage = method.options();
188
189    // A single rpc is expected to carry at most one Temporal annotation.
190    // First-match wins; validate.rs would reject a method that lands in two
191    // buckets, but in practice it cannot — only one extension field number
192    // can be set on a given MethodOptions.
193    if opts.has_extension(&ext.workflow) {
194        return decode_kind::<WorkflowOptions>(&opts.get_extension(&ext.workflow));
195    }
196    if opts.has_extension(&ext.activity) {
197        return decode_kind::<ActivityOptions>(&opts.get_extension(&ext.activity));
198    }
199    if opts.has_extension(&ext.signal) {
200        return decode_kind::<SignalOptions>(&opts.get_extension(&ext.signal));
201    }
202    if opts.has_extension(&ext.query) {
203        return decode_kind::<QueryOptions>(&opts.get_extension(&ext.query));
204    }
205    if opts.has_extension(&ext.update) {
206        return decode_kind::<UpdateOptions>(&opts.get_extension(&ext.update));
207    }
208    Ok(MethodKind::None)
209}
210
211trait IntoMethodKind {
212    fn into_kind(self) -> MethodKind;
213}
214
215impl IntoMethodKind for WorkflowOptions {
216    fn into_kind(self) -> MethodKind {
217        MethodKind::Workflow(Box::new(self))
218    }
219}
220impl IntoMethodKind for ActivityOptions {
221    fn into_kind(self) -> MethodKind {
222        MethodKind::Activity(Box::new(self))
223    }
224}
225impl IntoMethodKind for SignalOptions {
226    fn into_kind(self) -> MethodKind {
227        MethodKind::Signal(self)
228    }
229}
230impl IntoMethodKind for QueryOptions {
231    fn into_kind(self) -> MethodKind {
232        MethodKind::Query(self)
233    }
234}
235impl IntoMethodKind for UpdateOptions {
236    fn into_kind(self) -> MethodKind {
237        MethodKind::Update(self)
238    }
239}
240
241fn decode_kind<T: Message + Default + IntoMethodKind>(value: &Value) -> Result<MethodKind> {
242    let bytes = encode_message_value(value)?;
243    let parsed = T::decode(bytes.as_slice())?;
244    Ok(parsed.into_kind())
245}
246
247fn encode_message_value(value: &Value) -> Result<Vec<u8>> {
248    match value {
249        Value::Message(m) => Ok(m.encode_to_vec()),
250        other => Err(anyhow!("expected message extension, got {other:?}")),
251    }
252}
253
254fn workflow_from(
255    method: &MethodDescriptor,
256    opts: WorkflowOptions,
257    package: &str,
258    service_name: &str,
259) -> Result<WorkflowModel> {
260    let rpc_method = method.name().to_string();
261    let registered_name = if opts.name.is_empty() {
262        default_registered_name(package, service_name, &rpc_method)
263    } else {
264        opts.name
265    };
266
267    let id_expression = if opts.id.is_empty() {
268        None
269    } else {
270        Some(
271            parse_id_template(&opts.id, &method.input()).with_context(|| {
272                format!("parse (temporal.v1.workflow).id template on {service_name}.{rpc_method}")
273            })?,
274        )
275    };
276
277    Ok(WorkflowModel {
278        rpc_method,
279        registered_name,
280        input_type: ProtoType::new(method.input().full_name()),
281        output_type: ProtoType::new(method.output().full_name()),
282        task_queue: (!opts.task_queue.is_empty()).then_some(opts.task_queue),
283        id_expression,
284        id_reuse_policy: id_reuse_policy_from_proto(opts.id_reuse_policy),
285        execution_timeout: opts.execution_timeout.and_then(duration_from_proto),
286        run_timeout: opts.run_timeout.and_then(duration_from_proto),
287        task_timeout: opts.task_timeout.and_then(duration_from_proto),
288        aliases: opts.aliases,
289        attached_signals: opts
290            .signal
291            .into_iter()
292            .map(|s| SignalRef {
293                rpc_method: s.r#ref,
294                start: s.start,
295            })
296            .collect(),
297        attached_queries: opts
298            .query
299            .into_iter()
300            .map(|q| QueryRef {
301                rpc_method: q.r#ref,
302            })
303            .collect(),
304        attached_updates: opts
305            .update
306            .into_iter()
307            .map(|u| UpdateRef {
308                rpc_method: u.r#ref,
309                start: u.start,
310                validate: u.validate,
311            })
312            .collect(),
313    })
314}
315
316fn signal_from(method: &MethodDescriptor, opts: SignalOptions) -> SignalModel {
317    let rpc_method = method.name().to_string();
318    let registered_name = if opts.name.is_empty() {
319        rpc_method.clone()
320    } else {
321        opts.name
322    };
323    SignalModel {
324        rpc_method,
325        registered_name,
326        input_type: ProtoType::new(method.input().full_name()),
327        output_type: ProtoType::new(method.output().full_name()),
328    }
329}
330
331fn query_from(method: &MethodDescriptor, opts: QueryOptions) -> QueryModel {
332    let rpc_method = method.name().to_string();
333    let registered_name = if opts.name.is_empty() {
334        rpc_method.clone()
335    } else {
336        opts.name
337    };
338    QueryModel {
339        rpc_method,
340        registered_name,
341        input_type: ProtoType::new(method.input().full_name()),
342        output_type: ProtoType::new(method.output().full_name()),
343    }
344}
345
346fn update_from(method: &MethodDescriptor, opts: UpdateOptions) -> UpdateModel {
347    let rpc_method = method.name().to_string();
348    let registered_name = if opts.name.is_empty() {
349        rpc_method.clone()
350    } else {
351        opts.name
352    };
353    UpdateModel {
354        rpc_method,
355        registered_name,
356        input_type: ProtoType::new(method.input().full_name()),
357        output_type: ProtoType::new(method.output().full_name()),
358        validate: opts.validate,
359    }
360}
361
362fn activity_from(method: &MethodDescriptor, opts: ActivityOptions) -> ActivityModel {
363    let rpc_method = method.name().to_string();
364    let registered_name = if opts.name.is_empty() {
365        rpc_method.clone()
366    } else {
367        opts.name
368    };
369    ActivityModel {
370        rpc_method,
371        registered_name,
372        input_type: ProtoType::new(method.input().full_name()),
373        output_type: ProtoType::new(method.output().full_name()),
374    }
375}
376
377fn default_registered_name(package: &str, service: &str, rpc: &str) -> String {
378    if package.is_empty() {
379        format!("{service}/{rpc}")
380    } else {
381        format!("{package}.{service}/{rpc}")
382    }
383}
384
385fn id_reuse_policy_from_proto(raw: i32) -> Option<IdReusePolicy> {
386    match ProtoPolicy::try_from(raw).ok()? {
387        ProtoPolicy::WorkflowIdReusePolicyUnspecified => None,
388        ProtoPolicy::WorkflowIdReusePolicyAllowDuplicate => Some(IdReusePolicy::AllowDuplicate),
389        ProtoPolicy::WorkflowIdReusePolicyAllowDuplicateFailedOnly => {
390            Some(IdReusePolicy::AllowDuplicateFailedOnly)
391        }
392        ProtoPolicy::WorkflowIdReusePolicyRejectDuplicate => Some(IdReusePolicy::RejectDuplicate),
393        ProtoPolicy::WorkflowIdReusePolicyTerminateIfRunning => {
394            Some(IdReusePolicy::TerminateIfRunning)
395        }
396    }
397}
398
399fn duration_from_proto(d: prost_types::Duration) -> Option<Duration> {
400    if d.seconds < 0 || d.nanos < 0 {
401        return None;
402    }
403    let secs = u64::try_from(d.seconds).ok()?;
404    let nanos = u32::try_from(d.nanos).ok()?;
405    Some(Duration::new(secs, nanos))
406}
407
408/// Parse a cludden-style id template into segments, resolving each
409/// `{{ .FieldName }}` reference against the workflow input descriptor.
410///
411/// Supports only the simple form `{{ .FieldName }}` (with optional
412/// whitespace inside the braces). More complex Go-template syntax
413/// (conditionals, functions, ranges) returns an error so users see the
414/// limitation up front rather than at runtime.
415fn parse_id_template(
416    template: &str,
417    input: &prost_reflect::MessageDescriptor,
418) -> Result<Vec<IdTemplateSegment>> {
419    let mut out = Vec::new();
420    let mut rest = template;
421    while let Some(open) = rest.find("{{") {
422        if open > 0 {
423            out.push(IdTemplateSegment::Literal(rest[..open].to_string()));
424        }
425        let after_open = &rest[open + 2..];
426        let close = after_open
427            .find("}}")
428            .ok_or_else(|| anyhow!("unterminated `{{{{` in id template {template:?}"))?;
429        let token = after_open[..close].trim();
430        let field_name = token
431            .strip_prefix('.')
432            .ok_or_else(|| {
433                anyhow!(
434                    "id template token {token:?} must start with `.` (only field references are supported; \
435                     conditionals / pipelines / functions are not implemented)"
436                )
437            })?
438            .trim();
439        if field_name.is_empty() {
440            anyhow::bail!("id template token has no field name after `.`");
441        }
442        if !field_name
443            .chars()
444            .all(|c| c.is_ascii_alphanumeric() || c == '_')
445        {
446            anyhow::bail!(
447                "id template token {field_name:?} contains unsupported characters \
448                 (only simple field references like `.Name` are supported)"
449            );
450        }
451        let rust_field = field_name.to_snake_case();
452        let known = input.fields().any(|f| f.name() == rust_field);
453        if !known {
454            anyhow::bail!(
455                "id template references `{field_name}` (looked up as `{rust_field}`) \
456                 but no such field exists on input message `{}`",
457                input.full_name()
458            );
459        }
460        out.push(IdTemplateSegment::Field(rust_field));
461        rest = &after_open[close + 2..];
462    }
463    if !rest.is_empty() {
464        out.push(IdTemplateSegment::Literal(rest.to_string()));
465    }
466    Ok(out)
467}