use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::content::ContentBlock;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Message {
System(SystemMessage),
Assistant(AssistantMessage),
User(UserMessage),
Result(ResultMessage),
#[serde(rename = "stream_event")]
StreamEvent(StreamEvent),
}
impl Message {
pub fn session_id(&self) -> Option<&str> {
match self {
Message::System(m) => Some(&m.session_id),
Message::Assistant(m) => Some(&m.session_id),
Message::User(m) => Some(&m.session_id),
Message::Result(m) => Some(&m.session_id),
Message::StreamEvent(m) => Some(&m.session_id),
}
}
pub fn is_error_result(&self) -> bool {
matches!(self, Message::Result(r) if r.is_error)
}
pub fn is_stream_event(&self) -> bool {
matches!(self, Message::StreamEvent(_))
}
pub fn assistant_text(&self) -> Option<String> {
let Message::Assistant(m) = self else {
return None;
};
let texts: Vec<&str> = m
.message
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text(t) => Some(t.text.as_str()),
_ => None,
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(""))
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SystemMessage {
#[serde(default)]
pub subtype: String,
#[serde(default)]
pub session_id: String,
#[serde(default)]
pub cwd: String,
#[serde(default)]
pub tools: Vec<String>,
#[serde(default)]
pub mcp_servers: Vec<McpServerStatus>,
#[serde(default)]
pub model: String,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct McpServerStatus {
pub name: String,
#[serde(default)]
pub status: String,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AssistantMessage {
pub message: AssistantMessageInner,
#[serde(default)]
pub session_id: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AssistantMessageInner {
#[serde(default)]
pub role: String,
#[serde(default)]
pub content: Vec<ContentBlock>,
#[serde(default)]
pub model: String,
#[serde(default)]
pub stop_reason: String,
#[serde(default)]
pub stop_sequence: Option<String>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct UserMessage {
pub message: UserMessageInner,
#[serde(default)]
pub session_id: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct UserMessageInner {
#[serde(default)]
pub role: String,
#[serde(default)]
pub content: Vec<ContentBlock>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ResultMessage {
#[serde(default)]
pub subtype: String,
#[serde(default)]
pub is_error: bool,
#[serde(default)]
pub duration_ms: f64,
#[serde(default)]
pub duration_api_ms: f64,
#[serde(default)]
pub num_turns: u32,
#[serde(default)]
pub session_id: String,
#[serde(default)]
pub usage: Usage,
#[serde(default)]
pub stop_reason: String,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct Usage {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
#[serde(default)]
pub cache_read_input_tokens: u32,
#[serde(default)]
pub cache_creation_input_tokens: u32,
#[serde(default)]
pub thought_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StreamEvent {
pub event_type: String,
#[serde(default)]
pub data: Value,
#[serde(default)]
pub session_id: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SessionInfo {
pub session_id: String,
#[serde(default)]
pub model: String,
#[serde(default)]
pub tools: Vec<String>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PlanEntry {
#[serde(default)]
pub content: String,
#[serde(default)]
pub priority: String,
#[serde(default)]
pub status: String,
#[serde(flatten)]
pub extra: Value,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::content::TextBlock;
use serde_json::json;
fn system_msg(session_id: &str) -> Message {
Message::System(SystemMessage {
subtype: "init".to_owned(),
session_id: session_id.to_owned(),
cwd: "/tmp".to_owned(),
tools: vec![],
mcp_servers: vec![],
model: "gemini-2.5-pro".to_owned(),
extra: Value::Object(Default::default()),
})
}
fn result_msg(session_id: &str, is_error: bool) -> Message {
Message::Result(ResultMessage {
subtype: if is_error { "error" } else { "success" }.to_owned(),
is_error,
duration_ms: 123.4,
duration_api_ms: 100.0,
num_turns: 1,
session_id: session_id.to_owned(),
usage: Usage::default(),
stop_reason: "end_turn".to_owned(),
extra: Value::Object(Default::default()),
})
}
fn assistant_msg(session_id: &str, content: Vec<ContentBlock>) -> Message {
Message::Assistant(AssistantMessage {
message: AssistantMessageInner {
role: "assistant".to_owned(),
content,
model: "gemini-2.5-pro".to_owned(),
stop_reason: "end_turn".to_owned(),
stop_sequence: None,
extra: Value::Object(Default::default()),
},
session_id: session_id.to_owned(),
})
}
fn stream_event_msg(session_id: &str) -> Message {
Message::StreamEvent(StreamEvent {
event_type: "tool_call_start".to_owned(),
data: json!({ "tool": "bash" }),
session_id: session_id.to_owned(),
})
}
#[test]
fn test_message_system_session_id() {
let msg = system_msg("sess-abc");
assert_eq!(msg.session_id(), Some("sess-abc"));
}
#[test]
fn test_message_result_session_id() {
let msg = result_msg("sess-xyz", false);
assert_eq!(msg.session_id(), Some("sess-xyz"));
}
#[test]
fn test_message_stream_event_session_id() {
let msg = stream_event_msg("sess-ev");
assert_eq!(msg.session_id(), Some("sess-ev"));
}
#[test]
fn test_message_is_error_result_true() {
let msg = result_msg("s1", true);
assert!(msg.is_error_result(), "is_error=true must return true");
}
#[test]
fn test_message_is_error_result_false() {
let msg = result_msg("s1", false);
assert!(!msg.is_error_result(), "is_error=false must return false");
}
#[test]
fn test_message_is_error_result_non_result_variant() {
let msg = system_msg("s1");
assert!(!msg.is_error_result(), "non-Result variant must return false");
}
#[test]
fn test_message_is_stream_event() {
let msg = stream_event_msg("s1");
assert!(msg.is_stream_event());
}
#[test]
fn test_message_is_stream_event_false_for_system() {
let msg = system_msg("s1");
assert!(!msg.is_stream_event());
}
#[test]
fn test_message_assistant_text_single() {
let content = vec![ContentBlock::Text(TextBlock::new("hello world"))];
let msg = assistant_msg("s1", content);
assert_eq!(msg.assistant_text(), Some("hello world".to_owned()));
}
#[test]
fn test_message_assistant_text_multiple_blocks_concatenated() {
let content = vec![
ContentBlock::Text(TextBlock::new("foo")),
ContentBlock::Text(TextBlock::new("bar")),
];
let msg = assistant_msg("s1", content);
assert_eq!(msg.assistant_text(), Some("foobar".to_owned()));
}
#[test]
fn test_message_assistant_text_empty() {
let msg = assistant_msg("s1", vec![]);
assert_eq!(
msg.assistant_text(),
None,
"no content blocks must yield None"
);
}
#[test]
fn test_message_assistant_text_non_text_blocks_only() {
use crate::types::content::ThinkingBlock;
let content = vec![ContentBlock::Thinking(ThinkingBlock::new("reasoning..."))];
let msg = assistant_msg("s1", content);
assert_eq!(
msg.assistant_text(),
None,
"no Text blocks must yield None"
);
}
#[test]
fn test_message_assistant_text_non_assistant_variant() {
let msg = system_msg("s1");
assert_eq!(msg.assistant_text(), None);
}
#[test]
fn test_usage_default() {
let usage = Usage::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
assert_eq!(usage.cache_read_input_tokens, 0);
assert_eq!(usage.cache_creation_input_tokens, 0);
assert_eq!(usage.thought_tokens, 0);
}
#[test]
fn test_message_serde_roundtrip_system() {
let original = Message::System(SystemMessage {
subtype: "init".to_owned(),
session_id: "sess-roundtrip".to_owned(),
cwd: "/workspace".to_owned(),
tools: vec!["bash".to_owned(), "read_file".to_owned()],
mcp_servers: vec![McpServerStatus {
name: "filesystem".to_owned(),
status: "connected".to_owned(),
extra: Value::Object(Default::default()),
}],
model: "gemini-2.5-pro".to_owned(),
extra: Value::Object(Default::default()),
});
let json = serde_json::to_string(&original).expect("serialize");
let recovered: Message = serde_json::from_str(&json).expect("deserialize");
assert_eq!(original, recovered);
}
#[test]
fn test_message_serde_roundtrip_result() {
let original = Message::Result(ResultMessage {
subtype: "success".to_owned(),
is_error: false,
duration_ms: 450.75,
duration_api_ms: 400.0,
num_turns: 3,
session_id: "sess-rt2".to_owned(),
usage: Usage {
input_tokens: 512,
output_tokens: 128,
cache_read_input_tokens: 64,
cache_creation_input_tokens: 32,
thought_tokens: 256,
},
stop_reason: "end_turn".to_owned(),
extra: Value::Object(Default::default()),
});
let json = serde_json::to_string(&original).expect("serialize");
let recovered: Message = serde_json::from_str(&json).expect("deserialize");
assert_eq!(original, recovered);
}
#[test]
fn test_message_serde_roundtrip_stream_event() {
let original = Message::StreamEvent(StreamEvent {
event_type: "plan_update".to_owned(),
data: json!({ "step": 1, "action": "read_file" }),
session_id: "sess-rt3".to_owned(),
});
let json = serde_json::to_string(&original).expect("serialize");
let recovered: Message = serde_json::from_str(&json).expect("deserialize");
assert_eq!(original, recovered);
}
#[test]
fn test_plan_entry_defaults() {
let entry: PlanEntry =
serde_json::from_str("{}").expect("empty object must deserialize via defaults");
assert!(entry.content.is_empty());
assert!(entry.priority.is_empty());
assert!(entry.status.is_empty());
}
#[test]
fn test_plan_entry_roundtrip() {
let original = PlanEntry {
content: "Analyze the repository structure".to_owned(),
priority: "high".to_owned(),
status: "pending".to_owned(),
extra: Value::Object(Default::default()),
};
let json = serde_json::to_string(&original).expect("serialize");
let recovered: PlanEntry = serde_json::from_str(&json).expect("deserialize");
assert_eq!(original, recovered);
}
#[test]
fn test_usage_thought_tokens_roundtrip() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
thought_tokens: 300,
};
let json = serde_json::to_string(&usage).expect("serialize");
let recovered: Usage = serde_json::from_str(&json).expect("deserialize");
assert_eq!(recovered.thought_tokens, 300);
}
}