use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::Path;
use super::{
naming::{
camel_to_snake, factory_struct_name, last_segment_str, param_struct_name,
vtable_struct_name,
},
params::MethodInfo,
};
pub fn factory_tokens(
trait_path: &Path,
trait_path_str: &str,
factory_description: &str,
methods: &[MethodInfo],
vis: &syn::Visibility,
) -> TokenStream {
let factory_name = factory_struct_name(trait_path_str);
let vtable_name = vtable_struct_name(trait_path_str);
let trait_name_str = trait_path_str;
let method_name_strs: Vec<String> = methods.iter().map(|m| m.name.to_string()).collect();
let last = last_segment_str(trait_path_str);
let map_ident = proc_macro2::Ident::new(
&format!("{}_FACTORY_VTABLES", camel_to_snake(last).to_uppercase()),
Span::call_site(),
);
let descriptor_builders: Vec<TokenStream> = methods
.iter()
.map(|m| {
let method_str = m.name.to_string();
let field = &m.name;
let param_struct = param_struct_name(&method_str);
quote! {
{
let tool_name = format!("{prefix}__{method}", method = #method_str);
let schema_value = ::serde_json::to_value(
::schemars::schema_for!(#param_struct)
).unwrap_or(::serde_json::Value::Object(Default::default()));
let handler = vtable.#field.clone();
::elicitation::DynamicToolDescriptor {
name: tool_name,
description: format!("Call `{}` on {}", #method_str, slot.type_name()),
schema: schema_value,
handler,
}
}
}
})
.collect();
quote! {
static #map_ident: ::std::sync::LazyLock<
::std::sync::RwLock<
::std::collections::HashMap<
::std::any::TypeId,
::std::boxed::Box<#vtable_name>,
>
>
> = ::std::sync::LazyLock::new(Default::default);
#[doc = #factory_description]
#vis struct #factory_name;
impl #factory_name {
pub fn prime<T>()
where
T: #trait_path
+ ::serde::Serialize
+ ::serde::de::DeserializeOwned
+ ::schemars::JsonSchema
+ ::elicitation::Elicitation
+ Send + Sync + 'static,
{
let type_id = ::std::any::TypeId::of::<T>();
let mut map = #map_ident.write().expect("vtable map lock poisoned");
map.entry(type_id).or_insert_with(|| {
Box::new(#vtable_name::for_type::<T>())
});
}
}
impl ::elicitation::AnyToolFactory for #factory_name {
fn trait_name(&self) -> &'static str {
#trait_name_str
}
fn factory_description(&self) -> &'static str {
#factory_description
}
fn method_names(&self) -> &'static [&'static str] {
&[#(#method_name_strs,)*]
}
fn instantiate(
&self,
slot: &dyn ::elicitation::AnyToolSlot,
) -> ::std::result::Result<
::std::vec::Vec<::elicitation::DynamicToolDescriptor>,
::rmcp::ErrorData,
> {
let type_id = slot.slot_type_id();
let map = #map_ident.read().expect("vtable map lock poisoned");
let vtable = map.get(&type_id).ok_or_else(|| {
::rmcp::ErrorData::invalid_params(
format!(
"`{}` has not been primed for type `{}`. \
Call register_type::<T>(prefix) at startup.",
#trait_name_str,
slot.type_name(),
),
None,
)
})?;
let prefix = slot.prefix().to_string();
let descriptors = vec![#(#descriptor_builders,)*];
Ok(descriptors)
}
}
::inventory::submit!(::elicitation::ToolFactoryRegistration {
trait_name: #trait_name_str,
factory: &#factory_name,
});
}
}