copilot_rs_macro/
lib.rs

1use std::collections::HashMap;
2use std::str::FromStr;
3
4use anyhow::Result;
5use darling::{ast::NestedMeta, FromMeta};
6use darling::{FromDeriveInput, FromField};
7use proc_macro::TokenStream;
8use quote::quote;
9use serde::{Deserialize, Serialize};
10use syn::{parse_macro_input, DeriveInput, Ident};
11use syn::{Expr, ItemFn, LitStr, Stmt};
12#[proc_macro_attribute]
13pub fn complete(attr: TokenStream, item: TokenStream) -> proc_macro::TokenStream {
14    match common_simple(attr, item) {
15        Ok(output) => output,
16        Err(e) => TokenStream::from_str(e.to_string().as_str()).unwrap(),
17    }
18}
19#[derive(Debug, FromMeta)]
20struct MacroArgs {
21    client: String,
22    model: Option<String>,
23    temperature: Option<f32>,
24    max_tokens: Option<u32>,
25    tools: Option<Vec<LitStr>>,
26    response_format: Option<String>,
27}
28
29fn common_simple(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
30    let attr_args = NestedMeta::parse_meta_list(attr.into())?;
31    let args = MacroArgs::from_list(&attr_args).unwrap();
32
33    let client = Ident::new(&args.client, proc_macro::Span::call_site().into());
34
35    let mut item: ItemFn = syn::parse(item)?;
36
37    let method_name = item.sig.ident.to_string();
38    let mut is_async = item.sig.asyncness.is_some();
39    let mut block = item.block;
40
41    let new_chat_method = format!("chat_{}", method_name);
42
43    if let Stmt::Expr(expr, _) = block.stmts.last_mut().unwrap() {
44        if let Expr::Await(m) = expr {
45            if let Expr::MethodCall(m) = m.base.as_mut() {
46                let method = &m.method;
47                if method == "async_chat" {
48                    let ident = Ident::new(&new_chat_method, method.span());
49                    m.method = ident;
50                }
51            }
52        }
53        if let Expr::MethodCall(m) = expr {
54            let method = &m.method;
55            if method == "chat" {
56                let ident = Ident::new(&new_chat_method, method.span());
57                m.method = ident;
58                is_async = false;
59            }
60        }
61    }
62
63    // 更新函数体
64    item.block = block;
65
66    let new_chat_method_ident = Ident::new(&new_chat_method, proc_macro::Span::call_site().into());
67
68    let new_chat_trait_name_ident = Ident::new(
69        &format!("RealChat{}", uuid::Uuid::new_v4()).replace("-", ""),
70        proc_macro::Span::call_site().into(),
71    );
72
73    let client_model = client;
74    let model = args.model.clone().unwrap_or_default();
75    let temperature = args.temperature.unwrap_or(0.7);
76    let max_tokens = args.max_tokens.unwrap_or(1024);
77    let functions = args
78        .tools
79        .as_ref()
80        .map(|v| v.iter().map(|v| Ident::new(v.value().as_str(), v.span())))
81        .map(|tools|quote! {
82            {
83                let mut hm = std::collections::HashMap::new();
84                #(hm.insert(#tools::key(),(#tools::desc(),#tools::inject as fn(std::collections::HashMap<String, serde_json::Value>) -> String));)*
85                hm
86            }
87        }).unwrap_or(quote! { std::collections::HashMap::new() });
88    if is_async {
89        let trait_def = quote! {
90            trait #new_chat_trait_name_ident {
91                async fn #new_chat_method_ident(&self) -> String;
92            }
93        };
94        let impl_def = quote! {
95            impl #new_chat_trait_name_ident for Vec<copilot_rs::PromptMessage> {
96                async fn #new_chat_method_ident(&self) -> String {
97                    let model = #client_model();
98                    copilot_rs::async_chat(&model, &self).await
99                }
100            }
101        };
102        let expanded = quote! {
103            #item
104
105            #trait_def
106
107            #impl_def
108        };
109
110        Ok(expanded.into())
111    } else {
112        let trait_def = quote! {
113            trait #new_chat_trait_name_ident {
114                fn #new_chat_method_ident(&self) -> String;
115            }
116        };
117
118        let impl_def = quote! {
119            impl #new_chat_trait_name_ident for Vec<copilot_rs::PromptMessage> {
120                fn #new_chat_method_ident(&self) -> String {
121                    let client = #client_model();
122                    let model = #model;
123                    let temperature = #temperature;
124                    let max_tokens = #max_tokens;
125                    let functions = #functions;
126                    copilot_rs::chat(&client,&self,model,temperature, max_tokens,functions)
127                }
128            }
129        };
130
131        let expanded = quote! {
132            #item
133
134            #trait_def
135
136            #impl_def
137        };
138
139        Ok(expanded.into())
140    }
141}
142
143#[derive(FromDeriveInput, Debug)]
144#[darling(attributes(props), forward_attrs(allow, deny))]
145struct FunctionToolOptions {
146    ident: Ident,
147    data: darling::ast::Data<(), FunctionToolProperties>,
148    #[darling(default)]
149    desc: String,
150}
151
152#[derive(Debug, FromField)]
153#[darling(attributes(props), forward_attrs(allow, deny))]
154struct FunctionToolProperties {
155    ident: Option<Ident>,
156    ty: syn::Type,
157    desc: String,
158    #[darling(default)]
159    choices: Vec<LitStr>,
160}
161
162#[proc_macro_derive(FunctionTool, attributes(props))]
163pub fn derive_function_tool(input: TokenStream) -> TokenStream {
164    let input = parse_macro_input!(input as DeriveInput);
165
166    let parsed = FunctionToolOptions::from_derive_input(&input).unwrap();
167
168    let struct_name = &parsed.ident;
169    let struct_desc = parsed.desc;
170
171    let properties = parsed
172        .data
173        .take_struct()
174        .map(|v| v.fields)
175        .map(|v| {
176            v.iter().fold(HashMap::new(), |mut acc, field| {
177                let name = field
178                    .ident
179                    .as_ref()
180                    .map(|v| v.to_string())
181                    .unwrap_or_default();
182                let ty = match &field.ty {
183                    syn::Type::Path(p) => p
184                        .path
185                        .segments
186                        .first()
187                        .map(|seg| seg.ident.to_string())
188                        .unwrap_or_else(|| "unknown".to_string()),
189                    _ => "unknown".to_string(),
190                };
191                let mut prop = Property::default();
192                prop.r#type = ty.to_lowercase();
193                prop.description = field.desc.clone();
194                prop.choices = if field.choices.is_empty() {
195                    None
196                } else {
197                    Some(field.choices.iter().map(|v| v.value()).collect())
198                };
199                acc.insert(name, prop);
200                acc
201            })
202        })
203        .unwrap_or_default();
204    let required = properties
205        .iter()
206        .filter(|(_k, v)| v.choices.is_none())
207        .map(|(k, _v)| k.clone())
208        .collect();
209    let struct_str = struct_name.to_string();
210    let desc_impl = ToolImpl::Function {
211        name: struct_str.clone(),
212        description: struct_desc,
213        parameters: Parameters {
214            r#type: default_type(),
215            properties,
216            required,
217        },
218    };
219
220    let json = serde_json::to_string(&desc_impl).unwrap();
221
222    let ret = quote! {
223        impl FunctionTool for #struct_name {
224            fn key() -> String {
225                #struct_str.to_string()
226            }
227            fn desc() -> String {
228                #json.to_string()
229
230            }
231            fn inject(args: std::collections::HashMap<String, serde_json::Value>) -> String {
232                let args = serde_json::to_string(&args).unwrap();
233                let c : #struct_name = serde_json::from_str(&args).unwrap();
234                c.exec()
235            }
236        }
237    };
238    ret.into()
239}
240
241#[derive(Debug, Deserialize, Serialize)]
242#[serde(tag = "type", content = "function")]
243enum ToolImpl {
244    #[serde(rename = "function")]
245    Function {
246        name: String,
247        description: String,
248        parameters: Parameters,
249    },
250}
251
252#[derive(Debug, Deserialize, Serialize)]
253struct Parameters {
254    #[serde(default = "default_type")]
255    r#type: String,
256    properties: HashMap<String, Property>,
257    required: Vec<String>,
258}
259const DEFAULT_TYPE: &str = "object";
260
261fn default_type() -> String {
262    DEFAULT_TYPE.to_string()
263}
264
265#[derive(Debug, Deserialize, Serialize, Default)]
266struct Property {
267    r#type: String,
268    #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
269    choices: Option<Vec<String>>,
270    description: String,
271}