use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{Path, ReturnType};
use super::{
naming::{param_struct_name, vtable_struct_name},
params::MethodInfo,
type_map::TypeMap,
};
pub fn vtable_tokens(
trait_path: &Path,
trait_path_str: &str,
methods: &[MethodInfo],
vis: &syn::Visibility,
type_map: &TypeMap,
) -> TokenStream {
let vtable_name = vtable_struct_name(trait_path_str);
let fields: Vec<TokenStream> = methods
.iter()
.map(|m| {
let field = &m.name;
quote! {
#field: ::std::sync::Arc<
dyn Fn(::serde_json::Value)
-> ::futures::future::BoxFuture<
'static,
::std::result::Result<
::rmcp::model::CallToolResult,
::rmcp::ErrorData,
>
>
+ Send + Sync
>,
}
})
.collect();
let closure_fields: Vec<TokenStream> =
methods
.iter()
.map(|m| {
let method_name = &m.name;
let param_struct = param_struct_name(&method_name.to_string());
let param_names: Vec<&syn::Ident> = m.params.iter().map(|p| &p.name).collect();
let param_types: Vec<&syn::Type> = m.params.iter().map(|p| p.ty.as_ref()).collect();
let param_conversions: Vec<TokenStream> = m.params.iter().map(|p| {
let name = &p.name;
let ty = &p.ty;
if super::params::is_str_ref(ty) {
return quote! { let #name = p.#name.as_str(); };
}
let substituted = type_map.apply_to_type(ty);
if substituted.to_token_stream().to_string() != ty.to_token_stream().to_string() {
let field_access = quote! { p.#name };
let conversion = type_map.proxy_decode(field_access, &substituted);
quote! { let #name = #conversion; }
} else {
quote! {
let #name = <#ty as ::elicitation::ElicitProxy>::from_proxy(p.#name);
}
}
}).collect();
let ret_ty = match &m.return_type {
ReturnType::Default => None,
ReturnType::Type(_, ty) => Some(ty.as_ref()),
};
let result_conversion = if let Some(ret) = ret_ty {
type_map.proxy_encode(quote! { result }, ret)
} else {
quote! { ::elicitation::ElicitProxy::into_proxy(result) }
};
let call = if m.has_self {
quote! {
#method_name: ::std::sync::Arc::new(|params| {
::std::boxed::Box::pin(async move {
let p: #param_struct = ::serde_json::from_value(params)
.map_err(|e| ::rmcp::ErrorData::invalid_params(
e.to_string(), None
))?;
let target: T = ::serde_json::from_value(p.target)
.map_err(|e| ::rmcp::ErrorData::invalid_params(
format!("failed to deserialize target: {e}"), None
))?;
#(#param_conversions)*
let result = <T as #trait_path>::#method_name(
&target,
#(#param_names,)*
);
let proxied = #result_conversion;
let text = ::serde_json::to_string(&proxied)
.map_err(|e| ::rmcp::ErrorData::internal_error(
e.to_string(), None
))?;
Ok(::rmcp::model::CallToolResult::success(vec![
::rmcp::model::Content::new(
::rmcp::model::RawContent::text(text),
None,
)
]))
})
}),
}
} else {
quote! {
#method_name: ::std::sync::Arc::new(|params| {
::std::boxed::Box::pin(async move {
let p: #param_struct = ::serde_json::from_value(params)
.map_err(|e| ::rmcp::ErrorData::invalid_params(
e.to_string(), None
))?;
#(#param_conversions)*
let result = <T as #trait_path>::#method_name(#(#param_names,)*);
let proxied = #result_conversion;
let text = ::serde_json::to_string(&proxied)
.map_err(|e| ::rmcp::ErrorData::internal_error(
e.to_string(), None
))?;
Ok(::rmcp::model::CallToolResult::success(vec![
::rmcp::model::Content::new(
::rmcp::model::RawContent::text(text),
None,
)
]))
})
}),
}
};
let _ = param_types;
call
})
.collect();
let vtable_doc = format!("Vtable holding dispatch closures for `{trait_path_str}` tools.");
quote! {
#[doc = #vtable_doc]
#vis struct #vtable_name {
#(#fields)*
}
impl #vtable_name {
fn for_type<T>() -> Self
where
T: #trait_path
+ ::serde::Serialize
+ ::serde::de::DeserializeOwned
+ ::schemars::JsonSchema
+ ::elicitation::Elicitation
+ Send + Sync + 'static,
{
Self {
#(#closure_fields)*
}
}
}
}
}