use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use serdes_ai_core::{ModelResponse, RequestUsage};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentStreamEvent<Output = JsonValue> {
RunStart {
run_id: String,
step: u32,
},
RequestStart {
step: u32,
},
TextDelta {
content: String,
part_index: usize,
},
ToolCallStart {
name: String,
tool_call_id: Option<String>,
index: usize,
},
ToolCallDelta {
args_delta: String,
index: usize,
},
ToolCallComplete {
name: String,
args: JsonValue,
index: usize,
},
ToolResult {
name: String,
result: JsonValue,
success: bool,
index: usize,
},
ThinkingDelta {
content: String,
index: usize,
},
PartialOutput {
output: Output,
},
ResponseComplete {
response: ModelResponse,
},
UsageUpdate {
usage: RequestUsage,
},
FinalOutput {
output: Output,
},
RunComplete {
run_id: String,
total_steps: u32,
},
Error {
message: String,
recoverable: bool,
},
}
impl<Output> AgentStreamEvent<Output> {
pub fn run_start(run_id: impl Into<String>, step: u32) -> Self {
Self::RunStart {
run_id: run_id.into(),
step,
}
}
pub fn text_delta(content: impl Into<String>, part_index: usize) -> Self {
Self::TextDelta {
content: content.into(),
part_index,
}
}
pub fn error(message: impl Into<String>, recoverable: bool) -> Self {
Self::Error {
message: message.into(),
recoverable,
}
}
#[must_use]
pub fn is_terminal(&self) -> bool {
matches!(self, Self::RunComplete { .. } | Self::Error { .. })
}
#[must_use]
pub fn is_error(&self) -> bool {
matches!(self, Self::Error { .. })
}
pub fn as_text(&self) -> Option<&str> {
match self {
Self::TextDelta { content, .. } => Some(content),
_ => None,
}
}
pub fn as_output(&self) -> Option<&Output> {
match self {
Self::FinalOutput { output } => Some(output),
Self::PartialOutput { output } => Some(output),
_ => None,
}
}
pub fn map_output<U, F>(self, f: F) -> AgentStreamEvent<U>
where
F: FnOnce(Output) -> U,
{
match self {
Self::RunStart { run_id, step } => AgentStreamEvent::RunStart { run_id, step },
Self::RequestStart { step } => AgentStreamEvent::RequestStart { step },
Self::TextDelta {
content,
part_index,
} => AgentStreamEvent::TextDelta {
content,
part_index,
},
Self::ToolCallStart {
name,
tool_call_id,
index,
} => AgentStreamEvent::ToolCallStart {
name,
tool_call_id,
index,
},
Self::ToolCallDelta { args_delta, index } => {
AgentStreamEvent::ToolCallDelta { args_delta, index }
}
Self::ToolCallComplete { name, args, index } => {
AgentStreamEvent::ToolCallComplete { name, args, index }
}
Self::ToolResult {
name,
result,
success,
index,
} => AgentStreamEvent::ToolResult {
name,
result,
success,
index,
},
Self::ThinkingDelta { content, index } => {
AgentStreamEvent::ThinkingDelta { content, index }
}
Self::PartialOutput { output } => AgentStreamEvent::PartialOutput { output: f(output) },
Self::ResponseComplete { response } => AgentStreamEvent::ResponseComplete { response },
Self::UsageUpdate { usage } => AgentStreamEvent::UsageUpdate { usage },
Self::FinalOutput { output } => AgentStreamEvent::FinalOutput { output: f(output) },
Self::RunComplete {
run_id,
total_steps,
} => AgentStreamEvent::RunComplete {
run_id,
total_steps,
},
Self::Error {
message,
recoverable,
} => AgentStreamEvent::Error {
message,
recoverable,
},
}
}
}
impl<Output: fmt::Display> fmt::Display for AgentStreamEvent<Output> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::RunStart { run_id, .. } => write!(f, "[run_start] {}", run_id),
Self::RequestStart { step } => write!(f, "[request_start] step {}", step),
Self::TextDelta { content, .. } => write!(f, "{}", content),
Self::ToolCallStart { name, .. } => write!(f, "[tool_start] {}", name),
Self::ToolCallDelta { args_delta, .. } => write!(f, "{}", args_delta),
Self::ToolCallComplete { name, .. } => write!(f, "[tool_complete] {}", name),
Self::ToolResult { name, success, .. } => {
write!(
f,
"[tool_result] {} ({})",
name,
if *success { "ok" } else { "error" }
)
}
Self::ThinkingDelta { content, .. } => write!(f, "[thinking] {}", content),
Self::PartialOutput { output } => write!(f, "[partial] {}", output),
Self::ResponseComplete { .. } => write!(f, "[response_complete]"),
Self::UsageUpdate { .. } => write!(f, "[usage_update]"),
Self::FinalOutput { output } => write!(f, "[output] {}", output),
Self::RunComplete { run_id, .. } => write!(f, "[run_complete] {}", run_id),
Self::Error { message, .. } => write!(f, "[error] {}", message),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_creation() {
let event: AgentStreamEvent<String> = AgentStreamEvent::run_start("run-123", 0);
assert!(!event.is_terminal());
assert!(!event.is_error());
}
#[test]
fn test_text_delta() {
let event: AgentStreamEvent<String> = AgentStreamEvent::text_delta("Hello", 0);
assert_eq!(event.as_text(), Some("Hello"));
}
#[test]
fn test_terminal_events() {
let complete: AgentStreamEvent<String> = AgentStreamEvent::RunComplete {
run_id: "run-123".to_string(),
total_steps: 1,
};
assert!(complete.is_terminal());
let error: AgentStreamEvent<String> = AgentStreamEvent::error("oops", false);
assert!(error.is_terminal());
assert!(error.is_error());
}
#[test]
fn test_map_output() {
let event: AgentStreamEvent<i32> = AgentStreamEvent::FinalOutput { output: 42 };
let mapped = event.map_output(|n| n.to_string());
if let AgentStreamEvent::FinalOutput { output } = mapped {
assert_eq!(output, "42");
} else {
panic!("Expected FinalOutput");
}
}
#[test]
fn test_display() {
let event: AgentStreamEvent<String> = AgentStreamEvent::text_delta("test", 0);
assert_eq!(format!("{}", event), "test");
}
}