1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum Role {
8 User,
9 Assistant,
10 System,
11 Tool,
12}
13
14impl std::fmt::Display for Role {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 match self {
17 Self::User => write!(f, "user"),
18 Self::Assistant => write!(f, "assistant"),
19 Self::System => write!(f, "system"),
20 Self::Tool => write!(f, "tool"),
21 }
22 }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Message {
28 pub role: Role,
30 pub content: String,
32 #[serde(default, skip_serializing_if = "Vec::is_empty")]
34 pub tool_calls: Vec<ToolCall>,
35 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub tool_results: Vec<ToolCallResult>,
38 pub timestamp: DateTime<Utc>,
40}
41
42impl Message {
43 pub fn new(role: Role, content: impl Into<String>) -> Self {
45 Self {
46 role,
47 content: content.into(),
48 tool_calls: Vec::new(),
49 tool_results: Vec::new(),
50 timestamp: Utc::now(),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ToolCall {
58 pub id: String,
60 pub name: String,
62 pub input: serde_json::Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ToolCallResult {
69 pub id: String,
71 pub content: String,
73 #[serde(default)]
75 pub is_error: bool,
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_role_display() {
84 assert_eq!(Role::User.to_string(), "user");
85 assert_eq!(Role::Assistant.to_string(), "assistant");
86 assert_eq!(Role::System.to_string(), "system");
87 assert_eq!(Role::Tool.to_string(), "tool");
88 }
89
90 #[test]
91 fn test_role_serde_roundtrip() {
92 let roles = vec![Role::User, Role::Assistant, Role::System, Role::Tool];
93 for role in &roles {
94 let json = serde_json::to_string(role).expect("serialize");
95 let deser: Role = serde_json::from_str(&json).expect("deserialize");
96 assert_eq!(&deser, role);
97 }
98 }
99
100 #[test]
101 fn test_role_serde_values() {
102 assert_eq!(serde_json::to_string(&Role::User).unwrap(), "\"user\"");
103 assert_eq!(
104 serde_json::to_string(&Role::Assistant).unwrap(),
105 "\"assistant\""
106 );
107 assert_eq!(serde_json::to_string(&Role::System).unwrap(), "\"system\"");
108 assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), "\"tool\"");
109 }
110
111 #[test]
112 fn test_message_new() {
113 let msg = Message::new(Role::User, "Hello world");
114 assert_eq!(msg.role, Role::User);
115 assert_eq!(msg.content, "Hello world");
116 assert!(msg.tool_calls.is_empty());
117 assert!(msg.tool_results.is_empty());
118 }
119
120 #[test]
121 fn test_message_new_empty_content() {
122 let msg = Message::new(Role::Assistant, "");
123 assert_eq!(msg.content, "");
124 }
125
126 #[test]
127 fn test_message_serde_roundtrip() {
128 let msg = Message::new(Role::User, "test message");
129 let json = serde_json::to_string(&msg).expect("serialize");
130 let deser: Message = serde_json::from_str(&json).expect("deserialize");
131 assert_eq!(deser.role, Role::User);
132 assert_eq!(deser.content, "test message");
133 }
134
135 #[test]
136 fn test_message_serde_skips_empty_vecs() {
137 let msg = Message::new(Role::User, "hi");
138 let json = serde_json::to_string(&msg).expect("serialize");
139 assert!(!json.contains("tool_calls"));
141 assert!(!json.contains("tool_results"));
142 }
143
144 #[test]
145 fn test_tool_call_serde() {
146 let call = ToolCall {
147 id: "call_123".to_string(),
148 name: "read_file".to_string(),
149 input: serde_json::json!({"path": "/tmp/test.txt"}),
150 };
151 let json = serde_json::to_string(&call).expect("serialize");
152 let deser: ToolCall = serde_json::from_str(&json).expect("deserialize");
153 assert_eq!(deser.id, "call_123");
154 assert_eq!(deser.name, "read_file");
155 assert_eq!(deser.input["path"], "/tmp/test.txt");
156 }
157
158 #[test]
159 fn test_tool_call_result_serde() {
160 let result = ToolCallResult {
161 id: "call_123".to_string(),
162 content: "file contents here".to_string(),
163 is_error: false,
164 };
165 let json = serde_json::to_string(&result).expect("serialize");
166 let deser: ToolCallResult = serde_json::from_str(&json).expect("deserialize");
167 assert_eq!(deser.id, "call_123");
168 assert_eq!(deser.content, "file contents here");
169 assert!(!deser.is_error);
170 }
171
172 #[test]
173 fn test_tool_call_result_error() {
174 let result = ToolCallResult {
175 id: "call_456".to_string(),
176 content: "Permission denied".to_string(),
177 is_error: true,
178 };
179 assert!(result.is_error);
180 }
181
182 #[test]
183 fn test_tool_call_result_is_error_default() {
184 let json = r#"{"id": "x", "content": "ok"}"#;
186 let result: ToolCallResult = serde_json::from_str(json).expect("deserialize");
187 assert!(!result.is_error);
188 }
189
190 #[test]
191 fn test_message_with_tool_calls() {
192 let mut msg = Message::new(Role::Assistant, "Let me check that file");
193 msg.tool_calls.push(ToolCall {
194 id: "tc1".to_string(),
195 name: "read_file".to_string(),
196 input: serde_json::json!({"path": "main.rs"}),
197 });
198 let json = serde_json::to_string(&msg).expect("serialize");
199 assert!(json.contains("tool_calls"));
200 let deser: Message = serde_json::from_str(&json).expect("deserialize");
201 assert_eq!(deser.tool_calls.len(), 1);
202 assert_eq!(deser.tool_calls[0].name, "read_file");
203 }
204
205 #[test]
206 fn test_role_equality() {
207 assert_eq!(Role::User, Role::User);
208 assert_ne!(Role::User, Role::Assistant);
209 }
210
211 #[test]
212 fn test_role_hash() {
213 let mut set = std::collections::HashSet::new();
214 set.insert(Role::User);
215 set.insert(Role::Assistant);
216 set.insert(Role::User);
217 assert_eq!(set.len(), 2);
218 }
219}