Skip to main content

adk_rust_macros/
lib.rs

1//! # adk-macros
2//!
3//! Proc macros for ADK-Rust that eliminate tool registration boilerplate.
4//!
5//! ## `#[tool]`
6//!
7//! Turns an async function into a fully-wired `adk_tool::Tool` implementation:
8//!
9//! ```rust,ignore
10//! use adk_macros::tool;
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct WeatherArgs {
16//!     /// The city to look up
17//!     city: String,
18//! }
19//!
20//! /// Get the current weather for a city.
21//! #[tool]
22//! async fn get_weather(args: WeatherArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
23//!     Ok(serde_json::json!({ "temp": 72, "city": args.city }))
24//! }
25//!
26//! // This generates a struct `GetWeather` that implements `adk_tool::Tool`.
27//! // Use it like: Arc::new(GetWeather)
28//! ```
29//!
30//! The macro:
31//! - Uses the function's doc comment as the tool description
32//! - Derives the JSON schema from the argument type via `schemars::schema_for!`
33//! - Names the tool after the function (snake_case)
34//! - Generates a zero-sized struct (PascalCase) implementing `Tool`
35
36use proc_macro::TokenStream;
37use quote::{format_ident, quote};
38use syn::{FnArg, ItemFn, Meta, Type, parse_macro_input};
39
40/// Attribute macro that generates a `Tool` implementation from an async function.
41///
42/// # Requirements
43///
44/// - The function must be `async`
45/// - It must take exactly one argument (the args struct) that implements
46///   `serde::de::DeserializeOwned` and `schemars::JsonSchema`
47/// - It must return `Result<serde_json::Value, adk_tool::AdkError>`
48/// - Doc comments become the tool description
49///
50/// # Attributes
51///
52/// Optional attributes can be passed to configure tool metadata:
53///
54/// - `read_only` — marks the tool as having no side effects (`is_read_only() → true`)
55/// - `concurrency_safe` — marks the tool as safe for concurrent execution (`is_concurrency_safe() → true`)
56/// - `long_running` — marks the tool as long-running (`is_long_running() → true`)
57///
58/// # Examples
59///
60/// ```rust,ignore
61/// /// Search the knowledge base for documents matching a query.
62/// #[tool]
63/// async fn search_docs(args: SearchArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
64///     // ...
65/// }
66///
67/// /// Look up cached data (read-only, safe for parallel dispatch).
68/// #[tool(read_only, concurrency_safe)]
69/// async fn cache_lookup(args: LookupArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
70///     // ...
71/// }
72///
73/// // Generated: pub struct SearchDocs; implements Tool
74/// // Use: agent_builder.tool(Arc::new(SearchDocs))
75/// ```
76#[proc_macro_attribute]
77pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
78    let input_fn = parse_macro_input!(item as ItemFn);
79
80    // Parse optional attributes: #[tool(read_only, concurrency_safe, long_running)]
81    let mut is_read_only = false;
82    let mut is_concurrency_safe = false;
83    let mut is_long_running = false;
84
85    if !attr.is_empty() {
86        let meta = parse_macro_input!(attr as ToolAttrs);
87        is_read_only = meta.read_only;
88        is_concurrency_safe = meta.concurrency_safe;
89        is_long_running = meta.long_running;
90    }
91
92    let fn_name = &input_fn.sig.ident;
93    let fn_vis = &input_fn.vis;
94
95    // Extract doc comments for description
96    let doc_lines: Vec<String> = input_fn
97        .attrs
98        .iter()
99        .filter(|attr| attr.path().is_ident("doc"))
100        .filter_map(|attr| {
101            if let syn::Meta::NameValue(nv) = &attr.meta
102                && let syn::Expr::Lit(lit) = &nv.value
103                && let syn::Lit::Str(s) = &lit.lit
104            {
105                return Some(s.value().trim().to_string());
106            }
107            None
108        })
109        .collect();
110
111    let description = if doc_lines.is_empty() {
112        fn_name.to_string().replace('_', " ")
113    } else {
114        doc_lines.join(" ")
115    };
116
117    let tool_name_str = fn_name.to_string();
118
119    // Generate PascalCase struct name: get_weather → GetWeather
120    let struct_name = format_ident!(
121        "{}",
122        tool_name_str
123            .split('_')
124            .map(|seg| {
125                let mut chars = seg.chars();
126                match chars.next() {
127                    None => String::new(),
128                    Some(c) => c.to_uppercase().to_string() + chars.as_str(),
129                }
130            })
131            .collect::<String>()
132    );
133
134    // Extract the single argument type
135    let args_type = extract_args_type(&input_fn);
136
137    // Check if we have a typed args parameter or no params
138    let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
139        (
140            quote! {
141                {
142                    let mut schema = serde_json::to_value(
143                        schemars::schema_for!(#args_ty)
144                    ).unwrap_or_default();
145                    // Strip fields that Gemini/LLM APIs don't accept
146                    if let Some(obj) = schema.as_object_mut() {
147                        obj.remove("$schema");
148                        obj.remove("title");
149                    }
150                    // Simplify nullable types: {"type": ["string", "null"]} → {"type": "string"}
151                    fn simplify_nullable(v: &mut serde_json::Value) {
152                        match v {
153                            serde_json::Value::Object(map) => {
154                                if let Some(serde_json::Value::Array(types)) = map.get("type") {
155                                    let non_null: Vec<_> = types.iter()
156                                        .filter(|t| t.as_str() != Some("null"))
157                                        .cloned()
158                                        .collect();
159                                    if non_null.len() == 1 {
160                                        map.insert("type".to_string(), non_null[0].clone());
161                                    }
162                                }
163                                // Remove anyOf wrappers for simple nullable types
164                                if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
165                                    for variant in &any_of {
166                                        if let Some(obj) = variant.as_object() {
167                                            if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
168                                                for (k, val) in obj {
169                                                    map.insert(k.clone(), val.clone());
170                                                }
171                                                break;
172                                            }
173                                        }
174                                    }
175                                }
176                                for val in map.values_mut() {
177                                    simplify_nullable(val);
178                                }
179                            }
180                            serde_json::Value::Array(arr) => {
181                                for item in arr {
182                                    simplify_nullable(item);
183                                }
184                            }
185                            _ => {}
186                        }
187                    }
188                    simplify_nullable(&mut schema);
189                    Some(schema)
190                }
191            },
192            quote! {
193                let typed_args: #args_ty = serde_json::from_value(args)
194                    .map_err(|e| adk_tool::AdkError::tool(
195                        format!("invalid arguments for '{}': {e}", #tool_name_str)
196                    ))?;
197                #fn_name(typed_args).await
198            },
199        )
200    } else {
201        (
202            quote! { None },
203            quote! {
204                let _ = args;
205                #fn_name().await
206            },
207        )
208    };
209
210    // Check if the function signature includes ctx: Arc<dyn ToolContext>
211    let has_ctx = has_tool_context_param(&input_fn);
212    let execute_body = if has_ctx {
213        if let Some(args_ty) = &args_type {
214            quote! {
215                let typed_args: #args_ty = serde_json::from_value(args)
216                    .map_err(|e| adk_tool::AdkError::tool(
217                        format!("invalid arguments for '{}': {e}", #tool_name_str)
218                    ))?;
219                #fn_name(ctx, typed_args).await
220            }
221        } else {
222            quote! {
223                let _ = args;
224                #fn_name(ctx).await
225            }
226        }
227    } else {
228        deserialize_call
229    };
230
231    // Generate optional trait method overrides
232    let read_only_override = if is_read_only {
233        quote! {
234            fn is_read_only(&self) -> bool { true }
235        }
236    } else {
237        quote! {}
238    };
239
240    let concurrency_safe_override = if is_concurrency_safe {
241        quote! {
242            fn is_concurrency_safe(&self) -> bool { true }
243        }
244    } else {
245        quote! {}
246    };
247
248    let long_running_override = if is_long_running {
249        quote! {
250            fn is_long_running(&self) -> bool { true }
251        }
252    } else {
253        quote! {}
254    };
255
256    let output = quote! {
257        // Keep the original function
258        #input_fn
259
260        /// Auto-generated tool struct for [`#fn_name`].
261        #fn_vis struct #struct_name;
262
263        #[adk_tool::async_trait]
264        impl adk_tool::Tool for #struct_name {
265            fn name(&self) -> &str {
266                #tool_name_str
267            }
268
269            fn description(&self) -> &str {
270                #description
271            }
272
273            fn parameters_schema(&self) -> Option<serde_json::Value> {
274                #schema_gen
275            }
276
277            #read_only_override
278            #concurrency_safe_override
279            #long_running_override
280
281            async fn execute(
282                &self,
283                ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
284                args: serde_json::Value,
285            ) -> adk_tool::Result<serde_json::Value> {
286                #execute_body
287            }
288        }
289    };
290
291    output.into()
292}
293
294/// Extract the args type from the function signature.
295/// Skips any `Arc<dyn ToolContext>` parameter.
296fn extract_args_type(func: &ItemFn) -> Option<Type> {
297    for arg in &func.sig.inputs {
298        if let FnArg::Typed(pat_type) = arg {
299            // Skip context parameters (Arc<dyn ToolContext>)
300            let ty = &pat_type.ty;
301            let ty_str = quote!(#ty).to_string();
302            if ty_str.contains("ToolContext") {
303                continue;
304            }
305            return Some((*pat_type.ty).clone());
306        }
307    }
308    None
309}
310
311/// Check if the function has an Arc<dyn ToolContext> parameter.
312fn has_tool_context_param(func: &ItemFn) -> bool {
313    func.sig.inputs.iter().any(|arg| {
314        if let FnArg::Typed(pat_type) = arg {
315            let ty = &pat_type.ty;
316            let ty_str = quote!(#ty).to_string();
317            ty_str.contains("ToolContext")
318        } else {
319            false
320        }
321    })
322}
323
324/// Parsed attributes from `#[tool(read_only, concurrency_safe, long_running)]`.
325struct ToolAttrs {
326    read_only: bool,
327    concurrency_safe: bool,
328    long_running: bool,
329}
330
331impl syn::parse::Parse for ToolAttrs {
332    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
333        let mut attrs =
334            ToolAttrs { read_only: false, concurrency_safe: false, long_running: false };
335
336        let punctuated =
337            syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
338
339        for meta in punctuated {
340            if let Meta::Path(path) = &meta {
341                if path.is_ident("read_only") {
342                    attrs.read_only = true;
343                } else if path.is_ident("concurrency_safe") {
344                    attrs.concurrency_safe = true;
345                } else if path.is_ident("long_running") {
346                    attrs.long_running = true;
347                } else {
348                    return Err(syn::Error::new_spanned(
349                        path,
350                        "unknown tool attribute; expected `read_only`, `concurrency_safe`, or `long_running`",
351                    ));
352                }
353            } else {
354                return Err(syn::Error::new_spanned(
355                    meta,
356                    "expected identifier (e.g., `read_only`), not key-value",
357                ));
358            }
359        }
360
361        Ok(attrs)
362    }
363}
364
365// ─── Functional API Macros ─────────────────────────────────────────────────────
366
367/// Attribute macro that generates a workflow agent struct from an async function.
368///
369/// The annotated function becomes the workflow body. The macro generates:
370/// - A PascalCase struct (e.g., `my_workflow` → `MyWorkflowAgent`)
371/// - A `new()` constructor accepting `Arc<dyn Checkpointer>`
372/// - An `invoke()` method that creates/restores `TaskContext`, validates state,
373///   creates checkpoints, calls the function, and persists the final checkpoint
374///
375/// # Requirements
376///
377/// - The function **must** be `async`
378/// - The function **must** accept `&mut TaskContext` as its sole parameter
379/// - The function **must** return `Result<Value>` (or equivalent)
380///
381/// # Example
382///
383/// ```rust,ignore
384/// use adk_graph::functional::TaskContext;
385/// use adk_graph::error::Result;
386/// use serde_json::Value;
387///
388/// #[entrypoint]
389/// async fn my_workflow(ctx: &mut TaskContext) -> Result<Value> {
390///     let data = step_a(ctx, "input").await?;
391///     let result = step_b(ctx, data).await?;
392///     Ok(result)
393/// }
394///
395/// // Generates: pub struct MyWorkflowAgent { ... }
396/// // with MyWorkflowAgent::new(checkpointer) and invoke(initial_state, config)
397/// ```
398#[proc_macro_attribute]
399pub fn entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
400    let input_fn = parse_macro_input!(item as ItemFn);
401
402    // Validate: must be async
403    if input_fn.sig.asyncness.is_none() {
404        return syn::Error::new_spanned(
405            input_fn.sig.fn_token,
406            "#[entrypoint] functions must be async",
407        )
408        .to_compile_error()
409        .into();
410    }
411
412    // Validate: must accept &mut TaskContext
413    let has_task_context = input_fn.sig.inputs.iter().any(|arg| {
414        if let FnArg::Typed(pat_type) = arg {
415            let full_str = quote!(#pat_type).to_string();
416            full_str.contains("TaskContext")
417        } else {
418            false
419        }
420    });
421
422    if !has_task_context {
423        return syn::Error::new_spanned(
424            &input_fn.sig,
425            "#[entrypoint] functions must accept `&mut TaskContext` as a parameter",
426        )
427        .to_compile_error()
428        .into();
429    }
430
431    let fn_name = &input_fn.sig.ident;
432    let fn_vis = &input_fn.vis;
433    let fn_name_str = fn_name.to_string();
434
435    // Generate PascalCase struct name: my_workflow → MyWorkflowAgent
436    let struct_name = format_ident!(
437        "{}Agent",
438        fn_name_str
439            .split('_')
440            .map(|seg| {
441                let mut chars = seg.chars();
442                match chars.next() {
443                    None => String::new(),
444                    Some(c) => c.to_uppercase().to_string() + chars.as_str(),
445                }
446            })
447            .collect::<String>()
448    );
449
450    let output = quote! {
451        // Preserve the original function for direct testing
452        #input_fn
453
454        /// Auto-generated workflow agent struct for [`#fn_name`].
455        ///
456        /// Created by the `#[entrypoint]` macro. Provides `new()` and `invoke()`
457        /// methods for executing the workflow with automatic checkpointing.
458        #fn_vis struct #struct_name {
459            checkpointer: std::sync::Arc<dyn adk_graph::checkpoint::Checkpointer>,
460        }
461
462        impl #struct_name {
463            /// Create a new workflow agent with the given checkpointer.
464            pub fn new(checkpointer: std::sync::Arc<dyn adk_graph::checkpoint::Checkpointer>) -> Self {
465                Self { checkpointer }
466            }
467
468            /// Invoke the workflow with an initial state and execution configuration.
469            ///
470            /// This method:
471            /// 1. Creates or restores a `TaskContext` from the last checkpoint
472            /// 2. Validates initial state against the configured schema
473            /// 3. Creates a checkpoint before execution
474            /// 4. Calls the annotated workflow function
475            /// 5. Persists the final checkpoint
476            /// 6. Returns the final workflow state
477            pub async fn invoke(
478                &self,
479                initial_state: adk_graph::state::State,
480                execution_config: adk_graph::node::ExecutionConfig,
481            ) -> adk_graph::error::Result<adk_graph::state::State> {
482                use adk_graph::checkpoint::Checkpointer;
483                use adk_graph::functional::ExecutionLog;
484                use adk_graph::state::Checkpoint;
485                use adk_graph::stream::StreamEvent;
486
487                let thread_id = execution_config.thread_id.clone();
488
489                // Try to restore from checkpoint if resuming
490                let (state, execution_log) = if execution_config.resume_from.is_some() {
491                    match self.checkpointer.load(&thread_id).await? {
492                        Some(checkpoint) => {
493                            let log: ExecutionLog = checkpoint
494                                .metadata
495                                .get("execution_log")
496                                .and_then(|v| serde_json::from_value(v.clone()).ok())
497                                .unwrap_or_default();
498                            (checkpoint.state, log)
499                        }
500                        None => (initial_state, ExecutionLog::new()),
501                    }
502                } else {
503                    (initial_state, ExecutionLog::new())
504                };
505
506                // Create broadcast channel for stream events
507                let (event_tx, _) = tokio::sync::broadcast::channel::<StreamEvent>(256);
508                let cancel_token = tokio_util::sync::CancellationToken::new();
509                let execution_log = std::sync::Arc::new(tokio::sync::RwLock::new(execution_log));
510
511                // Create TaskContext
512                let mut ctx = adk_graph::functional::TaskContext::new(
513                    thread_id.clone(),
514                    state,
515                    self.checkpointer.clone(),
516                    event_tx.clone(),
517                    execution_log.clone(),
518                    cancel_token,
519                    None,
520                );
521
522                // Validate initial state against schema (if configured)
523                ctx.validate_state().map_err(|e| adk_graph::error::GraphError::Other(e.to_string()))?;
524
525                // Create pre-execution checkpoint
526                let pre_checkpoint = Checkpoint::new(
527                    &thread_id,
528                    ctx.state().clone(),
529                    0,
530                    vec![],
531                )
532                .with_metadata("phase", serde_json::Value::String("pre_execution".to_string()));
533                self.checkpointer.save(&pre_checkpoint).await?;
534
535                // Emit workflow start event
536                let _ = event_tx.send(StreamEvent::node_start(#fn_name_str, 0));
537
538                // Call the workflow function
539                let start = std::time::Instant::now();
540                let result = #fn_name(&mut ctx).await;
541
542                let duration = start.elapsed().as_millis() as u64;
543
544                match result {
545                    Ok(_value) => {
546                        // Persist final checkpoint
547                        let step = execution_log.read().await.current_step();
548                        let final_checkpoint = Checkpoint::new(
549                            &thread_id,
550                            ctx.state().clone(),
551                            step,
552                            vec![],
553                        )
554                        .with_metadata("phase", serde_json::Value::String("completed".to_string()))
555                        .with_metadata(
556                            "execution_log",
557                            serde_json::to_value(&*execution_log.read().await)
558                                .unwrap_or(serde_json::Value::Null),
559                        );
560                        self.checkpointer.save(&final_checkpoint).await?;
561
562                        // Emit workflow end event
563                        let _ = event_tx.send(StreamEvent::node_end(#fn_name_str, step, duration));
564
565                        Ok(ctx.state().clone())
566                    }
567                    Err(e) => {
568                        // Persist failure checkpoint
569                        let step = execution_log.read().await.current_step();
570                        let fail_checkpoint = Checkpoint::new(
571                            &thread_id,
572                            ctx.state().clone(),
573                            step,
574                            vec![],
575                        )
576                        .with_metadata("phase", serde_json::Value::String("failed".to_string()))
577                        .with_metadata("error", serde_json::Value::String(e.to_string()))
578                        .with_metadata(
579                            "execution_log",
580                            serde_json::to_value(&*execution_log.read().await)
581                                .unwrap_or(serde_json::Value::Null),
582                        );
583                        let _ = self.checkpointer.save(&fail_checkpoint).await;
584
585                        // Emit error event
586                        let _ = event_tx.send(StreamEvent::error(&e.to_string(), Some(#fn_name_str)));
587
588                        Err(e)
589                    }
590                }
591            }
592        }
593    };
594
595    output.into()
596}
597
598/// Attribute macro that generates a task wrapper with checkpointing, retry, and streaming.
599///
600/// The annotated function becomes the inner task body. The macro generates a wrapper
601/// function (prefixed with `__task_`) that:
602/// - Checks `ExecutionLog` for cached results (resume-skip path)
603/// - Emits `StreamEvent::node_start` and `StreamEvent::node_end` events
604/// - Implements retry logic when `retry(max_attempts, backoff)` is specified
605/// - Calls `record_completion()` on success
606/// - Calls `record_failure()` after all retries are exhausted
607///
608/// # Requirements
609///
610/// - The function **must** be `async`
611/// - The function **must** accept `&mut TaskContext` as its first argument
612///
613/// # Attributes
614///
615/// - `retry(max_attempts = N, backoff = "Xs")` — retry on failure with exponential backoff
616/// - `rerun_on_resume` — always re-execute on workflow resume, skip cached results
617/// - `rerun_on_resume = true` / `rerun_on_resume = false` — explicit boolean form
618///
619/// # Examples
620///
621/// ```rust,ignore
622/// use adk_graph::functional::TaskContext;
623/// use adk_graph::error::Result;
624/// use serde_json::Value;
625///
626/// #[task(retry(max_attempts = 3, backoff = "1s"))]
627/// async fn step_a(ctx: &mut TaskContext, input: &str) -> Result<Value> {
628///     Ok(serde_json::json!({"processed": input}))
629/// }
630///
631/// #[task(rerun_on_resume)]
632/// async fn step_b(ctx: &mut TaskContext) -> Result<Value> {
633///     // This task always re-executes on resume, never uses cached results
634///     Ok(serde_json::json!({"timestamp": chrono::Utc::now().to_rfc3339()}))
635/// }
636///
637/// #[task(rerun_on_resume, retry(max_attempts = 2, backoff = "2s"))]
638/// async fn step_c(ctx: &mut TaskContext) -> Result<Value> {
639///     // Combined: re-executes on resume with retry logic
640///     Ok(serde_json::json!({"status": "ok"}))
641/// }
642///
643/// // Generates: async fn __task_step_a(ctx: &mut TaskContext, input: &str) -> Result<Value>
644/// // which wraps step_a with checkpoint/retry/streaming logic.
645/// ```
646#[proc_macro_attribute]
647pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
648    let input_fn = parse_macro_input!(item as ItemFn);
649
650    // Validate: must be async
651    if input_fn.sig.asyncness.is_none() {
652        return syn::Error::new_spanned(input_fn.sig.fn_token, "#[task] functions must be async")
653            .to_compile_error()
654            .into();
655    }
656
657    // Validate: first argument must be &mut TaskContext
658    let has_task_context_first = input_fn
659        .sig
660        .inputs
661        .first()
662        .map(|arg| {
663            if let FnArg::Typed(pat_type) = arg {
664                let full_str = quote!(#pat_type).to_string();
665                full_str.contains("TaskContext")
666            } else {
667                false
668            }
669        })
670        .unwrap_or(false);
671
672    if !has_task_context_first {
673        return syn::Error::new_spanned(
674            &input_fn.sig,
675            "#[task] functions must accept `&mut TaskContext` as the first argument",
676        )
677        .to_compile_error()
678        .into();
679    }
680
681    // Parse retry attributes from #[task(retry(max_attempts = N, backoff = "Xs"))]
682    let task_attrs = parse_task_attrs(attr);
683
684    let fn_name = &input_fn.sig.ident;
685    let fn_vis = &input_fn.vis;
686    let fn_name_str = fn_name.to_string();
687    let wrapper_name = format_ident!("__task_{}", fn_name);
688
689    // Collect function parameters (all of them for the wrapper signature)
690    let params = &input_fn.sig.inputs;
691    let return_type = &input_fn.sig.output;
692
693    // Collect the argument names for forwarding the call (skip `ctx`)
694    let forward_args: Vec<_> = input_fn
695        .sig
696        .inputs
697        .iter()
698        .skip(1) // Skip ctx
699        .filter_map(|arg| if let FnArg::Typed(pat_type) = arg { Some(&pat_type.pat) } else { None })
700        .collect();
701
702    // Build the call expression
703    let call_expr = if forward_args.is_empty() {
704        quote! { #fn_name(ctx).await }
705    } else {
706        quote! { #fn_name(ctx, #(#forward_args),*).await }
707    };
708
709    // Generate retry logic or single-attempt logic
710    let execution_body = if let Some(retry_config) = &task_attrs.retry {
711        let max_attempts = retry_config.max_attempts;
712        let backoff_secs = retry_config.backoff_secs;
713        quote! {
714            let mut attempts: u32 = 0;
715            let max_attempts: u32 = #max_attempts;
716            let backoff = std::time::Duration::from_secs(#backoff_secs);
717
718            let result = loop {
719                attempts += 1;
720                match #call_expr {
721                    Ok(value) => break Ok(value),
722                    Err(e) if attempts < max_attempts => {
723                        tokio::time::sleep(backoff * attempts).await;
724                        continue;
725                    }
726                    Err(e) => {
727                        ctx.record_failure(task_id, &e.to_string()).await?;
728                        ctx.emit(adk_graph::stream::StreamEvent::error(
729                            &e.to_string(),
730                            Some(task_id),
731                        ));
732                        break Err(e);
733                    }
734                }
735            };
736        }
737    } else {
738        quote! {
739            let result = match #call_expr {
740                Ok(value) => Ok(value),
741                Err(e) => {
742                    ctx.record_failure(task_id, &e.to_string()).await?;
743                    ctx.emit(adk_graph::stream::StreamEvent::error(
744                        &e.to_string(),
745                        Some(task_id),
746                    ));
747                    Err(e)
748                }
749            };
750        }
751    };
752
753    // Generate cache-check code based on rerun_on_resume flag
754    let cache_check = if task_attrs.rerun_on_resume {
755        // rerun_on_resume = true: skip cache check, always execute
756        quote! {}
757    } else {
758        // Default: check ExecutionLog for cached results (resume-skip path)
759        quote! {
760            // Check if already completed (resume path)
761            if let Some(cached_result) = ctx.get_cached_result(task_id).await {
762                return Ok(cached_result);
763            }
764        }
765    };
766
767    let output = quote! {
768        // Preserve the original function for direct testing
769        #input_fn
770
771        /// Auto-generated task wrapper for [`#fn_name`].
772        ///
773        /// Wraps the original function with:
774        /// - Resume-skip logic (checks `ExecutionLog` for cached results)
775        /// - `StreamEvent::node_start` / `StreamEvent::node_end` emission
776        /// - Retry logic (if configured)
777        /// - `record_completion()` on success
778        /// - `record_failure()` after all retries exhausted
779        #fn_vis async fn #wrapper_name(#params) #return_type {
780            let task_id = #fn_name_str;
781
782            #cache_check
783
784            // Emit task start event
785            let current_step = ctx.current_step().await;
786            ctx.emit(adk_graph::stream::StreamEvent::node_start(task_id, current_step));
787
788            let start = std::time::Instant::now();
789
790            #execution_body
791
792            if let Ok(ref value) = result {
793                // Record completion and checkpoint
794                ctx.record_completion(task_id, value).await?;
795                let duration = start.elapsed().as_millis() as u64;
796                let step = ctx.current_step().await;
797                ctx.emit(adk_graph::stream::StreamEvent::node_end(task_id, step, duration));
798            }
799
800            result
801        }
802    };
803
804    output.into()
805}
806
807// ─── Task Attribute Parsing ────────────────────────────────────────────────────
808
809/// Parsed retry configuration from `#[task(retry(max_attempts = N, backoff = "Xs"))]`.
810struct RetryConfig {
811    max_attempts: u32,
812    backoff_secs: u64,
813}
814
815/// Parsed attributes from `#[task(...)]`.
816struct TaskAttrs {
817    retry: Option<RetryConfig>,
818    rerun_on_resume: bool,
819}
820
821/// Parse task attributes from the attribute token stream.
822///
823/// Supports:
824/// - `#[task]` — no retry, no rerun
825/// - `#[task(retry(max_attempts = 3, backoff = "1s"))]` — with retry
826/// - `#[task(rerun_on_resume)]` — always re-execute on resume (skip cache)
827/// - `#[task(rerun_on_resume = true)]` — explicit boolean form
828/// - `#[task(rerun_on_resume, retry(max_attempts = 3, backoff = "1s"))]` — combined
829fn parse_task_attrs(attr: TokenStream) -> TaskAttrs {
830    if attr.is_empty() {
831        return TaskAttrs { retry: None, rerun_on_resume: false };
832    }
833
834    // Parse the attribute as a Meta list
835    let attr_meta: syn::Result<syn::Meta> = syn::parse(attr.clone());
836    if let Ok(syn::Meta::List(meta_list)) = attr_meta
837        && meta_list.path.is_ident("retry")
838        && let Some(retry) = parse_retry_from_meta_list(&meta_list)
839    {
840        return TaskAttrs { retry: Some(retry), rerun_on_resume: false };
841    }
842
843    // Try parsing as just the inner content of task(...)
844    // e.g., the attr stream is: `retry(max_attempts = 3, backoff = "1s")`
845    // or: `rerun_on_resume`
846    // or: `rerun_on_resume, retry(max_attempts = 3, backoff = "1s")`
847    let attr2: proc_macro2::TokenStream = attr.into();
848    let parsed: syn::Result<TaskAttrContent> = syn::parse2(attr2);
849    if let Ok(content) = parsed {
850        return TaskAttrs { retry: content.retry, rerun_on_resume: content.rerun_on_resume };
851    }
852
853    TaskAttrs { retry: None, rerun_on_resume: false }
854}
855
856/// Inner content parsed from `#[task(retry(...), rerun_on_resume)]`.
857struct TaskAttrContent {
858    retry: Option<RetryConfig>,
859    rerun_on_resume: bool,
860}
861
862impl syn::parse::Parse for TaskAttrContent {
863    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
864        let mut retry = None;
865        let mut rerun_on_resume = false;
866
867        // Parse comma-separated items: identifiers, name-value pairs, or calls like retry(...)
868        while !input.is_empty() {
869            let ident: syn::Ident = input.parse()?;
870
871            if ident == "retry" {
872                let content;
873                syn::parenthesized!(content in input);
874
875                let mut max_attempts: u32 = 3;
876                let mut backoff_secs: u64 = 1;
877
878                let pairs = syn::punctuated::Punctuated::<syn::MetaNameValue, syn::Token![,]>::parse_terminated(&content)?;
879
880                for pair in pairs {
881                    if pair.path.is_ident("max_attempts")
882                        && let syn::Expr::Lit(expr_lit) = &pair.value
883                        && let syn::Lit::Int(lit_int) = &expr_lit.lit
884                    {
885                        max_attempts = lit_int.base10_parse().unwrap_or(3);
886                    } else if pair.path.is_ident("backoff")
887                        && let syn::Expr::Lit(expr_lit) = &pair.value
888                        && let syn::Lit::Str(lit_str) = &expr_lit.lit
889                    {
890                        backoff_secs = parse_duration_str(&lit_str.value());
891                    }
892                }
893
894                retry = Some(RetryConfig { max_attempts, backoff_secs });
895            } else if ident == "rerun_on_resume" {
896                // Accept both `rerun_on_resume` (flag, implies true)
897                // and `rerun_on_resume = true` / `rerun_on_resume = false`
898                if input.peek(syn::Token![=]) {
899                    let _eq: syn::Token![=] = input.parse()?;
900                    let lit: syn::LitBool = input.parse()?;
901                    rerun_on_resume = lit.value;
902                } else {
903                    rerun_on_resume = true;
904                }
905            } else {
906                return Err(syn::Error::new_spanned(
907                    ident,
908                    "unknown task attribute; expected `retry(...)` or `rerun_on_resume`",
909                ));
910            }
911
912            // Consume optional trailing comma
913            if input.peek(syn::Token![,]) {
914                let _comma: syn::Token![,] = input.parse()?;
915            }
916        }
917
918        Ok(TaskAttrContent { retry, rerun_on_resume })
919    }
920}
921
922/// Parse retry config from a `Meta::List` (e.g., `retry(max_attempts = 3, backoff = "1s")`).
923fn parse_retry_from_meta_list(meta_list: &syn::MetaList) -> Option<RetryConfig> {
924    let mut max_attempts: u32 = 3;
925    let mut backoff_secs: u64 = 1;
926
927    let pairs: syn::Result<syn::punctuated::Punctuated<syn::MetaNameValue, syn::Token![,]>> =
928        meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated);
929
930    if let Ok(pairs) = pairs {
931        for pair in pairs {
932            if pair.path.is_ident("max_attempts")
933                && let syn::Expr::Lit(expr_lit) = &pair.value
934                && let syn::Lit::Int(lit_int) = &expr_lit.lit
935            {
936                max_attempts = lit_int.base10_parse().unwrap_or(3);
937            } else if pair.path.is_ident("backoff")
938                && let syn::Expr::Lit(expr_lit) = &pair.value
939                && let syn::Lit::Str(lit_str) = &expr_lit.lit
940            {
941                backoff_secs = parse_duration_str(&lit_str.value());
942            }
943        }
944        Some(RetryConfig { max_attempts, backoff_secs })
945    } else {
946        None
947    }
948}
949
950/// Parse a duration string like "1s", "500ms", "2s" into seconds.
951/// Defaults to 1 second if parsing fails.
952fn parse_duration_str(s: &str) -> u64 {
953    let s = s.trim();
954    // Check "ms" suffix first (before "s" since "ms" ends with 's')
955    if let Some(ms) = s.strip_suffix("ms") {
956        return ms.parse::<u64>().ok().map(|v| v / 1000).unwrap_or(1);
957    }
958    if let Some(secs) = s.strip_suffix('s') {
959        return secs.parse::<u64>().unwrap_or(1);
960    }
961    // Try parsing as plain number (assume seconds)
962    s.parse::<u64>().unwrap_or(1)
963}