1use std::{borrow::Cow, collections::HashMap, iter::once, marker::PhantomData};
2
3pub use copilot_rs_macro::{complete, FunctionTool};
4use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use typed_builder::TypedBuilder;
8
9pub trait FunctionTool {
10 fn key() -> String;
11 fn desc() -> String;
12 fn inject(args: HashMap<String, serde_json::Value>) -> String;
13}
14pub trait Structure {}
15
16pub trait FunctionImplTrait {
17 fn exec(&self) -> String;
18}
19
20#[derive(TypedBuilder, Debug, Serialize, Deserialize)]
21pub struct ChatModel {
22 pub chat_api_base: String,
23 pub chat_model_default: String,
24 pub chat_api_key: String,
25}
26type FuncImpl = fn(std::collections::HashMap<String, serde_json::Value>) -> String;
27
28struct NormalChat<T = String> {
29 _marker: PhantomData<T>,
30}
31
32pub fn chat(
33 model: &ChatModel,
34 messages: &[PromptMessage],
35 chat_model: &str,
36 temperature: f32,
37 max_tokens: u32,
38 functions: HashMap<String, (String, FuncImpl)>,
39) -> String {
40 let tools: Vec<serde_json::Value> = functions
41 .iter()
42 .map(|(_, (v, _))| serde_json::from_str(v).unwrap())
43 .collect();
44 let client = reqwest::blocking::Client::new();
45 let mut headers = reqwest::header::HeaderMap::new();
46
47 headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
48 headers.insert(
49 AUTHORIZATION,
50 format!("Bearer {}", model.chat_api_key).parse().unwrap(),
51 );
52 let url = format!("{}/chat/completions", model.chat_api_base);
53 let common_builder = client.post(url).headers(headers);
54
55 let chat_model = if chat_model.is_empty() {
56 &model.chat_model_default
57 } else {
58 chat_model
59 };
60 let mut json = json!({
61 "model":chat_model,
62 "messages": messages,
63 "temperature": temperature,
64 "max_tokens": max_tokens,
65 "stream":false,
66 });
67 if !tools.is_empty() {
68 json["tools"] = serde_json::Value::Array(tools);
69 }
70
71 let builder = common_builder.try_clone().unwrap().json(&json);
72 let res = builder.send().unwrap().text().unwrap();
73 let res = serde_json::from_str::<ChatCompletion>(&res).unwrap();
74 if let Some(common_message) = res.choices.first().and_then(|v| v.message.as_ref()) {
75 if let Some(tool_calls) = &common_message.tool_calls {
76 let tool_messages = tool_calls
77 .first()
78 .map(|call| {
79 let call_name = &call.function.name;
80 let (_, call_func) = functions.get(call_name).unwrap();
81 let args = &call.function.arguments;
82 let args = args.replace("\\\"", "\"");
83 let args: HashMap<String, serde_json::Value> =
84 serde_json::from_str(&args).unwrap();
85 let result = call_func(args);
86 result.tool(call.id.clone())
87 })
88 .unwrap();
89 let tool_messages = vec![common_message.clone(), tool_messages];
90 let total_message = messages.iter().chain(&tool_messages).collect::<Vec<_>>();
91
92 let json = json!({
93 "model": model.chat_model_default,
94 "messages": total_message,
95 "temperature": temperature,
96 "max_tokens": max_tokens,
97 "stream":false,
98 });
99
100 let builder = common_builder.json(&json);
101 let res = builder.send().unwrap().text().unwrap();
102 let res = serde_json::from_str::<ChatCompletion>(&res).unwrap();
103 let r = res
104 .choices
105 .first()
106 .as_ref()
107 .unwrap()
108 .message
109 .as_ref()
110 .unwrap();
111 r.content.clone()
112 } else {
113 common_message.content.clone()
114 }
115 } else {
116 "none".to_string()
117 }
118}
119
120#[derive(Serialize, Deserialize, Debug, Clone)]
121#[serde(rename_all = "snake_case")]
122pub enum Role {
123 System,
124 User,
125 Assistant,
126 Tool,
127}
128
129#[derive(Serialize, Deserialize, Debug, Clone)]
130pub struct PromptMessage {
131 pub role: Role,
132 pub content: String,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 tool_calls: Option<Vec<ToolCall>>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 tool_call_id: Option<String>,
137}
138
139#[derive(Serialize, Deserialize, Debug, Clone)]
140pub struct ToolCall {
141 id: String,
142 #[serde(rename = "type")]
143 ty: String,
144 function: Function,
145}
146#[derive(Serialize, Deserialize, Debug, Clone)]
147pub struct Function {
148 name: String,
149 arguments: String,
151}
152
153pub trait Chat {
163 fn chat(&self) -> String {
164 "chat".to_string()
165 }
166}
167
168impl Chat for Vec<PromptMessage> {}
169impl Chat for dyn AsRef<[PromptMessage]> {}
170
171pub trait IntoPrompt
172where
173 Self: ToString,
174{
175 fn system(&self) -> PromptMessage {
176 PromptMessage {
177 role: Role::System,
178 content: self.to_string(),
179 tool_calls: None,
180 tool_call_id: None,
181 }
182 }
183 fn user(&self) -> PromptMessage {
184 PromptMessage {
185 role: Role::User,
186 content: self.to_string(),
187 tool_calls: None,
188 tool_call_id: None,
189 }
190 }
191 fn assistant(&self) -> PromptMessage {
192 PromptMessage {
193 role: Role::Assistant,
194 content: self.to_string(),
195 tool_calls: None,
196 tool_call_id: None,
197 }
198 }
199 fn tool(&self, id: String) -> PromptMessage {
200 PromptMessage {
201 role: Role::Tool,
202 content: self.to_string(),
203 tool_calls: None,
204 tool_call_id: Some(id),
205 }
206 }
207}
208
209impl IntoPrompt for &str {}
210
211impl IntoPrompt for String {}
212
213#[derive(Debug, Deserialize, Default)]
214pub struct ChatCompletion {
215 choices: Vec<Choice>,
216 created: u64,
217 id: String,
218 model: String,
219 object: String,
220}
221
222impl ChatCompletion {
223 pub fn get_content(&self) -> Cow<str> {
224 if let Some(content) = self.choices[0]
225 .delta
226 .as_ref()
227 .and_then(|v| v.content.as_ref())
228 {
229 Cow::Borrowed(content)
230 } else if let Some(msg) = self.choices[0].message.as_ref() {
231 Cow::Borrowed(&msg.content)
232 } else {
233 Cow::Borrowed("")
234 }
235 }
236}
237
238#[derive(Debug, Deserialize)]
239struct Choice {
240 delta: Option<Delta>,
241 message: Option<PromptMessage>,
242 finish_reason: Option<String>,
243 index: u32,
244}
245
246#[derive(Debug, Deserialize)]
247struct Delta {
248 #[serde(skip_serializing_if = "Option::is_none")]
249 content: Option<String>,
250}