chatgpt_functions/
chat_context.rs

1use std::fmt;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{function_specification::FunctionSpecification, message::Message};
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct ChatContext {
9    pub model: String,
10    pub messages: Vec<Message>,
11    pub functions: Vec<FunctionSpecification>,
12    pub function_call: Option<String>,
13}
14
15impl ChatContext {
16    /// Creates a new ChatContext with a model name
17    /// as a string. This is an internal function used by other functions.
18    pub fn new(model: String) -> ChatContext {
19        ChatContext {
20            model,
21            messages: Vec::new(),
22            functions: Vec::new(),
23            function_call: None,
24        }
25    }
26
27    /// Pushes a message in the chat context
28    /// as a Message. This is an internal function used by other functions.
29    /// It is recommended to use ChatGPT.push_message()
30    pub fn push_message(&mut self, message: Message) {
31        self.messages.push(message);
32    }
33
34    /// Sets the messages in the chat context
35    /// as a vector of Message.
36    /// This is an internal function used by other functions.
37    pub fn set_messages(&mut self, messages: Vec<Message>) {
38        self.messages = messages;
39    }
40
41    /// Pushes a function in the chat context
42    /// as a FunctionSpecification.
43    /// This is an internal function used by other functions.
44    /// It is recommended to use ChatGPT.push_function()
45    pub fn push_function(&mut self, functions: FunctionSpecification) {
46        self.functions.push(functions);
47    }
48
49    /// Sets the functions in the chat context
50    /// as a vector of FunctionSpecification.
51    /// This is an internal function used by other functions.
52    pub fn set_functions(&mut self, functions: Vec<FunctionSpecification>) {
53        self.functions = functions;
54    }
55
56    /// Sets the last message sent by the user or the bot
57    /// as a string. This is an internal function used by other functions.
58    pub fn set_function_call(&mut self, function_call: String) {
59        self.function_call = Some(function_call);
60    }
61
62    /// Returns the last message sent by the user or the bot
63    /// as a string. This is an internal function used by other functions.
64    /// It is recommended to use ChatGPT.last_content()
65    pub fn last_content(&self) -> Option<String> {
66        match self.messages.last() {
67            Some(message) => {
68                if let Some(c) = message.content.clone() {
69                    Some(c)
70                } else {
71                    None
72                }
73            }
74            None => None,
75        }
76    }
77
78    /// Returns the last function call in the chat context
79    /// as a tuple of the function name and the arguments.
80    /// This is an internal function used by other functions.
81    /// It is recommended to use ChatGPT.last_function_call()
82    pub fn last_function_call(&self) -> Option<(String, String)> {
83        match self.messages.last() {
84            Some(message) => {
85                if let Some(f) = message.function_call.clone() {
86                    Some((f.name, f.arguments))
87                } else {
88                    None
89                }
90            }
91            None => None,
92        }
93    }
94}
95
96// Print valid JSON for ChatContext, no commas if last field
97impl fmt::Display for ChatContext {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{{\"model\":\"{}\"", self.model)?;
100        if !self.messages.is_empty() {
101            write!(f, ",\"messages\":[")?;
102            for (i, message) in self.messages.iter().enumerate() {
103                write!(f, "{}", message)?;
104                if i < self.messages.len() - 1 {
105                    write!(f, ",")?;
106                }
107            }
108            write!(f, "]")?;
109        }
110        if self.functions.len() > 0 {
111            write!(f, ",\"functions\":[")?;
112            for (i, function) in self.functions.iter().enumerate() {
113                write!(f, "{}", function)?;
114                if i < self.functions.len() - 1 {
115                    write!(f, ",")?;
116                }
117            }
118            write!(f, "]")?;
119        }
120        if let Some(function_call) = &self.function_call {
121            write!(f, ",\"function_call\":\"{}\"", function_call)?;
122        }
123        write!(f, "}}")
124    }
125}
126#[cfg(test)]
127mod tests {
128    use std::collections::HashMap;
129
130    use super::*;
131    use crate::{
132        function_specification::{Parameters, Property},
133        message::MessageBuilder,
134    };
135
136    #[test]
137    fn test_display_for_chat_context() {
138        let mut chat_context = ChatContext::new("test_model".to_string());
139        let message = MessageBuilder::new()
140            .role("role".to_string())
141            .content("Hello".to_string())
142            .build()
143            .expect("Failed to build message");
144        chat_context.push_message(message);
145        let message = MessageBuilder::new()
146            .role("bot".to_string())
147            .content("Hi".to_string())
148            .build()
149            .expect("Failed to build message");
150        chat_context.push_message(message);
151        assert_eq!(
152            chat_context.to_string(),
153            "{\"model\":\"test_model\",\"messages\":[{\"role\":\"role\",\"content\":\"Hello\"},{\"role\":\"bot\",\"content\":\"Hi\"}]}"
154        );
155    }
156
157    #[test]
158    fn test_display_chat_context_with_functions() {
159        let mut chat_context = ChatContext::new("test_model".to_string());
160
161        // Add a function to the chat context
162        let mut properties = HashMap::new();
163        properties.insert(
164            "location".to_string(),
165            Property {
166                type_: "string".to_string(),
167                description: Some("a dummy string".to_string()),
168                enum_: None,
169            },
170        );
171        let function = FunctionSpecification {
172            name: "test_function".to_string(),
173            description: Some("a dummy function to test the chat context".to_string()),
174            parameters: Some(Parameters {
175                type_: "object".to_string(),
176                properties,
177                required: vec!["location".to_string()],
178            }),
179        };
180        chat_context.push_function(function);
181
182        // Add a message to the chat context
183        let message = MessageBuilder::new()
184            .role("test".to_string())
185            .content("hi".to_string())
186            .name("test_function".to_string())
187            .build()
188            .expect("Failed to build message");
189        chat_context.push_message(message);
190
191        // Print the chat context, with the model, the messages, the functions, and the function_call
192        assert_eq!(
193            chat_context.to_string(),
194            "{\"model\":\"test_model\",\"messages\":[{\"role\":\"test\",\"content\":\"hi\",\"name\":\"test_function\"}],\"functions\":[{\"name\":\"test_function\",\"description\":\"a dummy function to test the chat context\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"a dummy string\"}},\"required\":[\"location\"]}}]}"
195        );
196    }
197
198    #[test]
199    fn test_last_content() {
200        let mut chat_context = ChatContext::new("model".to_string());
201
202        // Test with no messages
203        assert_eq!(chat_context.last_content(), None);
204
205        // Test with a message with no content
206        let message = MessageBuilder::new()
207            .role("role".to_string())
208            .name("name".to_string())
209            .build()
210            .expect("Failed to build message");
211        chat_context.push_message(message);
212        assert_eq!(chat_context.last_content(), None);
213
214        // Test with a message with content
215        let message = MessageBuilder::new()
216            .role("role".to_string())
217            .content("content".to_string())
218            .build()
219            .expect("Failed to build message");
220        chat_context.push_message(message);
221        assert_eq!(chat_context.last_content(), Some("content".to_string()));
222    }
223
224    #[test]
225    fn test_last_function_call() {
226        let mut chat_context = ChatContext::new("model".to_string());
227
228        // Test with no messages
229        assert_eq!(chat_context.last_content(), None);
230
231        // Test with a message with no function call
232        let message = MessageBuilder::new()
233            .role("role".to_string())
234            .name("name".to_string())
235            .build()
236            .expect("Failed to build message");
237        chat_context.push_message(message);
238        assert_eq!(chat_context.last_content(), None);
239
240        // Test with a message with function call
241        use crate::message::FunctionCall;
242        let message = MessageBuilder::new()
243            .role("role".to_string())
244            .function_call(FunctionCall {
245                name: "function".to_string(),
246                arguments: "arguments".to_string(),
247            })
248            .build()
249            .expect("Failed to build message");
250        chat_context.push_message(message);
251        assert_eq!(
252            chat_context.last_function_call(),
253            Some(("function".to_string(), "arguments".to_string()))
254        );
255    }
256}