1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(tag = "type")]
11pub enum ContentPart {
12 #[serde(rename = "text")]
14 Text { text: String },
15 #[serde(rename = "image")]
17 Image {
18 media_type: String,
20 data: String,
22 },
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32 Tool,
33}
34
35impl std::fmt::Display for Role {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Self::User => write!(f, "user"),
39 Self::Assistant => write!(f, "assistant"),
40 Self::System => write!(f, "system"),
41 Self::Tool => write!(f, "tool"),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Message {
49 pub role: Role,
51 pub content: String,
53 #[serde(default, skip_serializing_if = "Vec::is_empty")]
55 pub tool_calls: Vec<ToolCall>,
56 #[serde(default, skip_serializing_if = "Vec::is_empty")]
58 pub tool_results: Vec<ToolCallResult>,
59 pub timestamp: DateTime<Utc>,
61 #[serde(default, skip_serializing_if = "Vec::is_empty")]
64 pub content_parts: Vec<ContentPart>,
65}
66
67impl Message {
68 pub fn new(role: Role, content: impl Into<String>) -> Self {
70 Self {
71 role,
72 content: content.into(),
73 tool_calls: Vec::new(),
74 tool_results: Vec::new(),
75 timestamp: Utc::now(),
76 content_parts: Vec::new(),
77 }
78 }
79
80 pub fn with_parts(role: Role, content: impl Into<String>, parts: Vec<ContentPart>) -> Self {
82 Self {
83 role,
84 content: content.into(),
85 tool_calls: Vec::new(),
86 tool_results: Vec::new(),
87 timestamp: Utc::now(),
88 content_parts: parts,
89 }
90 }
91
92 pub fn has_images(&self) -> bool {
94 self.content_parts
95 .iter()
96 .any(|p| matches!(p, ContentPart::Image { .. }))
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ToolCall {
103 pub id: String,
105 pub name: String,
107 pub input: serde_json::Value,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ToolCallResult {
114 pub id: String,
116 pub content: String,
118 #[serde(default)]
120 pub is_error: bool,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub image: Option<ContentPart>,
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_role_display() {
134 assert_eq!(Role::User.to_string(), "user");
135 assert_eq!(Role::Assistant.to_string(), "assistant");
136 assert_eq!(Role::System.to_string(), "system");
137 assert_eq!(Role::Tool.to_string(), "tool");
138 }
139
140 #[test]
141 fn test_role_serde_roundtrip() {
142 let roles = vec![Role::User, Role::Assistant, Role::System, Role::Tool];
143 for role in &roles {
144 let json = serde_json::to_string(role).expect("serialize");
145 let deser: Role = serde_json::from_str(&json).expect("deserialize");
146 assert_eq!(&deser, role);
147 }
148 }
149
150 #[test]
151 fn test_role_serde_values() {
152 assert_eq!(serde_json::to_string(&Role::User).unwrap(), "\"user\"");
153 assert_eq!(
154 serde_json::to_string(&Role::Assistant).unwrap(),
155 "\"assistant\""
156 );
157 assert_eq!(serde_json::to_string(&Role::System).unwrap(), "\"system\"");
158 assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), "\"tool\"");
159 }
160
161 #[test]
162 fn test_message_new() {
163 let msg = Message::new(Role::User, "Hello world");
164 assert_eq!(msg.role, Role::User);
165 assert_eq!(msg.content, "Hello world");
166 assert!(msg.tool_calls.is_empty());
167 assert!(msg.tool_results.is_empty());
168 }
169
170 #[test]
171 fn test_message_new_empty_content() {
172 let msg = Message::new(Role::Assistant, "");
173 assert_eq!(msg.content, "");
174 }
175
176 #[test]
177 fn test_message_serde_roundtrip() {
178 let msg = Message::new(Role::User, "test message");
179 let json = serde_json::to_string(&msg).expect("serialize");
180 let deser: Message = serde_json::from_str(&json).expect("deserialize");
181 assert_eq!(deser.role, Role::User);
182 assert_eq!(deser.content, "test message");
183 }
184
185 #[test]
186 fn test_message_serde_skips_empty_vecs() {
187 let msg = Message::new(Role::User, "hi");
188 let json = serde_json::to_string(&msg).expect("serialize");
189 assert!(!json.contains("tool_calls"));
191 assert!(!json.contains("tool_results"));
192 }
193
194 #[test]
195 fn test_tool_call_serde() {
196 let call = ToolCall {
197 id: "call_123".to_string(),
198 name: "read_file".to_string(),
199 input: serde_json::json!({"path": "/tmp/test.txt"}),
200 };
201 let json = serde_json::to_string(&call).expect("serialize");
202 let deser: ToolCall = serde_json::from_str(&json).expect("deserialize");
203 assert_eq!(deser.id, "call_123");
204 assert_eq!(deser.name, "read_file");
205 assert_eq!(deser.input["path"], "/tmp/test.txt");
206 }
207
208 #[test]
209 fn test_tool_call_result_serde() {
210 let result = ToolCallResult {
211 id: "call_123".to_string(),
212 content: "file contents here".to_string(),
213 is_error: false,
214 image: None,
215 };
216 let json = serde_json::to_string(&result).expect("serialize");
217 let deser: ToolCallResult = serde_json::from_str(&json).expect("deserialize");
218 assert_eq!(deser.id, "call_123");
219 assert_eq!(deser.content, "file contents here");
220 assert!(!deser.is_error);
221 }
222
223 #[test]
224 fn test_tool_call_result_error() {
225 let result = ToolCallResult {
226 id: "call_456".to_string(),
227 content: "Permission denied".to_string(),
228 is_error: true,
229 image: None,
230 };
231 assert!(result.is_error);
232 }
233
234 #[test]
235 fn test_tool_call_result_is_error_default() {
236 let json = r#"{"id": "x", "content": "ok"}"#;
238 let result: ToolCallResult = serde_json::from_str(json).expect("deserialize");
239 assert!(!result.is_error);
240 }
241
242 #[test]
243 fn test_message_with_tool_calls() {
244 let mut msg = Message::new(Role::Assistant, "Let me check that file");
245 msg.tool_calls.push(ToolCall {
246 id: "tc1".to_string(),
247 name: "read_file".to_string(),
248 input: serde_json::json!({"path": "main.rs"}),
249 });
250 let json = serde_json::to_string(&msg).expect("serialize");
251 assert!(json.contains("tool_calls"));
252 let deser: Message = serde_json::from_str(&json).expect("deserialize");
253 assert_eq!(deser.tool_calls.len(), 1);
254 assert_eq!(deser.tool_calls[0].name, "read_file");
255 }
256
257 #[test]
258 fn test_role_equality() {
259 assert_eq!(Role::User, Role::User);
260 assert_ne!(Role::User, Role::Assistant);
261 }
262
263 #[test]
264 fn test_role_hash() {
265 let mut set = std::collections::HashSet::new();
266 set.insert(Role::User);
267 set.insert(Role::Assistant);
268 set.insert(Role::User);
269 assert_eq!(set.len(), 2);
270 }
271
272 #[test]
273 fn test_content_part_text_serde() {
274 let part = ContentPart::Text {
275 text: "hello".to_string(),
276 };
277 let json = serde_json::to_string(&part).expect("serialize");
278 assert!(json.contains("\"type\":\"text\""));
279 let deser: ContentPart = serde_json::from_str(&json).expect("deserialize");
280 match deser {
281 ContentPart::Text { text } => assert_eq!(text, "hello"),
282 _ => panic!("expected Text variant"),
283 }
284 }
285
286 #[test]
287 fn test_content_part_image_serde() {
288 let part = ContentPart::Image {
289 media_type: "image/png".to_string(),
290 data: "iVBORw0KGgo=".to_string(),
291 };
292 let json = serde_json::to_string(&part).expect("serialize");
293 assert!(json.contains("\"type\":\"image\""));
294 let deser: ContentPart = serde_json::from_str(&json).expect("deserialize");
295 match deser {
296 ContentPart::Image { media_type, data } => {
297 assert_eq!(media_type, "image/png");
298 assert_eq!(data, "iVBORw0KGgo=");
299 }
300 _ => panic!("expected Image variant"),
301 }
302 }
303
304 #[test]
305 fn test_message_with_parts() {
306 let msg = Message::with_parts(
307 Role::User,
308 "What's in this image?",
309 vec![ContentPart::Image {
310 media_type: "image/png".to_string(),
311 data: "abc123".to_string(),
312 }],
313 );
314 assert!(msg.has_images());
315 assert_eq!(msg.content_parts.len(), 1);
316 }
317
318 #[test]
319 fn test_message_has_images_false() {
320 let msg = Message::new(Role::User, "just text");
321 assert!(!msg.has_images());
322 }
323
324 #[test]
325 fn test_message_content_parts_skipped_when_empty() {
326 let msg = Message::new(Role::User, "hi");
327 let json = serde_json::to_string(&msg).expect("serialize");
328 assert!(!json.contains("content_parts"));
329 }
330
331 #[test]
332 fn test_tool_call_result_with_image() {
333 let result = ToolCallResult {
334 id: "tc1".to_string(),
335 content: "Screenshot captured".to_string(),
336 is_error: false,
337 image: Some(ContentPart::Image {
338 media_type: "image/png".to_string(),
339 data: "base64data".to_string(),
340 }),
341 };
342 let json = serde_json::to_string(&result).expect("serialize");
343 assert!(json.contains("image"));
344 let deser: ToolCallResult = serde_json::from_str(&json).expect("deserialize");
345 assert!(deser.image.is_some());
346 }
347
348 #[test]
349 fn test_tool_call_result_image_skipped_when_none() {
350 let result = ToolCallResult {
351 id: "tc1".to_string(),
352 content: "ok".to_string(),
353 is_error: false,
354 image: None,
355 };
356 let json = serde_json::to_string(&result).expect("serialize");
357 assert!(!json.contains("image"));
358 }
359}