1use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use super::content::ContentBlock;
13
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum Message {
40 System(SystemMessage),
42 Assistant(AssistantMessage),
44 User(UserMessage),
46 Result(ResultMessage),
48 #[serde(rename = "stream_event")]
51 StreamEvent(StreamEvent),
52}
53
54impl Message {
55 pub fn session_id(&self) -> Option<&str> {
61 match self {
62 Message::System(m) => Some(&m.session_id),
63 Message::Assistant(m) => Some(&m.session_id),
64 Message::User(m) => Some(&m.session_id),
65 Message::Result(m) => Some(&m.session_id),
66 Message::StreamEvent(m) => Some(&m.session_id),
67 }
68 }
69
70 pub fn is_error_result(&self) -> bool {
72 matches!(self, Message::Result(r) if r.is_error)
73 }
74
75 pub fn is_stream_event(&self) -> bool {
77 matches!(self, Message::StreamEvent(_))
78 }
79
80 pub fn assistant_text(&self) -> Option<String> {
84 let Message::Assistant(m) = self else {
85 return None;
86 };
87 let texts: Vec<&str> = m
88 .message
89 .content
90 .iter()
91 .filter_map(|b| match b {
92 ContentBlock::Text(t) => Some(t.text.as_str()),
93 _ => None,
94 })
95 .collect();
96 if texts.is_empty() {
97 None
98 } else {
99 Some(texts.join(""))
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct SystemMessage {
114 #[serde(default)]
116 pub subtype: String,
117 #[serde(default)]
119 pub session_id: String,
120 #[serde(default)]
122 pub cwd: String,
123 #[serde(default)]
125 pub tools: Vec<String>,
126 #[serde(default)]
128 pub mcp_servers: Vec<McpServerStatus>,
129 #[serde(default)]
131 pub model: String,
132 #[serde(flatten)]
134 pub extra: Value,
135}
136
137#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
139pub struct McpServerStatus {
140 pub name: String,
142 #[serde(default)]
144 pub status: String,
145 #[serde(flatten)]
147 pub extra: Value,
148}
149
150#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
154pub struct AssistantMessage {
155 pub message: AssistantMessageInner,
157 #[serde(default)]
159 pub session_id: String,
160}
161
162#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub struct AssistantMessageInner {
165 #[serde(default)]
167 pub role: String,
168 #[serde(default)]
170 pub content: Vec<ContentBlock>,
171 #[serde(default)]
173 pub model: String,
174 #[serde(default)]
176 pub stop_reason: String,
177 #[serde(default)]
179 pub stop_sequence: Option<String>,
180 #[serde(flatten)]
182 pub extra: Value,
183}
184
185#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct UserMessage {
193 pub message: UserMessageInner,
195 #[serde(default)]
197 pub session_id: String,
198}
199
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
202pub struct UserMessageInner {
203 #[serde(default)]
205 pub role: String,
206 #[serde(default)]
208 pub content: Vec<ContentBlock>,
209 #[serde(flatten)]
211 pub extra: Value,
212}
213
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub struct ResultMessage {
222 #[serde(default)]
224 pub subtype: String,
225 #[serde(default)]
227 pub is_error: bool,
228 #[serde(default)]
230 pub duration_ms: f64,
231 #[serde(default)]
233 pub duration_api_ms: f64,
234 #[serde(default)]
236 pub num_turns: u32,
237 #[serde(default)]
239 pub session_id: String,
240 #[serde(default)]
242 pub usage: Usage,
243 #[serde(default)]
245 pub stop_reason: String,
246 #[serde(flatten)]
248 pub extra: Value,
249}
250
251#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
258pub struct Usage {
259 #[serde(default)]
261 pub input_tokens: u32,
262 #[serde(default)]
264 pub output_tokens: u32,
265 #[serde(default)]
267 pub cache_read_input_tokens: u32,
268 #[serde(default)]
270 pub cache_creation_input_tokens: u32,
271 #[serde(default)]
273 pub thought_tokens: u32,
274}
275
276#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
285pub struct StreamEvent {
286 pub event_type: String,
288 #[serde(default)]
290 pub data: Value,
291 #[serde(default)]
293 pub session_id: String,
294}
295
296#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
300pub struct SessionInfo {
301 pub session_id: String,
303 #[serde(default)]
305 pub model: String,
306 #[serde(default)]
308 pub tools: Vec<String>,
309 #[serde(flatten)]
311 pub extra: Value,
312}
313
314#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
322pub struct PlanEntry {
323 #[serde(default)]
325 pub content: String,
326 #[serde(default)]
328 pub priority: String,
329 #[serde(default)]
331 pub status: String,
332 #[serde(flatten)]
334 pub extra: Value,
335}
336
337#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::types::content::TextBlock;
343 use serde_json::json;
344
345 fn system_msg(session_id: &str) -> Message {
347 Message::System(SystemMessage {
348 subtype: "init".to_owned(),
349 session_id: session_id.to_owned(),
350 cwd: "/tmp".to_owned(),
351 tools: vec![],
352 mcp_servers: vec![],
353 model: "gemini-2.5-pro".to_owned(),
354 extra: Value::Object(Default::default()),
355 })
356 }
357
358 fn result_msg(session_id: &str, is_error: bool) -> Message {
360 Message::Result(ResultMessage {
361 subtype: if is_error { "error" } else { "success" }.to_owned(),
362 is_error,
363 duration_ms: 123.4,
364 duration_api_ms: 100.0,
365 num_turns: 1,
366 session_id: session_id.to_owned(),
367 usage: Usage::default(),
368 stop_reason: "end_turn".to_owned(),
369 extra: Value::Object(Default::default()),
370 })
371 }
372
373 fn assistant_msg(session_id: &str, content: Vec<ContentBlock>) -> Message {
375 Message::Assistant(AssistantMessage {
376 message: AssistantMessageInner {
377 role: "assistant".to_owned(),
378 content,
379 model: "gemini-2.5-pro".to_owned(),
380 stop_reason: "end_turn".to_owned(),
381 stop_sequence: None,
382 extra: Value::Object(Default::default()),
383 },
384 session_id: session_id.to_owned(),
385 })
386 }
387
388 fn stream_event_msg(session_id: &str) -> Message {
390 Message::StreamEvent(StreamEvent {
391 event_type: "tool_call_start".to_owned(),
392 data: json!({ "tool": "bash" }),
393 session_id: session_id.to_owned(),
394 })
395 }
396
397 #[test]
400 fn test_message_system_session_id() {
401 let msg = system_msg("sess-abc");
402 assert_eq!(msg.session_id(), Some("sess-abc"));
403 }
404
405 #[test]
406 fn test_message_result_session_id() {
407 let msg = result_msg("sess-xyz", false);
408 assert_eq!(msg.session_id(), Some("sess-xyz"));
409 }
410
411 #[test]
412 fn test_message_stream_event_session_id() {
413 let msg = stream_event_msg("sess-ev");
414 assert_eq!(msg.session_id(), Some("sess-ev"));
415 }
416
417 #[test]
420 fn test_message_is_error_result_true() {
421 let msg = result_msg("s1", true);
422 assert!(msg.is_error_result(), "is_error=true must return true");
423 }
424
425 #[test]
426 fn test_message_is_error_result_false() {
427 let msg = result_msg("s1", false);
428 assert!(!msg.is_error_result(), "is_error=false must return false");
429 }
430
431 #[test]
432 fn test_message_is_error_result_non_result_variant() {
433 let msg = system_msg("s1");
434 assert!(!msg.is_error_result(), "non-Result variant must return false");
435 }
436
437 #[test]
440 fn test_message_is_stream_event() {
441 let msg = stream_event_msg("s1");
442 assert!(msg.is_stream_event());
443 }
444
445 #[test]
446 fn test_message_is_stream_event_false_for_system() {
447 let msg = system_msg("s1");
448 assert!(!msg.is_stream_event());
449 }
450
451 #[test]
454 fn test_message_assistant_text_single() {
455 let content = vec![ContentBlock::Text(TextBlock::new("hello world"))];
456 let msg = assistant_msg("s1", content);
457 assert_eq!(msg.assistant_text(), Some("hello world".to_owned()));
458 }
459
460 #[test]
461 fn test_message_assistant_text_multiple_blocks_concatenated() {
462 let content = vec![
463 ContentBlock::Text(TextBlock::new("foo")),
464 ContentBlock::Text(TextBlock::new("bar")),
465 ];
466 let msg = assistant_msg("s1", content);
467 assert_eq!(msg.assistant_text(), Some("foobar".to_owned()));
468 }
469
470 #[test]
471 fn test_message_assistant_text_empty() {
472 let msg = assistant_msg("s1", vec![]);
473 assert_eq!(
474 msg.assistant_text(),
475 None,
476 "no content blocks must yield None"
477 );
478 }
479
480 #[test]
481 fn test_message_assistant_text_non_text_blocks_only() {
482 use crate::types::content::ThinkingBlock;
484 let content = vec![ContentBlock::Thinking(ThinkingBlock::new("reasoning..."))];
485 let msg = assistant_msg("s1", content);
486 assert_eq!(
487 msg.assistant_text(),
488 None,
489 "no Text blocks must yield None"
490 );
491 }
492
493 #[test]
494 fn test_message_assistant_text_non_assistant_variant() {
495 let msg = system_msg("s1");
496 assert_eq!(msg.assistant_text(), None);
497 }
498
499 #[test]
502 fn test_usage_default() {
503 let usage = Usage::default();
504 assert_eq!(usage.input_tokens, 0);
505 assert_eq!(usage.output_tokens, 0);
506 assert_eq!(usage.cache_read_input_tokens, 0);
507 assert_eq!(usage.cache_creation_input_tokens, 0);
508 assert_eq!(usage.thought_tokens, 0);
509 }
510
511 #[test]
514 fn test_message_serde_roundtrip_system() {
515 let original = Message::System(SystemMessage {
516 subtype: "init".to_owned(),
517 session_id: "sess-roundtrip".to_owned(),
518 cwd: "/workspace".to_owned(),
519 tools: vec!["bash".to_owned(), "read_file".to_owned()],
520 mcp_servers: vec![McpServerStatus {
521 name: "filesystem".to_owned(),
522 status: "connected".to_owned(),
523 extra: Value::Object(Default::default()),
524 }],
525 model: "gemini-2.5-pro".to_owned(),
526 extra: Value::Object(Default::default()),
527 });
528
529 let json = serde_json::to_string(&original).expect("serialize");
530 let recovered: Message = serde_json::from_str(&json).expect("deserialize");
531
532 assert_eq!(original, recovered);
533 }
534
535 #[test]
536 fn test_message_serde_roundtrip_result() {
537 let original = Message::Result(ResultMessage {
538 subtype: "success".to_owned(),
539 is_error: false,
540 duration_ms: 450.75,
541 duration_api_ms: 400.0,
542 num_turns: 3,
543 session_id: "sess-rt2".to_owned(),
544 usage: Usage {
545 input_tokens: 512,
546 output_tokens: 128,
547 cache_read_input_tokens: 64,
548 cache_creation_input_tokens: 32,
549 thought_tokens: 256,
550 },
551 stop_reason: "end_turn".to_owned(),
552 extra: Value::Object(Default::default()),
553 });
554
555 let json = serde_json::to_string(&original).expect("serialize");
556 let recovered: Message = serde_json::from_str(&json).expect("deserialize");
557
558 assert_eq!(original, recovered);
559 }
560
561 #[test]
562 fn test_message_serde_roundtrip_stream_event() {
563 let original = Message::StreamEvent(StreamEvent {
564 event_type: "plan_update".to_owned(),
565 data: json!({ "step": 1, "action": "read_file" }),
566 session_id: "sess-rt3".to_owned(),
567 });
568
569 let json = serde_json::to_string(&original).expect("serialize");
570 let recovered: Message = serde_json::from_str(&json).expect("deserialize");
571
572 assert_eq!(original, recovered);
573 }
574
575 #[test]
578 fn test_plan_entry_defaults() {
579 let entry: PlanEntry =
580 serde_json::from_str("{}").expect("empty object must deserialize via defaults");
581 assert!(entry.content.is_empty());
582 assert!(entry.priority.is_empty());
583 assert!(entry.status.is_empty());
584 }
585
586 #[test]
587 fn test_plan_entry_roundtrip() {
588 let original = PlanEntry {
589 content: "Analyze the repository structure".to_owned(),
590 priority: "high".to_owned(),
591 status: "pending".to_owned(),
592 extra: Value::Object(Default::default()),
593 };
594
595 let json = serde_json::to_string(&original).expect("serialize");
596 let recovered: PlanEntry = serde_json::from_str(&json).expect("deserialize");
597
598 assert_eq!(original, recovered);
599 }
600
601 #[test]
604 fn test_usage_thought_tokens_roundtrip() {
605 let usage = Usage {
606 input_tokens: 100,
607 output_tokens: 50,
608 cache_read_input_tokens: 0,
609 cache_creation_input_tokens: 0,
610 thought_tokens: 300,
611 };
612 let json = serde_json::to_string(&usage).expect("serialize");
613 let recovered: Usage = serde_json::from_str(&json).expect("deserialize");
614 assert_eq!(recovered.thought_tokens, 300);
615 }
616}