use crate::client::error::LlmError;
use crate::client::models::{ContentBlockType, StreamEvent, Usage};
const SSE_DATA_PREFIX: &str = "data: ";
const ERROR_SSE_PARSE: &str = "SSE_PARSE_ERROR";
const MSG_UNKNOWN_ERROR: &str = "Unknown error";
#[derive(Debug)]
pub struct SseEvent {
pub data: String,
}
#[derive(Debug, Default, Clone)]
pub struct StreamState {
pub block_index: usize,
pub text_block_started: bool,
pub pending_tool_block_indices: Vec<usize>,
pub finished: bool,
}
pub fn parse_sse_chunk(buffer: &str) -> (Vec<SseEvent>, String) {
let mut events = Vec::new();
let mut current_data: Option<String> = None;
let lines: Vec<&str> = buffer.split('\n').collect();
for line in &lines {
if line.is_empty() {
if let Some(data) = current_data.take() {
events.push(SseEvent { data });
}
} else if let Some(data) = line.strip_prefix(SSE_DATA_PREFIX) {
current_data = Some(data.to_string());
}
}
let remaining = if let Some(data) = current_data {
format!("{}{}", SSE_DATA_PREFIX, data)
} else {
String::new()
};
(events, remaining)
}
pub fn parse_stream_event(
sse: &SseEvent,
state: &mut StreamState,
) -> Result<Vec<StreamEvent>, LlmError> {
let json: serde_json::Value = serde_json::from_str(&sse.data)
.map_err(|e| LlmError::new(ERROR_SSE_PARSE, format!("Invalid JSON: {}", e)))?;
if let Some(error) = json.get("error") {
let error_msg = error["message"].as_str().unwrap_or(MSG_UNKNOWN_ERROR);
return Err(LlmError::new("COHERE_ERROR", error_msg));
}
let mut events = Vec::new();
let event_type = json["type"].as_str().unwrap_or("");
match event_type {
"message-start" => {
}
"content-start" => {
if !state.text_block_started {
events.push(StreamEvent::ContentBlockStart {
index: state.block_index,
block_type: ContentBlockType::Text,
});
state.text_block_started = true;
}
}
"content-delta" => {
if let Some(delta) = json.get("delta")
&& let Some(message) = delta.get("message")
&& let Some(content) = message.get("content")
&& let Some(text) = content.get("text").and_then(|t| t.as_str())
&& !text.is_empty()
{
if !state.text_block_started {
events.push(StreamEvent::ContentBlockStart {
index: state.block_index,
block_type: ContentBlockType::Text,
});
state.text_block_started = true;
}
events.push(StreamEvent::TextDelta {
index: state.block_index,
text: text.to_string(),
});
}
}
"content-end" => {
if state.text_block_started {
events.push(StreamEvent::ContentBlockStop {
index: state.block_index,
});
state.block_index += 1;
state.text_block_started = false;
}
}
"tool-call-start" => {
if let Some(delta) = json.get("delta")
&& let Some(tool_call) = delta.get("tool_call")
{
let id = tool_call["id"].as_str().unwrap_or("").to_string();
let name = tool_call["function"]["name"]
.as_str()
.unwrap_or("")
.to_string();
if state.text_block_started {
events.push(StreamEvent::ContentBlockStop {
index: state.block_index,
});
state.block_index += 1;
state.text_block_started = false;
}
let block_idx = state.block_index + state.pending_tool_block_indices.len();
events.push(StreamEvent::ContentBlockStart {
index: block_idx,
block_type: ContentBlockType::ToolUse { id, name },
});
state.pending_tool_block_indices.push(block_idx);
}
}
"tool-call-delta" => {
if let Some(delta) = json.get("delta")
&& let Some(tool_call) = delta.get("tool_call")
&& let Some(args) = tool_call["function"]["arguments"].as_str()
&& !args.is_empty()
{
if let Some(&block_idx) = state.pending_tool_block_indices.last() {
events.push(StreamEvent::InputJsonDelta {
index: block_idx,
json: args.to_string(),
});
}
}
}
"tool-call-end" => {
if let Some(block_idx) = state.pending_tool_block_indices.pop() {
events.push(StreamEvent::ContentBlockStop { index: block_idx });
}
}
"message-end" => {
state.finished = true;
if state.text_block_started {
events.push(StreamEvent::ContentBlockStop {
index: state.block_index,
});
state.text_block_started = false;
}
for &block_idx in &state.pending_tool_block_indices {
events.push(StreamEvent::ContentBlockStop { index: block_idx });
}
state.pending_tool_block_indices.clear();
let usage = json
.get("delta")
.and_then(|d| d.get("usage"))
.map(|u| Usage {
input_tokens: u["billed_units"]["input_tokens"].as_u64().unwrap_or(0) as u32,
output_tokens: u["billed_units"]["output_tokens"].as_u64().unwrap_or(0) as u32,
});
let finish_reason = json
.get("delta")
.and_then(|d| d.get("finish_reason"))
.and_then(|r| r.as_str())
.map(|r| match r {
"COMPLETE" => "end_turn".to_string(),
"MAX_TOKENS" => "max_tokens".to_string(),
"TOOL_CALL" => "tool_use".to_string(),
other => other.to_string(),
});
events.push(StreamEvent::MessageDelta {
stop_reason: finish_reason,
usage,
});
}
_ => {
}
}
Ok(events)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_sse_chunk() {
let chunk = "data: {\"test\":true}\n\ndata: {\"test\":false}\n\n";
let (events, remaining) = parse_sse_chunk(chunk);
assert_eq!(events.len(), 2);
assert_eq!(events[0].data, "{\"test\":true}");
assert_eq!(events[1].data, "{\"test\":false}");
assert!(remaining.is_empty());
}
#[test]
fn test_parse_incomplete_chunk() {
let chunk = "data: {\"test\":true}";
let (events, remaining) = parse_sse_chunk(chunk);
assert_eq!(events.len(), 0);
assert!(remaining.contains("{\"test\":true}"));
}
#[test]
fn test_parse_text_delta() {
let data = r#"{"type":"content-delta","delta":{"message":{"content":{"text":"Hello"}}}}"#;
let sse = SseEvent {
data: data.to_string(),
};
let mut state = StreamState::default();
let events = parse_stream_event(&sse, &mut state).unwrap();
let has_text_delta = events
.iter()
.any(|e| matches!(e, StreamEvent::TextDelta { text, .. } if text == "Hello"));
assert!(has_text_delta);
}
#[test]
fn test_parse_message_end() {
let data = r#"{"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":10,"output_tokens":20}}}}"#;
let sse = SseEvent {
data: data.to_string(),
};
let mut state = StreamState::default();
let events = parse_stream_event(&sse, &mut state).unwrap();
let has_message_delta = events.iter().any(|e| {
matches!(e, StreamEvent::MessageDelta { stop_reason: Some(reason), usage: Some(u) }
if reason == "end_turn" && u.output_tokens == 20)
});
assert!(has_message_delta);
assert!(state.finished);
}
#[test]
fn test_parse_error() {
let data = r#"{"error":{"message":"Invalid API key"}}"#;
let sse = SseEvent {
data: data.to_string(),
};
let mut state = StreamState::default();
let result = parse_stream_event(&sse, &mut state);
assert!(result.is_err());
}
}