protoc_gen_rust_temporal/
parse.rs1use std::collections::HashSet;
13use std::time::Duration;
14
15use anyhow::{Context, Result, anyhow};
16use prost::Message;
17use prost_reflect::{
18 DescriptorPool, DynamicMessage, ExtensionDescriptor, MethodDescriptor, ServiceDescriptor, Value,
19};
20
21use crate::model::{
22 ActivityModel, IdReusePolicy, IdTemplateSegment, ProtoType, QueryModel, QueryRef, ServiceModel,
23 SignalModel, SignalRef, UpdateModel, UpdateRef, WorkflowModel,
24};
25use crate::temporal::v1::{
26 ActivityOptions, IdReusePolicy as ProtoPolicy, QueryOptions, ServiceOptions, SignalOptions,
27 UpdateOptions, WorkflowOptions,
28};
29use heck::ToSnakeCase;
30
31const SERVICE_EXT: &str = "temporal.v1.service";
32const WORKFLOW_EXT: &str = "temporal.v1.workflow";
33const ACTIVITY_EXT: &str = "temporal.v1.activity";
34const SIGNAL_EXT: &str = "temporal.v1.signal";
35const QUERY_EXT: &str = "temporal.v1.query";
36const UPDATE_EXT: &str = "temporal.v1.update";
37
38struct ExtensionSet {
39 service: ExtensionDescriptor,
40 workflow: ExtensionDescriptor,
41 activity: ExtensionDescriptor,
42 signal: ExtensionDescriptor,
43 query: ExtensionDescriptor,
44 update: ExtensionDescriptor,
45}
46
47impl ExtensionSet {
48 fn load(pool: &DescriptorPool) -> Result<Self> {
49 Ok(Self {
50 service: get_ext(pool, SERVICE_EXT)?,
51 workflow: get_ext(pool, WORKFLOW_EXT)?,
52 activity: get_ext(pool, ACTIVITY_EXT)?,
53 signal: get_ext(pool, SIGNAL_EXT)?,
54 query: get_ext(pool, QUERY_EXT)?,
55 update: get_ext(pool, UPDATE_EXT)?,
56 })
57 }
58}
59
60fn get_ext(pool: &DescriptorPool, name: &str) -> Result<ExtensionDescriptor> {
61 pool.get_extension_by_name(name)
62 .ok_or_else(|| anyhow!("missing extension definition: {name}"))
63}
64
65pub fn parse(
66 pool: &DescriptorPool,
67 files_to_generate: &HashSet<String>,
68) -> Result<Vec<ServiceModel>> {
69 let has_any_services = pool
80 .files()
81 .filter(|f| files_to_generate.contains(f.name()))
82 .any(|f| f.services().next().is_some());
83 if !has_any_services {
84 return Ok(Vec::new());
85 }
86
87 let ext = ExtensionSet::load(pool)?;
88
89 let mut out = Vec::new();
90 for file in pool.files() {
91 if !files_to_generate.contains(file.name()) {
92 continue;
93 }
94 for service in file.services() {
95 if let Some(model) = parse_service(&file, &service, &ext)? {
96 out.push(model);
97 }
98 }
99 }
100 Ok(out)
101}
102
103fn parse_service(
104 file: &prost_reflect::FileDescriptor,
105 service: &ServiceDescriptor,
106 ext: &ExtensionSet,
107) -> Result<Option<ServiceModel>> {
108 let package = file.package_name().to_string();
109 let service_name = service.name().to_string();
110
111 let default_task_queue = service_default_task_queue(service, &ext.service)?;
112
113 let mut workflows = Vec::new();
114 let mut signals = Vec::new();
115 let mut queries = Vec::new();
116 let mut updates = Vec::new();
117 let mut activities = Vec::new();
118
119 for method in service.methods() {
120 match method_kind(&method, ext)? {
121 MethodKind::Workflow(opts) => {
122 workflows.push(workflow_from(&method, *opts, &package, &service_name)?);
123 }
124 MethodKind::Signal(opts) => {
125 signals.push(signal_from(&method, opts));
126 }
127 MethodKind::Query(opts) => {
128 queries.push(query_from(&method, opts));
129 }
130 MethodKind::Update(opts) => {
131 updates.push(update_from(&method, opts));
132 }
133 MethodKind::Activity(opts) => {
134 activities.push(activity_from(&method, *opts));
135 }
136 MethodKind::None => continue,
137 }
138 }
139
140 if workflows.is_empty()
141 && signals.is_empty()
142 && queries.is_empty()
143 && updates.is_empty()
144 && activities.is_empty()
145 {
146 return Ok(None);
147 }
148
149 Ok(Some(ServiceModel {
150 package,
151 service: service_name,
152 source_file: file.name().to_string(),
153 default_task_queue,
154 workflows,
155 signals,
156 queries,
157 updates,
158 activities,
159 }))
160}
161
162fn service_default_task_queue(
163 service: &ServiceDescriptor,
164 service_ext: &ExtensionDescriptor,
165) -> Result<Option<String>> {
166 let opts: DynamicMessage = service.options();
167 if !opts.has_extension(service_ext) {
168 return Ok(None);
169 }
170 let value = opts.get_extension(service_ext);
171 let bytes = encode_message_value(&value)?;
172 let parsed = ServiceOptions::decode(bytes.as_slice())?;
173 Ok((!parsed.task_queue.is_empty()).then_some(parsed.task_queue))
174}
175
176enum MethodKind {
177 Workflow(Box<WorkflowOptions>),
179 Activity(Box<ActivityOptions>),
180 Signal(SignalOptions),
181 Query(QueryOptions),
182 Update(UpdateOptions),
183 None,
184}
185
186fn method_kind(method: &MethodDescriptor, ext: &ExtensionSet) -> Result<MethodKind> {
187 let opts: DynamicMessage = method.options();
188
189 if opts.has_extension(&ext.workflow) {
194 return decode_kind::<WorkflowOptions>(&opts.get_extension(&ext.workflow));
195 }
196 if opts.has_extension(&ext.activity) {
197 return decode_kind::<ActivityOptions>(&opts.get_extension(&ext.activity));
198 }
199 if opts.has_extension(&ext.signal) {
200 return decode_kind::<SignalOptions>(&opts.get_extension(&ext.signal));
201 }
202 if opts.has_extension(&ext.query) {
203 return decode_kind::<QueryOptions>(&opts.get_extension(&ext.query));
204 }
205 if opts.has_extension(&ext.update) {
206 return decode_kind::<UpdateOptions>(&opts.get_extension(&ext.update));
207 }
208 Ok(MethodKind::None)
209}
210
211trait IntoMethodKind {
212 fn into_kind(self) -> MethodKind;
213}
214
215impl IntoMethodKind for WorkflowOptions {
216 fn into_kind(self) -> MethodKind {
217 MethodKind::Workflow(Box::new(self))
218 }
219}
220impl IntoMethodKind for ActivityOptions {
221 fn into_kind(self) -> MethodKind {
222 MethodKind::Activity(Box::new(self))
223 }
224}
225impl IntoMethodKind for SignalOptions {
226 fn into_kind(self) -> MethodKind {
227 MethodKind::Signal(self)
228 }
229}
230impl IntoMethodKind for QueryOptions {
231 fn into_kind(self) -> MethodKind {
232 MethodKind::Query(self)
233 }
234}
235impl IntoMethodKind for UpdateOptions {
236 fn into_kind(self) -> MethodKind {
237 MethodKind::Update(self)
238 }
239}
240
241fn decode_kind<T: Message + Default + IntoMethodKind>(value: &Value) -> Result<MethodKind> {
242 let bytes = encode_message_value(value)?;
243 let parsed = T::decode(bytes.as_slice())?;
244 Ok(parsed.into_kind())
245}
246
247fn encode_message_value(value: &Value) -> Result<Vec<u8>> {
248 match value {
249 Value::Message(m) => Ok(m.encode_to_vec()),
250 other => Err(anyhow!("expected message extension, got {other:?}")),
251 }
252}
253
254fn workflow_from(
255 method: &MethodDescriptor,
256 opts: WorkflowOptions,
257 package: &str,
258 service_name: &str,
259) -> Result<WorkflowModel> {
260 let rpc_method = method.name().to_string();
261 let registered_name = if opts.name.is_empty() {
262 default_registered_name(package, service_name, &rpc_method)
263 } else {
264 opts.name
265 };
266
267 let id_expression = if opts.id.is_empty() {
268 None
269 } else {
270 Some(
271 parse_id_template(&opts.id, &method.input()).with_context(|| {
272 format!("parse (temporal.v1.workflow).id template on {service_name}.{rpc_method}")
273 })?,
274 )
275 };
276
277 Ok(WorkflowModel {
278 rpc_method,
279 registered_name,
280 input_type: ProtoType::new(method.input().full_name()),
281 output_type: ProtoType::new(method.output().full_name()),
282 task_queue: (!opts.task_queue.is_empty()).then_some(opts.task_queue),
283 id_expression,
284 id_reuse_policy: id_reuse_policy_from_proto(opts.id_reuse_policy),
285 execution_timeout: opts.execution_timeout.and_then(duration_from_proto),
286 run_timeout: opts.run_timeout.and_then(duration_from_proto),
287 task_timeout: opts.task_timeout.and_then(duration_from_proto),
288 aliases: opts.aliases,
289 attached_signals: opts
290 .signal
291 .into_iter()
292 .map(|s| SignalRef {
293 rpc_method: s.r#ref,
294 start: s.start,
295 })
296 .collect(),
297 attached_queries: opts
298 .query
299 .into_iter()
300 .map(|q| QueryRef {
301 rpc_method: q.r#ref,
302 })
303 .collect(),
304 attached_updates: opts
305 .update
306 .into_iter()
307 .map(|u| UpdateRef {
308 rpc_method: u.r#ref,
309 start: u.start,
310 validate: u.validate,
311 })
312 .collect(),
313 })
314}
315
316fn signal_from(method: &MethodDescriptor, opts: SignalOptions) -> SignalModel {
317 let rpc_method = method.name().to_string();
318 let registered_name = if opts.name.is_empty() {
319 rpc_method.clone()
320 } else {
321 opts.name
322 };
323 SignalModel {
324 rpc_method,
325 registered_name,
326 input_type: ProtoType::new(method.input().full_name()),
327 output_type: ProtoType::new(method.output().full_name()),
328 }
329}
330
331fn query_from(method: &MethodDescriptor, opts: QueryOptions) -> QueryModel {
332 let rpc_method = method.name().to_string();
333 let registered_name = if opts.name.is_empty() {
334 rpc_method.clone()
335 } else {
336 opts.name
337 };
338 QueryModel {
339 rpc_method,
340 registered_name,
341 input_type: ProtoType::new(method.input().full_name()),
342 output_type: ProtoType::new(method.output().full_name()),
343 }
344}
345
346fn update_from(method: &MethodDescriptor, opts: UpdateOptions) -> UpdateModel {
347 let rpc_method = method.name().to_string();
348 let registered_name = if opts.name.is_empty() {
349 rpc_method.clone()
350 } else {
351 opts.name
352 };
353 UpdateModel {
354 rpc_method,
355 registered_name,
356 input_type: ProtoType::new(method.input().full_name()),
357 output_type: ProtoType::new(method.output().full_name()),
358 validate: opts.validate,
359 }
360}
361
362fn activity_from(method: &MethodDescriptor, opts: ActivityOptions) -> ActivityModel {
363 let rpc_method = method.name().to_string();
364 let registered_name = if opts.name.is_empty() {
365 rpc_method.clone()
366 } else {
367 opts.name
368 };
369 ActivityModel {
370 rpc_method,
371 registered_name,
372 input_type: ProtoType::new(method.input().full_name()),
373 output_type: ProtoType::new(method.output().full_name()),
374 }
375}
376
377fn default_registered_name(package: &str, service: &str, rpc: &str) -> String {
378 if package.is_empty() {
379 format!("{service}/{rpc}")
380 } else {
381 format!("{package}.{service}/{rpc}")
382 }
383}
384
385fn id_reuse_policy_from_proto(raw: i32) -> Option<IdReusePolicy> {
386 match ProtoPolicy::try_from(raw).ok()? {
387 ProtoPolicy::WorkflowIdReusePolicyUnspecified => None,
388 ProtoPolicy::WorkflowIdReusePolicyAllowDuplicate => Some(IdReusePolicy::AllowDuplicate),
389 ProtoPolicy::WorkflowIdReusePolicyAllowDuplicateFailedOnly => {
390 Some(IdReusePolicy::AllowDuplicateFailedOnly)
391 }
392 ProtoPolicy::WorkflowIdReusePolicyRejectDuplicate => Some(IdReusePolicy::RejectDuplicate),
393 ProtoPolicy::WorkflowIdReusePolicyTerminateIfRunning => {
394 Some(IdReusePolicy::TerminateIfRunning)
395 }
396 }
397}
398
399fn duration_from_proto(d: prost_types::Duration) -> Option<Duration> {
400 if d.seconds < 0 || d.nanos < 0 {
401 return None;
402 }
403 let secs = u64::try_from(d.seconds).ok()?;
404 let nanos = u32::try_from(d.nanos).ok()?;
405 Some(Duration::new(secs, nanos))
406}
407
408fn parse_id_template(
416 template: &str,
417 input: &prost_reflect::MessageDescriptor,
418) -> Result<Vec<IdTemplateSegment>> {
419 let mut out = Vec::new();
420 let mut rest = template;
421 while let Some(open) = rest.find("{{") {
422 if open > 0 {
423 out.push(IdTemplateSegment::Literal(rest[..open].to_string()));
424 }
425 let after_open = &rest[open + 2..];
426 let close = after_open
427 .find("}}")
428 .ok_or_else(|| anyhow!("unterminated `{{{{` in id template {template:?}"))?;
429 let token = after_open[..close].trim();
430 let field_name = token
431 .strip_prefix('.')
432 .ok_or_else(|| {
433 anyhow!(
434 "id template token {token:?} must start with `.` (only field references are supported; \
435 conditionals / pipelines / functions are not implemented)"
436 )
437 })?
438 .trim();
439 if field_name.is_empty() {
440 anyhow::bail!("id template token has no field name after `.`");
441 }
442 if !field_name
443 .chars()
444 .all(|c| c.is_ascii_alphanumeric() || c == '_')
445 {
446 anyhow::bail!(
447 "id template token {field_name:?} contains unsupported characters \
448 (only simple field references like `.Name` are supported)"
449 );
450 }
451 let rust_field = field_name.to_snake_case();
452 let known = input.fields().any(|f| f.name() == rust_field);
453 if !known {
454 anyhow::bail!(
455 "id template references `{field_name}` (looked up as `{rust_field}`) \
456 but no such field exists on input message `{}`",
457 input.full_name()
458 );
459 }
460 out.push(IdTemplateSegment::Field(rust_field));
461 rest = &after_open[close + 2..];
462 }
463 if !rest.is_empty() {
464 out.push(IdTemplateSegment::Literal(rest.to_string()));
465 }
466 Ok(out)
467}