1use serde::{Deserialize, Serialize};
28
29use crate::types::content::ContentBlock;
30
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum Message {
41 System(SystemMessage),
43 Assistant(AssistantMessage),
45 User(UserMessage),
47 Result(ResultMessage),
49 #[serde(rename = "stream_event")]
51 StreamEvent(StreamEvent),
52 #[serde(rename = "rate_limit_event")]
54 RateLimitEvent(RateLimitEvent),
55}
56
57impl Message {
58 #[must_use]
60 pub fn session_id(&self) -> Option<&str> {
61 match self {
62 Self::System(m) => Some(&m.session_id),
63 Self::Assistant(m) => m.session_id.as_deref(),
64 Self::User(_) => None,
65 Self::Result(m) => m.session_id.as_deref(),
66 Self::StreamEvent(m) => Some(&m.session_id),
67 Self::RateLimitEvent(m) => m.session_id.as_deref(),
68 }
69 }
70
71 #[inline]
73 #[must_use]
74 pub fn is_error_result(&self) -> bool {
75 matches!(self, Self::Result(r) if r.is_error)
76 }
77
78 #[inline]
80 #[must_use]
81 pub fn is_stream_event(&self) -> bool {
82 matches!(self, Self::StreamEvent(_))
83 }
84
85 #[must_use]
88 pub fn assistant_text(&self) -> Option<String> {
89 if let Self::Assistant(a) = self {
90 let text: String = a
91 .message
92 .content
93 .iter()
94 .filter_map(|b| b.as_text())
95 .collect::<Vec<_>>()
96 .join("\n");
97 if text.is_empty() { None } else { Some(text) }
98 } else {
99 None
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct SystemMessage {
110 #[serde(default)]
112 pub subtype: String,
113
114 #[serde(default)]
116 pub session_id: String,
117
118 #[serde(default)]
120 pub cwd: String,
121
122 #[serde(default)]
124 pub tools: Vec<String>,
125
126 #[serde(default)]
128 pub mcp_servers: Vec<McpServerStatus>,
129
130 #[serde(default)]
132 pub model: String,
133
134 #[serde(flatten)]
136 pub extra: serde_json::Value,
137}
138
139#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141pub struct McpServerStatus {
142 pub name: String,
144
145 #[serde(default)]
147 pub status: String,
148}
149
150#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
154pub struct AssistantMessage {
155 pub message: AssistantMessageInner,
157
158 #[serde(default)]
160 pub session_id: Option<String>,
161
162 #[serde(flatten)]
164 pub extra: serde_json::Value,
165}
166
167#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
170pub struct AssistantMessageInner {
171 #[serde(default)]
173 pub id: String,
174
175 #[serde(default)]
177 pub content: Vec<ContentBlock>,
178
179 #[serde(default)]
181 pub model: String,
182
183 #[serde(default)]
185 pub stop_reason: Option<String>,
186
187 #[serde(default)]
189 pub usage: Usage,
190}
191
192#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
196pub struct UserMessage {
197 pub message: UserMessageInner,
199
200 #[serde(flatten)]
202 pub extra: serde_json::Value,
203}
204
205#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
207pub struct UserMessageInner {
208 #[serde(default)]
210 pub role: String,
211
212 #[serde(default)]
215 pub content: serde_json::Value,
216}
217
218#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ResultMessage {
225 #[serde(default)]
227 pub subtype: String,
228
229 #[serde(default)]
231 pub cost_usd: Option<f64>,
232
233 #[serde(default)]
235 pub duration_ms: u64,
236
237 #[serde(default)]
239 pub duration_api_ms: u64,
240
241 #[serde(default)]
243 pub is_error: bool,
244
245 #[serde(default)]
247 pub num_turns: u32,
248
249 #[serde(default)]
251 pub session_id: Option<String>,
252
253 #[serde(default)]
255 pub total_cost_usd: Option<f64>,
256
257 #[serde(default)]
259 pub usage: Usage,
260
261 #[serde(default)]
263 pub result: Option<String>,
264
265 #[serde(flatten)]
267 pub extra: serde_json::Value,
268}
269
270#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
277pub struct StreamEvent {
278 pub uuid: String,
280
281 pub session_id: String,
283
284 pub event: serde_json::Value,
286
287 #[serde(default, skip_serializing_if = "Option::is_none")]
289 pub parent_tool_use_id: Option<String>,
290
291 #[serde(flatten)]
293 pub extra: serde_json::Value,
294}
295
296#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
300pub struct RateLimitEvent {
301 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub session_id: Option<String>,
304
305 #[serde(flatten)]
307 pub extra: serde_json::Value,
308}
309
310#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
317pub struct Usage {
318 #[serde(default)]
320 pub input_tokens: u32,
321
322 #[serde(default)]
324 pub output_tokens: u32,
325
326 #[serde(default)]
328 pub cache_read_input_tokens: u32,
329
330 #[serde(default)]
332 pub cache_creation_input_tokens: u32,
333}
334
335impl Usage {
336 #[inline]
338 #[must_use]
339 pub fn total_tokens(&self) -> u32 {
340 self.input_tokens.saturating_add(self.output_tokens)
341 }
342
343 #[inline]
345 pub fn merge(&mut self, other: &Self) {
346 self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
347 self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
348 self.cache_read_input_tokens = self
349 .cache_read_input_tokens
350 .saturating_add(other.cache_read_input_tokens);
351 self.cache_creation_input_tokens = self
352 .cache_creation_input_tokens
353 .saturating_add(other.cache_creation_input_tokens);
354 }
355}
356
357#[derive(Debug, Clone, PartialEq)]
362pub struct SessionInfo {
363 pub session_id: String,
365 pub cwd: String,
367 pub tools: Vec<String>,
369 pub mcp_servers: Vec<McpServerStatus>,
371 pub model: String,
373}
374
375impl TryFrom<&SystemMessage> for SessionInfo {
376 type Error = crate::errors::Error;
377
378 fn try_from(msg: &SystemMessage) -> crate::errors::Result<Self> {
379 if msg.subtype != "init" {
380 return Err(crate::errors::Error::ControlProtocol(format!(
381 "expected system/init message, got subtype '{}'",
382 msg.subtype
383 )));
384 }
385 Ok(Self {
386 session_id: msg.session_id.clone(),
387 cwd: msg.cwd.clone(),
388 tools: msg.tools.clone(),
389 mcp_servers: msg.mcp_servers.clone(),
390 model: msg.model.clone(),
391 })
392 }
393}
394
395#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::types::content::TextBlock;
401
402 fn round_trip<T>(value: &T) -> T
405 where
406 T: Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + PartialEq,
407 {
408 let json = serde_json::to_string(value).expect("serialize");
409 serde_json::from_str(&json).expect("deserialize")
410 }
411
412 #[test]
415 fn system_message_round_trip() {
416 let msg = Message::System(SystemMessage {
417 subtype: "init".into(),
418 session_id: "sess-1".into(),
419 cwd: "/home/user".into(),
420 tools: vec!["bash".into(), "read_file".into()],
421 mcp_servers: vec![McpServerStatus {
422 name: "my-server".into(),
423 status: "connected".into(),
424 }],
425 model: "claude-opus-4-5".into(),
426 extra: serde_json::Value::Object(Default::default()),
427 });
428 assert_eq!(round_trip(&msg), msg);
429 }
430
431 #[test]
432 fn system_message_from_ndjson() {
433 let line = r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/tmp","tools":["bash"],"mcp_servers":[],"model":"claude-opus-4-5"}"#;
434 let msg: Message = serde_json::from_str(line).unwrap();
435 match msg {
436 Message::System(s) => {
437 assert_eq!(s.subtype, "init");
438 assert_eq!(s.session_id, "s1");
439 assert_eq!(s.cwd, "/tmp");
440 assert_eq!(s.tools, ["bash"]);
441 assert_eq!(s.model, "claude-opus-4-5");
442 }
443 other => panic!("expected System, got {other:?}"),
444 }
445 }
446
447 #[test]
448 fn system_message_missing_fields_use_defaults() {
449 let line = r#"{"type":"system","subtype":"status"}"#;
451 let msg: Message = serde_json::from_str(line).unwrap();
452 if let Message::System(s) = msg {
453 assert_eq!(s.session_id, "");
454 assert!(s.tools.is_empty());
455 } else {
456 panic!("expected System");
457 }
458 }
459
460 #[test]
461 fn system_message_extra_fields_preserved() {
462 let line = r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m","future_field":"value"}"#;
463 let msg: Message = serde_json::from_str(line).unwrap();
464 if let Message::System(s) = msg {
465 assert_eq!(s.extra["future_field"], "value");
466 }
467 }
468
469 #[test]
472 fn assistant_message_round_trip() {
473 let msg = Message::Assistant(AssistantMessage {
474 message: AssistantMessageInner {
475 id: "msg_001".into(),
476 content: vec![ContentBlock::Text(TextBlock {
477 text: "Hello!".into(),
478 })],
479 model: "claude-opus-4-5".into(),
480 stop_reason: Some("end_turn".into()),
481 usage: Usage {
482 input_tokens: 10,
483 output_tokens: 5,
484 cache_read_input_tokens: 0,
485 cache_creation_input_tokens: 0,
486 },
487 },
488 session_id: Some("sess-1".into()),
489 extra: serde_json::Value::Object(Default::default()),
490 });
491 assert_eq!(round_trip(&msg), msg);
492 }
493
494 #[test]
495 fn assistant_message_from_ndjson() {
496 let line = r#"{"type":"assistant","message":{"id":"m1","content":[{"type":"text","text":"Hi!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}},"session_id":"s1"}"#;
497 let msg: Message = serde_json::from_str(line).unwrap();
498 match &msg {
499 Message::Assistant(a) => {
500 assert_eq!(a.message.id, "m1");
501 assert_eq!(a.message.content.len(), 1);
502 assert_eq!(a.message.usage.input_tokens, 5);
503 }
504 other => panic!("expected Assistant, got {other:?}"),
505 }
506 }
507
508 #[test]
509 fn assistant_message_text_helper() {
510 let msg = Message::Assistant(AssistantMessage {
511 message: AssistantMessageInner {
512 id: "x".into(),
513 content: vec![
514 ContentBlock::Text(TextBlock {
515 text: "line one".into(),
516 }),
517 ContentBlock::Text(TextBlock {
518 text: "line two".into(),
519 }),
520 ],
521 model: String::new(),
522 stop_reason: None,
523 usage: Usage::default(),
524 },
525 session_id: None,
526 extra: serde_json::Value::Object(Default::default()),
527 });
528 assert_eq!(msg.assistant_text(), Some("line one\nline two".into()));
529 }
530
531 #[test]
532 fn non_assistant_message_text_helper_returns_none() {
533 let msg = Message::System(SystemMessage {
534 subtype: "init".into(),
535 session_id: String::new(),
536 cwd: String::new(),
537 tools: vec![],
538 mcp_servers: vec![],
539 model: String::new(),
540 extra: serde_json::Value::Object(Default::default()),
541 });
542 assert_eq!(msg.assistant_text(), None);
543 }
544
545 #[test]
548 fn user_message_round_trip() {
549 let msg = Message::User(UserMessage {
550 message: UserMessageInner {
551 role: "user".into(),
552 content: serde_json::json!("What is Rust?"),
553 },
554 extra: serde_json::Value::Object(Default::default()),
555 });
556 assert_eq!(round_trip(&msg), msg);
557 }
558
559 #[test]
560 fn user_message_from_ndjson() {
561 let line = r#"{"type":"user","message":{"role":"user","content":"hello"}}"#;
562 let msg: Message = serde_json::from_str(line).unwrap();
563 assert!(matches!(msg, Message::User(_)));
564 }
565
566 #[test]
569 fn result_message_round_trip() {
570 let msg = Message::Result(ResultMessage {
571 subtype: "success".into(),
572 cost_usd: Some(0.0042),
573 duration_ms: 3500,
574 duration_api_ms: 3100,
575 is_error: false,
576 num_turns: 2,
577 session_id: Some("sess-1".into()),
578 total_cost_usd: Some(0.0042),
579 usage: Usage {
580 input_tokens: 100,
581 output_tokens: 50,
582 cache_read_input_tokens: 20,
583 cache_creation_input_tokens: 5,
584 },
585 result: Some("Task complete.".into()),
586 extra: serde_json::Value::Object(Default::default()),
587 });
588 assert_eq!(round_trip(&msg), msg);
589 }
590
591 #[test]
592 fn result_message_from_ndjson() {
593 let line = r#"{"type":"result","subtype":"success","cost_usd":0.01,"duration_ms":1000,"duration_api_ms":900,"is_error":false,"num_turns":1,"session_id":"s1","usage":{"input_tokens":50,"output_tokens":20}}"#;
594 let msg: Message = serde_json::from_str(line).unwrap();
595 match msg {
596 Message::Result(r) => {
597 assert_eq!(r.subtype, "success");
598 assert!(!r.is_error);
599 assert_eq!(r.num_turns, 1);
600 }
601 other => panic!("expected Result, got {other:?}"),
602 }
603 }
604
605 #[test]
606 fn result_message_is_error_flag() {
607 let r = ResultMessage {
608 subtype: "error".into(),
609 cost_usd: None,
610 duration_ms: 0,
611 duration_api_ms: 0,
612 is_error: true,
613 num_turns: 0,
614 session_id: None,
615 total_cost_usd: None,
616 usage: Usage::default(),
617 result: None,
618 extra: serde_json::Value::Object(Default::default()),
619 };
620 let msg = Message::Result(r);
621 assert!(msg.is_error_result());
622 }
623
624 #[test]
625 fn is_error_result_false_for_non_result() {
626 let msg = Message::System(SystemMessage {
627 subtype: "init".into(),
628 session_id: String::new(),
629 cwd: String::new(),
630 tools: vec![],
631 mcp_servers: vec![],
632 model: String::new(),
633 extra: serde_json::Value::Object(Default::default()),
634 });
635 assert!(!msg.is_error_result());
636 }
637
638 #[test]
641 fn usage_default_is_zero() {
642 let u = Usage::default();
643 assert_eq!(u.input_tokens, 0);
644 assert_eq!(u.output_tokens, 0);
645 assert_eq!(u.cache_read_input_tokens, 0);
646 assert_eq!(u.cache_creation_input_tokens, 0);
647 }
648
649 #[test]
650 fn usage_total_tokens() {
651 let u = Usage {
652 input_tokens: 10,
653 output_tokens: 20,
654 ..Default::default()
655 };
656 assert_eq!(u.total_tokens(), 30);
657 }
658
659 #[test]
660 fn usage_total_tokens_saturates_on_overflow() {
661 let u = Usage {
662 input_tokens: u32::MAX,
663 output_tokens: 1,
664 ..Default::default()
665 };
666 assert_eq!(u.total_tokens(), u32::MAX);
667 }
668
669 #[test]
670 fn usage_merge() {
671 let mut a = Usage {
672 input_tokens: 10,
673 output_tokens: 5,
674 cache_read_input_tokens: 2,
675 cache_creation_input_tokens: 1,
676 };
677 let b = Usage {
678 input_tokens: 3,
679 output_tokens: 7,
680 cache_read_input_tokens: 0,
681 cache_creation_input_tokens: 4,
682 };
683 a.merge(&b);
684 assert_eq!(a.input_tokens, 13);
685 assert_eq!(a.output_tokens, 12);
686 assert_eq!(a.cache_read_input_tokens, 2);
687 assert_eq!(a.cache_creation_input_tokens, 5);
688 }
689
690 #[test]
691 fn usage_round_trip() {
692 let u = Usage {
693 input_tokens: 100,
694 output_tokens: 200,
695 cache_read_input_tokens: 50,
696 cache_creation_input_tokens: 10,
697 };
698 let json = serde_json::to_string(&u).unwrap();
699 let decoded: Usage = serde_json::from_str(&json).unwrap();
700 assert_eq!(u, decoded);
701 }
702
703 #[test]
706 fn session_info_from_init_message() {
707 let sys = SystemMessage {
708 subtype: "init".into(),
709 session_id: "s42".into(),
710 cwd: "/workspace".into(),
711 tools: vec!["bash".into()],
712 mcp_servers: vec![],
713 model: "claude-opus-4-5".into(),
714 extra: serde_json::Value::Object(Default::default()),
715 };
716 let info = SessionInfo::try_from(&sys).unwrap();
717 assert_eq!(info.session_id, "s42");
718 assert_eq!(info.cwd, "/workspace");
719 assert_eq!(info.tools, ["bash"]);
720 }
721
722 #[test]
723 fn session_info_rejects_non_init_subtype() {
724 let sys = SystemMessage {
725 subtype: "status".into(),
726 session_id: "s1".into(),
727 cwd: "/".into(),
728 tools: vec![],
729 mcp_servers: vec![],
730 model: String::new(),
731 extra: serde_json::Value::Object(Default::default()),
732 };
733 let err = SessionInfo::try_from(&sys).unwrap_err();
734 assert!(
735 matches!(err, crate::errors::Error::ControlProtocol(_)),
736 "expected ControlProtocol error, got {err:?}"
737 );
738 }
739
740 #[test]
743 fn message_session_id_system() {
744 let msg = Message::System(SystemMessage {
745 subtype: "init".into(),
746 session_id: "s1".into(),
747 cwd: String::new(),
748 tools: vec![],
749 mcp_servers: vec![],
750 model: String::new(),
751 extra: serde_json::Value::Object(Default::default()),
752 });
753 assert_eq!(msg.session_id(), Some("s1"));
754 }
755
756 #[test]
757 fn message_session_id_result() {
758 let msg = Message::Result(ResultMessage {
759 subtype: String::new(),
760 cost_usd: None,
761 duration_ms: 0,
762 duration_api_ms: 0,
763 is_error: false,
764 num_turns: 0,
765 session_id: Some("s99".into()),
766 total_cost_usd: None,
767 usage: Usage::default(),
768 result: None,
769 extra: serde_json::Value::Object(Default::default()),
770 });
771 assert_eq!(msg.session_id(), Some("s99"));
772 }
773
774 #[test]
775 fn message_session_id_user_is_none() {
776 let msg = Message::User(UserMessage {
777 message: UserMessageInner {
778 role: "user".into(),
779 content: serde_json::Value::Null,
780 },
781 extra: serde_json::Value::Object(Default::default()),
782 });
783 assert_eq!(msg.session_id(), None);
784 }
785
786 #[test]
789 fn stream_event_from_ndjson() {
790 let line = r#"{"type":"stream_event","uuid":"evt-001","session_id":"s1","event":{"kind":"progress","percent":50}}"#;
791 let msg: Message = serde_json::from_str(line).unwrap();
792 match &msg {
793 Message::StreamEvent(e) => {
794 assert_eq!(e.uuid, "evt-001");
795 assert_eq!(e.session_id, "s1");
796 assert_eq!(e.event["kind"], "progress");
797 assert_eq!(e.event["percent"], 50);
798 assert!(e.parent_tool_use_id.is_none());
799 }
800 other => panic!("expected StreamEvent, got {other:?}"),
801 }
802 assert!(msg.is_stream_event());
803 assert_eq!(msg.session_id(), Some("s1"));
804 }
805
806 #[test]
807 fn stream_event_with_parent_tool_use_id() {
808 let line = r#"{"type":"stream_event","uuid":"evt-002","session_id":"s1","event":{},"parent_tool_use_id":"tu_123"}"#;
809 let msg: Message = serde_json::from_str(line).unwrap();
810 if let Message::StreamEvent(e) = &msg {
811 assert_eq!(e.parent_tool_use_id.as_deref(), Some("tu_123"));
812 } else {
813 panic!("expected StreamEvent");
814 }
815 }
816
817 #[test]
818 fn stream_event_round_trip() {
819 let msg = Message::StreamEvent(StreamEvent {
820 uuid: "evt-003".into(),
821 session_id: "s2".into(),
822 event: serde_json::json!({"status": "done"}),
823 parent_tool_use_id: Some("tu_456".into()),
824 extra: serde_json::Value::Object(Default::default()),
825 });
826 assert_eq!(round_trip(&msg), msg);
827 }
828
829 #[test]
830 fn is_stream_event_false_for_other() {
831 let msg = Message::System(SystemMessage {
832 subtype: "init".into(),
833 session_id: String::new(),
834 cwd: String::new(),
835 tools: vec![],
836 mcp_servers: vec![],
837 model: String::new(),
838 extra: serde_json::Value::Object(Default::default()),
839 });
840 assert!(!msg.is_stream_event());
841 }
842}