use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::agent::errors::SessionFailureKind;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
SessionStarted {
session_id: String,
},
Content {
content: String,
},
Reasoning {
content: String,
},
ToolCallStarted {
tool_name: String,
arguments: String,
},
ToolCallCompleted {
tool_name: String,
result: serde_json::Value,
success: bool,
duration_ms: f64,
},
ToolCallFailed {
tool_name: String,
error: String,
},
TurnCompleted,
UsageUpdate {
snapshot: crate::llm::usage::AggregatedUsage,
},
Done,
Error {
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
failure_kind: Option<SessionFailureKind>,
#[serde(skip_serializing_if = "Option::is_none")]
provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
http_status: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
request_payload: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
response_payload: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
provider_response_id: Option<String>,
},
}
pub trait StreamConsumer: Send + Sync {
fn on_event(&self, event: &StreamEvent) -> Result<()>;
}
pub struct MultiConsumer {
consumers: Vec<Box<dyn StreamConsumer>>,
}
impl MultiConsumer {
pub fn new() -> Self {
Self {
consumers: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, consumer: Box<dyn StreamConsumer>) -> Self {
self.consumers.push(consumer);
self
}
}
impl Default for MultiConsumer {
fn default() -> Self {
Self::new()
}
}
impl StreamConsumer for MultiConsumer {
fn on_event(&self, event: &StreamEvent) -> Result<()> {
for consumer in &self.consumers {
consumer.on_event(event)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::usage::AggregatedUsage;
struct TestConsumer {
events: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
}
impl StreamConsumer for TestConsumer {
fn on_event(&self, event: &StreamEvent) -> Result<()> {
let mut events = self.events.lock().unwrap();
events.push(format!("{:?}", event));
Ok(())
}
}
#[test]
fn test_multi_consumer() {
let events1 = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let events2 = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let consumer1 = TestConsumer {
events: events1.clone(),
};
let consumer2 = TestConsumer {
events: events2.clone(),
};
let multi = MultiConsumer::new()
.add(Box::new(consumer1))
.add(Box::new(consumer2));
let event = StreamEvent::Content {
content: "test".to_string(),
};
multi.on_event(&event).unwrap();
assert_eq!(events1.lock().unwrap().len(), 1);
assert_eq!(events2.lock().unwrap().len(), 1);
}
#[test]
fn test_multi_consumer_usage_update() {
let events = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let consumer = TestConsumer {
events: events.clone(),
};
let multi = MultiConsumer::new().add(Box::new(consumer));
let snapshot = AggregatedUsage {
total_input_tokens: 100,
total_output_tokens: 50,
total_cost_usd: 1.23,
request_count: 2,
..AggregatedUsage::default()
};
multi
.on_event(&StreamEvent::UsageUpdate { snapshot })
.expect("usage update should be forwarded");
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 1);
assert!(
captured[0].contains("UsageUpdate"),
"Expected usage update event in captured output"
);
}
#[test]
fn test_stream_event_serialization() {
let event = StreamEvent::Content {
content: "Hello".to_string(),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("content"));
assert!(json.contains("Hello"));
let deserialized: StreamEvent = serde_json::from_str(&json).unwrap();
match deserialized {
StreamEvent::Content { content } => assert_eq!(content, "Hello"),
_ => panic!("Wrong event type"),
}
}
}