adk_rust_macros/lib.rs
1//! # adk-macros
2//!
3//! Proc macros for ADK-Rust that eliminate tool registration boilerplate.
4//!
5//! ## `#[tool]`
6//!
7//! Turns an async function into a fully-wired `adk_tool::Tool` implementation:
8//!
9//! ```rust,ignore
10//! use adk_macros::tool;
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct WeatherArgs {
16//! /// The city to look up
17//! city: String,
18//! }
19//!
20//! /// Get the current weather for a city.
21//! #[tool]
22//! async fn get_weather(args: WeatherArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
23//! Ok(serde_json::json!({ "temp": 72, "city": args.city }))
24//! }
25//!
26//! // This generates a struct `GetWeather` that implements `adk_tool::Tool`.
27//! // Use it like: Arc::new(GetWeather)
28//! ```
29//!
30//! The macro:
31//! - Uses the function's doc comment as the tool description
32//! - Derives the JSON schema from the argument type via `schemars::schema_for!`
33//! - Names the tool after the function (snake_case)
34//! - Generates a zero-sized struct (PascalCase) implementing `Tool`
35
36use proc_macro::TokenStream;
37use quote::{format_ident, quote};
38use syn::{FnArg, ItemFn, Meta, Type, parse_macro_input};
39
40/// Attribute macro that generates a `Tool` implementation from an async function.
41///
42/// # Requirements
43///
44/// - The function must be `async`
45/// - It must take exactly one argument (the args struct) that implements
46/// `serde::de::DeserializeOwned` and `schemars::JsonSchema`
47/// - It must return `Result<serde_json::Value, adk_tool::AdkError>`
48/// - Doc comments become the tool description
49///
50/// # Attributes
51///
52/// Optional attributes can be passed to configure tool metadata:
53///
54/// - `read_only` — marks the tool as having no side effects (`is_read_only() → true`)
55/// - `concurrency_safe` — marks the tool as safe for concurrent execution (`is_concurrency_safe() → true`)
56/// - `long_running` — marks the tool as long-running (`is_long_running() → true`)
57///
58/// # Examples
59///
60/// ```rust,ignore
61/// /// Search the knowledge base for documents matching a query.
62/// #[tool]
63/// async fn search_docs(args: SearchArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
64/// // ...
65/// }
66///
67/// /// Look up cached data (read-only, safe for parallel dispatch).
68/// #[tool(read_only, concurrency_safe)]
69/// async fn cache_lookup(args: LookupArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
70/// // ...
71/// }
72///
73/// // Generated: pub struct SearchDocs; implements Tool
74/// // Use: agent_builder.tool(Arc::new(SearchDocs))
75/// ```
76#[proc_macro_attribute]
77pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
78 let input_fn = parse_macro_input!(item as ItemFn);
79
80 // Parse optional attributes: #[tool(read_only, concurrency_safe, long_running)]
81 let mut is_read_only = false;
82 let mut is_concurrency_safe = false;
83 let mut is_long_running = false;
84
85 if !attr.is_empty() {
86 let meta = parse_macro_input!(attr as ToolAttrs);
87 is_read_only = meta.read_only;
88 is_concurrency_safe = meta.concurrency_safe;
89 is_long_running = meta.long_running;
90 }
91
92 let fn_name = &input_fn.sig.ident;
93 let fn_vis = &input_fn.vis;
94
95 // Extract doc comments for description
96 let doc_lines: Vec<String> = input_fn
97 .attrs
98 .iter()
99 .filter(|attr| attr.path().is_ident("doc"))
100 .filter_map(|attr| {
101 if let syn::Meta::NameValue(nv) = &attr.meta
102 && let syn::Expr::Lit(lit) = &nv.value
103 && let syn::Lit::Str(s) = &lit.lit
104 {
105 return Some(s.value().trim().to_string());
106 }
107 None
108 })
109 .collect();
110
111 let description = if doc_lines.is_empty() {
112 fn_name.to_string().replace('_', " ")
113 } else {
114 doc_lines.join(" ")
115 };
116
117 let tool_name_str = fn_name.to_string();
118
119 // Generate PascalCase struct name: get_weather → GetWeather
120 let struct_name = format_ident!(
121 "{}",
122 tool_name_str
123 .split('_')
124 .map(|seg| {
125 let mut chars = seg.chars();
126 match chars.next() {
127 None => String::new(),
128 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
129 }
130 })
131 .collect::<String>()
132 );
133
134 // Extract the single argument type
135 let args_type = extract_args_type(&input_fn);
136
137 // Check if we have a typed args parameter or no params
138 let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
139 (
140 quote! {
141 {
142 let mut schema = serde_json::to_value(
143 schemars::schema_for!(#args_ty)
144 ).unwrap_or_default();
145 // Strip fields that Gemini/LLM APIs don't accept
146 if let Some(obj) = schema.as_object_mut() {
147 obj.remove("$schema");
148 obj.remove("title");
149 }
150 // Simplify nullable types: {"type": ["string", "null"]} → {"type": "string"}
151 fn simplify_nullable(v: &mut serde_json::Value) {
152 match v {
153 serde_json::Value::Object(map) => {
154 if let Some(serde_json::Value::Array(types)) = map.get("type") {
155 let non_null: Vec<_> = types.iter()
156 .filter(|t| t.as_str() != Some("null"))
157 .cloned()
158 .collect();
159 if non_null.len() == 1 {
160 map.insert("type".to_string(), non_null[0].clone());
161 }
162 }
163 // Remove anyOf wrappers for simple nullable types
164 if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
165 for variant in &any_of {
166 if let Some(obj) = variant.as_object() {
167 if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
168 for (k, val) in obj {
169 map.insert(k.clone(), val.clone());
170 }
171 break;
172 }
173 }
174 }
175 }
176 for val in map.values_mut() {
177 simplify_nullable(val);
178 }
179 }
180 serde_json::Value::Array(arr) => {
181 for item in arr {
182 simplify_nullable(item);
183 }
184 }
185 _ => {}
186 }
187 }
188 simplify_nullable(&mut schema);
189 Some(schema)
190 }
191 },
192 quote! {
193 let typed_args: #args_ty = serde_json::from_value(args)
194 .map_err(|e| adk_tool::AdkError::tool(
195 format!("invalid arguments for '{}': {e}", #tool_name_str)
196 ))?;
197 #fn_name(typed_args).await
198 },
199 )
200 } else {
201 (
202 quote! { None },
203 quote! {
204 let _ = args;
205 #fn_name().await
206 },
207 )
208 };
209
210 // Check if the function signature includes ctx: Arc<dyn ToolContext>
211 let has_ctx = has_tool_context_param(&input_fn);
212 let execute_body = if has_ctx {
213 if let Some(args_ty) = &args_type {
214 quote! {
215 let typed_args: #args_ty = serde_json::from_value(args)
216 .map_err(|e| adk_tool::AdkError::tool(
217 format!("invalid arguments for '{}': {e}", #tool_name_str)
218 ))?;
219 #fn_name(ctx, typed_args).await
220 }
221 } else {
222 quote! {
223 let _ = args;
224 #fn_name(ctx).await
225 }
226 }
227 } else {
228 deserialize_call
229 };
230
231 // Generate optional trait method overrides
232 let read_only_override = if is_read_only {
233 quote! {
234 fn is_read_only(&self) -> bool { true }
235 }
236 } else {
237 quote! {}
238 };
239
240 let concurrency_safe_override = if is_concurrency_safe {
241 quote! {
242 fn is_concurrency_safe(&self) -> bool { true }
243 }
244 } else {
245 quote! {}
246 };
247
248 let long_running_override = if is_long_running {
249 quote! {
250 fn is_long_running(&self) -> bool { true }
251 }
252 } else {
253 quote! {}
254 };
255
256 let output = quote! {
257 // Keep the original function
258 #input_fn
259
260 /// Auto-generated tool struct for [`#fn_name`].
261 #fn_vis struct #struct_name;
262
263 #[adk_tool::async_trait]
264 impl adk_tool::Tool for #struct_name {
265 fn name(&self) -> &str {
266 #tool_name_str
267 }
268
269 fn description(&self) -> &str {
270 #description
271 }
272
273 fn parameters_schema(&self) -> Option<serde_json::Value> {
274 #schema_gen
275 }
276
277 #read_only_override
278 #concurrency_safe_override
279 #long_running_override
280
281 async fn execute(
282 &self,
283 ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
284 args: serde_json::Value,
285 ) -> adk_tool::Result<serde_json::Value> {
286 #execute_body
287 }
288 }
289 };
290
291 output.into()
292}
293
294/// Extract the args type from the function signature.
295/// Skips any `Arc<dyn ToolContext>` parameter.
296fn extract_args_type(func: &ItemFn) -> Option<Type> {
297 for arg in &func.sig.inputs {
298 if let FnArg::Typed(pat_type) = arg {
299 // Skip context parameters (Arc<dyn ToolContext>)
300 let ty = &pat_type.ty;
301 let ty_str = quote!(#ty).to_string();
302 if ty_str.contains("ToolContext") {
303 continue;
304 }
305 return Some((*pat_type.ty).clone());
306 }
307 }
308 None
309}
310
311/// Check if the function has an Arc<dyn ToolContext> parameter.
312fn has_tool_context_param(func: &ItemFn) -> bool {
313 func.sig.inputs.iter().any(|arg| {
314 if let FnArg::Typed(pat_type) = arg {
315 let ty = &pat_type.ty;
316 let ty_str = quote!(#ty).to_string();
317 ty_str.contains("ToolContext")
318 } else {
319 false
320 }
321 })
322}
323
324/// Parsed attributes from `#[tool(read_only, concurrency_safe, long_running)]`.
325struct ToolAttrs {
326 read_only: bool,
327 concurrency_safe: bool,
328 long_running: bool,
329}
330
331impl syn::parse::Parse for ToolAttrs {
332 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
333 let mut attrs =
334 ToolAttrs { read_only: false, concurrency_safe: false, long_running: false };
335
336 let punctuated =
337 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
338
339 for meta in punctuated {
340 if let Meta::Path(path) = &meta {
341 if path.is_ident("read_only") {
342 attrs.read_only = true;
343 } else if path.is_ident("concurrency_safe") {
344 attrs.concurrency_safe = true;
345 } else if path.is_ident("long_running") {
346 attrs.long_running = true;
347 } else {
348 return Err(syn::Error::new_spanned(
349 path,
350 "unknown tool attribute; expected `read_only`, `concurrency_safe`, or `long_running`",
351 ));
352 }
353 } else {
354 return Err(syn::Error::new_spanned(
355 meta,
356 "expected identifier (e.g., `read_only`), not key-value",
357 ));
358 }
359 }
360
361 Ok(attrs)
362 }
363}
364
365// ─── Functional API Macros ─────────────────────────────────────────────────────
366
367/// Attribute macro that generates a workflow agent struct from an async function.
368///
369/// The annotated function becomes the workflow body. The macro generates:
370/// - A PascalCase struct (e.g., `my_workflow` → `MyWorkflowAgent`)
371/// - A `new()` constructor accepting `Arc<dyn Checkpointer>`
372/// - An `invoke()` method that creates/restores `TaskContext`, validates state,
373/// creates checkpoints, calls the function, and persists the final checkpoint
374///
375/// # Requirements
376///
377/// - The function **must** be `async`
378/// - The function **must** accept `&mut TaskContext` as its sole parameter
379/// - The function **must** return `Result<Value>` (or equivalent)
380///
381/// # Example
382///
383/// ```rust,ignore
384/// use adk_graph::functional::TaskContext;
385/// use adk_graph::error::Result;
386/// use serde_json::Value;
387///
388/// #[entrypoint]
389/// async fn my_workflow(ctx: &mut TaskContext) -> Result<Value> {
390/// let data = step_a(ctx, "input").await?;
391/// let result = step_b(ctx, data).await?;
392/// Ok(result)
393/// }
394///
395/// // Generates: pub struct MyWorkflowAgent { ... }
396/// // with MyWorkflowAgent::new(checkpointer) and invoke(initial_state, config)
397/// ```
398#[proc_macro_attribute]
399pub fn entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
400 let input_fn = parse_macro_input!(item as ItemFn);
401
402 // Validate: must be async
403 if input_fn.sig.asyncness.is_none() {
404 return syn::Error::new_spanned(
405 input_fn.sig.fn_token,
406 "#[entrypoint] functions must be async",
407 )
408 .to_compile_error()
409 .into();
410 }
411
412 // Validate: must accept &mut TaskContext
413 let has_task_context = input_fn.sig.inputs.iter().any(|arg| {
414 if let FnArg::Typed(pat_type) = arg {
415 let full_str = quote!(#pat_type).to_string();
416 full_str.contains("TaskContext")
417 } else {
418 false
419 }
420 });
421
422 if !has_task_context {
423 return syn::Error::new_spanned(
424 &input_fn.sig,
425 "#[entrypoint] functions must accept `&mut TaskContext` as a parameter",
426 )
427 .to_compile_error()
428 .into();
429 }
430
431 let fn_name = &input_fn.sig.ident;
432 let fn_vis = &input_fn.vis;
433 let fn_name_str = fn_name.to_string();
434
435 // Generate PascalCase struct name: my_workflow → MyWorkflowAgent
436 let struct_name = format_ident!(
437 "{}Agent",
438 fn_name_str
439 .split('_')
440 .map(|seg| {
441 let mut chars = seg.chars();
442 match chars.next() {
443 None => String::new(),
444 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
445 }
446 })
447 .collect::<String>()
448 );
449
450 let output = quote! {
451 // Preserve the original function for direct testing
452 #input_fn
453
454 /// Auto-generated workflow agent struct for [`#fn_name`].
455 ///
456 /// Created by the `#[entrypoint]` macro. Provides `new()` and `invoke()`
457 /// methods for executing the workflow with automatic checkpointing.
458 #fn_vis struct #struct_name {
459 checkpointer: std::sync::Arc<dyn adk_graph::checkpoint::Checkpointer>,
460 }
461
462 impl #struct_name {
463 /// Create a new workflow agent with the given checkpointer.
464 pub fn new(checkpointer: std::sync::Arc<dyn adk_graph::checkpoint::Checkpointer>) -> Self {
465 Self { checkpointer }
466 }
467
468 /// Invoke the workflow with an initial state and execution configuration.
469 ///
470 /// This method:
471 /// 1. Creates or restores a `TaskContext` from the last checkpoint
472 /// 2. Validates initial state against the configured schema
473 /// 3. Creates a checkpoint before execution
474 /// 4. Calls the annotated workflow function
475 /// 5. Persists the final checkpoint
476 /// 6. Returns the final workflow state
477 pub async fn invoke(
478 &self,
479 initial_state: adk_graph::state::State,
480 execution_config: adk_graph::node::ExecutionConfig,
481 ) -> adk_graph::error::Result<adk_graph::state::State> {
482 use adk_graph::checkpoint::Checkpointer;
483 use adk_graph::functional::ExecutionLog;
484 use adk_graph::state::Checkpoint;
485 use adk_graph::stream::StreamEvent;
486
487 let thread_id = execution_config.thread_id.clone();
488
489 // Try to restore from checkpoint if resuming
490 let (state, execution_log) = if execution_config.resume_from.is_some() {
491 match self.checkpointer.load(&thread_id).await? {
492 Some(checkpoint) => {
493 let log: ExecutionLog = checkpoint
494 .metadata
495 .get("execution_log")
496 .and_then(|v| serde_json::from_value(v.clone()).ok())
497 .unwrap_or_default();
498 (checkpoint.state, log)
499 }
500 None => (initial_state, ExecutionLog::new()),
501 }
502 } else {
503 (initial_state, ExecutionLog::new())
504 };
505
506 // Create broadcast channel for stream events
507 let (event_tx, _) = tokio::sync::broadcast::channel::<StreamEvent>(256);
508 let cancel_token = tokio_util::sync::CancellationToken::new();
509 let execution_log = std::sync::Arc::new(tokio::sync::RwLock::new(execution_log));
510
511 // Create TaskContext
512 let mut ctx = adk_graph::functional::TaskContext::new(
513 thread_id.clone(),
514 state,
515 self.checkpointer.clone(),
516 event_tx.clone(),
517 execution_log.clone(),
518 cancel_token,
519 None,
520 );
521
522 // Validate initial state against schema (if configured)
523 ctx.validate_state().map_err(|e| adk_graph::error::GraphError::Other(e.to_string()))?;
524
525 // Create pre-execution checkpoint
526 let pre_checkpoint = Checkpoint::new(
527 &thread_id,
528 ctx.state().clone(),
529 0,
530 vec![],
531 )
532 .with_metadata("phase", serde_json::Value::String("pre_execution".to_string()));
533 self.checkpointer.save(&pre_checkpoint).await?;
534
535 // Emit workflow start event
536 let _ = event_tx.send(StreamEvent::node_start(#fn_name_str, 0));
537
538 // Call the workflow function
539 let start = std::time::Instant::now();
540 let result = #fn_name(&mut ctx).await;
541
542 let duration = start.elapsed().as_millis() as u64;
543
544 match result {
545 Ok(_value) => {
546 // Persist final checkpoint
547 let step = execution_log.read().await.current_step();
548 let final_checkpoint = Checkpoint::new(
549 &thread_id,
550 ctx.state().clone(),
551 step,
552 vec![],
553 )
554 .with_metadata("phase", serde_json::Value::String("completed".to_string()))
555 .with_metadata(
556 "execution_log",
557 serde_json::to_value(&*execution_log.read().await)
558 .unwrap_or(serde_json::Value::Null),
559 );
560 self.checkpointer.save(&final_checkpoint).await?;
561
562 // Emit workflow end event
563 let _ = event_tx.send(StreamEvent::node_end(#fn_name_str, step, duration));
564
565 Ok(ctx.state().clone())
566 }
567 Err(e) => {
568 // Persist failure checkpoint
569 let step = execution_log.read().await.current_step();
570 let fail_checkpoint = Checkpoint::new(
571 &thread_id,
572 ctx.state().clone(),
573 step,
574 vec![],
575 )
576 .with_metadata("phase", serde_json::Value::String("failed".to_string()))
577 .with_metadata("error", serde_json::Value::String(e.to_string()))
578 .with_metadata(
579 "execution_log",
580 serde_json::to_value(&*execution_log.read().await)
581 .unwrap_or(serde_json::Value::Null),
582 );
583 let _ = self.checkpointer.save(&fail_checkpoint).await;
584
585 // Emit error event
586 let _ = event_tx.send(StreamEvent::error(&e.to_string(), Some(#fn_name_str)));
587
588 Err(e)
589 }
590 }
591 }
592 }
593 };
594
595 output.into()
596}
597
598/// Attribute macro that generates a task wrapper with checkpointing, retry, and streaming.
599///
600/// The annotated function becomes the inner task body. The macro generates a wrapper
601/// function (prefixed with `__task_`) that:
602/// - Checks `ExecutionLog` for cached results (resume-skip path)
603/// - Emits `StreamEvent::node_start` and `StreamEvent::node_end` events
604/// - Implements retry logic when `retry(max_attempts, backoff)` is specified
605/// - Calls `record_completion()` on success
606/// - Calls `record_failure()` after all retries are exhausted
607///
608/// # Requirements
609///
610/// - The function **must** be `async`
611/// - The function **must** accept `&mut TaskContext` as its first argument
612///
613/// # Attributes
614///
615/// - `retry(max_attempts = N, backoff = "Xs")` — retry on failure with exponential backoff
616/// - `rerun_on_resume` — always re-execute on workflow resume, skip cached results
617/// - `rerun_on_resume = true` / `rerun_on_resume = false` — explicit boolean form
618///
619/// # Examples
620///
621/// ```rust,ignore
622/// use adk_graph::functional::TaskContext;
623/// use adk_graph::error::Result;
624/// use serde_json::Value;
625///
626/// #[task(retry(max_attempts = 3, backoff = "1s"))]
627/// async fn step_a(ctx: &mut TaskContext, input: &str) -> Result<Value> {
628/// Ok(serde_json::json!({"processed": input}))
629/// }
630///
631/// #[task(rerun_on_resume)]
632/// async fn step_b(ctx: &mut TaskContext) -> Result<Value> {
633/// // This task always re-executes on resume, never uses cached results
634/// Ok(serde_json::json!({"timestamp": chrono::Utc::now().to_rfc3339()}))
635/// }
636///
637/// #[task(rerun_on_resume, retry(max_attempts = 2, backoff = "2s"))]
638/// async fn step_c(ctx: &mut TaskContext) -> Result<Value> {
639/// // Combined: re-executes on resume with retry logic
640/// Ok(serde_json::json!({"status": "ok"}))
641/// }
642///
643/// // Generates: async fn __task_step_a(ctx: &mut TaskContext, input: &str) -> Result<Value>
644/// // which wraps step_a with checkpoint/retry/streaming logic.
645/// ```
646#[proc_macro_attribute]
647pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
648 let input_fn = parse_macro_input!(item as ItemFn);
649
650 // Validate: must be async
651 if input_fn.sig.asyncness.is_none() {
652 return syn::Error::new_spanned(input_fn.sig.fn_token, "#[task] functions must be async")
653 .to_compile_error()
654 .into();
655 }
656
657 // Validate: first argument must be &mut TaskContext
658 let has_task_context_first = input_fn
659 .sig
660 .inputs
661 .first()
662 .map(|arg| {
663 if let FnArg::Typed(pat_type) = arg {
664 let full_str = quote!(#pat_type).to_string();
665 full_str.contains("TaskContext")
666 } else {
667 false
668 }
669 })
670 .unwrap_or(false);
671
672 if !has_task_context_first {
673 return syn::Error::new_spanned(
674 &input_fn.sig,
675 "#[task] functions must accept `&mut TaskContext` as the first argument",
676 )
677 .to_compile_error()
678 .into();
679 }
680
681 // Parse retry attributes from #[task(retry(max_attempts = N, backoff = "Xs"))]
682 let task_attrs = parse_task_attrs(attr);
683
684 let fn_name = &input_fn.sig.ident;
685 let fn_vis = &input_fn.vis;
686 let fn_name_str = fn_name.to_string();
687 let wrapper_name = format_ident!("__task_{}", fn_name);
688
689 // Collect function parameters (all of them for the wrapper signature)
690 let params = &input_fn.sig.inputs;
691 let return_type = &input_fn.sig.output;
692
693 // Collect the argument names for forwarding the call (skip `ctx`)
694 let forward_args: Vec<_> = input_fn
695 .sig
696 .inputs
697 .iter()
698 .skip(1) // Skip ctx
699 .filter_map(|arg| if let FnArg::Typed(pat_type) = arg { Some(&pat_type.pat) } else { None })
700 .collect();
701
702 // Build the call expression
703 let call_expr = if forward_args.is_empty() {
704 quote! { #fn_name(ctx).await }
705 } else {
706 quote! { #fn_name(ctx, #(#forward_args),*).await }
707 };
708
709 // Generate retry logic or single-attempt logic
710 let execution_body = if let Some(retry_config) = &task_attrs.retry {
711 let max_attempts = retry_config.max_attempts;
712 let backoff_secs = retry_config.backoff_secs;
713 quote! {
714 let mut attempts: u32 = 0;
715 let max_attempts: u32 = #max_attempts;
716 let backoff = std::time::Duration::from_secs(#backoff_secs);
717
718 let result = loop {
719 attempts += 1;
720 match #call_expr {
721 Ok(value) => break Ok(value),
722 Err(e) if attempts < max_attempts => {
723 tokio::time::sleep(backoff * attempts).await;
724 continue;
725 }
726 Err(e) => {
727 ctx.record_failure(task_id, &e.to_string()).await?;
728 ctx.emit(adk_graph::stream::StreamEvent::error(
729 &e.to_string(),
730 Some(task_id),
731 ));
732 break Err(e);
733 }
734 }
735 };
736 }
737 } else {
738 quote! {
739 let result = match #call_expr {
740 Ok(value) => Ok(value),
741 Err(e) => {
742 ctx.record_failure(task_id, &e.to_string()).await?;
743 ctx.emit(adk_graph::stream::StreamEvent::error(
744 &e.to_string(),
745 Some(task_id),
746 ));
747 Err(e)
748 }
749 };
750 }
751 };
752
753 // Generate cache-check code based on rerun_on_resume flag
754 let cache_check = if task_attrs.rerun_on_resume {
755 // rerun_on_resume = true: skip cache check, always execute
756 quote! {}
757 } else {
758 // Default: check ExecutionLog for cached results (resume-skip path)
759 quote! {
760 // Check if already completed (resume path)
761 if let Some(cached_result) = ctx.get_cached_result(task_id).await {
762 return Ok(cached_result);
763 }
764 }
765 };
766
767 let output = quote! {
768 // Preserve the original function for direct testing
769 #input_fn
770
771 /// Auto-generated task wrapper for [`#fn_name`].
772 ///
773 /// Wraps the original function with:
774 /// - Resume-skip logic (checks `ExecutionLog` for cached results)
775 /// - `StreamEvent::node_start` / `StreamEvent::node_end` emission
776 /// - Retry logic (if configured)
777 /// - `record_completion()` on success
778 /// - `record_failure()` after all retries exhausted
779 #fn_vis async fn #wrapper_name(#params) #return_type {
780 let task_id = #fn_name_str;
781
782 #cache_check
783
784 // Emit task start event
785 let current_step = ctx.current_step().await;
786 ctx.emit(adk_graph::stream::StreamEvent::node_start(task_id, current_step));
787
788 let start = std::time::Instant::now();
789
790 #execution_body
791
792 if let Ok(ref value) = result {
793 // Record completion and checkpoint
794 ctx.record_completion(task_id, value).await?;
795 let duration = start.elapsed().as_millis() as u64;
796 let step = ctx.current_step().await;
797 ctx.emit(adk_graph::stream::StreamEvent::node_end(task_id, step, duration));
798 }
799
800 result
801 }
802 };
803
804 output.into()
805}
806
807// ─── Task Attribute Parsing ────────────────────────────────────────────────────
808
809/// Parsed retry configuration from `#[task(retry(max_attempts = N, backoff = "Xs"))]`.
810struct RetryConfig {
811 max_attempts: u32,
812 backoff_secs: u64,
813}
814
815/// Parsed attributes from `#[task(...)]`.
816struct TaskAttrs {
817 retry: Option<RetryConfig>,
818 rerun_on_resume: bool,
819}
820
821/// Parse task attributes from the attribute token stream.
822///
823/// Supports:
824/// - `#[task]` — no retry, no rerun
825/// - `#[task(retry(max_attempts = 3, backoff = "1s"))]` — with retry
826/// - `#[task(rerun_on_resume)]` — always re-execute on resume (skip cache)
827/// - `#[task(rerun_on_resume = true)]` — explicit boolean form
828/// - `#[task(rerun_on_resume, retry(max_attempts = 3, backoff = "1s"))]` — combined
829fn parse_task_attrs(attr: TokenStream) -> TaskAttrs {
830 if attr.is_empty() {
831 return TaskAttrs { retry: None, rerun_on_resume: false };
832 }
833
834 // Parse the attribute as a Meta list
835 let attr_meta: syn::Result<syn::Meta> = syn::parse(attr.clone());
836 if let Ok(syn::Meta::List(meta_list)) = attr_meta
837 && meta_list.path.is_ident("retry")
838 && let Some(retry) = parse_retry_from_meta_list(&meta_list)
839 {
840 return TaskAttrs { retry: Some(retry), rerun_on_resume: false };
841 }
842
843 // Try parsing as just the inner content of task(...)
844 // e.g., the attr stream is: `retry(max_attempts = 3, backoff = "1s")`
845 // or: `rerun_on_resume`
846 // or: `rerun_on_resume, retry(max_attempts = 3, backoff = "1s")`
847 let attr2: proc_macro2::TokenStream = attr.into();
848 let parsed: syn::Result<TaskAttrContent> = syn::parse2(attr2);
849 if let Ok(content) = parsed {
850 return TaskAttrs { retry: content.retry, rerun_on_resume: content.rerun_on_resume };
851 }
852
853 TaskAttrs { retry: None, rerun_on_resume: false }
854}
855
856/// Inner content parsed from `#[task(retry(...), rerun_on_resume)]`.
857struct TaskAttrContent {
858 retry: Option<RetryConfig>,
859 rerun_on_resume: bool,
860}
861
862impl syn::parse::Parse for TaskAttrContent {
863 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
864 let mut retry = None;
865 let mut rerun_on_resume = false;
866
867 // Parse comma-separated items: identifiers, name-value pairs, or calls like retry(...)
868 while !input.is_empty() {
869 let ident: syn::Ident = input.parse()?;
870
871 if ident == "retry" {
872 let content;
873 syn::parenthesized!(content in input);
874
875 let mut max_attempts: u32 = 3;
876 let mut backoff_secs: u64 = 1;
877
878 let pairs = syn::punctuated::Punctuated::<syn::MetaNameValue, syn::Token![,]>::parse_terminated(&content)?;
879
880 for pair in pairs {
881 if pair.path.is_ident("max_attempts")
882 && let syn::Expr::Lit(expr_lit) = &pair.value
883 && let syn::Lit::Int(lit_int) = &expr_lit.lit
884 {
885 max_attempts = lit_int.base10_parse().unwrap_or(3);
886 } else if pair.path.is_ident("backoff")
887 && let syn::Expr::Lit(expr_lit) = &pair.value
888 && let syn::Lit::Str(lit_str) = &expr_lit.lit
889 {
890 backoff_secs = parse_duration_str(&lit_str.value());
891 }
892 }
893
894 retry = Some(RetryConfig { max_attempts, backoff_secs });
895 } else if ident == "rerun_on_resume" {
896 // Accept both `rerun_on_resume` (flag, implies true)
897 // and `rerun_on_resume = true` / `rerun_on_resume = false`
898 if input.peek(syn::Token![=]) {
899 let _eq: syn::Token![=] = input.parse()?;
900 let lit: syn::LitBool = input.parse()?;
901 rerun_on_resume = lit.value;
902 } else {
903 rerun_on_resume = true;
904 }
905 } else {
906 return Err(syn::Error::new_spanned(
907 ident,
908 "unknown task attribute; expected `retry(...)` or `rerun_on_resume`",
909 ));
910 }
911
912 // Consume optional trailing comma
913 if input.peek(syn::Token![,]) {
914 let _comma: syn::Token![,] = input.parse()?;
915 }
916 }
917
918 Ok(TaskAttrContent { retry, rerun_on_resume })
919 }
920}
921
922/// Parse retry config from a `Meta::List` (e.g., `retry(max_attempts = 3, backoff = "1s")`).
923fn parse_retry_from_meta_list(meta_list: &syn::MetaList) -> Option<RetryConfig> {
924 let mut max_attempts: u32 = 3;
925 let mut backoff_secs: u64 = 1;
926
927 let pairs: syn::Result<syn::punctuated::Punctuated<syn::MetaNameValue, syn::Token![,]>> =
928 meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated);
929
930 if let Ok(pairs) = pairs {
931 for pair in pairs {
932 if pair.path.is_ident("max_attempts")
933 && let syn::Expr::Lit(expr_lit) = &pair.value
934 && let syn::Lit::Int(lit_int) = &expr_lit.lit
935 {
936 max_attempts = lit_int.base10_parse().unwrap_or(3);
937 } else if pair.path.is_ident("backoff")
938 && let syn::Expr::Lit(expr_lit) = &pair.value
939 && let syn::Lit::Str(lit_str) = &expr_lit.lit
940 {
941 backoff_secs = parse_duration_str(&lit_str.value());
942 }
943 }
944 Some(RetryConfig { max_attempts, backoff_secs })
945 } else {
946 None
947 }
948}
949
950/// Parse a duration string like "1s", "500ms", "2s" into seconds.
951/// Defaults to 1 second if parsing fails.
952fn parse_duration_str(s: &str) -> u64 {
953 let s = s.trim();
954 // Check "ms" suffix first (before "s" since "ms" ends with 's')
955 if let Some(ms) = s.strip_suffix("ms") {
956 return ms.parse::<u64>().ok().map(|v| v / 1000).unwrap_or(1);
957 }
958 if let Some(secs) = s.strip_suffix('s') {
959 return secs.parse::<u64>().unwrap_or(1);
960 }
961 // Try parsing as plain number (assume seconds)
962 s.parse::<u64>().unwrap_or(1)
963}