1use crate::chat_completion_delta::forward_stream;
2use crate::error::{InternalError, OpenAIError};
3use crate::error::UtilsResult;
4use crate::{calculate_message_tokens, DeltaReceiver};
5use crate::{Chat, OPENAI_API_KEY};
6use crate::{Function, Message};
7use log::{error, trace};
8use reqwest::Method;
9use reqwest_eventsource::RequestBuilderExt;
10use schemars::JsonSchema;
11use serde::Deserialize;
12use std::{collections::HashMap, vec};
13use serde_json::to_string_pretty;
14use tokio::sync::mpsc;
15
16#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
17pub struct ChatCompletionRequest {
18 pub model: String,
19 pub messages: Vec<Message>,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub functions: Option<Vec<Function>>,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub function_call: Option<String>,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub temperature: Option<f64>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub top_p: Option<f64>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub n: Option<u64>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub stream: Option<bool>,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub stop: Option<Vec<String>>,
41
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub max_tokens: Option<u64>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub presence_penalty: Option<f64>,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub frequency_penalty: Option<f64>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub logit_bias: Option<HashMap<u64, f64>>,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub user: Option<String>,
56}
57
58impl ChatCompletionRequest {
59 fn new() -> Self {
60 Self {
61 model: "gpt-3.5-turbo".to_string(),
62 messages: vec![],
63 functions: None,
64 function_call: None,
65 temperature: None,
66 top_p: None,
67 n: None,
68 stream: None,
69 stop: None,
70 max_tokens: None,
71 presence_penalty: None,
72 frequency_penalty: None,
73 logit_bias: None,
74 user: None,
75 }
76 }
77}
78
79#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
80pub struct AiAgent {
81 pub model: String,
82
83 pub system_message: Option<Message>,
84
85 pub messages: Vec<Message>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub functions: Option<Vec<Function>>,
89
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub function_call: Option<String>,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub temperature: Option<f64>,
95
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub top_p: Option<f64>,
98
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub n: Option<u64>,
101
102 #[serde(skip_serializing_if = "Option::is_none")]
103 pub stop: Option<Vec<String>>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub max_tokens: Option<u64>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub presence_penalty: Option<f64>,
110
111 #[serde(skip_serializing_if = "Option::is_none")]
112 pub frequency_penalty: Option<f64>,
113
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub logit_bias: Option<HashMap<u64, f64>>,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub user: Option<String>,
119}
120
121impl AiAgent {
122 pub fn build_request(&self, stream: bool) -> ChatCompletionRequest {
124 let messages = if let Some(system_message) = &self.system_message {
125 let mut messages = self.messages.clone();
126 messages.insert(0, system_message.clone());
127 messages
128 } else {
129 self.messages.clone()
130 };
131
132 ChatCompletionRequest {
133 model: self.model.clone(),
134 messages,
135 functions: self.functions.clone(),
136 function_call: self.function_call.clone(),
137 temperature: self.temperature,
138 top_p: self.top_p,
139 n: self.n,
140 stream: Some(stream),
141 stop: self.stop.clone(),
142 max_tokens: self.max_tokens,
143 presence_penalty: self.presence_penalty,
144 frequency_penalty: self.frequency_penalty,
145 logit_bias: self.logit_bias.clone(),
146 user: self.user.clone(),
147 }
148 }
149
150 pub async fn create(&self) -> UtilsResult<Chat> {
151 let api_key = OPENAI_API_KEY.read().expect("failed to get lock").clone().ok_or_else(|| InternalError::ConfigurationError("API key not set".to_string()))?;
152
153 trace!("request body: {}", to_string_pretty(&self.build_request(false)).unwrap());
154 let req = reqwest::Client::new()
155 .post("https://api.openai.com/v1/chat/completions")
156 .json(&self.build_request(false))
157 .bearer_auth(api_key)
158 .header("Content-Type", "application/json")
159 .send()
160 .await.map_err(|e| InternalError::RequestBuildError(e))?;
161
162 let res = req.text().await.map_err(|e| InternalError::RequestBuildError(e))?;
163 serialize(&res)
164 }
165
166 pub async fn create_stream(&self) -> UtilsResult<DeltaReceiver> {
167 let api_key = OPENAI_API_KEY.read()
168 .expect("failed to get lock")
169 .as_ref()
170 .ok_or_else(|| InternalError::ConfigurationError("API key not set".to_string()))?
171 .to_string();
172
173 let (tx, rx) = mpsc::channel(64);
174 trace!("request body: {}", to_string_pretty(&self.build_request(true)).unwrap());
175 let es = reqwest::Client::new()
176 .request(Method::POST, "https://api.openai.com/v1/chat/completions")
177 .json(&self.build_request(true))
178 .bearer_auth(api_key)
179 .header("Content-Type", "application/json")
180 .eventsource()
181 .expect("cannot create eventsource? shouldn't happen i think.");
182
183 tokio::spawn(async move {
184 if let Err(e) = forward_stream(es, tx).await {
185 error!("Error in forward_stream: {}", e);
186 }
187 });
188
189 let usage = self.build_request(true).messages.iter().fold(3, |acc, m| {
190 acc + calculate_message_tokens(m) + 4
191 });
192
193 Ok(DeltaReceiver::from(rx, self, usage))
194 }
195
196
197 pub fn new(model: impl Into<String>) -> Self {
200 Self {
201 model: model.into(),
202 system_message: None,
203 messages: vec![],
204 functions: None,
205 function_call: None,
206 temperature: None,
207 top_p: None,
208 n: None,
209 stop: None,
210 max_tokens: None,
211 presence_penalty: None,
212 frequency_penalty: None,
213 logit_bias: None,
214 user: None,
215 }
216 }
217
218 pub fn with_system_message<'a>(mut self, system_message: impl Into<&'a str>) -> Self {
219 self.system_message = Some(Message::new("system").with_content(system_message.into()));
220 self
221 }
222
223 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
224 self.messages = messages;
225 self
226 }
227
228 pub fn with_function_call(mut self, function_call: impl Into<String>) -> Self {
229 self.function_call = Some(function_call.into());
230 self
231 }
232
233 pub fn with_temperature(mut self, temperature: f64) -> Self {
234 self.temperature = Some(temperature);
235 self
236 }
237
238 pub fn with_top_p(mut self, top_p: f64) -> Self {
239 self.top_p = Some(top_p);
240 self
241 }
242
243 pub fn with_n(mut self, n: u64) -> Self {
244 self.n = Some(n);
245 self
246 }
247
248 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
249 self.stop = Some(stop);
250 self
251 }
252
253 pub fn with_max_tokens(mut self, max_tokens: u64) -> Self {
254 self.max_tokens = Some(max_tokens);
255 self
256 }
257
258 pub fn with_presence_penalty(mut self, presence_penalty: f64) -> Self {
259 self.presence_penalty = Some(presence_penalty);
260 self
261 }
262
263 pub fn with_frequency_penalty(mut self, frequency_penalty: f64) -> Self {
264 self.frequency_penalty = Some(frequency_penalty);
265 self
266 }
267
268 pub fn with_logit_bias(mut self, logit_bias: HashMap<u64, f64>) -> Self {
269 self.logit_bias = Some(logit_bias);
270 self
271 }
272
273 pub fn with_user(mut self, user: impl Into<String>) -> Self {
274 self.user = Some(user.into());
275 self
276 }
277
278 pub fn push_message(&mut self, message: Message) {
281 self.messages.push(message);
282 }
283
284 pub fn push_function<FunctionArgs, Func, T>(&mut self, function: &Func, function_name: &str)
285 where
286 FunctionArgs: JsonSchema,
287 Func: FnMut(FunctionArgs) -> T,
288 {
289 if let Some(functions) = &mut self.functions {
290 functions.push(Function::from(function, function_name));
291 } else {
292 self.functions = Some(vec![Function::from(function, function_name)]);
293 }
294 }
295
296 pub fn push_stop(&mut self, stop: impl Into<String>) {
297 if let Some(stops) = &mut self.stop {
298 stops.push(stop.into());
299 } else {
300 self.stop = Some(vec![stop.into()]);
301 }
302 }
303
304 pub fn push_logit_bias(&mut self, logit_bias: (u64, f64)) {
305 if let Some(logit_biases) = &mut self.logit_bias {
306 logit_biases.insert(logit_bias.0, logit_bias.1);
307 } else {
308 let mut logit_biases = HashMap::new();
309 logit_biases.insert(logit_bias.0, logit_bias.1);
310 self.logit_bias = Some(logit_biases);
311 }
312 }
313}
314
315pub fn serialize<'a, T: Deserialize<'a>>(res: &'a str) -> UtilsResult<T> {
316 match serde_json::from_str::<T>(res) {
317 Ok(chat) => Ok(chat),
318 Err(_) => {
319 #[derive(Deserialize)]
320 struct TempWrapper {
321 error: OpenAIError
322 }
323
324 let err =
325 serde_json::from_str::<TempWrapper>(res).unwrap_or_else(|_| panic!("{}", res));
326 Err(err.error.into())
327 }
328 }
329}