copilot_rs/
lib.rs

1use std::{borrow::Cow, collections::HashMap, iter::once, marker::PhantomData};
2
3pub use copilot_rs_macro::{complete, FunctionTool};
4use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use typed_builder::TypedBuilder;
8
9pub trait FunctionTool {
10    fn key() -> String;
11    fn desc() -> String;
12    fn inject(args: HashMap<String, serde_json::Value>) -> String;
13}
14pub trait Structure {}
15
16pub trait FunctionImplTrait {
17    fn exec(&self) -> String;
18}
19
20#[derive(TypedBuilder, Debug, Serialize, Deserialize)]
21pub struct ChatModel {
22    pub chat_api_base: String,
23    pub chat_model_default: String,
24    pub chat_api_key: String,
25}
26type FuncImpl = fn(std::collections::HashMap<String, serde_json::Value>) -> String;
27
28struct NormalChat<T = String> {
29    _marker: PhantomData<T>,
30}
31
32pub fn chat(
33    model: &ChatModel,
34    messages: &[PromptMessage],
35    chat_model: &str,
36    temperature: f32,
37    max_tokens: u32,
38    functions: HashMap<String, (String, FuncImpl)>,
39) -> String {
40    let tools: Vec<serde_json::Value> = functions
41        .iter()
42        .map(|(_, (v, _))| serde_json::from_str(v).unwrap())
43        .collect();
44    let client = reqwest::blocking::Client::new();
45    let mut headers = reqwest::header::HeaderMap::new();
46
47    headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
48    headers.insert(
49        AUTHORIZATION,
50        format!("Bearer {}", model.chat_api_key).parse().unwrap(),
51    );
52    let url = format!("{}/chat/completions", model.chat_api_base);
53    let common_builder = client.post(url).headers(headers);
54
55    let chat_model = if chat_model.is_empty() {
56        &model.chat_model_default
57    } else {
58        chat_model
59    };
60    let mut json = json!({
61        "model":chat_model,
62        "messages": messages,
63        "temperature": temperature,
64        "max_tokens": max_tokens,
65        "stream":false,
66    });
67    if !tools.is_empty() {
68        json["tools"] = serde_json::Value::Array(tools);
69    }
70
71    let builder = common_builder.try_clone().unwrap().json(&json);
72    let res = builder.send().unwrap().text().unwrap();
73    let res = serde_json::from_str::<ChatCompletion>(&res).unwrap();
74    if let Some(common_message) = res.choices.first().and_then(|v| v.message.as_ref()) {
75        if let Some(tool_calls) = &common_message.tool_calls {
76            let tool_messages = tool_calls
77                .first()
78                .map(|call| {
79                    let call_name = &call.function.name;
80                    let (_, call_func) = functions.get(call_name).unwrap();
81                    let args = &call.function.arguments;
82                    let args = args.replace("\\\"", "\"");
83                    let args: HashMap<String, serde_json::Value> =
84                        serde_json::from_str(&args).unwrap();
85                    let result = call_func(args);
86                    result.tool(call.id.clone())
87                })
88                .unwrap();
89            let tool_messages = vec![common_message.clone(), tool_messages];
90            let total_message = messages.iter().chain(&tool_messages).collect::<Vec<_>>();
91
92            let json = json!({
93                "model": model.chat_model_default,
94                "messages": total_message,
95                "temperature": temperature,
96                "max_tokens": max_tokens,
97                "stream":false,
98            });
99
100            let builder = common_builder.json(&json);
101            let res = builder.send().unwrap().text().unwrap();
102            let res = serde_json::from_str::<ChatCompletion>(&res).unwrap();
103            let r = res
104                .choices
105                .first()
106                .as_ref()
107                .unwrap()
108                .message
109                .as_ref()
110                .unwrap();
111            r.content.clone()
112        } else {
113            common_message.content.clone()
114        }
115    } else {
116        "none".to_string()
117    }
118}
119
120#[derive(Serialize, Deserialize, Debug, Clone)]
121#[serde(rename_all = "snake_case")]
122pub enum Role {
123    System,
124    User,
125    Assistant,
126    Tool,
127}
128
129#[derive(Serialize, Deserialize, Debug, Clone)]
130pub struct PromptMessage {
131    pub role: Role,
132    pub content: String,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    tool_calls: Option<Vec<ToolCall>>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    tool_call_id: Option<String>,
137}
138
139#[derive(Serialize, Deserialize, Debug, Clone)]
140pub struct ToolCall {
141    id: String,
142    #[serde(rename = "type")]
143    ty: String,
144    function: Function,
145}
146#[derive(Serialize, Deserialize, Debug, Clone)]
147pub struct Function {
148    name: String,
149    // #[serde(deserialize_with = "deserialize_map")]
150    arguments: String,
151}
152
153// fn deserialize_map<'de, D>(deserializer: D) -> Result<HashMap<String, serde_json::Value>, D::Error>
154// where
155//     D: serde::Deserializer<'de>,
156// {
157//     let json_string: String = Deserialize::deserialize(deserializer)?;
158//     let s = json_string.replace("\\\"", "\\");
159//     let map: HashMap<String, serde_json::Value> = serde_json::from_str(&s).unwrap();
160//     Ok(map)
161// }
162pub trait Chat {
163    fn chat(&self) -> String {
164        "chat".to_string()
165    }
166}
167
168impl Chat for Vec<PromptMessage> {}
169impl Chat for dyn AsRef<[PromptMessage]> {}
170
171pub trait IntoPrompt
172where
173    Self: ToString,
174{
175    fn system(&self) -> PromptMessage {
176        PromptMessage {
177            role: Role::System,
178            content: self.to_string(),
179            tool_calls: None,
180            tool_call_id: None,
181        }
182    }
183    fn user(&self) -> PromptMessage {
184        PromptMessage {
185            role: Role::User,
186            content: self.to_string(),
187            tool_calls: None,
188            tool_call_id: None,
189        }
190    }
191    fn assistant(&self) -> PromptMessage {
192        PromptMessage {
193            role: Role::Assistant,
194            content: self.to_string(),
195            tool_calls: None,
196            tool_call_id: None,
197        }
198    }
199    fn tool(&self, id: String) -> PromptMessage {
200        PromptMessage {
201            role: Role::Tool,
202            content: self.to_string(),
203            tool_calls: None,
204            tool_call_id: Some(id),
205        }
206    }
207}
208
209impl IntoPrompt for &str {}
210
211impl IntoPrompt for String {}
212
213#[derive(Debug, Deserialize, Default)]
214pub struct ChatCompletion {
215    choices: Vec<Choice>,
216    created: u64,
217    id: String,
218    model: String,
219    object: String,
220}
221
222impl ChatCompletion {
223    pub fn get_content(&self) -> Cow<str> {
224        if let Some(content) = self.choices[0]
225            .delta
226            .as_ref()
227            .and_then(|v| v.content.as_ref())
228        {
229            Cow::Borrowed(content)
230        } else if let Some(msg) = self.choices[0].message.as_ref() {
231            Cow::Borrowed(&msg.content)
232        } else {
233            Cow::Borrowed("")
234        }
235    }
236}
237
238#[derive(Debug, Deserialize)]
239struct Choice {
240    delta: Option<Delta>,
241    message: Option<PromptMessage>,
242    finish_reason: Option<String>,
243    index: u32,
244}
245
246#[derive(Debug, Deserialize)]
247struct Delta {
248    #[serde(skip_serializing_if = "Option::is_none")]
249    content: Option<String>,
250}