1use serde::{Deserialize, Serialize};
28
29use crate::types::content::ContentBlock;
30
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum Message {
41 System(SystemMessage),
43 Assistant(AssistantMessage),
45 User(UserMessage),
47 Result(ResultMessage),
49 #[serde(rename = "stream_event")]
51 StreamEvent(StreamEvent),
52}
53
54impl Message {
55 #[must_use]
57 pub fn session_id(&self) -> Option<&str> {
58 match self {
59 Self::System(m) => Some(&m.session_id),
60 Self::Assistant(m) => m.session_id.as_deref(),
61 Self::User(_) => None,
62 Self::Result(m) => m.session_id.as_deref(),
63 Self::StreamEvent(m) => Some(&m.session_id),
64 }
65 }
66
67 #[inline]
69 #[must_use]
70 pub fn is_error_result(&self) -> bool {
71 matches!(self, Self::Result(r) if r.is_error)
72 }
73
74 #[inline]
76 #[must_use]
77 pub fn is_stream_event(&self) -> bool {
78 matches!(self, Self::StreamEvent(_))
79 }
80
81 #[must_use]
84 pub fn assistant_text(&self) -> Option<String> {
85 if let Self::Assistant(a) = self {
86 let text: String = a
87 .message
88 .content
89 .iter()
90 .filter_map(|b| b.as_text())
91 .collect::<Vec<_>>()
92 .join("\n");
93 if text.is_empty() { None } else { Some(text) }
94 } else {
95 None
96 }
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct SystemMessage {
106 #[serde(default)]
108 pub subtype: String,
109
110 #[serde(default)]
112 pub session_id: String,
113
114 #[serde(default)]
116 pub cwd: String,
117
118 #[serde(default)]
120 pub tools: Vec<String>,
121
122 #[serde(default)]
124 pub mcp_servers: Vec<McpServerStatus>,
125
126 #[serde(default)]
128 pub model: String,
129
130 #[serde(flatten)]
132 pub extra: serde_json::Value,
133}
134
135#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
137pub struct McpServerStatus {
138 pub name: String,
140
141 #[serde(default)]
143 pub status: String,
144}
145
146#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub struct AssistantMessage {
151 pub message: AssistantMessageInner,
153
154 #[serde(default)]
156 pub session_id: Option<String>,
157
158 #[serde(flatten)]
160 pub extra: serde_json::Value,
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct AssistantMessageInner {
167 #[serde(default)]
169 pub id: String,
170
171 #[serde(default)]
173 pub content: Vec<ContentBlock>,
174
175 #[serde(default)]
177 pub model: String,
178
179 #[serde(default)]
181 pub stop_reason: Option<String>,
182
183 #[serde(default)]
185 pub usage: Usage,
186}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct UserMessage {
193 pub message: UserMessageInner,
195
196 #[serde(flatten)]
198 pub extra: serde_json::Value,
199}
200
201#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203pub struct UserMessageInner {
204 #[serde(default)]
206 pub role: String,
207
208 #[serde(default)]
211 pub content: serde_json::Value,
212}
213
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
220pub struct ResultMessage {
221 #[serde(default)]
223 pub subtype: String,
224
225 #[serde(default)]
227 pub cost_usd: Option<f64>,
228
229 #[serde(default)]
231 pub duration_ms: u64,
232
233 #[serde(default)]
235 pub duration_api_ms: u64,
236
237 #[serde(default)]
239 pub is_error: bool,
240
241 #[serde(default)]
243 pub num_turns: u32,
244
245 #[serde(default)]
247 pub session_id: Option<String>,
248
249 #[serde(default)]
251 pub total_cost_usd: Option<f64>,
252
253 #[serde(default)]
255 pub usage: Usage,
256
257 #[serde(default)]
259 pub result: Option<String>,
260
261 #[serde(flatten)]
263 pub extra: serde_json::Value,
264}
265
266#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
273pub struct StreamEvent {
274 pub uuid: String,
276
277 pub session_id: String,
279
280 pub event: serde_json::Value,
282
283 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub parent_tool_use_id: Option<String>,
286
287 #[serde(flatten)]
289 pub extra: serde_json::Value,
290}
291
292#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
299pub struct Usage {
300 #[serde(default)]
302 pub input_tokens: u32,
303
304 #[serde(default)]
306 pub output_tokens: u32,
307
308 #[serde(default)]
310 pub cache_read_input_tokens: u32,
311
312 #[serde(default)]
314 pub cache_creation_input_tokens: u32,
315}
316
317impl Usage {
318 #[inline]
320 #[must_use]
321 pub fn total_tokens(&self) -> u32 {
322 self.input_tokens.saturating_add(self.output_tokens)
323 }
324
325 #[inline]
327 pub fn merge(&mut self, other: &Self) {
328 self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
329 self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
330 self.cache_read_input_tokens = self
331 .cache_read_input_tokens
332 .saturating_add(other.cache_read_input_tokens);
333 self.cache_creation_input_tokens = self
334 .cache_creation_input_tokens
335 .saturating_add(other.cache_creation_input_tokens);
336 }
337}
338
339#[derive(Debug, Clone, PartialEq)]
344pub struct SessionInfo {
345 pub session_id: String,
347 pub cwd: String,
349 pub tools: Vec<String>,
351 pub mcp_servers: Vec<McpServerStatus>,
353 pub model: String,
355}
356
357impl TryFrom<&SystemMessage> for SessionInfo {
358 type Error = crate::errors::Error;
359
360 fn try_from(msg: &SystemMessage) -> crate::errors::Result<Self> {
361 if msg.subtype != "init" {
362 return Err(crate::errors::Error::ControlProtocol(format!(
363 "expected system/init message, got subtype '{}'",
364 msg.subtype
365 )));
366 }
367 Ok(Self {
368 session_id: msg.session_id.clone(),
369 cwd: msg.cwd.clone(),
370 tools: msg.tools.clone(),
371 mcp_servers: msg.mcp_servers.clone(),
372 model: msg.model.clone(),
373 })
374 }
375}
376
377#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::types::content::TextBlock;
383
384 fn round_trip<T>(value: &T) -> T
387 where
388 T: Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + PartialEq,
389 {
390 let json = serde_json::to_string(value).expect("serialize");
391 serde_json::from_str(&json).expect("deserialize")
392 }
393
394 #[test]
397 fn system_message_round_trip() {
398 let msg = Message::System(SystemMessage {
399 subtype: "init".into(),
400 session_id: "sess-1".into(),
401 cwd: "/home/user".into(),
402 tools: vec!["bash".into(), "read_file".into()],
403 mcp_servers: vec![McpServerStatus {
404 name: "my-server".into(),
405 status: "connected".into(),
406 }],
407 model: "claude-opus-4-5".into(),
408 extra: serde_json::Value::Object(Default::default()),
409 });
410 assert_eq!(round_trip(&msg), msg);
411 }
412
413 #[test]
414 fn system_message_from_ndjson() {
415 let line = r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/tmp","tools":["bash"],"mcp_servers":[],"model":"claude-opus-4-5"}"#;
416 let msg: Message = serde_json::from_str(line).unwrap();
417 match msg {
418 Message::System(s) => {
419 assert_eq!(s.subtype, "init");
420 assert_eq!(s.session_id, "s1");
421 assert_eq!(s.cwd, "/tmp");
422 assert_eq!(s.tools, ["bash"]);
423 assert_eq!(s.model, "claude-opus-4-5");
424 }
425 other => panic!("expected System, got {other:?}"),
426 }
427 }
428
429 #[test]
430 fn system_message_missing_fields_use_defaults() {
431 let line = r#"{"type":"system","subtype":"status"}"#;
433 let msg: Message = serde_json::from_str(line).unwrap();
434 if let Message::System(s) = msg {
435 assert_eq!(s.session_id, "");
436 assert!(s.tools.is_empty());
437 } else {
438 panic!("expected System");
439 }
440 }
441
442 #[test]
443 fn system_message_extra_fields_preserved() {
444 let line = r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m","future_field":"value"}"#;
445 let msg: Message = serde_json::from_str(line).unwrap();
446 if let Message::System(s) = msg {
447 assert_eq!(s.extra["future_field"], "value");
448 }
449 }
450
451 #[test]
454 fn assistant_message_round_trip() {
455 let msg = Message::Assistant(AssistantMessage {
456 message: AssistantMessageInner {
457 id: "msg_001".into(),
458 content: vec![ContentBlock::Text(TextBlock {
459 text: "Hello!".into(),
460 })],
461 model: "claude-opus-4-5".into(),
462 stop_reason: Some("end_turn".into()),
463 usage: Usage {
464 input_tokens: 10,
465 output_tokens: 5,
466 cache_read_input_tokens: 0,
467 cache_creation_input_tokens: 0,
468 },
469 },
470 session_id: Some("sess-1".into()),
471 extra: serde_json::Value::Object(Default::default()),
472 });
473 assert_eq!(round_trip(&msg), msg);
474 }
475
476 #[test]
477 fn assistant_message_from_ndjson() {
478 let line = r#"{"type":"assistant","message":{"id":"m1","content":[{"type":"text","text":"Hi!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}},"session_id":"s1"}"#;
479 let msg: Message = serde_json::from_str(line).unwrap();
480 match &msg {
481 Message::Assistant(a) => {
482 assert_eq!(a.message.id, "m1");
483 assert_eq!(a.message.content.len(), 1);
484 assert_eq!(a.message.usage.input_tokens, 5);
485 }
486 other => panic!("expected Assistant, got {other:?}"),
487 }
488 }
489
490 #[test]
491 fn assistant_message_text_helper() {
492 let msg = Message::Assistant(AssistantMessage {
493 message: AssistantMessageInner {
494 id: "x".into(),
495 content: vec![
496 ContentBlock::Text(TextBlock {
497 text: "line one".into(),
498 }),
499 ContentBlock::Text(TextBlock {
500 text: "line two".into(),
501 }),
502 ],
503 model: String::new(),
504 stop_reason: None,
505 usage: Usage::default(),
506 },
507 session_id: None,
508 extra: serde_json::Value::Object(Default::default()),
509 });
510 assert_eq!(msg.assistant_text(), Some("line one\nline two".into()));
511 }
512
513 #[test]
514 fn non_assistant_message_text_helper_returns_none() {
515 let msg = Message::System(SystemMessage {
516 subtype: "init".into(),
517 session_id: String::new(),
518 cwd: String::new(),
519 tools: vec![],
520 mcp_servers: vec![],
521 model: String::new(),
522 extra: serde_json::Value::Object(Default::default()),
523 });
524 assert_eq!(msg.assistant_text(), None);
525 }
526
527 #[test]
530 fn user_message_round_trip() {
531 let msg = Message::User(UserMessage {
532 message: UserMessageInner {
533 role: "user".into(),
534 content: serde_json::json!("What is Rust?"),
535 },
536 extra: serde_json::Value::Object(Default::default()),
537 });
538 assert_eq!(round_trip(&msg), msg);
539 }
540
541 #[test]
542 fn user_message_from_ndjson() {
543 let line = r#"{"type":"user","message":{"role":"user","content":"hello"}}"#;
544 let msg: Message = serde_json::from_str(line).unwrap();
545 assert!(matches!(msg, Message::User(_)));
546 }
547
548 #[test]
551 fn result_message_round_trip() {
552 let msg = Message::Result(ResultMessage {
553 subtype: "success".into(),
554 cost_usd: Some(0.0042),
555 duration_ms: 3500,
556 duration_api_ms: 3100,
557 is_error: false,
558 num_turns: 2,
559 session_id: Some("sess-1".into()),
560 total_cost_usd: Some(0.0042),
561 usage: Usage {
562 input_tokens: 100,
563 output_tokens: 50,
564 cache_read_input_tokens: 20,
565 cache_creation_input_tokens: 5,
566 },
567 result: Some("Task complete.".into()),
568 extra: serde_json::Value::Object(Default::default()),
569 });
570 assert_eq!(round_trip(&msg), msg);
571 }
572
573 #[test]
574 fn result_message_from_ndjson() {
575 let line = r#"{"type":"result","subtype":"success","cost_usd":0.01,"duration_ms":1000,"duration_api_ms":900,"is_error":false,"num_turns":1,"session_id":"s1","usage":{"input_tokens":50,"output_tokens":20}}"#;
576 let msg: Message = serde_json::from_str(line).unwrap();
577 match msg {
578 Message::Result(r) => {
579 assert_eq!(r.subtype, "success");
580 assert!(!r.is_error);
581 assert_eq!(r.num_turns, 1);
582 }
583 other => panic!("expected Result, got {other:?}"),
584 }
585 }
586
587 #[test]
588 fn result_message_is_error_flag() {
589 let r = ResultMessage {
590 subtype: "error".into(),
591 cost_usd: None,
592 duration_ms: 0,
593 duration_api_ms: 0,
594 is_error: true,
595 num_turns: 0,
596 session_id: None,
597 total_cost_usd: None,
598 usage: Usage::default(),
599 result: None,
600 extra: serde_json::Value::Object(Default::default()),
601 };
602 let msg = Message::Result(r);
603 assert!(msg.is_error_result());
604 }
605
606 #[test]
607 fn is_error_result_false_for_non_result() {
608 let msg = Message::System(SystemMessage {
609 subtype: "init".into(),
610 session_id: String::new(),
611 cwd: String::new(),
612 tools: vec![],
613 mcp_servers: vec![],
614 model: String::new(),
615 extra: serde_json::Value::Object(Default::default()),
616 });
617 assert!(!msg.is_error_result());
618 }
619
620 #[test]
623 fn usage_default_is_zero() {
624 let u = Usage::default();
625 assert_eq!(u.input_tokens, 0);
626 assert_eq!(u.output_tokens, 0);
627 assert_eq!(u.cache_read_input_tokens, 0);
628 assert_eq!(u.cache_creation_input_tokens, 0);
629 }
630
631 #[test]
632 fn usage_total_tokens() {
633 let u = Usage {
634 input_tokens: 10,
635 output_tokens: 20,
636 ..Default::default()
637 };
638 assert_eq!(u.total_tokens(), 30);
639 }
640
641 #[test]
642 fn usage_total_tokens_saturates_on_overflow() {
643 let u = Usage {
644 input_tokens: u32::MAX,
645 output_tokens: 1,
646 ..Default::default()
647 };
648 assert_eq!(u.total_tokens(), u32::MAX);
649 }
650
651 #[test]
652 fn usage_merge() {
653 let mut a = Usage {
654 input_tokens: 10,
655 output_tokens: 5,
656 cache_read_input_tokens: 2,
657 cache_creation_input_tokens: 1,
658 };
659 let b = Usage {
660 input_tokens: 3,
661 output_tokens: 7,
662 cache_read_input_tokens: 0,
663 cache_creation_input_tokens: 4,
664 };
665 a.merge(&b);
666 assert_eq!(a.input_tokens, 13);
667 assert_eq!(a.output_tokens, 12);
668 assert_eq!(a.cache_read_input_tokens, 2);
669 assert_eq!(a.cache_creation_input_tokens, 5);
670 }
671
672 #[test]
673 fn usage_round_trip() {
674 let u = Usage {
675 input_tokens: 100,
676 output_tokens: 200,
677 cache_read_input_tokens: 50,
678 cache_creation_input_tokens: 10,
679 };
680 let json = serde_json::to_string(&u).unwrap();
681 let decoded: Usage = serde_json::from_str(&json).unwrap();
682 assert_eq!(u, decoded);
683 }
684
685 #[test]
688 fn session_info_from_init_message() {
689 let sys = SystemMessage {
690 subtype: "init".into(),
691 session_id: "s42".into(),
692 cwd: "/workspace".into(),
693 tools: vec!["bash".into()],
694 mcp_servers: vec![],
695 model: "claude-opus-4-5".into(),
696 extra: serde_json::Value::Object(Default::default()),
697 };
698 let info = SessionInfo::try_from(&sys).unwrap();
699 assert_eq!(info.session_id, "s42");
700 assert_eq!(info.cwd, "/workspace");
701 assert_eq!(info.tools, ["bash"]);
702 }
703
704 #[test]
705 fn session_info_rejects_non_init_subtype() {
706 let sys = SystemMessage {
707 subtype: "status".into(),
708 session_id: "s1".into(),
709 cwd: "/".into(),
710 tools: vec![],
711 mcp_servers: vec![],
712 model: String::new(),
713 extra: serde_json::Value::Object(Default::default()),
714 };
715 let err = SessionInfo::try_from(&sys).unwrap_err();
716 assert!(
717 matches!(err, crate::errors::Error::ControlProtocol(_)),
718 "expected ControlProtocol error, got {err:?}"
719 );
720 }
721
722 #[test]
725 fn message_session_id_system() {
726 let msg = Message::System(SystemMessage {
727 subtype: "init".into(),
728 session_id: "s1".into(),
729 cwd: String::new(),
730 tools: vec![],
731 mcp_servers: vec![],
732 model: String::new(),
733 extra: serde_json::Value::Object(Default::default()),
734 });
735 assert_eq!(msg.session_id(), Some("s1"));
736 }
737
738 #[test]
739 fn message_session_id_result() {
740 let msg = Message::Result(ResultMessage {
741 subtype: String::new(),
742 cost_usd: None,
743 duration_ms: 0,
744 duration_api_ms: 0,
745 is_error: false,
746 num_turns: 0,
747 session_id: Some("s99".into()),
748 total_cost_usd: None,
749 usage: Usage::default(),
750 result: None,
751 extra: serde_json::Value::Object(Default::default()),
752 });
753 assert_eq!(msg.session_id(), Some("s99"));
754 }
755
756 #[test]
757 fn message_session_id_user_is_none() {
758 let msg = Message::User(UserMessage {
759 message: UserMessageInner {
760 role: "user".into(),
761 content: serde_json::Value::Null,
762 },
763 extra: serde_json::Value::Object(Default::default()),
764 });
765 assert_eq!(msg.session_id(), None);
766 }
767
768 #[test]
771 fn stream_event_from_ndjson() {
772 let line = r#"{"type":"stream_event","uuid":"evt-001","session_id":"s1","event":{"kind":"progress","percent":50}}"#;
773 let msg: Message = serde_json::from_str(line).unwrap();
774 match &msg {
775 Message::StreamEvent(e) => {
776 assert_eq!(e.uuid, "evt-001");
777 assert_eq!(e.session_id, "s1");
778 assert_eq!(e.event["kind"], "progress");
779 assert_eq!(e.event["percent"], 50);
780 assert!(e.parent_tool_use_id.is_none());
781 }
782 other => panic!("expected StreamEvent, got {other:?}"),
783 }
784 assert!(msg.is_stream_event());
785 assert_eq!(msg.session_id(), Some("s1"));
786 }
787
788 #[test]
789 fn stream_event_with_parent_tool_use_id() {
790 let line = r#"{"type":"stream_event","uuid":"evt-002","session_id":"s1","event":{},"parent_tool_use_id":"tu_123"}"#;
791 let msg: Message = serde_json::from_str(line).unwrap();
792 if let Message::StreamEvent(e) = &msg {
793 assert_eq!(e.parent_tool_use_id.as_deref(), Some("tu_123"));
794 } else {
795 panic!("expected StreamEvent");
796 }
797 }
798
799 #[test]
800 fn stream_event_round_trip() {
801 let msg = Message::StreamEvent(StreamEvent {
802 uuid: "evt-003".into(),
803 session_id: "s2".into(),
804 event: serde_json::json!({"status": "done"}),
805 parent_tool_use_id: Some("tu_456".into()),
806 extra: serde_json::Value::Object(Default::default()),
807 });
808 assert_eq!(round_trip(&msg), msg);
809 }
810
811 #[test]
812 fn is_stream_event_false_for_other() {
813 let msg = Message::System(SystemMessage {
814 subtype: "init".into(),
815 session_id: String::new(),
816 cwd: String::new(),
817 tools: vec![],
818 mcp_servers: vec![],
819 model: String::new(),
820 extra: serde_json::Value::Object(Default::default()),
821 });
822 assert!(!msg.is_stream_event());
823 }
824}