use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse2, FnArg, ItemFn, Pat, Type};
use crate::paths;
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(c) => {
let mut result = c.to_uppercase().to_string();
result.extend(chars);
result
}
}
})
.collect()
}
struct FieldParam {
ident: syn::Ident,
ty: Type,
}
fn extract_field_params(func: &ItemFn) -> syn::Result<(Vec<FieldParam>, ItemFn)> {
let mut fields = Vec::new();
let mut clean = func.clone();
for arg in &mut clean.sig.inputs {
if let FnArg::Typed(pt) = arg {
let has_field = pt.attrs.iter().any(|a| a.path().is_ident("field"));
if has_field {
let ident = if let Pat::Ident(pi) = &*pt.pat {
pi.ident.clone()
} else {
return Err(syn::Error::new_spanned(
&pt.pat,
"expected a simple identifier for #[field] parameter",
));
};
fields.push(FieldParam {
ident,
ty: (*pt.ty).clone(),
});
pt.attrs.retain(|a| !a.path().is_ident("field"));
}
}
}
Ok((fields, clean))
}
fn gen_middleware_struct(
vis: &syn::Visibility,
struct_name: &syn::Ident,
fields: &[FieldParam],
) -> TokenStream {
if fields.is_empty() {
quote! { #vis struct #struct_name; }
} else {
let field_defs: Vec<TokenStream> = fields
.iter()
.map(|f| {
let ident = &f.ident;
let ty = &f.ty;
quote! { #ident: #ty }
})
.collect();
quote! { #vis struct #struct_name { #(#field_defs),* } }
}
}
fn gen_middleware_factory(
vis: &syn::Visibility,
fn_name: &syn::Ident,
struct_name: &syn::Ident,
fields: &[FieldParam],
) -> TokenStream {
let mw_crate = paths::middleware_path();
if fields.is_empty() {
quote! {
#vis fn #fn_name() -> ::std::sync::Arc<dyn #mw_crate::AgentMiddleware> {
::std::sync::Arc::new(#struct_name)
}
}
} else {
let params: Vec<TokenStream> = fields
.iter()
.map(|f| {
let ident = &f.ident;
let ty = &f.ty;
quote! { #ident: #ty }
})
.collect();
let inits: Vec<&syn::Ident> = fields.iter().map(|f| &f.ident).collect();
quote! {
#vis fn #fn_name(#(#params),*) -> ::std::sync::Arc<dyn #mw_crate::AgentMiddleware> {
::std::sync::Arc::new(#struct_name { #(#inits),* })
}
}
}
}
fn gen_field_clones(fields: &[FieldParam]) -> Vec<TokenStream> {
fields
.iter()
.map(|f| {
let ident = &f.ident;
quote! { let #ident = self.#ident.clone(); }
})
.collect()
}
fn field_idents(fields: &[FieldParam]) -> Vec<&syn::Ident> {
fields.iter().map(|f| &f.ident).collect()
}
pub fn expand_before_agent(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[before_agent] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("before_agent"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn before_agent(
&self,
messages: &mut Vec<#core_crate::Message>,
) -> Result<(), #core_crate::SynapticError> {
#(#field_clones)*
#impl_fn_name(#(#fidents,)* messages).await
}
}
#factory
})
}
pub fn expand_before_model(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[before_model] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("before_model"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn before_model(
&self,
request: &mut #mw_crate::ModelRequest,
) -> Result<(), #core_crate::SynapticError> {
#(#field_clones)*
#impl_fn_name(#(#fidents,)* request).await
}
}
#factory
})
}
pub fn expand_after_model(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[after_model] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("after_model"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn after_model(
&self,
request: &#mw_crate::ModelRequest,
response: &mut #mw_crate::ModelResponse,
) -> Result<(), #core_crate::SynapticError> {
#(#field_clones)*
#impl_fn_name(#(#fidents,)* request, response).await
}
}
#factory
})
}
pub fn expand_after_agent(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[after_agent] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("after_agent"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn after_agent(
&self,
messages: &mut Vec<#core_crate::Message>,
) -> Result<(), #core_crate::SynapticError> {
#(#field_clones)*
#impl_fn_name(#(#fidents,)* messages).await
}
}
#factory
})
}
pub fn expand_wrap_model_call(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[wrap_model_call] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("wrap_model_call"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn wrap_model_call(
&self,
request: #mw_crate::ModelRequest,
next: &dyn #mw_crate::ModelCaller,
) -> Result<#mw_crate::ModelResponse, #core_crate::SynapticError> {
#(#field_clones)*
#impl_fn_name(#(#fidents,)* request, next).await
}
}
#factory
})
}
pub fn expand_wrap_tool_call(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[wrap_tool_call] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("wrap_tool_call"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn wrap_tool_call(
&self,
request: #mw_crate::ToolCallRequest,
next: &dyn #mw_crate::ToolCaller,
) -> Result<::serde_json::Value, #core_crate::SynapticError> {
#(#field_clones)*
#impl_fn_name(#(#fidents,)* request, next).await
}
}
#factory
})
}
pub fn expand_dynamic_prompt(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[dynamic_prompt] does not accept arguments",
));
}
let func: ItemFn = parse2(item)?;
let (fields, clean_func) = extract_field_params(&func)?;
let fn_name = &func.sig.ident;
let vis = &func.vis;
let struct_name = format_ident!("{}Middleware", to_pascal_case(&fn_name.to_string()));
let impl_fn_name = format_ident!("{}_impl", fn_name);
let mut impl_func = clean_func;
impl_func.sig.ident = impl_fn_name.clone();
impl_func
.attrs
.retain(|a| !a.path().is_ident("dynamic_prompt"));
let struct_def = gen_middleware_struct(vis, &struct_name, &fields);
let factory = gen_middleware_factory(vis, fn_name, &struct_name, &fields);
let field_clones = gen_field_clones(&fields);
let fidents = field_idents(&fields);
let core_crate = paths::core_path();
let mw_crate = paths::middleware_path();
Ok(quote! {
#impl_func
#struct_def
#[::async_trait::async_trait]
impl #mw_crate::AgentMiddleware for #struct_name {
async fn before_model(
&self,
request: &mut #mw_crate::ModelRequest,
) -> Result<(), #core_crate::SynapticError> {
#(#field_clones)*
let prompt = #impl_fn_name(#(#fidents,)* &request.messages);
request.system_prompt = Some(prompt);
Ok(())
}
}
#factory
})
}