use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse2, Data, DeriveInput, Error, Fields, GenericArgument, PathArguments, Result, Type};
pub fn expand_tool_input(input: TokenStream) -> Result<TokenStream> {
let input: DeriveInput = parse2(input)?;
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
Fields::Unnamed(_) => {
return Err(Error::new_spanned(
name,
"ToolInput can only be derived for structs with named fields",
))
}
Fields::Unit => {
return Err(Error::new_spanned(
name,
"ToolInput cannot be derived for unit structs",
))
}
},
Data::Enum(_) => {
return Err(Error::new_spanned(
name,
"ToolInput can only be derived for structs, not enums",
))
}
Data::Union(_) => {
return Err(Error::new_spanned(
name,
"ToolInput can only be derived for structs, not unions",
))
}
};
let mut property_schemas = Vec::new();
let mut required_fields = Vec::new();
for field in fields {
let field_name = field
.ident
.as_ref()
.expect("expected named field but found tuple field - this should not happen");
let field_name_str = field_name.to_string();
let field_ty = &field.ty;
let description = extract_doc_comment(&field.attrs);
let is_optional = is_option_type(field_ty);
let type_schema = type_to_json_schema(field_ty);
let schema_with_desc = if let Some(desc) = &description {
quote! {
{
let mut schema = #type_schema;
if let serde_json::Value::Object(ref mut obj) = schema {
obj.insert("description".to_string(), serde_json::Value::String(#desc.to_string()));
}
schema
}
}
} else {
type_schema
};
property_schemas.push(quote! {
properties.insert(#field_name_str.to_string(), #schema_with_desc);
});
if !is_optional {
required_fields.push(field_name_str);
}
}
let required_array = if required_fields.is_empty() {
quote!(serde_json::Value::Null)
} else {
quote! {
serde_json::Value::Array(
vec![#(serde_json::Value::String(#required_fields.to_string())),*]
)
}
};
let struct_name_str = name.to_string();
Ok(quote! {
impl #impl_generics #name #ty_generics #where_clause {
pub fn tool_input_schema() -> serde_json::Value {
let mut properties = serde_json::Map::new();
#(#property_schemas)*
let mut schema = serde_json::json!({
"type": "object",
"title": #struct_name_str,
});
if let serde_json::Value::Object(ref mut obj) = schema {
obj.insert("properties".to_string(), serde_json::Value::Object(properties));
let required = #required_array;
if required != serde_json::Value::Null {
obj.insert("required".to_string(), required);
}
}
schema
}
}
})
}
fn extract_doc_comment(attrs: &[syn::Attribute]) -> Option<String> {
let docs: Vec<String> = attrs
.iter()
.filter_map(|attr| {
if attr.path().is_ident("doc") {
if let syn::Meta::NameValue(nv) = &attr.meta {
if let syn::Expr::Lit(lit) = &nv.value {
if let syn::Lit::Str(s) = &lit.lit {
return Some(s.value().trim().to_string());
}
}
}
}
None
})
.collect();
if docs.is_empty() {
None
} else {
Some(docs.join(" "))
}
}
fn is_option_type(ty: &Type) -> bool {
if let Type::Path(path) = ty {
if let Some(segment) = path.path.segments.last() {
return segment.ident == "Option";
}
}
false
}
fn get_option_inner_type(ty: &Type) -> Option<&Type> {
if let Type::Path(path) = ty {
if let Some(segment) = path.path.segments.last() {
if segment.ident == "Option" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
return Some(inner);
}
}
}
}
}
None
}
fn type_to_json_schema(ty: &Type) -> TokenStream {
if let Type::Path(path) = ty {
let path_str = quote!(#path).to_string().replace(' ', "");
match path_str.as_str() {
"String" | "&str" | "str" => quote!(serde_json::json!({"type": "string"})),
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64"
| "u128" | "usize" => {
quote!(serde_json::json!({"type": "integer"}))
}
"f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
"bool" => quote!(serde_json::json!({"type": "boolean"})),
_ if path_str.starts_with("Option<") => {
if let Some(inner) = get_option_inner_type(ty) {
let inner_schema = type_to_json_schema(inner);
return inner_schema;
}
quote!(serde_json::json!({}))
}
_ if path_str.starts_with("Vec<") => {
if let Some(segment) = path.path.segments.last() {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
let inner_schema = type_to_json_schema(inner);
return quote! {
serde_json::json!({
"type": "array",
"items": #inner_schema
})
};
}
}
}
quote!(serde_json::json!({"type": "array"}))
}
_ if path_str.starts_with("HashMap<") || path_str.starts_with("BTreeMap<") => {
quote!(serde_json::json!({
"type": "object",
"additionalProperties": true
}))
}
_ => {
quote!(serde_json::json!({"type": "object"}))
}
}
} else {
quote!(serde_json::json!({}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_is_option_type() {
let ty: Type = parse_quote!(Option<String>);
assert!(is_option_type(&ty));
let ty: Type = parse_quote!(String);
assert!(!is_option_type(&ty));
}
#[test]
fn test_extract_doc_comment() {
let attrs: Vec<syn::Attribute> = vec![parse_quote!(#[doc = " This is a test "])];
assert_eq!(
extract_doc_comment(&attrs),
Some("This is a test".to_string())
);
}
}