1use anyhow::Result;
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6use crate::escape_json::EscapeJson;
7
8pub struct MessageBuilder {
10 role: Option<String>,
11 content: Option<String>,
12 name: Option<String>,
13 function_call: Option<FunctionCall>,
14}
15
16impl MessageBuilder {
17 pub fn new() -> MessageBuilder {
18 MessageBuilder {
19 role: None,
20 content: None,
21 name: None,
22 function_call: None,
23 }
24 }
25
26 pub fn role(mut self, role: String) -> MessageBuilder {
27 self.role = Some(role);
28 self
29 }
30
31 pub fn content(mut self, content: String) -> MessageBuilder {
32 self.content = Some(content);
33 self
34 }
35
36 pub fn name(mut self, name: String) -> MessageBuilder {
37 self.name = Some(name);
38 self
39 }
40
41 pub fn function_call(mut self, function_call: FunctionCall) -> MessageBuilder {
42 self.function_call = Some(function_call);
43 self
44 }
45
46 pub fn build(self) -> Result<Message> {
47 let role = self.role.unwrap_or_else(|| "user".to_string());
48 let content = self.content.map(|c| c.escape_json());
49 let name = self.name;
50 let function_call = self.function_call;
51
52 Ok(Message {
53 role,
54 content,
55 name,
56 function_call,
57 })
58 }
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
62pub struct Message {
63 pub role: String,
64 pub content: Option<String>,
65 pub name: Option<String>,
66 pub function_call: Option<FunctionCall>,
67}
68
69#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
70pub struct FunctionCall {
71 pub name: String,
72 pub arguments: String,
73}
74
75impl Message {
76 pub fn new(role: String) -> Message {
77 Message {
78 role,
79 content: None,
80 name: None,
81 function_call: None,
82 }
83 }
84
85 pub fn new_user_message(content: String) -> Message {
86 let content = content.escape_json();
87 Message {
88 role: "user".to_string(),
89 content: Some(content),
90 name: None,
91 function_call: None,
92 }
93 }
94
95 pub fn set_content(&mut self, content: String) {
96 self.content = Some(content);
97 }
98
99 pub fn set_name(&mut self, name: String) {
100 self.name = Some(name);
101 }
102
103 pub fn set_function_call(&mut self, function_call: FunctionCall) {
104 self.function_call = Some(function_call);
105 }
106}
107
108impl fmt::Display for Message {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 write!(f, "{{\"role\":\"{}\"", self.role)?;
150 if let Some(content) = &self.content {
151 write!(f, ",\"content\":\"{}\"", content.escape_json())?;
152 } else {
153 write!(f, ",\"content\":\"\"")?;
154 }
155 if let Some(name) = &self.name {
156 write!(f, ",\"name\":\"{}\"", name)?;
157 }
158 if let Some(function_call) = &self.function_call {
159 write!(f, ",\"function_call\":{}", function_call)?;
160 }
161 write!(f, "}}")
162 }
163}
164
165impl fmt::Display for FunctionCall {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 write!(
171 f,
172 "{{\"name\":\"{}\",\"arguments\":\"{}\"}}",
173 self.name,
174 self.arguments.escape_json()
175 )
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_display_message() {
185 let mut message = Message::new("role".to_string());
186 assert_eq!(
187 message.to_string(),
188 "{\"role\":\"role\",\"content\":\"\"}".to_string()
189 );
190
191 message.set_content(
192 "content with \"quotes\" and a \nnewline, and other stuff like \\ \"\n\r\t\x08\x0C\""
193 .to_string(),
194 );
195 assert_eq!(
196 message.to_string(),
197 "{\"role\":\"role\",\"content\":\"content with \\\"quotes\\\" and a \\nnewline, and other stuff like \\\\ \\\"\\n\\r\\t\\b\\f\\\"\"}"
198 .to_string()
199 );
200
201 message.set_name("name".to_string());
202 assert_eq!(
203 message.to_string(),
204 "{\"role\":\"role\",\"content\":\"content with \\\"quotes\\\" and a \\nnewline, and other stuff like \\\\ \\\"\\n\\r\\t\\b\\f\\\"\",\"name\":\"name\"}"
205 .to_string()
206 );
207
208 let function_call = FunctionCall {
209 name: "name".to_string(),
210 arguments: "{\"example\":\"this\"}".to_string(),
211 };
212 message.set_function_call(function_call);
213 assert_eq!(
214 message.to_string(),
215 "{\"role\":\"role\",\"content\":\"content with \\\"quotes\\\" and a \\nnewline, and other stuff like \\\\ \\\"\\n\\r\\t\\b\\f\\\"\",\"name\":\"name\",\"function_call\":{\"name\":\"name\",\"arguments\":\"{\\\"example\\\":\\\"this\\\"}\"}}".to_string()
216 );
217 }
218
219 #[test]
220 fn test_display_function_call_no_name() {
221 let function_call = FunctionCall {
222 name: "".to_string(),
223 arguments: "{\"example\":\"this\"}".to_string(),
224 };
225 assert_eq!(
226 function_call.to_string(),
227 "{\"name\":\"\",\"arguments\":\"{\\\"example\\\":\\\"this\\\"}\"}".to_string()
228 );
229 }
230
231 #[test]
232 fn test_display_function_call_no_arguments() {
233 let function_call = FunctionCall {
234 name: "name".to_string(),
235 arguments: "{}".to_string(),
236 };
237 assert_eq!(
238 function_call.to_string(),
239 "{\"name\":\"name\",\"arguments\":\"{}\"}".to_string()
240 );
241 }
242
243 #[test]
244 fn test_display_function_call() {
245 let function_call = FunctionCall {
246 name: "name".to_string(),
247 arguments: "{\"example\":\"this\"}".to_string(),
248 };
249 assert_eq!(
250 function_call.to_string(),
251 "{\"name\":\"name\",\"arguments\":\"{\\\"example\\\":\\\"this\\\"}\"}".to_string()
252 );
253 }
254
255 #[test]
256 fn test_display_message_parsed_from_json_remove_newline() {
257 let message = r#"{
258 "role": "assistant",
259 "content": null,
260 "function_call": {
261 "name": "completion_managed",
262 "arguments": "{\n \"content\": \"Hi model, how are you today?\"\n}"
263 }
264 }"#
265 .to_string();
266 let message_parsed: Message =
267 serde_json::from_str(&message).expect("JSON was not well-formatted");
268
269 assert_eq!(message_parsed.role, "assistant".to_string());
271 assert_eq!(message_parsed.content, None);
272
273 assert_eq!(
276 message_parsed.to_string(),
277 "{\"role\":\"assistant\",\"content\":\"\",\"function_call\":{\"name\":\"completion_managed\",\"arguments\":\"{\\n \\\"content\\\": \\\"Hi model, how are you today?\\\"\\n}\"}}".to_string()
278 );
279
280 assert_eq!(
282 message_parsed.function_call,
283 Some(FunctionCall {
284 name: "completion_managed".to_string(),
285 arguments: "{\n \"content\": \"Hi model, how are you today?\"\n}".to_string(),
286 })
287 );
288 }
289
290 #[test]
291 fn test_message_new_user_message() {
292 let message =
293 Message::new_user_message("content with \"quotes\" and other' stuff \\".to_string());
294 assert_eq!(
295 message.to_string(),
296 "{\"role\":\"user\",\"content\":\"content with \\\\\\\"quotes\\\\\\\" and other' stuff \\\\\\\\\"}".to_string()
297 );
298 }
299
300 #[test]
301 fn test_message_builder() {
302 let message = MessageBuilder::new()
303 .content("content with \"quotes\" and other/' stuff \\".to_string())
304 .name("name".to_string())
305 .role("role".to_string())
306 .function_call(FunctionCall {
307 name: "name".to_string(),
308 arguments: "{\"example\":\"this\"}".to_string(),
309 })
310 .build()
311 .expect("MessageBuilder failed");
312
313 assert_eq!(
314 message.to_string(),
315 "{\"role\":\"role\",\"content\":\"content with \\\\\\\"quotes\\\\\\\" and other/' stuff \\\\\\\\\",\"name\":\"name\",\"function_call\":{\"name\":\"name\",\"arguments\":\"{\\\"example\\\":\\\"this\\\"}\"}}".to_string()
316 );
317 }
318}