use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{FnArg, ItemFn, Meta, Type, parse_macro_input};
#[proc_macro_attribute]
pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let mut is_read_only = false;
let mut is_concurrency_safe = false;
let mut is_long_running = false;
if !attr.is_empty() {
let meta = parse_macro_input!(attr as ToolAttrs);
is_read_only = meta.read_only;
is_concurrency_safe = meta.concurrency_safe;
is_long_running = meta.long_running;
}
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let doc_lines: Vec<String> = input_fn
.attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
.filter_map(|attr| {
if let syn::Meta::NameValue(nv) = &attr.meta
&& let syn::Expr::Lit(lit) = &nv.value
&& let syn::Lit::Str(s) = &lit.lit
{
return Some(s.value().trim().to_string());
}
None
})
.collect();
let description = if doc_lines.is_empty() {
fn_name.to_string().replace('_', " ")
} else {
doc_lines.join(" ")
};
let tool_name_str = fn_name.to_string();
let struct_name = format_ident!(
"{}",
tool_name_str
.split('_')
.map(|seg| {
let mut chars = seg.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().to_string() + chars.as_str(),
}
})
.collect::<String>()
);
let args_type = extract_args_type(&input_fn);
let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
(
quote! {
{
let mut schema = serde_json::to_value(
schemars::schema_for!(#args_ty)
).unwrap_or_default();
if let Some(obj) = schema.as_object_mut() {
obj.remove("$schema");
obj.remove("title");
}
fn simplify_nullable(v: &mut serde_json::Value) {
match v {
serde_json::Value::Object(map) => {
if let Some(serde_json::Value::Array(types)) = map.get("type") {
let non_null: Vec<_> = types.iter()
.filter(|t| t.as_str() != Some("null"))
.cloned()
.collect();
if non_null.len() == 1 {
map.insert("type".to_string(), non_null[0].clone());
}
}
if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
for variant in &any_of {
if let Some(obj) = variant.as_object() {
if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
for (k, val) in obj {
map.insert(k.clone(), val.clone());
}
break;
}
}
}
}
for val in map.values_mut() {
simplify_nullable(val);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
simplify_nullable(item);
}
}
_ => {}
}
}
simplify_nullable(&mut schema);
Some(schema)
}
},
quote! {
let typed_args: #args_ty = serde_json::from_value(args)
.map_err(|e| adk_tool::AdkError::tool(
format!("invalid arguments for '{}': {e}", #tool_name_str)
))?;
#fn_name(typed_args).await
},
)
} else {
(
quote! { None },
quote! {
let _ = args;
#fn_name().await
},
)
};
let has_ctx = has_tool_context_param(&input_fn);
let execute_body = if has_ctx {
if let Some(args_ty) = &args_type {
quote! {
let typed_args: #args_ty = serde_json::from_value(args)
.map_err(|e| adk_tool::AdkError::tool(
format!("invalid arguments for '{}': {e}", #tool_name_str)
))?;
#fn_name(ctx, typed_args).await
}
} else {
quote! {
let _ = args;
#fn_name(ctx).await
}
}
} else {
deserialize_call
};
let read_only_override = if is_read_only {
quote! {
fn is_read_only(&self) -> bool { true }
}
} else {
quote! {}
};
let concurrency_safe_override = if is_concurrency_safe {
quote! {
fn is_concurrency_safe(&self) -> bool { true }
}
} else {
quote! {}
};
let long_running_override = if is_long_running {
quote! {
fn is_long_running(&self) -> bool { true }
}
} else {
quote! {}
};
let output = quote! {
#input_fn
#fn_vis struct #struct_name;
#[adk_tool::async_trait]
impl adk_tool::Tool for #struct_name {
fn name(&self) -> &str {
#tool_name_str
}
fn description(&self) -> &str {
#description
}
fn parameters_schema(&self) -> Option<serde_json::Value> {
#schema_gen
}
#read_only_override
#concurrency_safe_override
#long_running_override
async fn execute(
&self,
ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
args: serde_json::Value,
) -> adk_tool::Result<serde_json::Value> {
#execute_body
}
}
};
output.into()
}
fn extract_args_type(func: &ItemFn) -> Option<Type> {
for arg in &func.sig.inputs {
if let FnArg::Typed(pat_type) = arg {
let ty = &pat_type.ty;
let ty_str = quote!(#ty).to_string();
if ty_str.contains("ToolContext") {
continue;
}
return Some((*pat_type.ty).clone());
}
}
None
}
fn has_tool_context_param(func: &ItemFn) -> bool {
func.sig.inputs.iter().any(|arg| {
if let FnArg::Typed(pat_type) = arg {
let ty = &pat_type.ty;
let ty_str = quote!(#ty).to_string();
ty_str.contains("ToolContext")
} else {
false
}
})
}
struct ToolAttrs {
read_only: bool,
concurrency_safe: bool,
long_running: bool,
}
impl syn::parse::Parse for ToolAttrs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut attrs =
ToolAttrs { read_only: false, concurrency_safe: false, long_running: false };
let punctuated =
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
for meta in punctuated {
if let Meta::Path(path) = &meta {
if path.is_ident("read_only") {
attrs.read_only = true;
} else if path.is_ident("concurrency_safe") {
attrs.concurrency_safe = true;
} else if path.is_ident("long_running") {
attrs.long_running = true;
} else {
return Err(syn::Error::new_spanned(
path,
"unknown tool attribute; expected `read_only`, `concurrency_safe`, or `long_running`",
));
}
} else {
return Err(syn::Error::new_spanned(
meta,
"expected identifier (e.g., `read_only`), not key-value",
));
}
}
Ok(attrs)
}
}
#[proc_macro_attribute]
pub fn entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
if input_fn.sig.asyncness.is_none() {
return syn::Error::new_spanned(
input_fn.sig.fn_token,
"#[entrypoint] functions must be async",
)
.to_compile_error()
.into();
}
let has_task_context = input_fn.sig.inputs.iter().any(|arg| {
if let FnArg::Typed(pat_type) = arg {
let full_str = quote!(#pat_type).to_string();
full_str.contains("TaskContext")
} else {
false
}
});
if !has_task_context {
return syn::Error::new_spanned(
&input_fn.sig,
"#[entrypoint] functions must accept `&mut TaskContext` as a parameter",
)
.to_compile_error()
.into();
}
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_name_str = fn_name.to_string();
let struct_name = format_ident!(
"{}Agent",
fn_name_str
.split('_')
.map(|seg| {
let mut chars = seg.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().to_string() + chars.as_str(),
}
})
.collect::<String>()
);
let output = quote! {
#input_fn
#fn_vis struct #struct_name {
checkpointer: std::sync::Arc<dyn adk_graph::checkpoint::Checkpointer>,
}
impl #struct_name {
pub fn new(checkpointer: std::sync::Arc<dyn adk_graph::checkpoint::Checkpointer>) -> Self {
Self { checkpointer }
}
pub async fn invoke(
&self,
initial_state: adk_graph::state::State,
execution_config: adk_graph::node::ExecutionConfig,
) -> adk_graph::error::Result<adk_graph::state::State> {
use adk_graph::checkpoint::Checkpointer;
use adk_graph::functional::ExecutionLog;
use adk_graph::state::Checkpoint;
use adk_graph::stream::StreamEvent;
let thread_id = execution_config.thread_id.clone();
let (state, execution_log) = if execution_config.resume_from.is_some() {
match self.checkpointer.load(&thread_id).await? {
Some(checkpoint) => {
let log: ExecutionLog = checkpoint
.metadata
.get("execution_log")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
(checkpoint.state, log)
}
None => (initial_state, ExecutionLog::new()),
}
} else {
(initial_state, ExecutionLog::new())
};
let (event_tx, _) = tokio::sync::broadcast::channel::<StreamEvent>(256);
let cancel_token = tokio_util::sync::CancellationToken::new();
let execution_log = std::sync::Arc::new(tokio::sync::RwLock::new(execution_log));
let mut ctx = adk_graph::functional::TaskContext::new(
thread_id.clone(),
state,
self.checkpointer.clone(),
event_tx.clone(),
execution_log.clone(),
cancel_token,
None,
);
ctx.validate_state().map_err(|e| adk_graph::error::GraphError::Other(e.to_string()))?;
let pre_checkpoint = Checkpoint::new(
&thread_id,
ctx.state().clone(),
0,
vec![],
)
.with_metadata("phase", serde_json::Value::String("pre_execution".to_string()));
self.checkpointer.save(&pre_checkpoint).await?;
let _ = event_tx.send(StreamEvent::node_start(#fn_name_str, 0));
let start = std::time::Instant::now();
let result = #fn_name(&mut ctx).await;
let duration = start.elapsed().as_millis() as u64;
match result {
Ok(_value) => {
let step = execution_log.read().await.current_step();
let final_checkpoint = Checkpoint::new(
&thread_id,
ctx.state().clone(),
step,
vec![],
)
.with_metadata("phase", serde_json::Value::String("completed".to_string()))
.with_metadata(
"execution_log",
serde_json::to_value(&*execution_log.read().await)
.unwrap_or(serde_json::Value::Null),
);
self.checkpointer.save(&final_checkpoint).await?;
let _ = event_tx.send(StreamEvent::node_end(#fn_name_str, step, duration));
Ok(ctx.state().clone())
}
Err(e) => {
let step = execution_log.read().await.current_step();
let fail_checkpoint = Checkpoint::new(
&thread_id,
ctx.state().clone(),
step,
vec![],
)
.with_metadata("phase", serde_json::Value::String("failed".to_string()))
.with_metadata("error", serde_json::Value::String(e.to_string()))
.with_metadata(
"execution_log",
serde_json::to_value(&*execution_log.read().await)
.unwrap_or(serde_json::Value::Null),
);
let _ = self.checkpointer.save(&fail_checkpoint).await;
let _ = event_tx.send(StreamEvent::error(&e.to_string(), Some(#fn_name_str)));
Err(e)
}
}
}
}
};
output.into()
}
#[proc_macro_attribute]
pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
if input_fn.sig.asyncness.is_none() {
return syn::Error::new_spanned(input_fn.sig.fn_token, "#[task] functions must be async")
.to_compile_error()
.into();
}
let has_task_context_first = input_fn
.sig
.inputs
.first()
.map(|arg| {
if let FnArg::Typed(pat_type) = arg {
let full_str = quote!(#pat_type).to_string();
full_str.contains("TaskContext")
} else {
false
}
})
.unwrap_or(false);
if !has_task_context_first {
return syn::Error::new_spanned(
&input_fn.sig,
"#[task] functions must accept `&mut TaskContext` as the first argument",
)
.to_compile_error()
.into();
}
let task_attrs = parse_task_attrs(attr);
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_name_str = fn_name.to_string();
let wrapper_name = format_ident!("__task_{}", fn_name);
let params = &input_fn.sig.inputs;
let return_type = &input_fn.sig.output;
let forward_args: Vec<_> = input_fn
.sig
.inputs
.iter()
.skip(1) .filter_map(|arg| if let FnArg::Typed(pat_type) = arg { Some(&pat_type.pat) } else { None })
.collect();
let call_expr = if forward_args.is_empty() {
quote! { #fn_name(ctx).await }
} else {
quote! { #fn_name(ctx, #(#forward_args),*).await }
};
let execution_body = if let Some(retry_config) = &task_attrs.retry {
let max_attempts = retry_config.max_attempts;
let backoff_secs = retry_config.backoff_secs;
quote! {
let mut attempts: u32 = 0;
let max_attempts: u32 = #max_attempts;
let backoff = std::time::Duration::from_secs(#backoff_secs);
let result = loop {
attempts += 1;
match #call_expr {
Ok(value) => break Ok(value),
Err(e) if attempts < max_attempts => {
tokio::time::sleep(backoff * attempts).await;
continue;
}
Err(e) => {
ctx.record_failure(task_id, &e.to_string()).await?;
ctx.emit(adk_graph::stream::StreamEvent::error(
&e.to_string(),
Some(task_id),
));
break Err(e);
}
}
};
}
} else {
quote! {
let result = match #call_expr {
Ok(value) => Ok(value),
Err(e) => {
ctx.record_failure(task_id, &e.to_string()).await?;
ctx.emit(adk_graph::stream::StreamEvent::error(
&e.to_string(),
Some(task_id),
));
Err(e)
}
};
}
};
let cache_check = if task_attrs.rerun_on_resume {
quote! {}
} else {
quote! {
if let Some(cached_result) = ctx.get_cached_result(task_id).await {
return Ok(cached_result);
}
}
};
let output = quote! {
#input_fn
#fn_vis async fn #wrapper_name(#params) #return_type {
let task_id = #fn_name_str;
#cache_check
let current_step = ctx.current_step().await;
ctx.emit(adk_graph::stream::StreamEvent::node_start(task_id, current_step));
let start = std::time::Instant::now();
#execution_body
if let Ok(ref value) = result {
ctx.record_completion(task_id, value).await?;
let duration = start.elapsed().as_millis() as u64;
let step = ctx.current_step().await;
ctx.emit(adk_graph::stream::StreamEvent::node_end(task_id, step, duration));
}
result
}
};
output.into()
}
struct RetryConfig {
max_attempts: u32,
backoff_secs: u64,
}
struct TaskAttrs {
retry: Option<RetryConfig>,
rerun_on_resume: bool,
}
fn parse_task_attrs(attr: TokenStream) -> TaskAttrs {
if attr.is_empty() {
return TaskAttrs { retry: None, rerun_on_resume: false };
}
let attr_meta: syn::Result<syn::Meta> = syn::parse(attr.clone());
if let Ok(syn::Meta::List(meta_list)) = attr_meta
&& meta_list.path.is_ident("retry")
&& let Some(retry) = parse_retry_from_meta_list(&meta_list)
{
return TaskAttrs { retry: Some(retry), rerun_on_resume: false };
}
let attr2: proc_macro2::TokenStream = attr.into();
let parsed: syn::Result<TaskAttrContent> = syn::parse2(attr2);
if let Ok(content) = parsed {
return TaskAttrs { retry: content.retry, rerun_on_resume: content.rerun_on_resume };
}
TaskAttrs { retry: None, rerun_on_resume: false }
}
struct TaskAttrContent {
retry: Option<RetryConfig>,
rerun_on_resume: bool,
}
impl syn::parse::Parse for TaskAttrContent {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut retry = None;
let mut rerun_on_resume = false;
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
if ident == "retry" {
let content;
syn::parenthesized!(content in input);
let mut max_attempts: u32 = 3;
let mut backoff_secs: u64 = 1;
let pairs = syn::punctuated::Punctuated::<syn::MetaNameValue, syn::Token![,]>::parse_terminated(&content)?;
for pair in pairs {
if pair.path.is_ident("max_attempts")
&& let syn::Expr::Lit(expr_lit) = &pair.value
&& let syn::Lit::Int(lit_int) = &expr_lit.lit
{
max_attempts = lit_int.base10_parse().unwrap_or(3);
} else if pair.path.is_ident("backoff")
&& let syn::Expr::Lit(expr_lit) = &pair.value
&& let syn::Lit::Str(lit_str) = &expr_lit.lit
{
backoff_secs = parse_duration_str(&lit_str.value());
}
}
retry = Some(RetryConfig { max_attempts, backoff_secs });
} else if ident == "rerun_on_resume" {
if input.peek(syn::Token![=]) {
let _eq: syn::Token![=] = input.parse()?;
let lit: syn::LitBool = input.parse()?;
rerun_on_resume = lit.value;
} else {
rerun_on_resume = true;
}
} else {
return Err(syn::Error::new_spanned(
ident,
"unknown task attribute; expected `retry(...)` or `rerun_on_resume`",
));
}
if input.peek(syn::Token![,]) {
let _comma: syn::Token![,] = input.parse()?;
}
}
Ok(TaskAttrContent { retry, rerun_on_resume })
}
}
fn parse_retry_from_meta_list(meta_list: &syn::MetaList) -> Option<RetryConfig> {
let mut max_attempts: u32 = 3;
let mut backoff_secs: u64 = 1;
let pairs: syn::Result<syn::punctuated::Punctuated<syn::MetaNameValue, syn::Token![,]>> =
meta_list.parse_args_with(syn::punctuated::Punctuated::parse_terminated);
if let Ok(pairs) = pairs {
for pair in pairs {
if pair.path.is_ident("max_attempts")
&& let syn::Expr::Lit(expr_lit) = &pair.value
&& let syn::Lit::Int(lit_int) = &expr_lit.lit
{
max_attempts = lit_int.base10_parse().unwrap_or(3);
} else if pair.path.is_ident("backoff")
&& let syn::Expr::Lit(expr_lit) = &pair.value
&& let syn::Lit::Str(lit_str) = &expr_lit.lit
{
backoff_secs = parse_duration_str(&lit_str.value());
}
}
Some(RetryConfig { max_attempts, backoff_secs })
} else {
None
}
}
fn parse_duration_str(s: &str) -> u64 {
let s = s.trim();
if let Some(ms) = s.strip_suffix("ms") {
return ms.parse::<u64>().ok().map(|v| v / 1000).unwrap_or(1);
}
if let Some(secs) = s.strip_suffix('s') {
return secs.parse::<u64>().unwrap_or(1);
}
s.parse::<u64>().unwrap_or(1)
}