Skip to main content

punch_types/
message.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4/// A content part within a message — text or image.
5///
6/// Enables multimodal messages: screenshots from desktop automation,
7/// photos from Telegram, or any other image content that needs to flow
8/// through the LLM pipeline.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(tag = "type")]
11pub enum ContentPart {
12    /// Plain text content.
13    #[serde(rename = "text")]
14    Text { text: String },
15    /// Base64-encoded image content.
16    #[serde(rename = "image")]
17    Image {
18        /// MIME type (e.g. "image/png", "image/jpeg").
19        media_type: String,
20        /// Base64-encoded image data.
21        data: String,
22    },
23}
24
25/// The role of a message participant.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum Role {
29    User,
30    Assistant,
31    System,
32    Tool,
33}
34
35impl std::fmt::Display for Role {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            Self::User => write!(f, "user"),
39            Self::Assistant => write!(f, "assistant"),
40            Self::System => write!(f, "system"),
41            Self::Tool => write!(f, "tool"),
42        }
43    }
44}
45
46/// A message in a bout (conversation).
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Message {
49    /// The role of the message sender.
50    pub role: Role,
51    /// Text content of the message (may be empty for tool-only messages).
52    pub content: String,
53    /// Tool calls requested by the assistant.
54    #[serde(default, skip_serializing_if = "Vec::is_empty")]
55    pub tool_calls: Vec<ToolCall>,
56    /// Results from tool executions (for role = Tool).
57    #[serde(default, skip_serializing_if = "Vec::is_empty")]
58    pub tool_results: Vec<ToolCallResult>,
59    /// When the message was created.
60    pub timestamp: DateTime<Utc>,
61    /// Multimodal content parts (images, etc.). When non-empty, drivers should
62    /// use these instead of `content` for multimodal-capable providers.
63    #[serde(default, skip_serializing_if = "Vec::is_empty")]
64    pub content_parts: Vec<ContentPart>,
65}
66
67impl Message {
68    /// Create a simple text message with the current timestamp.
69    pub fn new(role: Role, content: impl Into<String>) -> Self {
70        Self {
71            role,
72            content: content.into(),
73            tool_calls: Vec::new(),
74            tool_results: Vec::new(),
75            timestamp: Utc::now(),
76            content_parts: Vec::new(),
77        }
78    }
79
80    /// Create a message with multimodal content parts.
81    pub fn with_parts(role: Role, content: impl Into<String>, parts: Vec<ContentPart>) -> Self {
82        Self {
83            role,
84            content: content.into(),
85            tool_calls: Vec::new(),
86            tool_results: Vec::new(),
87            timestamp: Utc::now(),
88            content_parts: parts,
89        }
90    }
91
92    /// Returns true if this message contains any image content parts.
93    pub fn has_images(&self) -> bool {
94        self.content_parts
95            .iter()
96            .any(|p| matches!(p, ContentPart::Image { .. }))
97    }
98}
99
100/// A tool call requested by the assistant.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ToolCall {
103    /// Unique identifier for this tool call.
104    pub id: String,
105    /// Name of the tool to invoke.
106    pub name: String,
107    /// Input arguments as a JSON object.
108    pub input: serde_json::Value,
109}
110
111/// The result of a tool call execution.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ToolCallResult {
114    /// The ID of the tool call this result corresponds to.
115    pub id: String,
116    /// Output content from the tool.
117    pub content: String,
118    /// Whether the tool execution resulted in an error.
119    #[serde(default)]
120    pub is_error: bool,
121    /// Optional image content returned by the tool (e.g. screenshots).
122    /// When present, drivers should include this as a vision input alongside
123    /// the text content so the LLM can "see" it.
124    #[serde(default, skip_serializing_if = "Option::is_none")]
125    pub image: Option<ContentPart>,
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_role_display() {
134        assert_eq!(Role::User.to_string(), "user");
135        assert_eq!(Role::Assistant.to_string(), "assistant");
136        assert_eq!(Role::System.to_string(), "system");
137        assert_eq!(Role::Tool.to_string(), "tool");
138    }
139
140    #[test]
141    fn test_role_serde_roundtrip() {
142        let roles = vec![Role::User, Role::Assistant, Role::System, Role::Tool];
143        for role in &roles {
144            let json = serde_json::to_string(role).expect("serialize");
145            let deser: Role = serde_json::from_str(&json).expect("deserialize");
146            assert_eq!(&deser, role);
147        }
148    }
149
150    #[test]
151    fn test_role_serde_values() {
152        assert_eq!(serde_json::to_string(&Role::User).unwrap(), "\"user\"");
153        assert_eq!(
154            serde_json::to_string(&Role::Assistant).unwrap(),
155            "\"assistant\""
156        );
157        assert_eq!(serde_json::to_string(&Role::System).unwrap(), "\"system\"");
158        assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), "\"tool\"");
159    }
160
161    #[test]
162    fn test_message_new() {
163        let msg = Message::new(Role::User, "Hello world");
164        assert_eq!(msg.role, Role::User);
165        assert_eq!(msg.content, "Hello world");
166        assert!(msg.tool_calls.is_empty());
167        assert!(msg.tool_results.is_empty());
168    }
169
170    #[test]
171    fn test_message_new_empty_content() {
172        let msg = Message::new(Role::Assistant, "");
173        assert_eq!(msg.content, "");
174    }
175
176    #[test]
177    fn test_message_serde_roundtrip() {
178        let msg = Message::new(Role::User, "test message");
179        let json = serde_json::to_string(&msg).expect("serialize");
180        let deser: Message = serde_json::from_str(&json).expect("deserialize");
181        assert_eq!(deser.role, Role::User);
182        assert_eq!(deser.content, "test message");
183    }
184
185    #[test]
186    fn test_message_serde_skips_empty_vecs() {
187        let msg = Message::new(Role::User, "hi");
188        let json = serde_json::to_string(&msg).expect("serialize");
189        // skip_serializing_if = "Vec::is_empty" means these fields should be absent
190        assert!(!json.contains("tool_calls"));
191        assert!(!json.contains("tool_results"));
192    }
193
194    #[test]
195    fn test_tool_call_serde() {
196        let call = ToolCall {
197            id: "call_123".to_string(),
198            name: "read_file".to_string(),
199            input: serde_json::json!({"path": "/tmp/test.txt"}),
200        };
201        let json = serde_json::to_string(&call).expect("serialize");
202        let deser: ToolCall = serde_json::from_str(&json).expect("deserialize");
203        assert_eq!(deser.id, "call_123");
204        assert_eq!(deser.name, "read_file");
205        assert_eq!(deser.input["path"], "/tmp/test.txt");
206    }
207
208    #[test]
209    fn test_tool_call_result_serde() {
210        let result = ToolCallResult {
211            id: "call_123".to_string(),
212            content: "file contents here".to_string(),
213            is_error: false,
214            image: None,
215        };
216        let json = serde_json::to_string(&result).expect("serialize");
217        let deser: ToolCallResult = serde_json::from_str(&json).expect("deserialize");
218        assert_eq!(deser.id, "call_123");
219        assert_eq!(deser.content, "file contents here");
220        assert!(!deser.is_error);
221    }
222
223    #[test]
224    fn test_tool_call_result_error() {
225        let result = ToolCallResult {
226            id: "call_456".to_string(),
227            content: "Permission denied".to_string(),
228            is_error: true,
229            image: None,
230        };
231        assert!(result.is_error);
232    }
233
234    #[test]
235    fn test_tool_call_result_is_error_default() {
236        // is_error has #[serde(default)], so missing field should be false
237        let json = r#"{"id": "x", "content": "ok"}"#;
238        let result: ToolCallResult = serde_json::from_str(json).expect("deserialize");
239        assert!(!result.is_error);
240    }
241
242    #[test]
243    fn test_message_with_tool_calls() {
244        let mut msg = Message::new(Role::Assistant, "Let me check that file");
245        msg.tool_calls.push(ToolCall {
246            id: "tc1".to_string(),
247            name: "read_file".to_string(),
248            input: serde_json::json!({"path": "main.rs"}),
249        });
250        let json = serde_json::to_string(&msg).expect("serialize");
251        assert!(json.contains("tool_calls"));
252        let deser: Message = serde_json::from_str(&json).expect("deserialize");
253        assert_eq!(deser.tool_calls.len(), 1);
254        assert_eq!(deser.tool_calls[0].name, "read_file");
255    }
256
257    #[test]
258    fn test_role_equality() {
259        assert_eq!(Role::User, Role::User);
260        assert_ne!(Role::User, Role::Assistant);
261    }
262
263    #[test]
264    fn test_role_hash() {
265        let mut set = std::collections::HashSet::new();
266        set.insert(Role::User);
267        set.insert(Role::Assistant);
268        set.insert(Role::User);
269        assert_eq!(set.len(), 2);
270    }
271
272    #[test]
273    fn test_content_part_text_serde() {
274        let part = ContentPart::Text {
275            text: "hello".to_string(),
276        };
277        let json = serde_json::to_string(&part).expect("serialize");
278        assert!(json.contains("\"type\":\"text\""));
279        let deser: ContentPart = serde_json::from_str(&json).expect("deserialize");
280        match deser {
281            ContentPart::Text { text } => assert_eq!(text, "hello"),
282            _ => panic!("expected Text variant"),
283        }
284    }
285
286    #[test]
287    fn test_content_part_image_serde() {
288        let part = ContentPart::Image {
289            media_type: "image/png".to_string(),
290            data: "iVBORw0KGgo=".to_string(),
291        };
292        let json = serde_json::to_string(&part).expect("serialize");
293        assert!(json.contains("\"type\":\"image\""));
294        let deser: ContentPart = serde_json::from_str(&json).expect("deserialize");
295        match deser {
296            ContentPart::Image { media_type, data } => {
297                assert_eq!(media_type, "image/png");
298                assert_eq!(data, "iVBORw0KGgo=");
299            }
300            _ => panic!("expected Image variant"),
301        }
302    }
303
304    #[test]
305    fn test_message_with_parts() {
306        let msg = Message::with_parts(
307            Role::User,
308            "What's in this image?",
309            vec![ContentPart::Image {
310                media_type: "image/png".to_string(),
311                data: "abc123".to_string(),
312            }],
313        );
314        assert!(msg.has_images());
315        assert_eq!(msg.content_parts.len(), 1);
316    }
317
318    #[test]
319    fn test_message_has_images_false() {
320        let msg = Message::new(Role::User, "just text");
321        assert!(!msg.has_images());
322    }
323
324    #[test]
325    fn test_message_content_parts_skipped_when_empty() {
326        let msg = Message::new(Role::User, "hi");
327        let json = serde_json::to_string(&msg).expect("serialize");
328        assert!(!json.contains("content_parts"));
329    }
330
331    #[test]
332    fn test_tool_call_result_with_image() {
333        let result = ToolCallResult {
334            id: "tc1".to_string(),
335            content: "Screenshot captured".to_string(),
336            is_error: false,
337            image: Some(ContentPart::Image {
338                media_type: "image/png".to_string(),
339                data: "base64data".to_string(),
340            }),
341        };
342        let json = serde_json::to_string(&result).expect("serialize");
343        assert!(json.contains("image"));
344        let deser: ToolCallResult = serde_json::from_str(&json).expect("deserialize");
345        assert!(deser.image.is_some());
346    }
347
348    #[test]
349    fn test_tool_call_result_image_skipped_when_none() {
350        let result = ToolCallResult {
351            id: "tc1".to_string(),
352            content: "ok".to_string(),
353            is_error: false,
354            image: None,
355        };
356        let json = serde_json::to_string(&result).expect("serialize");
357        assert!(!json.contains("image"));
358    }
359}