1use std::collections::HashMap;
2use std::str::FromStr;
3
4use anyhow::Result;
5use darling::{ast::NestedMeta, FromMeta};
6use darling::{FromDeriveInput, FromField};
7use proc_macro::TokenStream;
8use quote::quote;
9use serde::{Deserialize, Serialize};
10use syn::{parse_macro_input, DeriveInput, Ident};
11use syn::{Expr, ItemFn, LitStr, Stmt};
12#[proc_macro_attribute]
13pub fn complete(attr: TokenStream, item: TokenStream) -> proc_macro::TokenStream {
14 match common_simple(attr, item) {
15 Ok(output) => output,
16 Err(e) => TokenStream::from_str(e.to_string().as_str()).unwrap(),
17 }
18}
19#[derive(Debug, FromMeta)]
20struct MacroArgs {
21 client: String,
22 model: Option<String>,
23 temperature: Option<f32>,
24 max_tokens: Option<u32>,
25 tools: Option<Vec<LitStr>>,
26 response_format: Option<String>,
27}
28
29fn common_simple(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
30 let attr_args = NestedMeta::parse_meta_list(attr.into())?;
31 let args = MacroArgs::from_list(&attr_args).unwrap();
32
33 let client = Ident::new(&args.client, proc_macro::Span::call_site().into());
34
35 let mut item: ItemFn = syn::parse(item)?;
36
37 let method_name = item.sig.ident.to_string();
38 let mut is_async = item.sig.asyncness.is_some();
39 let mut block = item.block;
40
41 let new_chat_method = format!("chat_{}", method_name);
42
43 if let Stmt::Expr(expr, _) = block.stmts.last_mut().unwrap() {
44 if let Expr::Await(m) = expr {
45 if let Expr::MethodCall(m) = m.base.as_mut() {
46 let method = &m.method;
47 if method == "async_chat" {
48 let ident = Ident::new(&new_chat_method, method.span());
49 m.method = ident;
50 }
51 }
52 }
53 if let Expr::MethodCall(m) = expr {
54 let method = &m.method;
55 if method == "chat" {
56 let ident = Ident::new(&new_chat_method, method.span());
57 m.method = ident;
58 is_async = false;
59 }
60 }
61 }
62
63 item.block = block;
65
66 let new_chat_method_ident = Ident::new(&new_chat_method, proc_macro::Span::call_site().into());
67
68 let new_chat_trait_name_ident = Ident::new(
69 &format!("RealChat{}", uuid::Uuid::new_v4()).replace("-", ""),
70 proc_macro::Span::call_site().into(),
71 );
72
73 let client_model = client;
74 let model = args.model.clone().unwrap_or_default();
75 let temperature = args.temperature.unwrap_or(0.7);
76 let max_tokens = args.max_tokens.unwrap_or(1024);
77 let functions = args
78 .tools
79 .as_ref()
80 .map(|v| v.iter().map(|v| Ident::new(v.value().as_str(), v.span())))
81 .map(|tools|quote! {
82 {
83 let mut hm = std::collections::HashMap::new();
84 #(hm.insert(#tools::key(),(#tools::desc(),#tools::inject as fn(std::collections::HashMap<String, serde_json::Value>) -> String));)*
85 hm
86 }
87 }).unwrap_or(quote! { std::collections::HashMap::new() });
88 if is_async {
89 let trait_def = quote! {
90 trait #new_chat_trait_name_ident {
91 async fn #new_chat_method_ident(&self) -> String;
92 }
93 };
94 let impl_def = quote! {
95 impl #new_chat_trait_name_ident for Vec<copilot_rs::PromptMessage> {
96 async fn #new_chat_method_ident(&self) -> String {
97 let model = #client_model();
98 copilot_rs::async_chat(&model, &self).await
99 }
100 }
101 };
102 let expanded = quote! {
103 #item
104
105 #trait_def
106
107 #impl_def
108 };
109
110 Ok(expanded.into())
111 } else {
112 let trait_def = quote! {
113 trait #new_chat_trait_name_ident {
114 fn #new_chat_method_ident(&self) -> String;
115 }
116 };
117
118 let impl_def = quote! {
119 impl #new_chat_trait_name_ident for Vec<copilot_rs::PromptMessage> {
120 fn #new_chat_method_ident(&self) -> String {
121 let client = #client_model();
122 let model = #model;
123 let temperature = #temperature;
124 let max_tokens = #max_tokens;
125 let functions = #functions;
126 copilot_rs::chat(&client,&self,model,temperature, max_tokens,functions)
127 }
128 }
129 };
130
131 let expanded = quote! {
132 #item
133
134 #trait_def
135
136 #impl_def
137 };
138
139 Ok(expanded.into())
140 }
141}
142
143#[derive(FromDeriveInput, Debug)]
144#[darling(attributes(props), forward_attrs(allow, deny))]
145struct FunctionToolOptions {
146 ident: Ident,
147 data: darling::ast::Data<(), FunctionToolProperties>,
148 #[darling(default)]
149 desc: String,
150}
151
152#[derive(Debug, FromField)]
153#[darling(attributes(props), forward_attrs(allow, deny))]
154struct FunctionToolProperties {
155 ident: Option<Ident>,
156 ty: syn::Type,
157 desc: String,
158 #[darling(default)]
159 choices: Vec<LitStr>,
160}
161
162#[proc_macro_derive(FunctionTool, attributes(props))]
163pub fn derive_function_tool(input: TokenStream) -> TokenStream {
164 let input = parse_macro_input!(input as DeriveInput);
165
166 let parsed = FunctionToolOptions::from_derive_input(&input).unwrap();
167
168 let struct_name = &parsed.ident;
169 let struct_desc = parsed.desc;
170
171 let properties = parsed
172 .data
173 .take_struct()
174 .map(|v| v.fields)
175 .map(|v| {
176 v.iter().fold(HashMap::new(), |mut acc, field| {
177 let name = field
178 .ident
179 .as_ref()
180 .map(|v| v.to_string())
181 .unwrap_or_default();
182 let ty = match &field.ty {
183 syn::Type::Path(p) => p
184 .path
185 .segments
186 .first()
187 .map(|seg| seg.ident.to_string())
188 .unwrap_or_else(|| "unknown".to_string()),
189 _ => "unknown".to_string(),
190 };
191 let mut prop = Property::default();
192 prop.r#type = ty.to_lowercase();
193 prop.description = field.desc.clone();
194 prop.choices = if field.choices.is_empty() {
195 None
196 } else {
197 Some(field.choices.iter().map(|v| v.value()).collect())
198 };
199 acc.insert(name, prop);
200 acc
201 })
202 })
203 .unwrap_or_default();
204 let required = properties
205 .iter()
206 .filter(|(_k, v)| v.choices.is_none())
207 .map(|(k, _v)| k.clone())
208 .collect();
209 let struct_str = struct_name.to_string();
210 let desc_impl = ToolImpl::Function {
211 name: struct_str.clone(),
212 description: struct_desc,
213 parameters: Parameters {
214 r#type: default_type(),
215 properties,
216 required,
217 },
218 };
219
220 let json = serde_json::to_string(&desc_impl).unwrap();
221
222 let ret = quote! {
223 impl FunctionTool for #struct_name {
224 fn key() -> String {
225 #struct_str.to_string()
226 }
227 fn desc() -> String {
228 #json.to_string()
229
230 }
231 fn inject(args: std::collections::HashMap<String, serde_json::Value>) -> String {
232 let args = serde_json::to_string(&args).unwrap();
233 let c : #struct_name = serde_json::from_str(&args).unwrap();
234 c.exec()
235 }
236 }
237 };
238 ret.into()
239}
240
241#[derive(Debug, Deserialize, Serialize)]
242#[serde(tag = "type", content = "function")]
243enum ToolImpl {
244 #[serde(rename = "function")]
245 Function {
246 name: String,
247 description: String,
248 parameters: Parameters,
249 },
250}
251
252#[derive(Debug, Deserialize, Serialize)]
253struct Parameters {
254 #[serde(default = "default_type")]
255 r#type: String,
256 properties: HashMap<String, Property>,
257 required: Vec<String>,
258}
259const DEFAULT_TYPE: &str = "object";
260
261fn default_type() -> String {
262 DEFAULT_TYPE.to_string()
263}
264
265#[derive(Debug, Deserialize, Serialize, Default)]
266struct Property {
267 r#type: String,
268 #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
269 choices: Option<Vec<String>>,
270 description: String,
271}