chatgpt_functions/
chat_context.rs1use std::fmt;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{function_specification::FunctionSpecification, message::Message};
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct ChatContext {
9 pub model: String,
10 pub messages: Vec<Message>,
11 pub functions: Vec<FunctionSpecification>,
12 pub function_call: Option<String>,
13}
14
15impl ChatContext {
16 pub fn new(model: String) -> ChatContext {
19 ChatContext {
20 model,
21 messages: Vec::new(),
22 functions: Vec::new(),
23 function_call: None,
24 }
25 }
26
27 pub fn push_message(&mut self, message: Message) {
31 self.messages.push(message);
32 }
33
34 pub fn set_messages(&mut self, messages: Vec<Message>) {
38 self.messages = messages;
39 }
40
41 pub fn push_function(&mut self, functions: FunctionSpecification) {
46 self.functions.push(functions);
47 }
48
49 pub fn set_functions(&mut self, functions: Vec<FunctionSpecification>) {
53 self.functions = functions;
54 }
55
56 pub fn set_function_call(&mut self, function_call: String) {
59 self.function_call = Some(function_call);
60 }
61
62 pub fn last_content(&self) -> Option<String> {
66 match self.messages.last() {
67 Some(message) => {
68 if let Some(c) = message.content.clone() {
69 Some(c)
70 } else {
71 None
72 }
73 }
74 None => None,
75 }
76 }
77
78 pub fn last_function_call(&self) -> Option<(String, String)> {
83 match self.messages.last() {
84 Some(message) => {
85 if let Some(f) = message.function_call.clone() {
86 Some((f.name, f.arguments))
87 } else {
88 None
89 }
90 }
91 None => None,
92 }
93 }
94}
95
96impl fmt::Display for ChatContext {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(f, "{{\"model\":\"{}\"", self.model)?;
100 if !self.messages.is_empty() {
101 write!(f, ",\"messages\":[")?;
102 for (i, message) in self.messages.iter().enumerate() {
103 write!(f, "{}", message)?;
104 if i < self.messages.len() - 1 {
105 write!(f, ",")?;
106 }
107 }
108 write!(f, "]")?;
109 }
110 if self.functions.len() > 0 {
111 write!(f, ",\"functions\":[")?;
112 for (i, function) in self.functions.iter().enumerate() {
113 write!(f, "{}", function)?;
114 if i < self.functions.len() - 1 {
115 write!(f, ",")?;
116 }
117 }
118 write!(f, "]")?;
119 }
120 if let Some(function_call) = &self.function_call {
121 write!(f, ",\"function_call\":\"{}\"", function_call)?;
122 }
123 write!(f, "}}")
124 }
125}
126#[cfg(test)]
127mod tests {
128 use std::collections::HashMap;
129
130 use super::*;
131 use crate::{
132 function_specification::{Parameters, Property},
133 message::MessageBuilder,
134 };
135
136 #[test]
137 fn test_display_for_chat_context() {
138 let mut chat_context = ChatContext::new("test_model".to_string());
139 let message = MessageBuilder::new()
140 .role("role".to_string())
141 .content("Hello".to_string())
142 .build()
143 .expect("Failed to build message");
144 chat_context.push_message(message);
145 let message = MessageBuilder::new()
146 .role("bot".to_string())
147 .content("Hi".to_string())
148 .build()
149 .expect("Failed to build message");
150 chat_context.push_message(message);
151 assert_eq!(
152 chat_context.to_string(),
153 "{\"model\":\"test_model\",\"messages\":[{\"role\":\"role\",\"content\":\"Hello\"},{\"role\":\"bot\",\"content\":\"Hi\"}]}"
154 );
155 }
156
157 #[test]
158 fn test_display_chat_context_with_functions() {
159 let mut chat_context = ChatContext::new("test_model".to_string());
160
161 let mut properties = HashMap::new();
163 properties.insert(
164 "location".to_string(),
165 Property {
166 type_: "string".to_string(),
167 description: Some("a dummy string".to_string()),
168 enum_: None,
169 },
170 );
171 let function = FunctionSpecification {
172 name: "test_function".to_string(),
173 description: Some("a dummy function to test the chat context".to_string()),
174 parameters: Some(Parameters {
175 type_: "object".to_string(),
176 properties,
177 required: vec!["location".to_string()],
178 }),
179 };
180 chat_context.push_function(function);
181
182 let message = MessageBuilder::new()
184 .role("test".to_string())
185 .content("hi".to_string())
186 .name("test_function".to_string())
187 .build()
188 .expect("Failed to build message");
189 chat_context.push_message(message);
190
191 assert_eq!(
193 chat_context.to_string(),
194 "{\"model\":\"test_model\",\"messages\":[{\"role\":\"test\",\"content\":\"hi\",\"name\":\"test_function\"}],\"functions\":[{\"name\":\"test_function\",\"description\":\"a dummy function to test the chat context\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"a dummy string\"}},\"required\":[\"location\"]}}]}"
195 );
196 }
197
198 #[test]
199 fn test_last_content() {
200 let mut chat_context = ChatContext::new("model".to_string());
201
202 assert_eq!(chat_context.last_content(), None);
204
205 let message = MessageBuilder::new()
207 .role("role".to_string())
208 .name("name".to_string())
209 .build()
210 .expect("Failed to build message");
211 chat_context.push_message(message);
212 assert_eq!(chat_context.last_content(), None);
213
214 let message = MessageBuilder::new()
216 .role("role".to_string())
217 .content("content".to_string())
218 .build()
219 .expect("Failed to build message");
220 chat_context.push_message(message);
221 assert_eq!(chat_context.last_content(), Some("content".to_string()));
222 }
223
224 #[test]
225 fn test_last_function_call() {
226 let mut chat_context = ChatContext::new("model".to_string());
227
228 assert_eq!(chat_context.last_content(), None);
230
231 let message = MessageBuilder::new()
233 .role("role".to_string())
234 .name("name".to_string())
235 .build()
236 .expect("Failed to build message");
237 chat_context.push_message(message);
238 assert_eq!(chat_context.last_content(), None);
239
240 use crate::message::FunctionCall;
242 let message = MessageBuilder::new()
243 .role("role".to_string())
244 .function_call(FunctionCall {
245 name: "function".to_string(),
246 arguments: "arguments".to_string(),
247 })
248 .build()
249 .expect("Failed to build message");
250 chat_context.push_message(message);
251 assert_eq!(
252 chat_context.last_function_call(),
253 Some(("function".to_string(), "arguments".to_string()))
254 );
255 }
256}