1use std::{collections::VecDeque, error::Error};
13use tokio::sync::Mutex;
14
15use serde::{Deserialize, Serialize};
16
17use crate::SendRequest;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ChatMessage {
22 pub role: Role,
23 pub content: Option<String>,
24}
25
26impl Default for ChatMessage {
27 fn default() -> Self {
28 Self {
29 role: Role::User,
30 content: Some(String::new()),
31 }
32 }
33}
34
35#[derive(Debug, Deserialize, Serialize)]
36pub struct Usage {
38 pub prompt_tokens: u32,
39 pub completion_tokens: u32,
40 pub total_tokens: u32,
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44pub struct ChatChoice {
46 pub index: u32,
47 pub message: ChatMessage,
48 pub finish_reason: Option<String>,
49}
50
51#[derive(Debug, Deserialize, Serialize)]
52pub struct ChatResponse {
54 pub id: String,
55 pub object: String,
56 pub created: u64,
57 pub choices: Vec<ChatChoice>,
58 pub usage: Usage,
59}
60#[derive(Debug, Clone)]
61pub enum Role {
63 User,
64 Assistant,
65 System,
66}
67
68impl Serialize for Role {
69 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: serde::Serializer,
72 {
73 serializer.serialize_str(&self.to_string())
74 }
75}
76
77impl<'de> Deserialize<'de> for Role {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: serde::Deserializer<'de>,
81 {
82 let s = String::deserialize(deserializer)?;
83
84 Role::try_from(s.as_str()).map_err(serde::de::Error::custom)
85 }
86}
87
88impl ToString for Role {
89 fn to_string(&self) -> String {
90 match self {
91 Role::User => "user",
92 Role::Assistant => "assistant",
93 Role::System => "system",
94 }
95 .to_string()
96 }
97}
98
99impl TryFrom<&str> for Role {
100 type Error = Box<dyn Error>;
101
102 fn try_from(value: &str) -> Result<Self, Self::Error> {
103 match value {
104 "user" => Ok(Role::User),
105 "assistant" => Ok(Role::Assistant),
106 "system" => Ok(Role::System),
107 _ => Err("Invalid Role".into()),
108 }
109 }
110}
111
112pub struct ChatBuilder {
117 system: ChatMessage,
118 chat_parameters: ChatParameters,
119 api_key: String,
120 model: crate::ChatModel,
121 len: usize,
122}
123
124impl ChatBuilder {
125 pub fn new(model: crate::ChatModel, api_key: String) -> Self {
127 let default_msg = ChatMessage {
128 role: Role::System,
129 ..Default::default()
130 };
131
132 ChatBuilder {
133 model,
134 api_key,
135 system: default_msg,
136 chat_parameters: ChatParameters::default(),
137 len: 5,
138 }
139 }
140
141 pub fn len(mut self, len: usize) -> Self {
143 self.len = len;
144 self
145 }
146
147 pub fn system(mut self, system: ChatMessage) -> Self {
149 self.system = system;
150 self
151 }
152
153 pub fn temperature(mut self, temperature: f32) -> Self {
155 self.chat_parameters.temperature = Some(temperature);
156 self
157 }
158
159 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
161 self.chat_parameters.max_tokens = Some(max_tokens);
162 self
163 }
164
165 pub fn top_p(mut self, top_p: f32) -> Self {
167 self.chat_parameters.top_p = Some(top_p);
168 self
169 }
170
171 pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
173 self.chat_parameters.presence_penalty = Some(presence_penalty);
174 self
175 }
176
177 pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
179 self.chat_parameters.frequency_penalty = Some(frequency_penalty);
180 self
181 }
182
183 pub fn user(mut self, user: String) -> Self {
185 self.chat_parameters.user = Some(user);
186 self
187 }
188
189 pub fn build(self) -> Chat {
191 Chat::new(
192 self.system,
193 self.model,
194 self.len,
195 self.api_key,
196 self.chat_parameters,
197 )
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[doc(hidden)]
203#[derive(Default)]
204pub struct ChatParameters {
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub temperature: Option<f32>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub max_tokens: Option<u32>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub top_p: Option<f32>,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 pub presence_penalty: Option<f32>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 pub frequency_penalty: Option<f32>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub user: Option<String>,
217}
218
219pub struct Chat {
233 system: ChatMessage,
234 chat_parameters: ChatParameters,
235 api_key: String,
236 model: crate::ChatModel,
237 len: usize,
238 messages: Mutex<VecDeque<ChatMessage>>,
239 message_queue: Mutex<VecDeque<ChatMessage>>,
240}
241
242impl Chat {
243 fn new<T: ToString>(
244 system: ChatMessage,
245 model: crate::ChatModel,
246 len: usize,
247 api_key: T,
248 chat_parameters: ChatParameters,
249 ) -> Self {
250 Self {
251 system,
252 chat_parameters,
253 api_key: api_key.to_string(),
254 model,
255 len: len * 2 + 2,
256 messages: Mutex::new(VecDeque::new()),
257 message_queue: Mutex::new(VecDeque::new()),
258 }
259 }
260
261 pub async fn get_messages(&self) -> Vec<ChatMessage> {
263 let mut messages = self.messages.lock().await.clone();
264
265 messages.push_front(self.system.clone());
266
267 messages.into()
268 }
269
270 pub async fn ask(&self, message: &str) -> Result<(), Box<dyn Error>> {
272 let msg = ChatMessage {
273 role: Role::User,
274 content: Some(message.to_string()),
275 };
276
277 self.message_queue.lock().await.push_back(msg);
278 Ok(())
279 }
280
281 pub async fn get_response(&self, user: Option<String>) -> Result<ChatMessage, Box<dyn Error>> {
283
284 let msg = if let Some(message) = self.message_queue.lock().await.pop_front() {
288 message
289 } else {
290 return Err("No message to send".into());
291 };
292
293 let mut messages = self.messages.lock().await;
294
295 if messages.len() >= self.len {
296 messages.pop_front();
297 messages.pop_front();
298 }
300
301 messages.push_back(msg.clone());
302
303 let mut to_send = messages.clone();
304 to_send.push_front(self.system.clone());
305
306 let builder = crate::RequestBuilder::new(self.model.clone(), self.api_key.clone())
307 .messages(to_send.into())
308 .chat_parameters(self.chat_parameters.clone());
309
310 let builder = if let Some(user) = user {
311 builder.user(user)
312 } else {
313 builder
314 };
315
316 let req = builder.build_chat();
317
318 let resp = match req.send().await {
319 Ok(resp) => resp,
320 Err(e) => {
321 messages.pop_back(); return Err(e.into());
323 }
324 };
325
326 let message = resp.choices[0].message.clone();
327
328 messages.push_back(message.clone());
329
330 Ok(message)
331 }
332}