chatgpt_functions/
chat_gpt.rs

1use anyhow::{Context, Result};
2use uuid::Uuid;
3
4use crate::{
5    chat_context::ChatContext, chat_response::ChatResponse,
6    function_specification::FunctionSpecification, message::Message,
7};
8
9const DEFAULT_MODEL: &str = "gpt-3.5-turbo-0613";
10const URL: &str = "https://api.openai.com/v1/chat/completions";
11
12// Builder for ChatGPT
13pub struct ChatGPTBuilder {
14    model: Option<String>,
15    openai_api_token: Option<String>,
16    session_id: Option<String>,
17    chat_context: Option<ChatContext>,
18}
19
20impl ChatGPTBuilder {
21    pub fn new() -> Self {
22        ChatGPTBuilder {
23            model: None,
24            openai_api_token: None,
25            session_id: None,
26            chat_context: None,
27        }
28    }
29
30    pub fn model(mut self, model: String) -> Self {
31        self.model = Some(model);
32        self
33    }
34
35    pub fn openai_api_token(mut self, openai_api_token: String) -> Self {
36        self.openai_api_token = Some(openai_api_token);
37        self
38    }
39
40    pub fn session_id(mut self, session_id: String) -> Self {
41        self.session_id = Some(session_id);
42        self
43    }
44
45    pub fn chat_context(mut self, chat_context: ChatContext) -> Self {
46        self.chat_context = Some(chat_context);
47        self
48    }
49
50    pub fn build(self) -> Result<ChatGPT> {
51        let client = reqwest::Client::new();
52        let model = if let Some(m) = self.model {
53            m
54        } else {
55            DEFAULT_MODEL.to_string()
56        };
57        let openai_api_token = self
58            .openai_api_token
59            .context("OpenAI API token is missing")?;
60        let session_id = if let Some(s) = self.session_id {
61            s
62        } else {
63            Uuid::new_v4().to_string()
64        };
65        let chat_context = if let Some(c) = self.chat_context {
66            c
67        } else {
68            let mut c = ChatContext::new(model.clone());
69            c.model = model.clone();
70            c
71        };
72
73        Ok(ChatGPT {
74            client,
75            model,
76            openai_api_token,
77            session_id,
78            chat_context,
79        })
80    }
81}
82
83/// The ChatGPT object
84pub struct ChatGPT {
85    client: reqwest::Client,
86    pub model: String,
87    openai_api_token: String,
88    pub session_id: String,
89    pub chat_context: ChatContext,
90}
91
92impl ChatGPT {
93    /// Create a new ChatGPT object
94    /// # Arguments
95    /// * `openai_api_token` - The API token from OpenAI
96    /// * `chat_context` - The context of the chatbot.
97    /// Optional. If not provided, it will start a new context with the default model
98    /// * `session_id` - The session ID of the chatbot.
99    /// Optional. If not provided, it will generate a new session ID. This will be useful to track the conversation history
100    /// # Example
101    /// ```
102    /// use chatgpt_functions::chat_gpt::ChatGPTBuilder;
103    /// use anyhow::Result;
104    ///
105    /// #[tokio::main]
106    /// async fn main() -> Result<()> {
107    ///     let key = std::env::var("OPENAI_API_KEY").unwrap_or("test".to_string());
108    ///     let mut gpt = ChatGPTBuilder::new().openai_api_token(key).build()?;
109    ///     Ok(())
110    /// }
111    /// ```
112    /// # Errors
113    /// It returns an error if the API token is not valid
114    /// # Panics
115    /// It panics if the API token is not provided
116    /// # Remarks
117    /// The API token can be found on the [OpenAI API keys](https://platform.openai.com/account/api-keys)
118    pub fn new(
119        client: reqwest::Client,
120        model: String,
121        openai_api_token: String,
122        session_id: String,
123        chat_context: ChatContext,
124    ) -> Result<ChatGPT> {
125        Ok(ChatGPT {
126            client,
127            model,
128            openai_api_token,
129            session_id,
130            chat_context,
131        })
132    }
133
134    /// Calls the OpenAI API to get a response using the current context
135    /// # Arguments
136    /// * `message` - The message to send to the AI
137    /// # Errors
138    /// It returns an error if the API token is not valid
139    /// It returns an error if the response from the API is not valid or if the content of the response is not valid
140    /// # Panics
141    /// It panics if the API token is not provided
142    /// # Remarks
143    /// The context is updated with the response from the AI
144    pub async fn completion(&mut self) -> Result<ChatResponse> {
145        let response = self
146            .client
147            .post(URL)
148            .bearer_auth(&self.openai_api_token)
149            .header("Content-Type", "application/json")
150            // Use Display trait to avoid sending None fields that the API would reject
151            .body(self.chat_context.to_string())
152            .send()
153            .await
154            .context(format!("Failed to receive the response from {}", URL))?
155            .text()
156            .await
157            .context("Failed to retrieve the content of the response")?;
158
159        let answer = parse_removing_newlines(response)?;
160        Ok(answer)
161    }
162
163    /// Calls the OpenAI API to get a response using the current context, adding the content provided by the user
164    /// This is the preferred function to use for chat completions that work with context.
165    ///
166    /// This is a fully managed function, it does update the context with the message provided,
167    /// and it does update the context with the response from the AI.
168    /// It calls completion_with_user_content_updating_context internally, it's for convenience.
169    /// # Arguments
170    /// * `content` - The content of the message
171    /// # Errors
172    /// It returns an error if the API token is not valid
173    /// It returns an error if the response from the API is not valid or if the content of the response is not valid
174    /// # Panics
175    /// It panics if the API token is not provided
176    /// # Remarks
177    /// This is a fully managed function, it does update the context with the message provided,
178    /// and it does update the context with the response from the AI.
179    pub async fn completion_managed(&mut self, content: String) -> Result<ChatResponse> {
180        self.completion_with_user_content_updating_context(content)
181            .await
182    }
183
184    /// This function is used to call the openai API, using a Message already prepared.
185    /// It requires a Message object as an argument, so access to some internal work of the library.
186    /// This gives more flexibility to the user, but it is not recommended to use it directly.
187    /// It returns the response from the AI
188    /// It does update the context with the message provided,
189    /// but it does not update the context with the response from the AI
190    /// # Arguments
191    /// * `message` - The message to send to the AI
192    /// # Errors
193    /// It returns an error if the API token is not valid
194    /// It returns an error if the response from the API is not valid or if the content of the response is not valid
195    /// # Remarks
196    /// The context is updated with the message provided
197    /// The context is not updated with the response from the AI
198    /// This function is used by the other functions of the library
199    /// It is not recommended to use it directly
200    pub async fn completion_with_message(&mut self, message: Message) -> Result<ChatResponse> {
201        self.push_message(message);
202        self.completion().await
203    }
204
205    /// This function is used to call the openai API, using a String as the content of the message.
206    /// It returns the response from the AI
207    /// It does update the context with the message provided,
208    /// but it does not update the context with the response from the AI
209    /// # Arguments
210    /// * `content` - The content of the message
211    /// # Errors
212    /// It returns an error if the API token is not valid
213    /// It returns an error if the response from the API is not valid or if the content of the response is not valid
214    /// # Remarks
215    /// The context is updated with the message provided
216    /// The context is not updated with the response from the AI
217    /// This function is used by the other functions of the library
218    /// It is not recommended to use it directly
219    pub async fn completion_with_user_content(&mut self, content: String) -> Result<ChatResponse> {
220        let message = Message::new_user_message(content);
221        self.completion_with_message(message).await
222    }
223
224    /// This function is used to call the openai API, using content as the content of the message.
225    /// It returns the response from the AI
226    /// It does update the context with the message provided and the response from the AI
227    /// # Arguments
228    /// * `content` - The content of the message
229    /// # Errors
230    /// It returns an error if the API token is not valid
231    /// It returns an error if the response from the API is not valid or if the content of the response is not valid
232    /// # Remarks
233    /// The context is updated with the message provided
234    /// The context is updated with the response from the AI
235    /// This function is used by the other functions of the library
236    /// It assumes that there will only be one choice in the response
237    /// It returns the response from the AI
238    pub async fn completion_with_user_content_updating_context(
239        &mut self,
240        content: String,
241    ) -> Result<ChatResponse> {
242        let message = Message::new_user_message(content);
243        self.completion_with_message_updating_context(message).await
244    }
245
246    /// This function is used to update the context with the response from the AI
247    /// It assumes that there will only be one choice in the response
248    /// It returns the response from the AI
249    /// It does update the context with the response from the AI
250    /// # Arguments
251    /// * `message` - The message to send to the AI
252    /// # Errors
253    /// It returns an error if the API token is not valid
254    /// It returns an error if the response from the API is not valid or if the content of the response is not valid
255    /// # Remarks
256    /// Important: The message received from the AI has to be modified when it is a function
257    /// This is because when a function is returned the model still says that it is an assistant message.
258    /// This is a bug in the API.
259    /// If this is inserted in the context, the next request to the API will fail since it won't conform with the rules of the model.
260    /// https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages
261    ///
262    /// The context is updated with the response from the AI
263    /// This function is used by the other functions of the library
264    /// It assumes that there will only be one choice in the response
265    /// It panics if there is more than one choice in the response
266    pub async fn completion_with_message_updating_context(
267        &mut self,
268        message: Message,
269    ) -> Result<ChatResponse> {
270        self.push_message(message);
271        let response = self.completion().await?;
272        if let Some(choice) = response.choices.last() {
273            self.push_message(choice.message.clone());
274        };
275        Ok(response)
276    }
277
278    /// This function is used to push a message to the context
279    /// This is a low level function, it is not recommended to use it directly
280    /// # Arguments
281    /// * `message` - The message to push to the context
282    /// # Remarks
283    /// This function is used by the other functions of the library
284    pub fn push_message(&mut self, message: Message) {
285        self.chat_context.push_message(message);
286    }
287
288    /// This function is used to set all the messages in the context
289    /// This will override the current messages in the context
290    /// This is a low level function, it is not recommended to use it directly
291    /// # Arguments
292    /// * `messages` - The messages to set in the context
293    /// # Remarks
294    /// This function is used by the other functions of the library
295    pub fn set_messages(&mut self, messages: Vec<Message>) {
296        self.chat_context.set_messages(messages);
297    }
298
299    /// This function is used to push a function to the context
300    /// This is a low level function, it is not recommended to use it directly
301    /// # Arguments
302    /// * `function` - The function to push to the context
303    /// # Remarks
304    /// This function is used by the other functions of the library
305    pub fn push_function(&mut self, function: FunctionSpecification) {
306        self.chat_context.push_function(function);
307    }
308
309    /// This function is used to set all the functions in the context
310    /// This will override the current functions in the context
311    /// This is a low level function, it is not recommended to use it directly
312    /// # Arguments
313    /// * `functions` - The vec of functions to set in the context
314    /// # Remarks
315    /// This function is used by the other functions of the library
316    pub fn set_functions(&mut self, functions: Vec<FunctionSpecification>) {
317        self.chat_context.set_functions(functions);
318    }
319
320    /// This function is used to retrieve the content of the last message in the context
321    pub fn last_content(&self) -> Option<String> {
322        self.chat_context.last_content()
323    }
324
325    /// This function is used to retrieve the function_call of the last message in the context
326    pub fn last_function(&self) -> Option<(String, String)> {
327        self.chat_context.last_function_call()
328    }
329}
330
331fn parse_removing_newlines(response: String) -> Result<ChatResponse> {
332    let r = response.replace("\n", "");
333    let response: ChatResponse = serde_json::from_str(&r).context(format!(
334        "Could not parse the response. The object to parse: \n{}",
335        r
336    ))?;
337    Ok(response)
338}
339
340#[cfg(test)]
341mod tests {
342    use std::collections::HashMap;
343
344    use crate::{function_specification::Parameters, message::FunctionCall};
345
346    use super::*;
347
348    #[test]
349    fn test_chat_gpt_new() {
350        let chat_gpt = ChatGPTBuilder::new()
351            .openai_api_token("123".to_string())
352            .build()
353            .expect("Failed to create ChatGPT");
354        assert_eq!(chat_gpt.session_id.len(), 36);
355        assert_eq!(chat_gpt.chat_context.model, DEFAULT_MODEL);
356        assert_eq!(chat_gpt.model, DEFAULT_MODEL);
357    }
358
359    #[test]
360    fn test_chat_gpt_new_with_everything() {
361        let chat_gpt = ChatGPTBuilder::new()
362            .session_id("session_id".to_string())
363            .model("model".to_string())
364            .openai_api_token("1234".to_string())
365            .build()
366            .expect("Failed to create ChatGPT");
367        assert_eq!(chat_gpt.session_id, "session_id");
368        assert_eq!(chat_gpt.openai_api_token, "1234");
369        assert_eq!(chat_gpt.chat_context.model, "model");
370    }
371
372    #[test]
373    fn test_chat_gpt_push_message() {
374        let mut chat_gpt = ChatGPTBuilder::new()
375            .openai_api_token("key".to_string())
376            .build()
377            .expect("Failed to create ChatGPT");
378        let message = Message::new_user_message("content".to_string());
379        chat_gpt.push_message(message);
380        assert_eq!(chat_gpt.chat_context.messages.len(), 1);
381    }
382
383    #[test]
384    fn test_chat_gpt_set_message() {
385        let mut chat_gpt = ChatGPTBuilder::new()
386            .openai_api_token("key".to_string())
387            .build()
388            .expect("Failed to create ChatGPT");
389        let message = Message::new_user_message("content".to_string());
390        chat_gpt.set_messages(vec![message]);
391        assert_eq!(chat_gpt.chat_context.messages.len(), 1);
392    }
393
394    #[test]
395    fn test_chat_gpt_push_function() {
396        let mut chat_gpt = ChatGPTBuilder::new()
397            .openai_api_token("key".to_string())
398            .build()
399            .expect("Failed to create ChatGPT");
400        let function = FunctionSpecification::new("function".to_string(), None, None);
401        chat_gpt.push_function(function);
402        assert_eq!(chat_gpt.chat_context.functions.len(), 1);
403    }
404
405    #[test]
406    fn test_chat_gpt_set_function() {
407        let mut chat_gpt = ChatGPTBuilder::new()
408            .openai_api_token("key".to_string())
409            .build()
410            .expect("Failed to create ChatGPT");
411        let function = FunctionSpecification::new(
412            "function".to_string(),
413            Some("Test function".to_string()),
414            Some(Parameters {
415                type_: "string".to_string(),
416                properties: HashMap::new(),
417                required: vec![],
418            }),
419        );
420        chat_gpt.set_functions(vec![function]);
421        assert_eq!(chat_gpt.chat_context.functions.len(), 1);
422
423        let function = chat_gpt
424            .chat_context
425            .functions
426            .get(0)
427            .expect("Failed to get the function");
428        assert_eq!(function.name, "function");
429        assert_eq!(
430            function
431                .description
432                .as_ref()
433                .expect("Failed to get the description"),
434            "Test function"
435        );
436        assert_eq!(
437            function
438                .parameters
439                .as_ref()
440                .expect("Failed to get the parameters")
441                .type_,
442            "string"
443        );
444    }
445
446    #[test]
447    fn test_parse_removing_newlines() {
448        use crate::message::FunctionCall;
449
450        let r = r#"{
451    "id": "chatcmpl-7Ut7jsNlTUO9k9L5kBF0uDAyG19pK",
452    "object": "chat.completion",
453    "created": 1687596091,
454    "model": "gpt-3.5-turbo-0613",
455    "choices": [
456        {
457        "index": 0,
458        "message": {
459            "role": "assistant",
460            "content": null,
461            "function_call": {
462                "name": "get_current_weather",
463                "arguments": "{\n  \"location\": \"Madrid, Spain\"\n}"
464            }
465        },
466        "finish_reason": "function_call"
467        }
468    ],
469    "usage": {
470        "prompt_tokens": 90,
471        "completion_tokens": 19,
472        "total_tokens": 109
473    }
474}"#
475        .to_string();
476        let response = parse_removing_newlines(r).expect("Failed to parse");
477        let message = response
478            .choices
479            .first()
480            .expect("There is no choice")
481            .message
482            .clone();
483
484        assert_eq!(message.role, "assistant");
485        assert_eq!(message.content, None);
486        assert_eq!(message.name, None);
487        assert_eq!(
488            message.function_call,
489            Some(FunctionCall {
490                name: "get_current_weather".to_string(),
491                arguments: "{\n  \"location\": \"Madrid, Spain\"\n}".to_string(),
492            })
493        );
494    }
495
496    #[test]
497    fn test_fix_context_when_function_replied_with_content() {
498        use crate::message::FunctionCall;
499
500        let r = r#"{"id":"chatcmpl-7VneSVRn9qJ1crw3m0V0kmnCq8Pnn","object":"chat.completion","created":1687813384,"choices":[{"index":0,"message":{"role":"assistant","function_call":{"name":"completion_managed","arguments":"{
501    \"content\": \"Hi, model!\"
502}"}},"finish_reason":"function_call"}],"usage":{"prompt_tokens":61,"completion_tokens":18,"total_tokens":79}}"#.to_string();
503        let response = parse_removing_newlines(r).expect("Failed to parse");
504        let message = response
505            .choices
506            .last()
507            .expect("There is no choice")
508            .message
509            .clone();
510
511        assert_eq!(message.role, "assistant");
512        assert_eq!(message.content, None);
513        assert_eq!(message.name, None);
514        assert_eq!(
515            message.function_call,
516            Some(FunctionCall {
517                name: "completion_managed".to_string(),
518                arguments: "{    \"content\": \"Hi, model!\"}".to_string(),
519            })
520        );
521    }
522
523    #[test]
524    fn test_last_content() {
525        let mut chat_gpt = ChatGPTBuilder::new()
526            .openai_api_token("key".to_string())
527            .build()
528            .expect("Failed to create ChatGPT");
529        let message = Message::new_user_message("content".to_string());
530        chat_gpt.push_message(message);
531        let message = Message::new_user_message("content2".to_string());
532        chat_gpt.push_message(message);
533        let message = Message::new_user_message("content3".to_string());
534        chat_gpt.push_message(message);
535        assert_eq!(chat_gpt.last_content(), Some("content3".to_string()));
536    }
537
538    #[test]
539    fn test_last_content_empty() {
540        let chat_gpt = ChatGPTBuilder::new()
541            .openai_api_token("key".to_string())
542            .build()
543            .expect("Failed to create ChatGPT");
544        assert_eq!(chat_gpt.last_content(), None);
545    }
546
547    #[test]
548    fn test_last_function() {
549        let mut chat_gpt = ChatGPTBuilder::new()
550            .openai_api_token("key".to_string())
551            .build()
552            .expect("Failed to create ChatGPT");
553        let mut msg = Message::new("function".to_string());
554        msg.set_function_call(FunctionCall {
555            name: "function".to_string(),
556            arguments: "1".to_string(),
557        });
558        chat_gpt.push_message(msg);
559        let mut msg = Message::new("function2".to_string());
560        msg.set_function_call(FunctionCall {
561            name: "function2".to_string(),
562            arguments: "2".to_string(),
563        });
564        chat_gpt.push_message(msg);
565        let mut msg = Message::new("function3".to_string());
566        msg.set_function_call(FunctionCall {
567            name: "function3".to_string(),
568            arguments: "3".to_string(),
569        });
570        chat_gpt.push_message(msg);
571        assert_eq!(
572            chat_gpt.last_function(),
573            Some(("function3".to_string(), "3".to_string()))
574        );
575    }
576}