use crate::common::Usage;
use crate::common::errors::{ErrorDetail, Result};
use crate::messages::request::content::ContentBlock;
use crate::messages::request::model::Model;
use crate::messages::response::Response;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum StreamEvent {
#[serde(rename = "message_start")]
MessageStart { message: Response },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: ContentBlock,
},
#[serde(rename = "ping")]
Ping,
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: Delta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta { delta: MessageDelta, usage: Usage },
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "error")]
Error { error: ErrorDetail },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum Delta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
#[serde(rename = "thinking_delta")]
ThinkingDelta { thinking: String },
#[serde(rename = "signature_delta")]
SignatureDelta { signature: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
}
const SSE_DATA_PREFIX: &str = "data: ";
const SSE_EVENT_PREFIX: &str = "event: ";
pub fn parse_sse_line(line: &str) -> Result<Option<StreamEvent>> {
if line.trim().is_empty() {
return Ok(None);
}
if line.starts_with(SSE_EVENT_PREFIX) {
return Ok(None);
}
if let Some(data) = line.strip_prefix(SSE_DATA_PREFIX) {
if data.trim() == "[DONE]" {
return Ok(None);
}
let event: StreamEvent = serde_json::from_str(data)?;
return Ok(Some(event));
}
Ok(None)
}
#[derive(Debug, Default)]
pub struct StreamAccumulator {
pub text: String,
pub tool_inputs: std::collections::HashMap<String, String>,
pub thinking: String,
pub content_blocks: Vec<ContentBlock>,
pub usage: Option<Usage>,
pub stop_reason: Option<String>,
pub model: Option<Model>,
pub id: Option<String>,
}
impl StreamAccumulator {
pub fn new() -> Self {
StreamAccumulator::default()
}
pub fn process_event(&mut self, event: StreamEvent) {
match event {
StreamEvent::MessageStart { message } => {
self.id = Some(message.id);
self.model = Some(message.model);
}
StreamEvent::ContentBlockStart {
content_block,
index,
} => {
while self.content_blocks.len() <= index {
self.content_blocks.push(ContentBlock::Text {
text: String::new(),
cache_control: None,
});
}
self.content_blocks[index] = content_block;
}
StreamEvent::ContentBlockDelta { index, delta } => match delta {
Delta::TextDelta { text } => {
self.text.push_str(&text);
if let Some(ContentBlock::Text {
text: block_text, ..
}) = self.content_blocks.get_mut(index)
{
block_text.push_str(&text);
}
}
Delta::InputJsonDelta { partial_json } => {
if let Some(ContentBlock::ToolUse { id, .. }) = self.content_blocks.get(index) {
self.tool_inputs
.entry(id.clone())
.or_default()
.push_str(&partial_json);
}
}
Delta::ThinkingDelta { thinking } => {
self.thinking.push_str(&thinking);
}
Delta::SignatureDelta { .. } => {
}
},
StreamEvent::ContentBlockStop { .. } => {
}
StreamEvent::MessageDelta { delta, usage } => {
self.stop_reason = delta.stop_reason;
self.usage = Some(usage);
}
StreamEvent::MessageStop => {
}
StreamEvent::Ping => {
}
StreamEvent::Error { .. } => {
}
}
}
pub fn get_text(&self) -> &str {
&self.text
}
pub fn is_complete(&self) -> bool {
self.stop_reason.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_text_delta() {
let line = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
let event = parse_sse_line(line).unwrap().unwrap();
match event {
StreamEvent::ContentBlockDelta { index, delta } => {
assert_eq!(index, 0);
match delta {
Delta::TextDelta { text } => assert_eq!(text, "Hello"),
_ => panic!("Expected TextDelta"),
}
}
_ => panic!("Expected ContentBlockDelta"),
}
}
#[test]
fn test_parse_message_stop() {
let line = r#"data: {"type":"message_stop"}"#;
let event = parse_sse_line(line).unwrap().unwrap();
assert!(matches!(event, StreamEvent::MessageStop));
}
#[test]
fn test_parse_ping() {
let line = r#"data: {"type":"ping"}"#;
let event = parse_sse_line(line).unwrap().unwrap();
assert!(matches!(event, StreamEvent::Ping));
}
#[test]
fn test_parse_done() {
let line = "data: [DONE]";
let result = parse_sse_line(line).unwrap();
assert!(result.is_none());
}
#[test]
fn test_parse_empty_line() {
let line = "";
let result = parse_sse_line(line).unwrap();
assert!(result.is_none());
}
#[test]
fn test_parse_event_line() {
let line = "event: message_start";
let result = parse_sse_line(line).unwrap();
assert!(result.is_none());
}
#[test]
fn test_accumulator_text() {
let mut acc = StreamAccumulator::new();
acc.process_event(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text {
text: String::new(),
cache_control: None,
},
});
acc.process_event(StreamEvent::ContentBlockDelta {
index: 0,
delta: Delta::TextDelta {
text: "Hello ".to_string(),
},
});
acc.process_event(StreamEvent::ContentBlockDelta {
index: 0,
delta: Delta::TextDelta {
text: "world!".to_string(),
},
});
assert_eq!(acc.get_text(), "Hello world!");
}
#[test]
fn test_accumulator_complete() {
let mut acc = StreamAccumulator::new();
assert!(!acc.is_complete());
acc.process_event(StreamEvent::MessageDelta {
delta: MessageDelta {
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
},
usage: Usage::new(10, 5),
});
assert!(acc.is_complete());
assert!(acc.usage.is_some());
}
}