1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
6#[serde(rename_all = "snake_case")]
7pub enum MessageRole {
8 User,
9 Assistant,
10 System,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
15pub struct MessageContent {
16 pub role: MessageRole,
17 pub content: String,
18}
19
20pub trait SDKMessageBase: Clone + Send + Sync + std::fmt::Debug {
22 fn session_id(&self) -> &str;
23 fn message_type(&self) -> MessageType;
24}
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum MessageType {
29 User,
30 Assistant,
31 System,
32 Result,
33 PartialAssistant,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct SDKUserMessage {
39 pub session_id: String,
40 pub message: MessageContent,
41 pub parent_tool_use_id: Option<String>,
42}
43
44impl SDKMessageBase for SDKUserMessage {
45 fn session_id(&self) -> &str {
46 &self.session_id
47 }
48
49 fn message_type(&self) -> MessageType {
50 MessageType::User
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SDKAssistantMessage {
57 pub session_id: String,
58 pub message: MessageContent,
59}
60
61impl SDKMessageBase for SDKAssistantMessage {
62 fn session_id(&self) -> &str {
63 &self.session_id
64 }
65
66 fn message_type(&self) -> MessageType {
67 MessageType::Assistant
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct SDKSystemMessage {
74 pub session_id: String,
75 pub message: MessageContent,
76}
77
78impl SDKMessageBase for SDKSystemMessage {
79 fn session_id(&self) -> &str {
80 &self.session_id
81 }
82
83 fn message_type(&self) -> MessageType {
84 MessageType::System
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SDKResultMessage {
91 pub session_id: String,
92 pub result: serde_json::Value,
93 pub exit_code: i32,
94}
95
96impl SDKMessageBase for SDKResultMessage {
97 fn session_id(&self) -> &str {
98 &self.session_id
99 }
100
101 fn message_type(&self) -> MessageType {
102 MessageType::Result
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SDKPartialAssistantMessage {
109 pub session_id: String,
110 pub message: MessageContent,
111 pub is_complete: bool,
112}
113
114impl SDKMessageBase for SDKPartialAssistantMessage {
115 fn session_id(&self) -> &str {
116 &self.session_id
117 }
118
119 fn message_type(&self) -> MessageType {
120 MessageType::PartialAssistant
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126#[serde(tag = "type", rename_all = "snake_case")]
127pub enum SDKMessage {
128 User(SDKUserMessage),
129 Assistant(SDKAssistantMessage),
130 System(SDKSystemMessage),
131 Result(SDKResultMessage),
132 PartialAssistant(SDKPartialAssistantMessage),
133}
134
135impl SDKMessage {
136 pub fn session_id(&self) -> &str {
137 match self {
138 SDKMessage::User(m) => &m.session_id,
139 SDKMessage::Assistant(m) => &m.session_id,
140 SDKMessage::System(m) => &m.session_id,
141 SDKMessage::Result(m) => &m.session_id,
142 SDKMessage::PartialAssistant(m) => &m.session_id,
143 }
144 }
145
146 pub fn message_type(&self) -> MessageType {
147 match self {
148 SDKMessage::User(_) => MessageType::User,
149 SDKMessage::Assistant(_) => MessageType::Assistant,
150 SDKMessage::System(_) => MessageType::System,
151 SDKMessage::Result(_) => MessageType::Result,
152 SDKMessage::PartialAssistant(_) => MessageType::PartialAssistant,
153 }
154 }
155}
156
157impl SDKMessage {
159 pub fn is_user_message(&self) -> bool {
160 matches!(self, SDKMessage::User(_))
161 }
162
163 pub fn is_assistant_message(&self) -> bool {
164 matches!(self, SDKMessage::Assistant(_))
165 }
166
167 pub fn is_system_message(&self) -> bool {
168 matches!(self, SDKMessage::System(_))
169 }
170
171 pub fn is_result_message(&self) -> bool {
172 matches!(self, SDKMessage::Result(_))
173 }
174
175 pub fn is_partial_assistant_message(&self) -> bool {
176 matches!(self, SDKMessage::PartialAssistant(_))
177 }
178
179 pub fn from_assistant_text(content: &str) -> Self {
181 SDKMessage::Assistant(SDKAssistantMessage {
182 session_id: String::new(),
183 message: MessageContent {
184 role: MessageRole::Assistant,
185 content: content.to_string(),
186 },
187 })
188 }
189
190 pub fn from_result_value(result: serde_json::Value) -> Self {
192 SDKMessage::Result(SDKResultMessage {
193 session_id: String::new(),
194 result,
195 exit_code: 0,
196 })
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_user_message_creation() {
206 let msg = SDKUserMessage {
207 session_id: "test-session".to_string(),
208 message: MessageContent {
209 role: MessageRole::User,
210 content: "Hello".to_string(),
211 },
212 parent_tool_use_id: None,
213 };
214
215 assert_eq!(msg.session_id(), "test-session");
216 assert_eq!(msg.message_type(), MessageType::User);
217
218 let wrapped = SDKMessage::User(msg.clone());
219 assert!(wrapped.is_user_message());
220 }
221
222 #[test]
223 fn test_assistant_message_creation() {
224 let msg = SDKAssistantMessage {
225 session_id: "test-session".to_string(),
226 message: MessageContent {
227 role: MessageRole::Assistant,
228 content: "Hi there!".to_string(),
229 },
230 };
231
232 assert_eq!(msg.session_id(), "test-session");
233 assert_eq!(msg.message_type(), MessageType::Assistant);
234
235 let wrapped = SDKMessage::Assistant(msg.clone());
236 assert!(wrapped.is_assistant_message());
237 }
238
239 #[test]
240 fn test_system_message_creation() {
241 let msg = SDKSystemMessage {
242 session_id: "test-session".to_string(),
243 message: MessageContent {
244 role: MessageRole::System,
245 content: "System initialized".to_string(),
246 },
247 };
248
249 assert_eq!(msg.session_id(), "test-session");
250 assert_eq!(msg.message_type(), MessageType::System);
251
252 let wrapped = SDKMessage::System(msg.clone());
253 assert!(wrapped.is_system_message());
254 }
255
256 #[test]
257 fn test_result_message_creation() {
258 let msg = SDKResultMessage {
259 session_id: "test-session".to_string(),
260 result: serde_json::json!({"status": "success"}),
261 exit_code: 0,
262 };
263
264 assert_eq!(msg.session_id(), "test-session");
265 assert_eq!(msg.message_type(), MessageType::Result);
266
267 let wrapped = SDKMessage::Result(msg.clone());
268 assert!(wrapped.is_result_message());
269 }
270
271 #[test]
272 fn test_partial_assistant_message_creation() {
273 let msg = SDKPartialAssistantMessage {
274 session_id: "test-session".to_string(),
275 message: MessageContent {
276 role: MessageRole::Assistant,
277 content: "Partial...".to_string(),
278 },
279 is_complete: false,
280 };
281
282 assert_eq!(msg.session_id(), "test-session");
283 assert_eq!(msg.message_type(), MessageType::PartialAssistant);
284
285 let wrapped = SDKMessage::PartialAssistant(msg.clone());
286 assert!(wrapped.is_partial_assistant_message());
287 }
288
289 #[test]
290 fn test_sdk_message_enum_user() {
291 let user_msg = SDKUserMessage {
292 session_id: "s1".to_string(),
293 message: MessageContent {
294 role: MessageRole::User,
295 content: "test".to_string(),
296 },
297 parent_tool_use_id: None,
298 };
299
300 let msg = SDKMessage::User(user_msg);
301 assert!(msg.is_user_message());
302 assert_eq!(msg.session_id(), "s1");
303 }
304
305 #[test]
306 fn test_sdk_message_enum_assistant() {
307 let assistant_msg = SDKAssistantMessage {
308 session_id: "s2".to_string(),
309 message: MessageContent {
310 role: MessageRole::Assistant,
311 content: "response".to_string(),
312 },
313 };
314
315 let msg = SDKMessage::Assistant(assistant_msg);
316 assert!(msg.is_assistant_message());
317 assert_eq!(msg.session_id(), "s2");
318 }
319
320 #[test]
321 fn test_message_role_serialization() {
322 let role = MessageRole::User;
323 let serialized = serde_json::to_string(&role).unwrap();
324 assert_eq!(serialized, "\"user\"");
325
326 let deserialized: MessageRole = serde_json::from_str(&serialized).unwrap();
327 assert!(matches!(deserialized, MessageRole::User));
328 }
329
330 #[test]
331 fn test_message_content_serialization() {
332 let content = MessageContent {
333 role: MessageRole::Assistant,
334 content: "Hello".to_string(),
335 };
336
337 let serialized = serde_json::to_string(&content).unwrap();
338 assert!(serialized.contains("\"role\":\"assistant\""));
339 assert!(serialized.contains("\"content\":\"Hello\""));
340 }
341
342 #[test]
343 fn test_type_guards_all_return_false_for_wrong_type() {
344 let msg = SDKMessage::User(SDKUserMessage {
345 session_id: "s1".to_string(),
346 message: MessageContent {
347 role: MessageRole::User,
348 content: "test".to_string(),
349 },
350 parent_tool_use_id: None,
351 });
352
353 assert!(!msg.is_assistant_message());
354 assert!(!msg.is_system_message());
355 assert!(!msg.is_result_message());
356 assert!(!msg.is_partial_assistant_message());
357 }
358
359 #[test]
360 fn test_message_debug_format() {
361 let msg = SDKUserMessage {
362 session_id: "debug-session".to_string(),
363 message: MessageContent {
364 role: MessageRole::User,
365 content: "Debug test".to_string(),
366 },
367 parent_tool_use_id: None,
368 };
369
370 let debug_str = format!("{:?}", msg);
371 assert!(debug_str.contains("debug-session"));
372 assert!(debug_str.contains("Debug test"));
373 }
374}