use super::error::ClaudeApiError;
use super::types::{ContentBlock, MessagesResponse, StopReason, Usage};
use futures_util::StreamExt;
use serde::Deserialize;
use std::pin::Pin;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
MessageStart { message: MessagesResponse },
ContentBlockStart {
index: usize,
content_block: ContentBlock,
},
ContentBlockDelta { index: usize, delta: ContentDelta },
ContentBlockStop { index: usize },
MessageDelta {
delta: MessageDeltaPayload,
#[serde(skip_serializing_if = "Option::is_none")]
usage: Option<Usage>,
},
MessageStop,
Ping,
Error { error: super::error::ApiErrorBody },
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentDelta {
TextDelta { text: String },
ThinkingDelta { thinking: String },
SignatureDelta { signature: String },
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Clone, Deserialize)]
pub struct MessageDeltaPayload {
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
}
pub struct SseParser {
stream: Pin<Box<dyn futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
buffer: String,
done: bool,
}
impl SseParser {
pub fn new(response: reqwest::Response) -> Self {
Self {
stream: Box::pin(response.bytes_stream()),
buffer: String::new(),
done: false,
}
}
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ClaudeApiError> {
if self.done {
return Ok(None);
}
loop {
if let Some(event) = self.try_parse_event()? {
return Ok(Some(event));
}
match self.stream.next().await {
Some(Ok(bytes)) => {
let text = String::from_utf8_lossy(&bytes);
self.buffer.push_str(&text);
}
Some(Err(e)) => {
self.done = true;
return Err(ClaudeApiError::Network(e));
}
None => {
self.done = true;
return self.try_parse_event();
}
}
}
}
fn try_parse_event(&mut self) -> Result<Option<StreamEvent>, ClaudeApiError> {
let boundary = match self.buffer.find("\n\n") {
Some(pos) => pos,
None => return Ok(None),
};
let raw_event = self.buffer[..boundary].to_string();
self.buffer = self.buffer[boundary + 2..].to_string();
let mut event_name = String::new();
let mut data_lines = Vec::new();
for line in raw_event.lines() {
if let Some(name) = line.strip_prefix("event: ") {
event_name = name.trim().to_string();
} else if let Some(data) = line.strip_prefix("data: ") {
data_lines.push(data);
} else if let Some(stripped) = line.strip_prefix("data:") {
data_lines.push(stripped);
}
}
if data_lines.is_empty() {
return Ok(None);
}
let data = data_lines.join("\n");
match serde_json::from_str::<StreamEvent>(&data) {
Ok(event) => Ok(Some(event)),
Err(e) => {
Err(ClaudeApiError::StreamError {
message: format!(
"Failed to parse SSE event '{}': {} (data: {})",
event_name, e, data
),
})
}
}
}
}
#[derive(Default)]
pub struct MessageAccumulator {
response: Option<MessagesResponse>,
blocks: Vec<BlockBuilder>,
}
struct BlockBuilder {
text: String,
thinking: String,
tool_id: String,
tool_name: String,
partial_json: String,
signature: String,
is_thinking: bool,
is_tool_use: bool,
}
impl BlockBuilder {
fn new_text() -> Self {
Self {
text: String::new(),
thinking: String::new(),
tool_id: String::new(),
tool_name: String::new(),
partial_json: String::new(),
signature: String::new(),
is_thinking: false,
is_tool_use: false,
}
}
fn new_thinking() -> Self {
Self {
is_thinking: true,
..Self::new_text()
}
}
fn new_tool(id: String, name: String) -> Self {
Self {
tool_id: id,
tool_name: name,
is_tool_use: true,
..Self::new_text()
}
}
fn finish(self) -> ContentBlock {
if self.is_thinking {
ContentBlock::Thinking {
thinking: self.thinking,
signature: if self.signature.is_empty() {
None
} else {
Some(self.signature)
},
}
} else if self.is_tool_use {
let input = serde_json::from_str(&self.partial_json)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
ContentBlock::ToolUse {
id: self.tool_id,
name: self.tool_name,
input,
}
} else {
ContentBlock::Text {
text: self.text,
cache_control: None,
}
}
}
}
impl MessageAccumulator {
pub fn new() -> Self {
Self::default()
}
pub fn process_event(&mut self, event: &StreamEvent) {
match event {
StreamEvent::MessageStart { message } => {
self.response = Some(message.clone());
}
StreamEvent::ContentBlockStart { content_block, .. } => match content_block {
ContentBlock::Thinking { .. } => self.blocks.push(BlockBuilder::new_thinking()),
ContentBlock::ToolUse { id, name, .. } => {
self.blocks
.push(BlockBuilder::new_tool(id.clone(), name.clone()));
}
_ => self.blocks.push(BlockBuilder::new_text()),
},
StreamEvent::ContentBlockDelta { index, delta } => {
if let Some(block) = self.blocks.get_mut(*index) {
match delta {
ContentDelta::TextDelta { text } => block.text.push_str(text),
ContentDelta::ThinkingDelta { thinking } => {
block.thinking.push_str(thinking)
}
ContentDelta::InputJsonDelta { partial_json } => {
block.partial_json.push_str(partial_json);
}
ContentDelta::SignatureDelta { signature } => {
block.signature.push_str(signature);
}
}
}
}
StreamEvent::MessageDelta { delta, usage } => {
if let Some(ref mut resp) = self.response {
resp.stop_reason = delta.stop_reason.clone();
resp.stop_sequence = delta.stop_sequence.clone();
if let Some(u) = usage {
resp.usage.output_tokens = u.output_tokens;
}
}
}
_ => {}
}
}
pub fn finish(mut self) -> Result<MessagesResponse, ClaudeApiError> {
let mut response = self
.response
.take()
.ok_or_else(|| ClaudeApiError::StreamError {
message: "Stream ended without message_start event".to_string(),
})?;
response.content = self.blocks.into_iter().map(|b| b.finish()).collect();
Ok(response)
}
}