use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::ItemImpl;
use crate::agentic::helpers::{
AgentConfigAttrRemover, Asyncness, FunctionOutputInfo, get_asyncness, has_agent_config_attr,
has_async_trait_attr, is_constructor_method, is_static_method, trim_type_parameter,
};
use syn::visit_mut::VisitMut;
pub fn agent_implementation_impl(_attrs: TokenStream, item: TokenStream) -> TokenStream {
let mut impl_block = match parse_impl_block(&item) {
Ok(b) => b,
Err(e) => return e.to_compile_error().into(),
};
let has_async_trait_attribute = has_async_trait_attr(&impl_block);
if has_async_trait_attribute {
return syn::Error::new_spanned(
&impl_block.self_ty,
"#[async_trait] cannot be used along with #[agent_implementation]. #[agent_implementation] automatically handles async methods. Please remove it",
)
.to_compile_error()
.into();
}
let (impl_generics, ty_generics, where_clause) = impl_block.generics.split_for_impl();
let self_ty = &impl_block.self_ty;
let (trait_name_ident, trait_name_str_raw) = extract_trait_name(&impl_block);
let (match_arms, constructor_method) =
build_match_arms(&impl_block, trait_name_str_raw.to_string());
let constructor_method = match constructor_method {
Some(m) => m,
None => {
return syn::Error::new_spanned(
&impl_block.self_ty,
"No constructor found (a function returning Self is required)",
)
.to_compile_error()
.into();
}
};
let has_load_snapshot = impl_block.items.iter().any(|item| {
if let syn::ImplItem::Fn(method) = item {
method.sig.ident == "load_snapshot"
} else {
false
}
});
let has_save_snapshot = impl_block.items.iter().any(|item| {
if let syn::ImplItem::Fn(method) = item {
method.sig.ident == "save_snapshot"
} else {
false
}
});
if has_load_snapshot != has_save_snapshot {
return syn::Error::new_spanned(
&impl_block.self_ty,
"Both load_snapshot and save_snapshot must be implemented together, or neither should be implemented",
)
.to_compile_error()
.into();
}
let has_custom_snapshot = has_load_snapshot && has_save_snapshot;
let ctor_ident = &constructor_method.sig.ident;
let ctor_param_idents_and_types = extract_param_idents(constructor_method);
let ctor_param_idents: Vec<syn::Ident> = ctor_param_idents_and_types
.iter()
.map(|(ident, _)| ident.clone())
.collect();
let base_agent_impl = generate_base_agent_impl(
&impl_block,
&match_arms,
&trait_name_str_raw,
&impl_generics,
&ty_generics,
where_clause,
has_custom_snapshot,
);
let constructor_kind = get_asyncness(&constructor_method.sig);
let constructor_param_extraction_call_back = match constructor_kind {
Asyncness::Future => {
quote! {
let agent_instance_raw = <#self_ty>::#ctor_ident(#(#ctor_param_idents),*).await;
let agent_instance = Box::new(agent_instance_raw);
let agent_id = golem_rust::bindings::golem::api::host::get_self_metadata().agent_id;
golem_rust::agentic::register_agent_instance(
golem_rust::agentic::ResolvedAgent::new(agent_instance)
);
Ok(())
}
}
Asyncness::Immediate => {
quote! {
let agent_instance = Box::new(<#self_ty>::#ctor_ident(#(#ctor_param_idents),*));
let agent_id = golem_rust::bindings::golem::api::host::get_self_metadata().agent_id;
golem_rust::agentic::register_agent_instance(
golem_rust::agentic::ResolvedAgent::new(agent_instance)
);
Ok(())
}
}
};
let constructor_param_extraction = generate_constructor_extraction(
&ctor_param_idents_and_types,
&trait_name_str_raw,
constructor_param_extraction_call_back,
);
let initiator_ident = format_ident!("__{}Initiator", trait_name_ident);
let base_initiator_impl =
generate_initiator_impl(&initiator_ident, &constructor_param_extraction);
let register_initiator_fn =
generate_register_initiator_fn(&impl_block.self_ty, &trait_name_ident, &initiator_ident);
AgentConfigAttrRemover.visit_item_impl_mut(&mut impl_block);
quote! {
#impl_block
#base_agent_impl
#base_initiator_impl
#register_initiator_fn
}
.into()
}
fn parse_impl_block(item: &TokenStream) -> syn::Result<ItemImpl> {
syn::parse::<ItemImpl>(item.clone())
}
fn extract_trait_name(impl_block: &syn::ItemImpl) -> (syn::Ident, String) {
let trait_name = if let Some((_bang, path, _for_token)) = &impl_block.trait_ {
path.segments.last().unwrap().ident.clone()
} else {
panic!("Expected trait implementation, found none");
};
let trait_name_str_raw = trait_name.to_string();
(trait_name, trait_name_str_raw)
}
fn extract_param_idents(method: &syn::ImplItemFn) -> Vec<(syn::Ident, syn::PatType)> {
method
.sig
.inputs
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_ty) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_ty.pat {
Some((pat_ident.ident.clone(), pat_ty.clone()))
} else {
None
}
} else {
None
}
})
.collect()
}
fn build_match_arms(
impl_block: &ItemImpl,
agent_type_name: String,
) -> (Vec<proc_macro2::TokenStream>, Option<&syn::ImplItemFn>) {
let mut match_arms = Vec::new();
let mut constructor_method = None;
struct MethodInfo<'a> {
method: &'a syn::ImplItemFn,
name: String,
param_idents: Vec<syn::Ident>,
}
let mut eligible_methods: Vec<MethodInfo> = Vec::new();
for item in &impl_block.items {
if let syn::ImplItem::Fn(method) = item {
let self_ty = &impl_block.self_ty;
let agent_impl_type_name = match &**self_ty {
syn::Type::Path(type_path) => {
type_path.path.segments.last().unwrap().ident.to_string()
}
_ => String::new(),
};
if is_constructor_method(&method.sig, Some(&agent_impl_type_name)) {
constructor_method = Some(method);
continue;
}
if is_static_method(&method.sig) {
continue;
}
if method.sig.ident == "load_snapshot" || method.sig.ident == "save_snapshot" {
continue;
}
let name = method.sig.ident.to_string();
let param_idents: Vec<syn::Ident> = extract_param_idents(method)
.into_iter()
.map(|(ident, _)| ident)
.collect();
eligible_methods.push(MethodInfo {
method,
name,
param_idents,
});
}
}
eligible_methods.sort_by(|a, b| a.name.cmp(&b.name));
for (sorted_method_index, info) in eligible_methods.iter().enumerate() {
let method_name = &info.name;
let param_idents = &info.param_idents;
let ident = &info.method.sig.ident;
let fn_output_info = FunctionOutputInfo::from_signature(&info.method.sig);
let post_method_param_extraction_logic = match fn_output_info.async_ness {
Asyncness::Future if !fn_output_info.is_unit => quote! {
let result = self.#ident(#(#param_idents),*).await;
<_ as golem_rust::agentic::Schema>::to_structured_value(result).map_err(|e| {
golem_rust::agentic::custom_error(format!(
"Failed serializing return value for method {}: {}",
#method_name, e
))
}).and_then(|result_value| {
match result_value {
golem_rust::agentic::StructuredValue::Default(element_value) => {
Ok(golem_rust::golem_agentic::golem::agent::common::DataValue::Tuple(vec![element_value]))
},
golem_rust::agentic::StructuredValue::Multimodal(result) => {
Ok(golem_rust::golem_agentic::golem::agent::common::DataValue::Multimodal(result))
},
golem_rust::agentic::StructuredValue::AutoInjected(_) => {
Err(golem_rust::agentic::custom_error(format!(
"Principal value cannot be returned from method {}",
#method_name
)))
}
}
})
},
Asyncness::Future => quote! {
let _ = self.#ident(#(#param_idents),*).await;
Ok(golem_rust::golem_agentic::golem::agent::common::DataValue::Tuple(vec![]))
},
Asyncness::Immediate if !fn_output_info.is_unit => quote! {
let result = self.#ident(#(#param_idents),*);
<_ as golem_rust::agentic::Schema>::to_structured_value(result).map_err(|e| {
golem_rust::agentic::custom_error(format!(
"Failed serializing return value for method {}: {}",
#method_name, e
))
}).and_then(|result_val| {
match result_val {
golem_rust::agentic::StructuredValue::Default(element_value) => {
Ok(golem_rust::golem_agentic::golem::agent::common::DataValue::Tuple(vec![element_value]))
},
golem_rust::agentic::StructuredValue::Multimodal(result) => {
Ok(golem_rust::golem_agentic::golem::agent::common::DataValue::Multimodal(result))
},
golem_rust::agentic::StructuredValue::AutoInjected(_) => {
Err(golem_rust::agentic::custom_error(format!(
"Principal value cannot be returned from method {}",
#method_name
)))
}
}
})
},
Asyncness::Immediate => quote! {
let _ = self.#ident(#(#param_idents),*);
Ok(golem_rust::golem_agentic::golem::agent::common::DataValue::Tuple(vec![]))
},
};
let method_param_extraction = generate_method_param_extraction(
param_idents,
&agent_type_name,
method_name.as_str(),
sorted_method_index,
post_method_param_extraction_logic,
);
match_arms.push(quote! {
#method_name => {
#method_param_extraction
}
});
}
(match_arms, constructor_method)
}
fn generate_method_param_extraction(
param_idents: &[syn::Ident],
agent_type_name: &str,
method_name: &str,
sorted_method_index: usize,
post_method_param_extraction_logic: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let input_param_index_init = quote! {
let mut input_param_index: usize = 0;
let __agent_type_name = golem_rust::agentic::AgentTypeName(#agent_type_name.to_string());
let __param_schemas = golem_rust::agentic::get_method_parameter_types_by_index(
&__agent_type_name,
#sorted_method_index
).ok_or_else(|| {
golem_rust::agentic::custom_error(format!(
"Internal Error: Parameter schemas not found for agent: {}, method index: {}",
#agent_type_name, #sorted_method_index
))
})?;
};
let extraction: Vec<proc_macro2::TokenStream> = param_idents.iter().enumerate().map(|(original_method_param_idx, ident)| {
let ident_result = format_ident!("{}_result", ident);
quote! {
let #ident_result = match &mut __input_variant {
__InputVariant::Tuple(values) => {
let enriched_schema = __param_schemas.get(#original_method_param_idx)
.cloned()
.ok_or_else(|| {
golem_rust::agentic::custom_error(format!(
"Internal Error: Parameter schema not found for agent: {}, method: {}, parameter index: {}",
#agent_type_name, #method_name, #original_method_param_idx
))
})?;
match enriched_schema {
golem_rust::agentic::EnrichedElementSchema::AutoInject(auto_injected_schema) => {
match auto_injected_schema {
golem_rust::agentic::AutoInjectedParamType::Principal => {
golem_rust::agentic::Schema::from_structured_value(golem_rust::agentic::StructuredValue::AutoInjected(golem_rust::agentic::AutoInjectedValue::Principal(principal.clone())), golem_rust::agentic::StructuredSchema::AutoInject(golem_rust::agentic::AutoInjectedParamType::Principal)).map_err(|e| {
golem_rust::agentic::invalid_input_error(format!("Failed parsing arg {} for method {}: {}", #original_method_param_idx, #method_name, e))
})
}
}
}
golem_rust::agentic::EnrichedElementSchema::ElementSchema(element_schema) => {
let element_value = if input_param_index < values.len() {
values[input_param_index].take().ok_or_else(|| {
golem_rust::agentic::invalid_input_error(format!("Argument already consumed in method {}", #method_name))
})?
} else {
return Err(golem_rust::agentic::invalid_input_error(format!("Missing arguments in method {}", #method_name)));
};
input_param_index += 1;
golem_rust::agentic::Schema::from_structured_value(golem_rust::agentic::StructuredValue::Default(element_value), golem_rust::agentic::StructuredSchema::Default(element_schema)).map_err(|e| {
golem_rust::agentic::invalid_input_error(format!("Failed parsing arg {} for method {}: {}", #original_method_param_idx, #method_name, e))
})
}
}
},
__InputVariant::Multimodal(elements) => {
let deserialized_value = golem_rust::agentic::Schema::from_structured_value(golem_rust::agentic::StructuredValue::Multimodal(elements.take().unwrap_or_default()), golem_rust::agentic::StructuredSchema::Multimodal(vec![])).map_err(|e| {
golem_rust::agentic::invalid_input_error(format!("Failed parsing arg {} for method {}: {}", #original_method_param_idx, #method_name, e))
})?;
Ok(deserialized_value)
}
};
let #ident = #ident_result?;
}
}).collect();
quote! {
enum __InputVariant {
Tuple(Vec<Option<golem_rust::golem_agentic::golem::agent::common::ElementValue>>),
Multimodal(Option<Vec<(String, golem_rust::golem_agentic::golem::agent::common::ElementValue)>>),
}
let mut __input_variant = match input {
golem_rust::golem_agentic::golem::agent::common::DataValue::Tuple(values) => {
__InputVariant::Tuple(values.into_iter().map(Some).collect())
},
golem_rust::golem_agentic::golem::agent::common::DataValue::Multimodal(elements) => {
__InputVariant::Multimodal(Some(elements))
},
};
#input_param_index_init
#(#extraction)*
#post_method_param_extraction_logic
}
}
fn generate_base_agent_impl(
impl_block: &syn::ItemImpl,
match_arms: &[proc_macro2::TokenStream],
trait_name_str: &str,
impl_generics: &syn::ImplGenerics<'_>,
ty_generics: &syn::TypeGenerics<'_>,
where_clause: Option<&syn::WhereClause>,
has_custom_snapshot: bool,
) -> proc_macro2::TokenStream {
let self_ty = &impl_block.self_ty;
let snapshot_impl = if has_custom_snapshot {
quote! {
async fn load_snapshot_base(&mut self, bytes: Vec<u8>) -> Result<(), String> {
self.load_snapshot(bytes).await
}
async fn save_snapshot_base(&self) -> Result<golem_rust::agentic::SnapshotData, String> {
let data = self.save_snapshot().await?;
Ok(golem_rust::agentic::SnapshotData {
data,
mime_type: "application/octet-stream".to_string(),
})
}
}
} else {
quote! {
async fn load_snapshot_base(&mut self, bytes: Vec<u8>) -> Result<(), String> {
use golem_rust::agentic::snapshot_auto::SnapshotLoadFallback;
let mut helper = golem_rust::agentic::snapshot_auto::LoadHelper(self);
helper.snapshot_load(&bytes)
}
async fn save_snapshot_base(&self) -> Result<golem_rust::agentic::SnapshotData, String> {
use golem_rust::agentic::snapshot_auto::SnapshotSaveFallback;
let helper = golem_rust::agentic::snapshot_auto::SaveHelper(self);
helper.snapshot_save()
}
}
};
quote! {
#[golem_rust::async_trait::async_trait(?Send)]
impl #impl_generics golem_rust::agentic::BaseAgent for #self_ty #ty_generics #where_clause {
fn get_agent_id(&self) -> String {
golem_rust::agentic::get_agent_id().agent_id
}
async fn invoke(&mut self, method_name: String, input: golem_rust::golem_agentic::golem::agent::common::DataValue, principal: golem_rust::golem_agentic::golem::agent::common::Principal)
-> Result<golem_rust::golem_agentic::golem::agent::common::DataValue, golem_rust::golem_agentic::golem::agent::common::AgentError> {
match method_name.as_str() {
#(#match_arms,)*
_ => Err(golem_rust::agentic::invalid_method_error(method_name)),
}
}
fn get_definition(&self)
-> golem_rust::golem_agentic::golem::agent::common::AgentType {
golem_rust::agentic::get_agent_type_by_name(&golem_rust::agentic::AgentTypeName(#trait_name_str.to_string()))
.expect("Agent definition not found")
}
#snapshot_impl
}
}
}
fn generate_constructor_extraction(
ctor_params: &[(syn::Ident, syn::PatType)],
agent_type_name: &str,
call_back: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let mut config_extractions = Vec::new();
let mut predecls = Vec::new();
let mut tuple_extractions = Vec::new();
let mut schema_param_index: usize = 0;
for (ident, pat_type) in ctor_params {
if has_agent_config_attr(pat_type) {
let ty = &pat_type.ty;
config_extractions.push(quote! {
let #ident: #ty = ::golem_rust::agentic::Config::new();
});
} else {
let ty = &pat_type.ty;
let idx = schema_param_index;
predecls.push(quote! {
let #ident: #ty;
});
tuple_extractions.push(quote! {
{
let enriched_schema = __ctor_schemas.get(#idx)
.cloned()
.ok_or_else(|| {
golem_rust::agentic::internal_error(format!(
"Constructor parameter schema not found for agent: {}, parameter index: {}",
#agent_type_name, #idx
))
})?;
match enriched_schema {
golem_rust::agentic::EnrichedElementSchema::AutoInject(auto_injected_schema) => {
match auto_injected_schema {
golem_rust::agentic::AutoInjectedParamType::Principal => {
#ident = golem_rust::agentic::Schema::from_structured_value(golem_rust::agentic::StructuredValue::AutoInjected(golem_rust::agentic::AutoInjectedValue::Principal(principal.clone())), golem_rust::agentic::StructuredSchema::AutoInject(golem_rust::agentic::AutoInjectedParamType::Principal)).map_err(|e| {
golem_rust::agentic::invalid_input_error(format!("Failed parsing constructor arg {}: {}", #idx, e))
})?;
}
}
}
golem_rust::agentic::EnrichedElementSchema::ElementSchema(element_schema) => {
let element_value = if input_param_index < values.len() {
values[input_param_index].take().ok_or_else(|| {
golem_rust::agentic::invalid_input_error(format!("Constructor argument already consumed for agent {}", #agent_type_name))
})?
} else {
return Err(golem_rust::agentic::invalid_input_error(format!("Missing constructor arguments for agent {}", #agent_type_name)));
};
input_param_index += 1;
#ident = golem_rust::agentic::Schema::from_structured_value(golem_rust::agentic::StructuredValue::Default(element_value), golem_rust::agentic::StructuredSchema::Default(element_schema)).map_err(|e| {
golem_rust::agentic::invalid_input_error(format!("Failed parsing constructor arg {}: {}", #idx, e))
})?;
}
}
}
});
schema_param_index += 1;
}
}
quote! {
let __agent_type_name = golem_rust::agentic::AgentTypeName(#agent_type_name.to_string());
#(#config_extractions)*
#(#predecls)*
match params {
golem_rust::golem_agentic::golem::agent::common::DataValue::Tuple(values) => {
let mut values: Vec<Option<golem_rust::golem_agentic::golem::agent::common::ElementValue>> = values.into_iter().map(Some).collect();
let mut input_param_index: usize = 0;
let __ctor_schemas = golem_rust::agentic::get_constructor_parameter_types(
&__agent_type_name,
).ok_or_else(|| {
golem_rust::agentic::internal_error(format!(
"Constructor parameter schemas not found for agent: {}",
#agent_type_name
))
})?;
#(#tuple_extractions)*
},
golem_rust::golem_agentic::golem::agent::common::DataValue::Multimodal(_) => {
return Err(golem_rust::agentic::internal_error("Multimodal constructor input not supported currently"));
},
}
#call_back
}
}
fn generate_initiator_impl(
initiator_ident: &syn::Ident,
constructor_param_extraction: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
quote! {
struct #initiator_ident;
#[golem_rust::async_trait::async_trait(?Send)]
impl golem_rust::agentic::AgentInitiator for #initiator_ident {
async fn initiate(&self, params: golem_rust::golem_agentic::golem::agent::common::DataValue, principal: golem_rust::golem_agentic::golem::agent::common::Principal)
-> Result<(), golem_rust::golem_agentic::golem::agent::common::AgentError> {
#constructor_param_extraction
}
}
}
}
fn generate_register_initiator_fn(
self_ty: &syn::Type,
agent_trait_ident: &syn::Ident,
initiator_ident: &syn::Ident,
) -> proc_macro2::TokenStream {
let agent_impl_type_trimmed = trim_type_parameter(self_ty);
let agent_impl_type_trimmed_ident = format_ident!("{}", agent_impl_type_trimmed);
let agent_trait_name = agent_trait_ident.to_string();
let register_initiator_fn_name = format_ident!(
"__register_agent_initiator_{}",
agent_trait_ident.to_string().to_lowercase()
);
quote! {
::golem_rust::ctor::__support::ctor_parse!(
#[ctor] fn #register_initiator_fn_name() {
#agent_impl_type_trimmed_ident::__register_agent_type();
golem_rust::agentic::register_agent_initiator(
&#agent_trait_name,
std::sync::Arc::new(#initiator_ident)
);
}
);
}
}