openai_utils/
chat_completion_request.rs

1use crate::chat_completion_delta::forward_stream;
2use crate::error::{InternalError, OpenAIError};
3use crate::error::UtilsResult;
4use crate::{calculate_message_tokens, DeltaReceiver};
5use crate::{Chat, OPENAI_API_KEY};
6use crate::{Function, Message};
7use log::{error, trace};
8use reqwest::Method;
9use reqwest_eventsource::RequestBuilderExt;
10use schemars::JsonSchema;
11use serde::Deserialize;
12use std::{collections::HashMap, vec};
13use serde_json::to_string_pretty;
14use tokio::sync::mpsc;
15
16#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
17pub struct ChatCompletionRequest {
18    pub model: String,
19    pub messages: Vec<Message>,
20
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub functions: Option<Vec<Function>>,
23
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub function_call: Option<String>,
26
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub temperature: Option<f64>,
29
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub top_p: Option<f64>,
32
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub n: Option<u64>,
35
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub stream: Option<bool>,
38
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub stop: Option<Vec<String>>,
41
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub max_tokens: Option<u64>,
44
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub presence_penalty: Option<f64>,
47
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub frequency_penalty: Option<f64>,
50
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub logit_bias: Option<HashMap<u64, f64>>,
53
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub user: Option<String>,
56}
57
58impl ChatCompletionRequest {
59    fn new() -> Self {
60        Self {
61            model: "gpt-3.5-turbo".to_string(),
62            messages: vec![],
63            functions: None,
64            function_call: None,
65            temperature: None,
66            top_p: None,
67            n: None,
68            stream: None,
69            stop: None,
70            max_tokens: None,
71            presence_penalty: None,
72            frequency_penalty: None,
73            logit_bias: None,
74            user: None,
75        }
76    }
77}
78
79#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
80pub struct AiAgent {
81    pub model: String,
82
83    pub system_message: Option<Message>,
84
85    pub messages: Vec<Message>,
86
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub functions: Option<Vec<Function>>,
89
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub function_call: Option<String>,
92
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub temperature: Option<f64>,
95
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub top_p: Option<f64>,
98
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub n: Option<u64>,
101
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub stop: Option<Vec<String>>,
104
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub max_tokens: Option<u64>,
107
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub presence_penalty: Option<f64>,
110
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub frequency_penalty: Option<f64>,
113
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub logit_bias: Option<HashMap<u64, f64>>,
116
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub user: Option<String>,
119}
120
121impl AiAgent {
122    // request part
123    pub fn build_request(&self, stream: bool) -> ChatCompletionRequest {
124        let messages = if let Some(system_message) = &self.system_message {
125            let mut messages = self.messages.clone();
126            messages.insert(0, system_message.clone());
127            messages
128        } else {
129            self.messages.clone()
130        };
131
132        ChatCompletionRequest {
133            model: self.model.clone(),
134            messages,
135            functions: self.functions.clone(),
136            function_call: self.function_call.clone(),
137            temperature: self.temperature,
138            top_p: self.top_p,
139            n: self.n,
140            stream: Some(stream),
141            stop: self.stop.clone(),
142            max_tokens: self.max_tokens,
143            presence_penalty: self.presence_penalty,
144            frequency_penalty: self.frequency_penalty,
145            logit_bias: self.logit_bias.clone(),
146            user: self.user.clone(),
147        }
148    }
149
150    pub async fn create(&self) -> UtilsResult<Chat> {
151        let api_key = OPENAI_API_KEY.read().expect("failed to get lock").clone().ok_or_else(|| InternalError::ConfigurationError("API key not set".to_string()))?;
152        
153        trace!("request body: {}", to_string_pretty(&self.build_request(false)).unwrap());
154        let req = reqwest::Client::new()
155            .post("https://api.openai.com/v1/chat/completions")
156            .json(&self.build_request(false))
157            .bearer_auth(api_key)
158            .header("Content-Type", "application/json")
159            .send()
160            .await.map_err(|e| InternalError::RequestBuildError(e))?;
161
162        let res = req.text().await.map_err(|e| InternalError::RequestBuildError(e))?;
163        serialize(&res)
164    }
165
166    pub async fn create_stream(&self) -> UtilsResult<DeltaReceiver> {
167        let api_key = OPENAI_API_KEY.read()
168            .expect("failed to get lock")
169            .as_ref()
170            .ok_or_else(|| InternalError::ConfigurationError("API key not set".to_string()))?
171            .to_string();
172
173        let (tx, rx) = mpsc::channel(64);
174        trace!("request body: {}", to_string_pretty(&self.build_request(true)).unwrap());
175        let es = reqwest::Client::new()
176            .request(Method::POST, "https://api.openai.com/v1/chat/completions")
177            .json(&self.build_request(true))
178            .bearer_auth(api_key)
179            .header("Content-Type", "application/json")
180            .eventsource()
181            .expect("cannot create eventsource? shouldn't happen i think.");
182
183        tokio::spawn(async move {
184            if let Err(e) = forward_stream(es, tx).await {
185                error!("Error in forward_stream: {}", e);
186            }
187        });
188
189        let usage = self.build_request(true).messages.iter().fold(3, |acc, m| {
190            acc + calculate_message_tokens(m) + 4
191        });
192
193        Ok(DeltaReceiver::from(rx, self, usage))
194    }
195
196
197    // builder part
198
199    pub fn new(model: impl Into<String>) -> Self {
200        Self {
201            model: model.into(),
202            system_message: None,
203            messages: vec![],
204            functions: None,
205            function_call: None,
206            temperature: None,
207            top_p: None,
208            n: None,
209            stop: None,
210            max_tokens: None,
211            presence_penalty: None,
212            frequency_penalty: None,
213            logit_bias: None,
214            user: None,
215        }
216    }
217
218    pub fn with_system_message<'a>(mut self, system_message: impl Into<&'a str>) -> Self {
219        self.system_message = Some(Message::new("system").with_content(system_message.into()));
220        self
221    }
222
223    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
224        self.messages = messages;
225        self
226    }
227
228    pub fn with_function_call(mut self, function_call: impl Into<String>) -> Self {
229        self.function_call = Some(function_call.into());
230        self
231    }
232
233    pub fn with_temperature(mut self, temperature: f64) -> Self {
234        self.temperature = Some(temperature);
235        self
236    }
237
238    pub fn with_top_p(mut self, top_p: f64) -> Self {
239        self.top_p = Some(top_p);
240        self
241    }
242
243    pub fn with_n(mut self, n: u64) -> Self {
244        self.n = Some(n);
245        self
246    }
247
248    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
249        self.stop = Some(stop);
250        self
251    }
252
253    pub fn with_max_tokens(mut self, max_tokens: u64) -> Self {
254        self.max_tokens = Some(max_tokens);
255        self
256    }
257
258    pub fn with_presence_penalty(mut self, presence_penalty: f64) -> Self {
259        self.presence_penalty = Some(presence_penalty);
260        self
261    }
262
263    pub fn with_frequency_penalty(mut self, frequency_penalty: f64) -> Self {
264        self.frequency_penalty = Some(frequency_penalty);
265        self
266    }
267
268    pub fn with_logit_bias(mut self, logit_bias: HashMap<u64, f64>) -> Self {
269        self.logit_bias = Some(logit_bias);
270        self
271    }
272
273    pub fn with_user(mut self, user: impl Into<String>) -> Self {
274        self.user = Some(user.into());
275        self
276    }
277
278    // mutably update part
279
280    pub fn push_message(&mut self, message: Message) {
281        self.messages.push(message);
282    }
283
284    pub fn push_function<FunctionArgs, Func, T>(&mut self, function: &Func, function_name: &str)
285    where
286        FunctionArgs: JsonSchema,
287        Func: FnMut(FunctionArgs) -> T,
288    {
289        if let Some(functions) = &mut self.functions {
290            functions.push(Function::from(function, function_name));
291        } else {
292            self.functions = Some(vec![Function::from(function, function_name)]);
293        }
294    }
295
296    pub fn push_stop(&mut self, stop: impl Into<String>) {
297        if let Some(stops) = &mut self.stop {
298            stops.push(stop.into());
299        } else {
300            self.stop = Some(vec![stop.into()]);
301        }
302    }
303
304    pub fn push_logit_bias(&mut self, logit_bias: (u64, f64)) {
305        if let Some(logit_biases) = &mut self.logit_bias {
306            logit_biases.insert(logit_bias.0, logit_bias.1);
307        } else {
308            let mut logit_biases = HashMap::new();
309            logit_biases.insert(logit_bias.0, logit_bias.1);
310            self.logit_bias = Some(logit_biases);
311        }
312    }
313}
314
315pub fn serialize<'a, T: Deserialize<'a>>(res: &'a str) -> UtilsResult<T> {
316    match serde_json::from_str::<T>(res) {
317        Ok(chat) => Ok(chat),
318        Err(_) => {
319            #[derive(Deserialize)]
320            struct TempWrapper {
321                error: OpenAIError
322            }
323
324            let err =
325                serde_json::from_str::<TempWrapper>(res).unwrap_or_else(|_| panic!("{}", res));
326            Err(err.error.into())
327        }
328    }
329}