Skip to main content

punch_types/
message.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4/// The role of a message participant.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum Role {
8    User,
9    Assistant,
10    System,
11    Tool,
12}
13
14impl std::fmt::Display for Role {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        match self {
17            Self::User => write!(f, "user"),
18            Self::Assistant => write!(f, "assistant"),
19            Self::System => write!(f, "system"),
20            Self::Tool => write!(f, "tool"),
21        }
22    }
23}
24
25/// A message in a bout (conversation).
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Message {
28    /// The role of the message sender.
29    pub role: Role,
30    /// Text content of the message (may be empty for tool-only messages).
31    pub content: String,
32    /// Tool calls requested by the assistant.
33    #[serde(default, skip_serializing_if = "Vec::is_empty")]
34    pub tool_calls: Vec<ToolCall>,
35    /// Results from tool executions (for role = Tool).
36    #[serde(default, skip_serializing_if = "Vec::is_empty")]
37    pub tool_results: Vec<ToolCallResult>,
38    /// When the message was created.
39    pub timestamp: DateTime<Utc>,
40}
41
42impl Message {
43    /// Create a simple text message with the current timestamp.
44    pub fn new(role: Role, content: impl Into<String>) -> Self {
45        Self {
46            role,
47            content: content.into(),
48            tool_calls: Vec::new(),
49            tool_results: Vec::new(),
50            timestamp: Utc::now(),
51        }
52    }
53}
54
55/// A tool call requested by the assistant.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ToolCall {
58    /// Unique identifier for this tool call.
59    pub id: String,
60    /// Name of the tool to invoke.
61    pub name: String,
62    /// Input arguments as a JSON object.
63    pub input: serde_json::Value,
64}
65
66/// The result of a tool call execution.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ToolCallResult {
69    /// The ID of the tool call this result corresponds to.
70    pub id: String,
71    /// Output content from the tool.
72    pub content: String,
73    /// Whether the tool execution resulted in an error.
74    #[serde(default)]
75    pub is_error: bool,
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn test_role_display() {
84        assert_eq!(Role::User.to_string(), "user");
85        assert_eq!(Role::Assistant.to_string(), "assistant");
86        assert_eq!(Role::System.to_string(), "system");
87        assert_eq!(Role::Tool.to_string(), "tool");
88    }
89
90    #[test]
91    fn test_role_serde_roundtrip() {
92        let roles = vec![Role::User, Role::Assistant, Role::System, Role::Tool];
93        for role in &roles {
94            let json = serde_json::to_string(role).expect("serialize");
95            let deser: Role = serde_json::from_str(&json).expect("deserialize");
96            assert_eq!(&deser, role);
97        }
98    }
99
100    #[test]
101    fn test_role_serde_values() {
102        assert_eq!(serde_json::to_string(&Role::User).unwrap(), "\"user\"");
103        assert_eq!(
104            serde_json::to_string(&Role::Assistant).unwrap(),
105            "\"assistant\""
106        );
107        assert_eq!(serde_json::to_string(&Role::System).unwrap(), "\"system\"");
108        assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), "\"tool\"");
109    }
110
111    #[test]
112    fn test_message_new() {
113        let msg = Message::new(Role::User, "Hello world");
114        assert_eq!(msg.role, Role::User);
115        assert_eq!(msg.content, "Hello world");
116        assert!(msg.tool_calls.is_empty());
117        assert!(msg.tool_results.is_empty());
118    }
119
120    #[test]
121    fn test_message_new_empty_content() {
122        let msg = Message::new(Role::Assistant, "");
123        assert_eq!(msg.content, "");
124    }
125
126    #[test]
127    fn test_message_serde_roundtrip() {
128        let msg = Message::new(Role::User, "test message");
129        let json = serde_json::to_string(&msg).expect("serialize");
130        let deser: Message = serde_json::from_str(&json).expect("deserialize");
131        assert_eq!(deser.role, Role::User);
132        assert_eq!(deser.content, "test message");
133    }
134
135    #[test]
136    fn test_message_serde_skips_empty_vecs() {
137        let msg = Message::new(Role::User, "hi");
138        let json = serde_json::to_string(&msg).expect("serialize");
139        // skip_serializing_if = "Vec::is_empty" means these fields should be absent
140        assert!(!json.contains("tool_calls"));
141        assert!(!json.contains("tool_results"));
142    }
143
144    #[test]
145    fn test_tool_call_serde() {
146        let call = ToolCall {
147            id: "call_123".to_string(),
148            name: "read_file".to_string(),
149            input: serde_json::json!({"path": "/tmp/test.txt"}),
150        };
151        let json = serde_json::to_string(&call).expect("serialize");
152        let deser: ToolCall = serde_json::from_str(&json).expect("deserialize");
153        assert_eq!(deser.id, "call_123");
154        assert_eq!(deser.name, "read_file");
155        assert_eq!(deser.input["path"], "/tmp/test.txt");
156    }
157
158    #[test]
159    fn test_tool_call_result_serde() {
160        let result = ToolCallResult {
161            id: "call_123".to_string(),
162            content: "file contents here".to_string(),
163            is_error: false,
164        };
165        let json = serde_json::to_string(&result).expect("serialize");
166        let deser: ToolCallResult = serde_json::from_str(&json).expect("deserialize");
167        assert_eq!(deser.id, "call_123");
168        assert_eq!(deser.content, "file contents here");
169        assert!(!deser.is_error);
170    }
171
172    #[test]
173    fn test_tool_call_result_error() {
174        let result = ToolCallResult {
175            id: "call_456".to_string(),
176            content: "Permission denied".to_string(),
177            is_error: true,
178        };
179        assert!(result.is_error);
180    }
181
182    #[test]
183    fn test_tool_call_result_is_error_default() {
184        // is_error has #[serde(default)], so missing field should be false
185        let json = r#"{"id": "x", "content": "ok"}"#;
186        let result: ToolCallResult = serde_json::from_str(json).expect("deserialize");
187        assert!(!result.is_error);
188    }
189
190    #[test]
191    fn test_message_with_tool_calls() {
192        let mut msg = Message::new(Role::Assistant, "Let me check that file");
193        msg.tool_calls.push(ToolCall {
194            id: "tc1".to_string(),
195            name: "read_file".to_string(),
196            input: serde_json::json!({"path": "main.rs"}),
197        });
198        let json = serde_json::to_string(&msg).expect("serialize");
199        assert!(json.contains("tool_calls"));
200        let deser: Message = serde_json::from_str(&json).expect("deserialize");
201        assert_eq!(deser.tool_calls.len(), 1);
202        assert_eq!(deser.tool_calls[0].name, "read_file");
203    }
204
205    #[test]
206    fn test_role_equality() {
207        assert_eq!(Role::User, Role::User);
208        assert_ne!(Role::User, Role::Assistant);
209    }
210
211    #[test]
212    fn test_role_hash() {
213        let mut set = std::collections::HashSet::new();
214        set.insert(Role::User);
215        set.insert(Role::Assistant);
216        set.insert(Role::User);
217        assert_eq!(set.len(), 2);
218    }
219}