use darling::FromMeta;
use proc_macro2::{Literal, TokenStream};
use quote::{format_ident, quote};
use syn::{Item, ItemMod};
use crate::tool::{self, ToolAttrs, ToolInfo};
fn expand_include_item(mac_item: &syn::ItemMacro) -> Option<Vec<Item>> {
if !mac_item.mac.path.is_ident("include") {
return None;
}
let lit: syn::LitStr = syn::parse2(mac_item.mac.tokens.clone()).ok()?;
let file_path_str = lit.value();
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").ok()?;
let manifest_path = std::path::Path::new(&manifest_dir);
let in_src = manifest_path.join("src").join(&file_path_str);
let at_root = manifest_path.join(&file_path_str);
let full_path = if in_src.exists() {
in_src
} else if at_root.exists() {
at_root
} else {
return None;
};
let content = std::fs::read_to_string(&full_path).ok()?;
let file: syn::File = syn::parse_str(&content).ok()?;
Some(file.items)
}
#[derive(Debug, FromMeta)]
pub struct ComponentAttrs {
#[darling(default)]
pub manifest: Option<String>,
#[darling(default)]
pub name: Option<String>,
#[darling(default)]
pub version: Option<String>,
#[darling(default)]
pub description: Option<String>,
#[darling(default)]
pub default_language: Option<String>,
}
pub fn generate(attrs: ComponentAttrs, module: &ItemMod) -> syn::Result<TokenStream> {
let tools = extract_tools(module)?;
let session_hooks = extract_session_hooks(module)?;
let user_items = collect_user_items(module);
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
let manifest_file = attrs.manifest.as_deref().unwrap_or("act.toml");
let manifest_path = std::path::Path::new(&manifest_dir).join(manifest_file);
let manifest = crate::manifest::read_manifest(&manifest_path).unwrap_or_else(|e| panic!("{e}"));
let overrides = crate::manifest::Overrides {
name: attrs.name,
version: attrs.version,
description: attrs.description,
default_language: attrs.default_language,
};
let info = crate::manifest::build_component_info(manifest, overrides);
let default_lang = info.std.default_language.as_deref().unwrap_or("en");
let comp_version = info.std.version.clone();
let comp_description = info.std.description.clone();
let mut cbor_buf = Vec::new();
ciborium::into_writer(&info, &mut cbor_buf).expect("CBOR encoding failed");
let act_component_cbor = cbor_buf;
let cbor_len = act_component_cbor.len();
let cbor_literal = Literal::byte_string(&act_component_cbor);
let version_len = comp_version.len();
let version_literal = Literal::byte_string(comp_version.as_bytes());
let description_len = comp_description.len();
let description_literal = Literal::byte_string(comp_description.as_bytes());
let tool_defs = tools.iter().map(|t| gen_tool_definition(t, default_lang));
let call_arms = tools.iter().map(|t| gen_call_arm(t, default_lang));
let arg_structs = tools
.iter()
.filter(|t| t.struct_args.is_none() && !t.args.is_empty())
.map(gen_arg_struct);
let session_provider_impl = match &session_hooks {
Some(h) => gen_session_provider_impl(h),
None => quote! {},
};
let manifest_tracking = if manifest_path.exists() {
let path_str = manifest_path.to_string_lossy().to_string();
quote! {
const _: &[u8] = include_bytes!(#path_str);
}
} else {
quote! {}
};
let output = quote! {
use ::act_sdk::__private::serde;
use ::act_sdk::__private::schemars;
wit_bindgen::generate!({
path: "wit",
world: "component-world",
generate_all,
});
#manifest_tracking
#[unsafe(link_section = "version")]
#[used]
static __ACT_VERSION_SECTION: [u8; #version_len] = *#version_literal;
#[unsafe(link_section = "description")]
#[used]
static __ACT_DESCRIPTION_SECTION: [u8; #description_len] = *#description_literal;
#[unsafe(link_section = "act:component")]
#[used]
static __ACT_COMPONENT_SECTION: [u8; #cbor_len] = *#cbor_literal;
#(#user_items)*
#(#arg_structs)*
fn __raw_to_wit(raw: ::act_sdk::context::RawToolEvent) -> exports::act::tools::tool_provider::ToolEvent {
match raw {
::act_sdk::context::RawToolEvent::Content { data, mime_type, metadata } => {
exports::act::tools::tool_provider::ToolEvent::Content(exports::act::tools::tool_provider::ContentPart {
data,
mime_type,
metadata,
})
}
::act_sdk::context::RawToolEvent::Error { kind, message, default_language: _ } => {
exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind,
message: exports::act::tools::tool_provider::LocalizedString::Plain(message),
metadata: vec![],
})
}
}
}
struct __ActComponent;
export!(__ActComponent);
#session_provider_impl
impl exports::act::tools::tool_provider::Guest for __ActComponent {
async fn list_tools(
_metadata: Vec<(String, Vec<u8>)>,
) -> Result<exports::act::tools::tool_provider::ListToolsResponse, exports::act::tools::tool_provider::Error> {
Ok(exports::act::tools::tool_provider::ListToolsResponse {
metadata: vec![],
tools: vec![
#(#tool_defs),*
],
})
}
async fn call_tool(
__name: String,
__arguments: Vec<u8>,
__metadata: Vec<(String, Vec<u8>)>,
) -> exports::act::tools::tool_provider::ToolResult {
let __default_lang = #default_lang;
match __name.as_str() {
#(#call_arms)*
__other => exports::act::tools::tool_provider::ToolResult::Immediate(vec![
exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind: ::act_sdk::constants::ERR_NOT_FOUND.to_string(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(format!("Tool '{}' not found", __other)),
metadata: vec![],
})
])
}
}
}
};
Ok(output)
}
fn extract_tools(module: &ItemMod) -> syn::Result<Vec<ToolInfo>> {
let mut tools = Vec::new();
if let Some((_, items)) = &module.content {
for item in items {
extract_tools_from_item(item, &mut tools)?;
}
}
Ok(tools)
}
fn extract_tools_from_item(item: &Item, tools: &mut Vec<ToolInfo>) -> syn::Result<()> {
match item {
Item::Fn(func) => {
let tool_attr = func.attrs.iter().find(|a| a.path().is_ident("act_tool"));
if let Some(attr) = tool_attr {
let attrs = ToolAttrs::from_meta(&attr.meta).map_err(syn::Error::from)?;
let info = tool::parse_tool_fn(func, attrs)?;
tools.push(info);
}
}
Item::Macro(mac_item) => {
if let Some(expanded_items) = expand_include_item(mac_item) {
for sub_item in &expanded_items {
extract_tools_from_item(sub_item, tools)?;
}
}
}
_ => {}
}
Ok(())
}
fn collect_user_items(module: &ItemMod) -> Vec<TokenStream> {
let mut items = Vec::new();
if let Some((_, mod_items)) = &module.content {
for item in mod_items {
collect_user_item(item, &mut items);
}
}
items
}
fn collect_user_item(item: &Item, items: &mut Vec<TokenStream>) {
match item {
Item::Fn(func) => {
let mut clean_func = func.clone();
clean_func.attrs.retain(|a| {
!a.path().is_ident("act_tool")
&& !a.path().is_ident("session_open")
&& !a.path().is_ident("session_close")
});
for input in &mut clean_func.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
pat_type
.attrs
.retain(|a| !a.path().is_ident("doc") && !a.path().is_ident("args"));
}
}
items.push(quote! { #clean_func });
}
Item::Use(u) => {
if is_super_use(u) {
return;
}
items.push(quote! { #u });
}
Item::Macro(mac_item) => {
if let Some(expanded_items) = expand_include_item(mac_item) {
for sub_item in &expanded_items {
collect_user_item(sub_item, items);
}
} else {
items.push(quote! { #mac_item });
}
}
other => {
items.push(quote! { #other });
}
}
}
fn is_super_use(u: &syn::ItemUse) -> bool {
fn tree_starts_with_super(tree: &syn::UseTree) -> bool {
match tree {
syn::UseTree::Path(p) => p.ident == "super",
syn::UseTree::Group(g) => g.items.iter().any(tree_starts_with_super),
_ => false,
}
}
tree_starts_with_super(&u.tree)
}
fn gen_tool_definition(tool: &ToolInfo, _default_lang: &str) -> TokenStream {
let name = &tool.tool_name;
let desc = &tool.description;
let schema_expr = if let Some(struct_type) = &tool.struct_args {
quote! {
{
let schema = ::act_sdk::__private::schemars::schema_for!(#struct_type);
::act_sdk::__private::serde_json::to_string(&schema)
.unwrap_or_else(|_| r#"{"type":"object"}"#.to_string())
}
}
} else if tool.args.is_empty() {
quote! { r#"{"type":"object","properties":{}}"#.to_string() }
} else {
let struct_name = gen_args_struct_ident(&tool.fn_ident);
quote! {
{
let schema = ::act_sdk::__private::schemars::schema_for!(#struct_name);
::act_sdk::__private::serde_json::to_string(&schema)
.unwrap_or_else(|_| r#"{"type":"object"}"#.to_string())
}
}
};
let mut metadata_entries = Vec::new();
if tool.read_only {
metadata_entries.push(quote! {
(::act_sdk::constants::META_READ_ONLY.to_string(), ::act_sdk::cbor::to_cbor(&true))
});
}
if tool.idempotent {
metadata_entries.push(quote! {
(::act_sdk::constants::META_IDEMPOTENT.to_string(), ::act_sdk::cbor::to_cbor(&true))
});
}
if tool.destructive {
metadata_entries.push(quote! {
(::act_sdk::constants::META_DESTRUCTIVE.to_string(), ::act_sdk::cbor::to_cbor(&true))
});
}
if tool.streaming {
metadata_entries.push(quote! {
(::act_sdk::constants::META_STREAMING.to_string(), ::act_sdk::cbor::to_cbor(&true))
});
}
if let Some(ms) = tool.timeout_ms {
metadata_entries.push(quote! {
(::act_sdk::constants::META_TIMEOUT_MS.to_string(), ::act_sdk::cbor::to_cbor(&#ms))
});
}
quote! {
exports::act::tools::tool_provider::ToolDefinition {
name: #name.to_string(),
description: exports::act::tools::tool_provider::LocalizedString::Plain(#desc.to_string()),
parameters_schema: #schema_expr,
metadata: vec![#(#metadata_entries),*],
}
}
}
fn has_direct_into_response(ty: &syn::Type) -> bool {
let s = quote!(#ty).to_string().replace(' ', "");
s == "()"
|| s == "String"
|| s == "&str"
|| s == "Vec<u8>"
|| s.starts_with("Json<")
|| s.starts_with("Content")
}
fn gen_call_arm(tool: &ToolInfo, _default_lang: &str) -> TokenStream {
let tool_name = &tool.tool_name;
let fn_ident = &tool.fn_ident;
let (deser_code, call_expr) = if let Some(struct_type) = &tool.struct_args {
let deser = quote! {
let __args: #struct_type = match ::act_sdk::cbor::from_cbor(&__arguments) {
Ok(v) => v,
Err(e) => {
return exports::act::tools::tool_provider::ToolResult::Immediate(vec![
exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind: ::act_sdk::constants::ERR_INVALID_ARGS.to_string(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(format!("Failed to deserialize arguments: {}", e)),
metadata: vec![],
})
]);
}
};
};
let call = if tool.has_context {
quote! { #fn_ident(__args, &mut __ctx) }
} else {
quote! { #fn_ident(__args) }
};
(deser, call)
} else if tool.args.is_empty() {
let call = if tool.has_context {
quote! { #fn_ident(&mut __ctx) }
} else {
quote! { #fn_ident() }
};
(quote! {}, call)
} else {
let struct_name = gen_args_struct_ident(fn_ident);
let field_names: Vec<_> = tool
.args
.iter()
.map(|a| format_ident!("{}", a.name))
.collect();
let deser = quote! {
let __args_struct: #struct_name = match ::act_sdk::cbor::from_cbor(&__arguments) {
Ok(v) => v,
Err(e) => {
return exports::act::tools::tool_provider::ToolResult::Immediate(vec![
exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind: ::act_sdk::constants::ERR_INVALID_ARGS.to_string(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(format!("Failed to deserialize arguments: {}", e)),
metadata: vec![],
})
]);
}
};
};
let call = if tool.has_context {
quote! { #fn_ident(#(__args_struct.#field_names),*, &mut __ctx) }
} else {
quote! { #fn_ident(#(__args_struct.#field_names),*) }
};
(deser, call)
};
let awaited_call = if tool.is_async {
quote! { #call_expr.await }
} else {
quote! { #call_expr }
};
let metadata_parse = if let Some(metadata_type) = &tool.metadata_type {
quote! {
let __metadata_val: #metadata_type = {
let mut __map = ::act_sdk::__private::serde_json::Map::new();
for (k, v) in &__metadata {
if let Ok(val) = ::act_sdk::cbor::from_cbor::<::act_sdk::__private::serde_json::Value>(v) {
__map.insert(k.clone(), val);
}
}
let __metadata_json = ::act_sdk::__private::serde_json::Value::Object(__map);
match ::act_sdk::__private::serde_json::from_value::<#metadata_type>(__metadata_json) {
Ok(v) => v,
Err(e) => {
let _ = __wit_writer.write_all(vec![
exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind: ::act_sdk::constants::ERR_INVALID_ARGS.to_string(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(format!("Failed to deserialize metadata: {}", e)),
metadata: vec![],
})
]).await;
return;
}
}
};
let mut __ctx = ::act_sdk::ActContext::__new(__metadata_val);
}
} else {
quote! {
let mut __ctx = ::act_sdk::ActContext::__new(());
}
};
let use_into_response = tool
.inner_return_type
.as_ref()
.is_none_or(has_direct_into_response);
let ok_response = if use_into_response {
quote! {
use ::act_sdk::IntoResponse;
let __response_events = __val.into_tool_events(__default_lang);
}
} else {
quote! {
let __response_events = ::act_sdk::response::cbor_encode_response(&__val, __default_lang);
}
};
if tool.has_context {
quote! {
#tool_name => {
#deser_code
let (mut __wit_writer, __reader) = wit_stream::new::<exports::act::tools::tool_provider::ToolEvent>();
wit_bindgen::spawn_local(async move {
#metadata_parse
let __result = #awaited_call;
let __ctx_events = __ctx.__take_events();
let mut __wit_events: Vec<exports::act::tools::tool_provider::ToolEvent> = __ctx_events
.into_iter()
.map(|e| __raw_to_wit(e))
.collect();
match __result {
Ok(__val) => {
#ok_response
__wit_events.extend(__response_events.into_iter().map(|e| __raw_to_wit(e)));
}
Err(__err) => {
__wit_events.push(exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind: __err.kind.clone(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(__err.message.clone()),
metadata: vec![],
}));
}
}
if !__wit_events.is_empty() {
let _ = __wit_writer.write_all(__wit_events).await;
}
});
exports::act::tools::tool_provider::ToolResult::Streaming(__reader)
}
}
} else {
quote! {
#tool_name => {
#deser_code
let __result = #awaited_call;
match __result {
Ok(__val) => {
#ok_response
let __wit_events: Vec<exports::act::tools::tool_provider::ToolEvent> = __response_events
.into_iter()
.map(|e| __raw_to_wit(e))
.collect();
exports::act::tools::tool_provider::ToolResult::Immediate(__wit_events)
}
Err(__err) => exports::act::tools::tool_provider::ToolResult::Immediate(vec![
exports::act::tools::tool_provider::ToolEvent::Error(exports::act::tools::tool_provider::Error {
kind: __err.kind.clone(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(__err.message.clone()),
metadata: vec![],
})
])
}
}
}
}
}
fn gen_args_struct_ident(fn_ident: &syn::Ident) -> syn::Ident {
let pascal = fn_ident
.to_string()
.split('_')
.map(|s| {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
})
.collect::<String>();
format_ident!("__{}Args", pascal)
}
struct SessionHooks {
open_fn_ident: syn::Ident,
open_args_ty: syn::Type,
open_is_async: bool,
close_fn_ident: syn::Ident,
close_is_async: bool,
}
fn extract_session_hooks(module: &ItemMod) -> syn::Result<Option<SessionHooks>> {
let Some((_, items)) = &module.content else {
return Ok(None);
};
let mut open: Option<(syn::Ident, syn::Type, bool)> = None;
let mut close: Option<(syn::Ident, bool)> = None;
for item in items {
let Item::Fn(func) = item else { continue };
let has_open = func.attrs.iter().any(|a| a.path().is_ident("session_open"));
let has_close = func
.attrs
.iter()
.any(|a| a.path().is_ident("session_close"));
if has_open && has_close {
return Err(syn::Error::new_spanned(
&func.sig.ident,
"function cannot be both #[session_open] and #[session_close]",
));
}
if has_open {
if open.is_some() {
return Err(syn::Error::new_spanned(
&func.sig.ident,
"only one #[session_open] function is allowed per component",
));
}
let (_, args_ty) = parse_open_signature(func)?;
open = Some((
func.sig.ident.clone(),
args_ty,
func.sig.asyncness.is_some(),
));
}
if has_close {
if close.is_some() {
return Err(syn::Error::new_spanned(
&func.sig.ident,
"only one #[session_close] function is allowed per component",
));
}
validate_close_signature(func)?;
close = Some((func.sig.ident.clone(), func.sig.asyncness.is_some()));
}
}
match (open, close) {
(Some((oi, oa, o_async)), Some((ci, c_async))) => Ok(Some(SessionHooks {
open_fn_ident: oi,
open_args_ty: oa,
open_is_async: o_async,
close_fn_ident: ci,
close_is_async: c_async,
})),
(Some((ident, _, _)), None) => Err(syn::Error::new_spanned(
ident,
"#[session_open] requires a paired #[session_close] in the same module",
)),
(None, Some((ident, _))) => Err(syn::Error::new_spanned(
ident,
"#[session_close] requires a paired #[session_open] in the same module",
)),
(None, None) => Ok(None),
}
}
fn parse_open_signature(func: &syn::ItemFn) -> syn::Result<(syn::Ident, syn::Type)> {
let mut typed_inputs = func.sig.inputs.iter().filter_map(|i| match i {
syn::FnArg::Typed(pt) => Some(pt),
_ => None,
});
let Some(arg) = typed_inputs.next() else {
return Err(syn::Error::new_spanned(
&func.sig,
"#[session_open] function must take one args parameter (e.g. `fn open(args: OpenArgs)`)",
));
};
if typed_inputs.next().is_some() {
return Err(syn::Error::new_spanned(
&func.sig,
"#[session_open] function must take exactly one args parameter",
));
}
let ident = match arg.pat.as_ref() {
syn::Pat::Ident(pi) => pi.ident.clone(),
_ => syn::Ident::new("__args", proc_macro2::Span::call_site()),
};
Ok((ident, arg.ty.as_ref().clone()))
}
fn validate_close_signature(func: &syn::ItemFn) -> syn::Result<()> {
if func.sig.asyncness.is_some() {
return Err(syn::Error::new_spanned(
&func.sig,
"#[session_close] function must be sync (WIT close-session is sync)",
));
}
let typed_count = func
.sig
.inputs
.iter()
.filter(|i| matches!(i, syn::FnArg::Typed(_)))
.count();
if typed_count != 1 {
return Err(syn::Error::new_spanned(
&func.sig,
"#[session_close] function must take exactly one parameter (`session_id: String`)",
));
}
Ok(())
}
fn gen_session_provider_impl(hooks: &SessionHooks) -> TokenStream {
let open_ident = &hooks.open_fn_ident;
let close_ident = &hooks.close_fn_ident;
let open_args_ty = &hooks.open_args_ty;
let open_call = if hooks.open_is_async {
quote! { #open_ident(__args).await }
} else {
quote! { #open_ident(__args) }
};
let _ = hooks.close_is_async; let close_call = quote! { #close_ident(session_id) };
quote! {
impl exports::act::sessions::session_provider::Guest for __ActComponent {
async fn get_open_session_args_schema(
_metadata: Vec<(String, Vec<u8>)>,
) -> Result<String, exports::act::sessions::session_provider::Error> {
let schema = ::act_sdk::__private::schemars::schema_for!(#open_args_ty);
::act_sdk::__private::serde_json::to_string(&schema).map_err(|e| {
exports::act::sessions::session_provider::Error {
kind: ::act_sdk::constants::ERR_INTERNAL.to_string(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(
format!("schema serialization failed: {e}")
),
metadata: vec![],
}
})
}
async fn open_session(
args: Vec<(String, Vec<u8>)>,
_metadata: Vec<(String, Vec<u8>)>,
) -> Result<exports::act::sessions::session_provider::Session, exports::act::sessions::session_provider::Error> {
let args_map: ::std::collections::BTreeMap<String, ::act_sdk::__private::serde_json::Value> = args
.into_iter()
.filter_map(|(k, v)| {
::act_sdk::cbor::from_cbor::<::act_sdk::__private::serde_json::Value>(&v)
.ok()
.map(|val| (k, val))
})
.collect();
let args_json = ::act_sdk::__private::serde_json::to_value(&args_map).unwrap_or(
::act_sdk::__private::serde_json::Value::Object(Default::default())
);
let __args: #open_args_ty = match ::act_sdk::__private::serde_json::from_value(args_json) {
Ok(v) => v,
Err(e) => {
return Err(exports::act::sessions::session_provider::Error {
kind: ::act_sdk::constants::ERR_INVALID_ARGS.to_string(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(
format!("Failed to deserialize session args: {e}")
),
metadata: vec![],
});
}
};
match #open_call {
Ok(id) => Ok(exports::act::sessions::session_provider::Session {
id,
metadata: vec![],
}),
Err(err) => Err(exports::act::sessions::session_provider::Error {
kind: err.kind.clone(),
message: exports::act::tools::tool_provider::LocalizedString::Plain(err.message.clone()),
metadata: vec![],
}),
}
}
fn close_session(session_id: String) {
let _ = #close_call;
}
}
}
}
fn gen_arg_struct(tool: &ToolInfo) -> TokenStream {
let struct_name = gen_args_struct_ident(&tool.fn_ident);
let fields: Vec<TokenStream> = tool
.args
.iter()
.map(|arg| {
let name = format_ident!("{}", arg.name);
let ty = &arg.ty;
if let Some(doc) = &arg.doc {
quote! {
#[doc = #doc]
pub #name: #ty,
}
} else {
quote! {
pub #name: #ty,
}
}
})
.collect();
quote! {
#[derive(::act_sdk::__private::serde::Deserialize, ::act_sdk::__private::schemars::JsonSchema)]
#[serde(crate = "::act_sdk::__private::serde")]
#[schemars(crate = "::act_sdk::__private::schemars")]
#[allow(non_camel_case_types)]
struct #struct_name {
#(#fields)*
}
}
}