rust_gpt/
chat.rs

1//! # Chat API
2//!
3//! The chat API is used to have a conversation with the GPT-3.5 model which runs ChatGPT.  
4//!
5//! The main structs used in here are [`ChatResponse`] and [`ChatMessage`].
6//!
7//! ## Chat
8//! This is a new experimental struct that allows you to have a conversation with the GPT-3.5 model.
9//! It will automatically remember the messages you send and the messages the model sends so the model can remember the conversation.
10//!
11//! See the [`ChatBuilder`] and [`Chat`] structs for more information.
12use std::{collections::VecDeque, error::Error};
13use tokio::sync::Mutex;
14
15use serde::{Deserialize, Serialize};
16
17use crate::SendRequest;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20/// Represents one of the messages sent to or received from the chat API.
21pub struct ChatMessage {
22    pub role: Role,
23    pub content: Option<String>,
24}
25
26impl Default for ChatMessage {
27    fn default() -> Self {
28        Self {
29            role: Role::User,
30            content: Some(String::new()),
31        }
32    }
33}
34
35#[derive(Debug, Deserialize, Serialize)]
36/// Represents the usage information returned by the chat API.
37pub struct Usage {
38    pub prompt_tokens: u32,
39    pub completion_tokens: u32,
40    pub total_tokens: u32,
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44/// Represents the choice object returned by the chat API.
45pub struct ChatChoice {
46    pub index: u32,
47    pub message: ChatMessage,
48    pub finish_reason: Option<String>,
49}
50
51#[derive(Debug, Deserialize, Serialize)]
52/// Represents a response from the chat API.
53pub struct ChatResponse {
54    pub id: String,
55    pub object: String,
56    pub created: u64,
57    pub choices: Vec<ChatChoice>,
58    pub usage: Usage,
59}
60#[derive(Debug, Clone)]
61/// Represents one of the roles that can be used in the chat API.
62pub enum Role {
63    User,
64    Assistant,
65    System,
66}
67
68impl Serialize for Role {
69    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: serde::Serializer,
72    {
73        serializer.serialize_str(&self.to_string())
74    }
75}
76
77impl<'de> Deserialize<'de> for Role {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: serde::Deserializer<'de>,
81    {
82        let s = String::deserialize(deserializer)?;
83
84        Role::try_from(s.as_str()).map_err(serde::de::Error::custom)
85    }
86}
87
88impl ToString for Role {
89    fn to_string(&self) -> String {
90        match self {
91            Role::User => "user",
92            Role::Assistant => "assistant",
93            Role::System => "system",
94        }
95        .to_string()
96    }
97}
98
99impl TryFrom<&str> for Role {
100    type Error = Box<dyn Error>;
101
102    fn try_from(value: &str) -> Result<Self, Self::Error> {
103        match value {
104            "user" => Ok(Role::User),
105            "assistant" => Ok(Role::Assistant),
106            "system" => Ok(Role::System),
107            _ => Err("Invalid Role".into()),
108        }
109    }
110}
111
112// ----------------------------------------------------
113// new unstable chat thing
114
115/// Builds a [`Chat`] struct for initiating a chat session.
116pub struct ChatBuilder {
117    system: ChatMessage,
118    chat_parameters: ChatParameters,
119    api_key: String,
120    model: crate::ChatModel,
121    len: usize,
122}
123
124impl ChatBuilder {
125    /// Creates a new [`ChatBuilder`] with the given model and API key.
126    pub fn new(model: crate::ChatModel, api_key: String) -> Self {
127        let default_msg = ChatMessage {
128            role: Role::System,
129            ..Default::default()
130        };
131
132        ChatBuilder {
133            model,
134            api_key,
135            system: default_msg,
136            chat_parameters: ChatParameters::default(),
137            len: 5,
138        }
139    }
140
141    /// Sets the amount of user messages that are stored in the chat session.
142    pub fn len(mut self, len: usize) -> Self {
143        self.len = len;
144        self
145    }
146
147    /// Sets the system message that is sent to the chat API
148    pub fn system(mut self, system: ChatMessage) -> Self {
149        self.system = system;
150        self
151    }
152
153    /// Sets the temperature
154    pub fn temperature(mut self, temperature: f32) -> Self {
155        self.chat_parameters.temperature = Some(temperature);
156        self
157    }
158
159    /// Sets the maximum amount of tokens that can be generated by the chat API.
160    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
161        self.chat_parameters.max_tokens = Some(max_tokens);
162        self
163    }
164
165    /// Sets the top_p parameter
166    pub fn top_p(mut self, top_p: f32) -> Self {
167        self.chat_parameters.top_p = Some(top_p);
168        self
169    }
170
171    /// Sets the presence penalty
172    pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
173        self.chat_parameters.presence_penalty = Some(presence_penalty);
174        self
175    }
176
177    /// Sets the frequency penalty
178    pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
179        self.chat_parameters.frequency_penalty = Some(frequency_penalty);
180        self
181    }
182
183    /// Sets the user
184    pub fn user(mut self, user: String) -> Self {
185        self.chat_parameters.user = Some(user);
186        self
187    }
188
189    /// Builds the [`Chat`] struct.
190    pub fn build(self) -> Chat {
191        Chat::new(
192            self.system,
193            self.model,
194            self.len,
195            self.api_key,
196            self.chat_parameters,
197        )
198    }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[doc(hidden)]
203#[derive(Default)]
204pub struct ChatParameters {
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub temperature: Option<f32>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub max_tokens: Option<u32>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub top_p: Option<f32>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub presence_penalty: Option<f32>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub frequency_penalty: Option<f32>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub user: Option<String>,
217}
218
219/// A struct that represents a chat session
220/// This struct makes it easy to interact with the chat api, as well as remembering messages.
221/// This struct guarantees that messages are sent and stored in the order that [`ask`] is called.
222///
223/// [`ask`]: #method.ask
224///
225/// Requests to the API are only sent when [`get_response`] is called.
226///
227/// [`get_response`]: #method.get_response
228///
229/// You can build a new chat session with [`ChatBuilder`].
230///
231/// [`ChatBuilder`]: ./struct.ChatBuilder.html
232pub struct Chat {
233    system: ChatMessage,
234    chat_parameters: ChatParameters,
235    api_key: String,
236    model: crate::ChatModel,
237    len: usize,
238    messages: Mutex<VecDeque<ChatMessage>>,
239    message_queue: Mutex<VecDeque<ChatMessage>>,
240}
241
242impl Chat {
243    fn new<T: ToString>(
244        system: ChatMessage,
245        model: crate::ChatModel,
246        len: usize,
247        api_key: T,
248        chat_parameters: ChatParameters,
249    ) -> Self {
250        Self {
251            system,
252            chat_parameters,
253            api_key: api_key.to_string(),
254            model,
255            len: len * 2 + 2,
256            messages: Mutex::new(VecDeque::new()),
257            message_queue: Mutex::new(VecDeque::new()),
258        }
259    }
260
261    /// Get the messages that have been sent and received including the system and assistan messages.
262    pub async fn get_messages(&self) -> Vec<ChatMessage> {
263        let mut messages = self.messages.lock().await.clone();
264
265        messages.push_front(self.system.clone());
266
267        messages.into()
268    }
269
270    /// Adds a message to the queue to be sent to the API.
271    pub async fn ask(&self, message: &str) -> Result<(), Box<dyn Error>> {
272        let msg = ChatMessage {
273            role: Role::User,
274            content: Some(message.to_string()),
275        };
276
277        self.message_queue.lock().await.push_back(msg);
278        Ok(())
279    }
280
281    /// Sends the message history to the API including the last question asked, and returns the response.
282    pub async fn get_response(&self, user: Option<String>) -> Result<ChatMessage, Box<dyn Error>> {
283
284        // the pushing and popping is in reverse order because we want to order the messages
285        // in the API from oldest to newest.
286
287        let msg = if let Some(message) = self.message_queue.lock().await.pop_front() {
288            message
289        } else {
290            return Err("No message to send".into());
291        };
292
293        let mut messages = self.messages.lock().await;
294
295        if messages.len() >= self.len {
296            messages.pop_front();
297            messages.pop_front();
298            // pop the oldest user + assistant message
299        }
300
301        messages.push_back(msg.clone());
302
303        let mut to_send = messages.clone();
304        to_send.push_front(self.system.clone());
305
306        let builder = crate::RequestBuilder::new(self.model.clone(), self.api_key.clone())
307            .messages(to_send.into())
308            .chat_parameters(self.chat_parameters.clone());
309
310        let builder = if let Some(user) = user {
311            builder.user(user)
312        } else {
313            builder
314        };
315
316        let req = builder.build_chat();
317
318        let resp = match req.send().await {
319            Ok(resp) => resp,
320            Err(e) => {
321                messages.pop_back(); // remove the message we just added
322                return Err(e.into());
323            }
324        };
325
326        let message = resp.choices[0].message.clone();
327
328        messages.push_back(message.clone());
329
330        Ok(message)
331    }
332}