llm-worker-macros 0.2.0

llm-worker's proc macros
Documentation
//! llm-worker-macros - Tool生成用手続きマクロ
//!
//! `#[tool_registry]` と `#[tool]` マクロを提供し、
//! ユーザー定義のメソッドから `Tool` トレイト実装を自動生成する。

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
    Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
};

/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
///
/// # Example
/// ```ignore
/// #[tool_registry]
/// impl MyApp {
///     /// ユーザー情報を取得する
///     /// 指定されたIDのユーザーをDBから検索します。
///     #[tool]
///     async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
/// }
/// ```
///
/// これにより以下が生成されます:
/// - `GetUserArgs` 構造体(引数用)
/// - `Tool_get_user` 構造体(Toolラッパー)
/// - `impl Tool for Tool_get_user`
/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
#[proc_macro_attribute]
pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut impl_block = parse_macro_input!(item as ItemImpl);
    let self_ty = &impl_block.self_ty;

    let mut generated_items = Vec::new();

    for item in &mut impl_block.items {
        if let ImplItem::Fn(method) = item {
            // #[tool] 属性を探す
            let mut is_tool = false;

            // 属性を走査してtoolがあるか確認し、削除する
            method.attrs.retain(|attr| {
                if attr.path().is_ident("tool") {
                    is_tool = true;
                    false // 属性を削除
                } else {
                    true
                }
            });

            if is_tool {
                let tool_impl = generate_tool_impl(self_ty, method);
                generated_items.push(tool_impl);
            }
        }
    }

    let expanded = quote! {
        #impl_block

        #(#generated_items)*
    };

    TokenStream::from(expanded)
}

/// ドキュメントコメントから説明文を抽出
fn extract_doc_comment(attrs: &[Attribute]) -> String {
    let mut lines = Vec::new();

    for attr in attrs {
        if attr.path().is_ident("doc") {
            if let Meta::NameValue(meta) = &attr.meta {
                if let syn::Expr::Lit(expr_lit) = &meta.value {
                    if let Lit::Str(lit_str) = &expr_lit.lit {
                        let line = lit_str.value();
                        // 先頭の空白を1つだけ除去(/// の後のスペース)
                        let trimmed = line.strip_prefix(' ').unwrap_or(&line);
                        lines.push(trimmed.to_string());
                    }
                }
            }
        }
    }

    lines.join("\n")
}

/// #[description = "..."] 属性から説明を抽出
fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
    for attr in attrs {
        if attr.path().is_ident("description") {
            if let Meta::NameValue(meta) = &attr.meta {
                if let syn::Expr::Lit(expr_lit) = &meta.value {
                    if let Lit::Str(lit_str) = &expr_lit.lit {
                        return Some(lit_str.value());
                    }
                }
            }
        }
    }
    None
}

/// メソッドからTool実装を生成
fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
    let sig = &method.sig;
    let method_name = &sig.ident;
    let tool_name = method_name.to_string();

    // 構造体名を生成(PascalCase変換)
    let pascal_name = to_pascal_case(&method_name.to_string());
    let tool_struct_name = format_ident!("Tool{}", pascal_name);
    let args_struct_name = format_ident!("{}Args", pascal_name);
    let definition_name = format_ident!("{}_definition", method_name);

    // ドキュメントコメントから説明を取得
    let description = extract_doc_comment(&method.attrs);
    let description = if description.is_empty() {
        format!("Tool: {}", tool_name)
    } else {
        description
    };

    // 引数を解析(selfを除く)
    let args: Vec<_> = sig
        .inputs
        .iter()
        .filter_map(|arg| {
            if let FnArg::Typed(pat_type) = arg {
                Some(pat_type)
            } else {
                None // selfを除外
            }
        })
        .collect();

    // 引数構造体のフィールドを生成
    let arg_fields: Vec<_> = args
        .iter()
        .map(|pat_type| {
            let pat = &pat_type.pat;
            let ty = &pat_type.ty;
            let desc = extract_description_attr(&pat_type.attrs);

            // パターンから識別子を抽出
            let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
                &pat_ident.ident
            } else {
                panic!("Only simple identifiers are supported for tool arguments");
            };

            // #[description] があればschemarsのdocに変換
            if let Some(desc_str) = desc {
                quote! {
                    #[schemars(description = #desc_str)]
                    pub #field_name: #ty
                }
            } else {
                quote! {
                    pub #field_name: #ty
                }
            }
        })
        .collect();

    // execute内で引数を展開するコード
    let arg_names: Vec<_> = args
        .iter()
        .map(|pat_type| {
            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
                let ident = &pat_ident.ident;
                quote! { args.#ident }
            } else {
                panic!("Only simple identifiers are supported");
            }
        })
        .collect();

    // メソッドが非同期かどうか
    let is_async = sig.asyncness.is_some();

    // 戻り値の型を解析してResult判定
    let awaiter = if is_async {
        quote! { .await }
    } else {
        quote! {}
    };

    // 戻り値がResultかどうかを判定
    let result_handling = if is_result_type(&sig.output) {
        quote! {
            match result {
                Ok(val) => Ok(format!("{:?}", val)),
                Err(e) => Err(::llm_worker::tool::ToolError::ExecutionFailed(format!("{}", e))),
            }
        }
    } else {
        quote! {
            Ok(format!("{:?}", result))
        }
    };

    // 引数がない場合は空のArgs構造体を作成
    let args_struct_def = if arg_fields.is_empty() {
        quote! {
            #[derive(serde::Deserialize, schemars::JsonSchema)]
            struct #args_struct_name {}
        }
    } else {
        quote! {
            #[derive(serde::Deserialize, schemars::JsonSchema)]
            struct #args_struct_name {
                #(#arg_fields),*
            }
        }
    };

    // 引数がない場合のexecute処理
    let execute_body = if args.is_empty() {
        quote! {
            // 引数なしでも空のJSONオブジェクトを許容
            let _: #args_struct_name = serde_json::from_str(input_json)
                .unwrap_or(#args_struct_name {});

            let result = self.ctx.#method_name()#awaiter;
            #result_handling
        }
    } else {
        quote! {
            let args: #args_struct_name = serde_json::from_str(input_json)
                .map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?;

            let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
            #result_handling
        }
    };

    quote! {
        #args_struct_def

        #[derive(Clone)]
        pub struct #tool_struct_name {
            ctx: #self_ty,
        }

        #[async_trait::async_trait]
        impl ::llm_worker::tool::Tool for #tool_struct_name {
            async fn execute(&self, input_json: &str) -> Result<String, ::llm_worker::tool::ToolError> {
                #execute_body
            }
        }

        impl #self_ty {
            /// ToolDefinition を取得(Worker への登録用)
            pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
                let ctx = self.clone();
                ::std::sync::Arc::new(move || {
                    let schema = schemars::schema_for!(#args_struct_name);
                    let meta = ::llm_worker::tool::ToolMeta::new(#tool_name)
                        .description(#description)
                        .input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({})));
                    let tool: ::std::sync::Arc<dyn ::llm_worker::tool::Tool> =
                        ::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() });
                    (meta, tool)
                })
            }
        }
    }
}

/// 戻り値の型がResultかどうかを判定
fn is_result_type(return_type: &ReturnType) -> bool {
    match return_type {
        ReturnType::Default => false,
        ReturnType::Type(_, ty) => {
            // Type::Pathの場合、最後のセグメントが"Result"かチェック
            if let Type::Path(type_path) = ty.as_ref() {
                if let Some(segment) = type_path.path.segments.last() {
                    return segment.ident == "Result";
                }
            }
            false
        }
    }
}

/// snake_case を PascalCase に変換
fn to_pascal_case(s: &str) -> String {
    s.split('_')
        .map(|part| {
            let mut chars = part.chars();
            match chars.next() {
                None => String::new(),
                Some(first) => first.to_uppercase().chain(chars).collect(),
            }
        })
        .collect()
}

/// マーカー属性。`tool_registry` によって処理されるため、ここでは何もしない。
#[proc_macro_attribute]
pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
    item
}

/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。
///
/// # Example
/// ```ignore
/// #[tool]
/// async fn get_user(
///     &self,
///     #[description = "取得したいユーザーのID"] user_id: String
/// ) -> Result<User, Error> { ... }
/// ```
#[proc_macro_attribute]
pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream {
    item
}