strands_agents_macros/
lib.rs

1//! Procedural macros for Strands agents.
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, Attribute, FnArg, ItemFn, Pat};
6
7/// Transforms a function into a Strands agent tool.
8///
9/// # Example
10///
11/// ```rust,ignore
12/// use strands::tool;
13///
14/// #[tool]
15/// /// Get the current weather for a location.
16/// async fn get_weather(location: String, units: Option<String>) -> String {
17///     format!("Weather in {}: 72°F", location)
18/// }
19/// ```
20///
21/// The generated tool can be used with an agent:
22///
23/// ```rust,ignore
24/// let agent = Agent::builder()
25///     .model(BedrockModel::default())
26///     .tools(vec![GetWeatherTool::new()])
27///     .build()?;
28/// ```
29#[proc_macro_attribute]
30pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
31    let input_fn = parse_macro_input!(item as ItemFn);
32
33    let fn_name = &input_fn.sig.ident;
34    let fn_name_str = fn_name.to_string();
35    let struct_name = format_ident!("{}Tool", to_pascal_case(&fn_name_str));
36
37    let description = extract_doc_comment(&input_fn.attrs);
38
39    let params: Vec<_> = input_fn
40        .sig
41        .inputs
42        .iter()
43        .filter_map(|arg| {
44            if let FnArg::Typed(pat_type) = arg {
45                if let Pat::Ident(pat_ident) = &*pat_type.pat {
46                    let name = pat_ident.ident.to_string();
47                    if name == "tool_context" || name == "agent" {
48                        return None;
49                    }
50                    let ty = &pat_type.ty;
51                    return Some((name, quote!(#ty).to_string()));
52                }
53            }
54            None
55        })
56        .collect();
57
58    let param_schemas: Vec<_> = params
59        .iter()
60        .map(|(name, _ty)| {
61            quote! {
62                properties.insert(
63                    #name.to_string(),
64                    serde_json::json!({
65                        "type": "string",
66                        "description": format!("Parameter {}", #name)
67                    })
68                );
69            }
70        })
71        .collect();
72
73    let required_params: Vec<_> = params
74        .iter()
75        .filter(|(_, ty)| !ty.contains("Option"))
76        .map(|(name, _)| name.clone())
77        .collect();
78
79    let is_async = input_fn.sig.asyncness.is_some();
80
81    let param_names: Vec<_> = params
82        .iter()
83        .map(|(name, _)| format_ident!("{}", name))
84        .collect();
85
86    let execute_call = if is_async {
87        quote! { #fn_name(#(#param_names),*).await }
88    } else {
89        quote! { #fn_name(#(#param_names),*) }
90    };
91
92    let param_extractions: Vec<_> = params
93        .iter()
94        .map(|(name, ty)| {
95            let ident = format_ident!("{}", name);
96            if ty.contains("Option") {
97                quote! {
98                    let #ident = tool_use.input.get(#name)
99                        .and_then(|v| serde_json::from_value(v.clone()).ok());
100                }
101            } else {
102                quote! {
103                    let #ident = tool_use.input.get(#name)
104                        .and_then(|v| serde_json::from_value(v.clone()).ok())
105                        .ok_or_else(|| format!("Missing required parameter: {}", #name))?;
106                }
107            }
108        })
109        .collect();
110
111    let expanded = quote! {
112        #input_fn
113
114        #[derive(Clone)]
115        pub struct #struct_name;
116
117        impl #struct_name {
118            pub fn new() -> Self { Self }
119        }
120
121        impl Default for #struct_name {
122            fn default() -> Self { Self::new() }
123        }
124
125        impl strands::tools::AgentTool for #struct_name {
126            fn tool_name(&self) -> &str {
127                #fn_name_str
128            }
129
130            fn tool_spec(&self) -> strands::types::tools::ToolSpec {
131                let mut properties = std::collections::HashMap::new();
132                #(#param_schemas)*
133
134                let required: Vec<String> = vec![#(#required_params.to_string()),*];
135
136                strands::types::tools::ToolSpec {
137                    name: #fn_name_str.to_string(),
138                    description: #description.to_string(),
139                    input_schema: strands::types::tools::InputSchema {
140                        json: serde_json::json!({
141                            "type": "object",
142                            "properties": properties,
143                            "required": required
144                        }),
145                    },
146                    output_schema: None,
147                }
148            }
149
150            fn tool_type(&self) -> &str {
151                "function"
152            }
153
154            fn stream(
155                &self,
156                tool_use: &strands::types::tools::ToolUse,
157                _invocation_state: &strands::tools::InvocationState,
158            ) -> strands::tools::ToolGenerator {
159                let tool_use = tool_use.clone();
160                Box::pin(async_stream::stream! {
161                    let result: Result<String, String> = (|| async {
162                        #(#param_extractions)*
163                        let output = #execute_call;
164                        Ok(output.to_string())
165                    })().await;
166
167                    let tool_result = match result {
168                        Ok(text) => strands::types::tools::ToolResult::success(
169                            &tool_use.tool_use_id,
170                            text,
171                        ),
172                        Err(e) => strands::types::tools::ToolResult::error(
173                            &tool_use.tool_use_id,
174                            e,
175                        ),
176                    };
177                    yield strands::tools::ToolEvent::Result(tool_result);
178                })
179            }
180        }
181    };
182
183    TokenStream::from(expanded)
184}
185
186fn extract_doc_comment(attrs: &[Attribute]) -> String {
187    let mut doc_lines = Vec::new();
188
189    for attr in attrs {
190        if attr.path().is_ident("doc") {
191            if let syn::Meta::NameValue(meta) = &attr.meta {
192                if let syn::Expr::Lit(expr_lit) = &meta.value {
193                    if let syn::Lit::Str(lit_str) = &expr_lit.lit {
194                        doc_lines.push(lit_str.value().trim().to_string());
195                    }
196                }
197            }
198        }
199    }
200
201    let mut description = Vec::new();
202    for line in doc_lines {
203        let lower = line.to_lowercase();
204        if lower.starts_with("# arg") || lower.starts_with("args:") || lower.starts_with("arguments:") {
205            break;
206        }
207        description.push(line);
208    }
209
210    description.join(" ").trim().to_string()
211}
212
213fn to_pascal_case(s: &str) -> String {
214    s.split('_')
215        .map(|word| {
216            let mut chars = word.chars();
217            match chars.next() {
218                Some(c) => c.to_uppercase().chain(chars).collect(),
219                None => String::new(),
220            }
221        })
222        .collect()
223}