Skip to main content

protoc_gen_rust_temporal/
validate.rs

1//! Cross-method invariants applied after `parse.rs` builds a `ServiceModel`.
2//!
3//! Errors here translate directly into `CodeGeneratorResponse.error` and
4//! surface to the user as `protoc` failures, so messages should pinpoint
5//! the service + rpc + offending option.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::{Result, bail};
10
11use crate::model::ServiceModel;
12
13pub fn validate(model: &ServiceModel) -> Result<()> {
14    reject_rpc_collisions(model)?;
15    validate_workflows(model)?;
16    validate_signal_outputs(model)?;
17    validate_empty_with_start(model)?;
18    Ok(())
19}
20
21/// A single rpc method may carry at most one `temporal.v1.*` annotation;
22/// declaring two on the same rpc collapses to a single entry in `parse.rs`
23/// (first match wins), but two different annotation buckets pointing at the
24/// same method *name* — which can happen when an activity is named the same
25/// as a sibling workflow — would break the generated handle. Reject up front.
26fn reject_rpc_collisions(model: &ServiceModel) -> Result<()> {
27    let mut seen: HashMap<&str, &'static str> = HashMap::new();
28
29    let kinds: [(&'static str, Vec<&str>); 5] = [
30        (
31            "workflow",
32            model
33                .workflows
34                .iter()
35                .map(|w| w.rpc_method.as_str())
36                .collect(),
37        ),
38        (
39            "signal",
40            model
41                .signals
42                .iter()
43                .map(|s| s.rpc_method.as_str())
44                .collect(),
45        ),
46        (
47            "query",
48            model
49                .queries
50                .iter()
51                .map(|q| q.rpc_method.as_str())
52                .collect(),
53        ),
54        (
55            "update",
56            model
57                .updates
58                .iter()
59                .map(|u| u.rpc_method.as_str())
60                .collect(),
61        ),
62        (
63            "activity",
64            model
65                .activities
66                .iter()
67                .map(|a| a.rpc_method.as_str())
68                .collect(),
69        ),
70    ];
71
72    for (kind, names) in &kinds {
73        for name in names {
74            if let Some(prev) = seen.insert(name, kind) {
75                bail!(
76                    "{}.{name}: rpc carries conflicting Temporal annotations ({prev} and {kind}) — pick one",
77                    model.service,
78                );
79            }
80        }
81    }
82    Ok(())
83}
84
85fn validate_workflows(model: &ServiceModel) -> Result<()> {
86    let signal_methods: HashSet<&str> = model
87        .signals
88        .iter()
89        .map(|s| s.rpc_method.as_str())
90        .collect();
91    let query_methods: HashSet<&str> = model
92        .queries
93        .iter()
94        .map(|q| q.rpc_method.as_str())
95        .collect();
96    let update_methods: HashSet<&str> = model
97        .updates
98        .iter()
99        .map(|u| u.rpc_method.as_str())
100        .collect();
101
102    for wf in &model.workflows {
103        let effective_tq = wf
104            .task_queue
105            .as_deref()
106            .or(model.default_task_queue.as_deref());
107        if effective_tq.is_none() {
108            bail!(
109                "{}.{}: workflow has no task_queue — set either (temporal.v1.workflow).task_queue or service-level (temporal.v1.service).task_queue",
110                model.service,
111                wf.rpc_method,
112            );
113        }
114
115        for sref in &wf.attached_signals {
116            check_ref(
117                model,
118                wf,
119                &signal_methods,
120                &sref.rpc_method,
121                "signal",
122                "(temporal.v1.signal)",
123            )?;
124        }
125        for qref in &wf.attached_queries {
126            check_ref(
127                model,
128                wf,
129                &query_methods,
130                &qref.rpc_method,
131                "query",
132                "(temporal.v1.query)",
133            )?;
134        }
135        for uref in &wf.attached_updates {
136            check_ref(
137                model,
138                wf,
139                &update_methods,
140                &uref.rpc_method,
141                "update",
142                "(temporal.v1.update)",
143            )?;
144        }
145    }
146    Ok(())
147}
148
149fn check_ref(
150    model: &ServiceModel,
151    wf: &crate::model::WorkflowModel,
152    declared: &HashSet<&str>,
153    target: &str,
154    kind: &str,
155    expected_annotation: &str,
156) -> Result<()> {
157    if declared.contains(target) {
158        return Ok(());
159    }
160    bail!(
161        "{}.{}: workflow references {kind} \"{target}\" but no sibling rpc carries {expected_annotation}",
162        model.service,
163        wf.rpc_method,
164    );
165}
166
167/// `signal_with_start` / `update_with_start` free functions take both the
168/// workflow input and the signal/update input. Emitting them generically
169/// over Empty would require a combinatorial set of runtime functions or
170/// a `TypedPayload` adapter we don't ship yet. Reject the combination
171/// up front with a clear error so users wrap empty messages in a no-field
172/// struct (the canonical proto workaround).
173fn validate_empty_with_start(model: &ServiceModel) -> Result<()> {
174    for wf in &model.workflows {
175        for sref in &wf.attached_signals {
176            if !sref.start {
177                continue;
178            }
179            let Some(sig) = model
180                .signals
181                .iter()
182                .find(|s| s.rpc_method == sref.rpc_method)
183            else {
184                continue; // unresolved ref — caught earlier
185            };
186            if wf.input_type.is_empty || sig.input_type.is_empty {
187                bail!(
188                    "{}.{}: signal `{}` is marked start:true but {} input is google.protobuf.Empty; the with_start emit path doesn't support Empty payloads. Wrap the empty side in a single-field message and retry.",
189                    model.service,
190                    wf.rpc_method,
191                    sig.rpc_method,
192                    if wf.input_type.is_empty {
193                        "the workflow's"
194                    } else {
195                        "the signal's"
196                    },
197                );
198            }
199        }
200        for uref in &wf.attached_updates {
201            if !uref.start {
202                continue;
203            }
204            let Some(u) = model
205                .updates
206                .iter()
207                .find(|u| u.rpc_method == uref.rpc_method)
208            else {
209                continue;
210            };
211            if wf.input_type.is_empty || u.input_type.is_empty {
212                bail!(
213                    "{}.{}: update `{}` is marked start:true but {} input is google.protobuf.Empty; the with_start emit path doesn't support Empty payloads. Wrap the empty side in a single-field message and retry.",
214                    model.service,
215                    wf.rpc_method,
216                    u.rpc_method,
217                    if wf.input_type.is_empty {
218                        "the workflow's"
219                    } else {
220                        "the update's"
221                    },
222                );
223            }
224        }
225    }
226    Ok(())
227}
228
229fn validate_signal_outputs(model: &ServiceModel) -> Result<()> {
230    for sig in &model.signals {
231        if !sig.output_type.is_empty {
232            bail!(
233                "{}.{}: signal rpc must return google.protobuf.Empty, got {}",
234                model.service,
235                sig.rpc_method,
236                sig.output_type.full_name,
237            );
238        }
239    }
240    Ok(())
241}