use crate::tasks::generate::ToolCall;
use crate::TokenUsage;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta(String),
ToolCallStart {
name: String,
index: usize,
id: Option<String>,
},
ToolCallDelta {
index: usize,
arguments_delta: String,
},
Usage {
input_tokens: u64,
output_tokens: u64,
},
Done {
text: String,
tool_calls: Vec<ToolCall>,
},
}
pub fn parse_openai_sse_line(line: &str) -> Vec<StreamEvent> {
let data = match line.strip_prefix("data: ") {
Some(d) => d,
None => return Vec::new(),
};
if data == "[DONE]" {
return Vec::new();
}
let json: serde_json::Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let mut events = Vec::new();
if let Some(delta) = json
.get("choices")
.and_then(|c| c.as_array())
.and_then(|c| c.first())
.and_then(|c| c.get("delta"))
{
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
if !content.is_empty() {
events.push(StreamEvent::TextDelta(content.to_string()));
}
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
for tc in tool_calls {
let index = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
if let Some(function) = tc.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
events.push(StreamEvent::ToolCallStart {
name: name.to_string(),
index,
id,
});
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
if !args.is_empty() {
events.push(StreamEvent::ToolCallDelta {
index,
arguments_delta: args.to_string(),
});
}
}
}
}
}
}
if let Some(usage) = json.get("usage") {
let input = usage
.get("prompt_tokens")
.and_then(|n| n.as_u64())
.unwrap_or(0);
let output = usage
.get("completion_tokens")
.and_then(|n| n.as_u64())
.unwrap_or(0);
if input != 0 || output != 0 {
events.push(StreamEvent::Usage {
input_tokens: input,
output_tokens: output,
});
}
}
events
}
pub fn parse_anthropic_sse_line(event_type: &str, data: &str) -> Vec<StreamEvent> {
match event_type {
"content_block_delta" => {
let json: serde_json::Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let delta = match json.get("delta") {
Some(d) => d,
None => return Vec::new(),
};
let delta_type = match delta.get("type").and_then(|t| t.as_str()) {
Some(t) => t,
None => return Vec::new(),
};
match delta_type {
"text_delta" => match delta.get("text").and_then(|t| t.as_str()) {
Some(text) => vec![StreamEvent::TextDelta(text.to_string())],
None => Vec::new(),
},
"input_json_delta" => match delta.get("partial_json").and_then(|p| p.as_str()) {
Some(partial) => {
let index =
json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
vec![StreamEvent::ToolCallDelta {
index,
arguments_delta: partial.to_string(),
}]
}
None => Vec::new(),
},
_ => Vec::new(),
}
}
"content_block_start" => {
let json: serde_json::Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let block = match json.get("content_block") {
Some(b) => b,
None => return Vec::new(),
};
if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
if let Some(name) = block.get("name").and_then(|n| n.as_str()) {
let index = json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
let id = block
.get("id")
.and_then(|i| i.as_str())
.map(|s| s.to_string());
return vec![StreamEvent::ToolCallStart {
name: name.to_string(),
index,
id,
}];
}
}
Vec::new()
}
"message_start" => {
let json: serde_json::Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let Some(usage) = json.pointer("/message/usage") else {
return Vec::new();
};
let input = usage
.get("input_tokens")
.and_then(|n| n.as_u64())
.unwrap_or(0);
let output = usage
.get("output_tokens")
.and_then(|n| n.as_u64())
.unwrap_or(0);
if input == 0 && output == 0 {
return Vec::new();
}
vec![StreamEvent::Usage {
input_tokens: input,
output_tokens: output,
}]
}
"message_delta" => {
let json: serde_json::Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let Some(usage) = json.get("usage") else {
return Vec::new();
};
let input = usage
.get("input_tokens")
.and_then(|n| n.as_u64())
.unwrap_or(0);
let output = usage
.get("output_tokens")
.and_then(|n| n.as_u64())
.unwrap_or(0);
if input == 0 && output == 0 {
return Vec::new();
}
vec![StreamEvent::Usage {
input_tokens: input,
output_tokens: output,
}]
}
_ => Vec::new(),
}
}
#[derive(Default)]
pub struct StreamAccumulator {
pub text: String,
tool_names: HashMap<usize, String>,
tool_args: HashMap<usize, String>,
tool_ids: HashMap<usize, String>,
input_tokens: u64,
output_tokens: u64,
saw_usage: bool,
}
impl StreamAccumulator {
pub fn push(&mut self, event: &StreamEvent) {
match event {
StreamEvent::TextDelta(t) => self.text.push_str(t),
StreamEvent::ToolCallStart { name, index, id } => {
self.tool_names.insert(*index, name.clone());
self.tool_args.entry(*index).or_default();
if let Some(id) = id {
self.tool_ids.insert(*index, id.clone());
}
}
StreamEvent::ToolCallDelta {
index,
arguments_delta,
} => {
self.tool_args
.entry(*index)
.or_default()
.push_str(arguments_delta);
}
StreamEvent::Usage {
input_tokens,
output_tokens,
} => {
self.saw_usage = true;
if *input_tokens > self.input_tokens {
self.input_tokens = *input_tokens;
}
if *output_tokens > self.output_tokens {
self.output_tokens = *output_tokens;
}
}
StreamEvent::Done { .. } => {}
}
}
pub fn finish(self) -> (String, Vec<ToolCall>) {
let (text, tool_calls, _) = self.finish_with_usage();
(text, tool_calls)
}
pub fn finish_with_usage(self) -> (String, Vec<ToolCall>, Option<TokenUsage>) {
let mut tool_calls = Vec::new();
let mut indices: Vec<usize> = self.tool_names.keys().copied().collect();
indices.sort();
for idx in indices {
let id = self.tool_ids.get(&idx).cloned();
let name = self.tool_names.get(&idx).cloned().unwrap_or_default();
let args_str = self.tool_args.get(&idx).cloned().unwrap_or_default();
let arguments: HashMap<String, serde_json::Value> =
serde_json::from_str(&args_str).unwrap_or_default();
tool_calls.push(ToolCall {
id,
name,
arguments,
});
}
let usage = if self.saw_usage {
Some(TokenUsage {
prompt_tokens: self.input_tokens,
completion_tokens: self.output_tokens,
total_tokens: self.input_tokens + self.output_tokens,
context_window: 0,
})
} else {
None
};
(self.text, tool_calls, usage)
}
}
pub fn parse_sse_lines(chunk: &str) -> Vec<(String, String)> {
let mut events = Vec::new();
let mut current_event = String::new();
let mut current_data = String::new();
for line in chunk.lines() {
if line.starts_with("event: ") {
current_event = line[7..].to_string();
} else if line.starts_with("data: ") {
current_data = line[6..].to_string();
} else if line.is_empty() && !current_data.is_empty() {
events.push((
if current_event.is_empty() {
"message".to_string()
} else {
current_event.clone()
},
current_data.clone(),
));
current_event.clear();
current_data.clear();
}
}
if !current_data.is_empty() {
events.push((
if current_event.is_empty() {
"message".to_string()
} else {
current_event
},
current_data,
));
}
events
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_openai_text_delta() {
let line = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#;
let events = parse_openai_sse_line(line);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::TextDelta(t) => assert_eq!(t, "Hello"),
other => panic!("expected TextDelta, got {:?}", other),
}
}
#[test]
fn parse_openai_tool_call_start() {
let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"edit_file"}}]}}]}"#;
let events = parse_openai_sse_line(line);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::ToolCallStart { name, index, .. } => {
assert_eq!(name, "edit_file");
assert_eq!(*index, 0);
}
other => panic!("expected ToolCallStart, got {:?}", other),
}
}
#[test]
fn parse_openai_tool_call_delta() {
let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":"}}]}}]}"#;
let events = parse_openai_sse_line(line);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::ToolCallDelta {
index,
arguments_delta,
} => {
assert_eq!(*index, 0);
assert!(arguments_delta.contains("path"));
}
other => panic!("expected ToolCallDelta, got {:?}", other),
}
}
#[test]
fn parse_openai_multiple_tool_calls_in_chunk() {
let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"read_file"}},{"index":1,"function":{"name":"search"}}]}}]}"#;
let events = parse_openai_sse_line(line);
assert_eq!(events.len(), 2);
match &events[0] {
StreamEvent::ToolCallStart { name, index, .. } => {
assert_eq!(name, "read_file");
assert_eq!(*index, 0);
}
other => panic!("expected ToolCallStart, got {:?}", other),
}
match &events[1] {
StreamEvent::ToolCallStart { name, index, .. } => {
assert_eq!(name, "search");
assert_eq!(*index, 1);
}
other => panic!("expected ToolCallStart, got {:?}", other),
}
}
#[test]
fn parse_openai_done() {
assert!(parse_openai_sse_line("data: [DONE]").is_empty());
}
#[test]
fn parse_anthropic_text_delta() {
let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"world"}}"#;
let events = parse_anthropic_sse_line("content_block_delta", data);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::TextDelta(t) => assert_eq!(t, "world"),
other => panic!("expected TextDelta, got {:?}", other),
}
}
#[test]
fn parse_anthropic_tool_start() {
let data = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"t1","name":"search","input":{}}}"#;
let events = parse_anthropic_sse_line("content_block_start", data);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::ToolCallStart { name, index, .. } => {
assert_eq!(name, "search");
assert_eq!(*index, 1);
}
other => panic!("expected ToolCallStart, got {:?}", other),
}
}
#[test]
fn accumulator_builds_result() {
let mut acc = StreamAccumulator::default();
acc.push(&StreamEvent::TextDelta("Hello ".into()));
acc.push(&StreamEvent::TextDelta("world".into()));
acc.push(&StreamEvent::ToolCallStart {
name: "search".into(),
index: 0,
id: None,
});
acc.push(&StreamEvent::ToolCallDelta {
index: 0,
arguments_delta: r#"{"q":"test"}"#.into(),
});
let (text, tools) = acc.finish();
assert_eq!(text, "Hello world");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "search");
assert!(tools[0].arguments.contains_key("q"));
}
#[test]
fn parse_sse_lines_openai_format() {
let chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n";
let events = parse_sse_lines(chunk);
assert_eq!(events.len(), 2);
assert_eq!(events[0].0, "message");
assert_eq!(events[1].1, "[DONE]");
}
#[test]
fn parse_sse_lines_anthropic_format() {
let chunk = "event: content_block_delta\ndata: {\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n";
let events = parse_sse_lines(chunk);
assert_eq!(events.len(), 1);
assert_eq!(events[0].0, "content_block_delta");
}
#[test]
fn parse_anthropic_message_start_emits_usage() {
let data = r#"{"type":"message_start","message":{"id":"msg_1","role":"assistant","usage":{"input_tokens":245,"output_tokens":1}}}"#;
let events = parse_anthropic_sse_line("message_start", data);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::Usage {
input_tokens,
output_tokens,
} => {
assert_eq!(*input_tokens, 245);
assert_eq!(*output_tokens, 1);
}
other => panic!("expected Usage, got {:?}", other),
}
}
#[test]
fn parse_anthropic_message_delta_emits_usage() {
let data = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":87}}"#;
let events = parse_anthropic_sse_line("message_delta", data);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::Usage {
input_tokens,
output_tokens,
} => {
assert_eq!(*input_tokens, 0);
assert_eq!(*output_tokens, 87);
}
other => panic!("expected Usage, got {:?}", other),
}
}
#[test]
fn parse_anthropic_message_start_without_usage_is_empty() {
let data = r#"{"type":"message_start","message":{"id":"msg_1"}}"#;
assert!(parse_anthropic_sse_line("message_start", data).is_empty());
}
#[test]
fn accumulator_tracks_usage_across_anthropic_stream() {
let mut acc = StreamAccumulator::default();
for event in parse_anthropic_sse_line(
"message_start",
r#"{"message":{"usage":{"input_tokens":245,"output_tokens":1}}}"#,
) {
acc.push(&event);
}
for event in parse_anthropic_sse_line(
"content_block_start",
r#"{"index":0,"content_block":{"type":"text","text":""}}"#,
) {
acc.push(&event);
}
for (chunk, _) in [
(r#"{"delta":{"type":"text_delta","text":"Hello"}}"#, ()),
(r#"{"delta":{"type":"text_delta","text":", "}}"#, ()),
(r#"{"delta":{"type":"text_delta","text":"world"}}"#, ()),
] {
for event in parse_anthropic_sse_line("content_block_delta", chunk) {
acc.push(&event);
}
}
for event in parse_anthropic_sse_line("message_delta", r#"{"usage":{"output_tokens":87}}"#)
{
acc.push(&event);
}
let (text, tools, usage) = acc.finish_with_usage();
assert_eq!(text, "Hello, world");
assert!(tools.is_empty());
let usage = usage.expect("provider reported usage; must surface");
assert_eq!(usage.prompt_tokens, 245);
assert_eq!(usage.completion_tokens, 87);
assert_eq!(usage.total_tokens, 332);
}
#[test]
fn parse_openai_final_chunk_emits_usage() {
let line = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87,"total_tokens":332}}"#;
let events = parse_openai_sse_line(line);
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::Usage {
input_tokens,
output_tokens,
} => {
assert_eq!(*input_tokens, 245);
assert_eq!(*output_tokens, 87);
}
other => panic!("expected Usage, got {:?}", other),
}
}
#[test]
fn accumulator_tracks_usage_across_openai_stream() {
let mut acc = StreamAccumulator::default();
for line in [
r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#,
r#"data: {"choices":[{"delta":{"content":", "}}]}"#,
r#"data: {"choices":[{"delta":{"content":"world"}}]}"#,
r#"data: {"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87}}"#,
] {
for event in parse_openai_sse_line(line) {
acc.push(&event);
}
}
let (text, tools, usage) = acc.finish_with_usage();
assert_eq!(text, "Hello, world");
assert!(tools.is_empty());
let usage = usage.expect("provider reported usage; must surface");
assert_eq!(usage.prompt_tokens, 245);
assert_eq!(usage.completion_tokens, 87);
assert_eq!(usage.total_tokens, 332);
}
#[test]
fn accumulator_returns_no_usage_when_provider_silent() {
let mut acc = StreamAccumulator::default();
acc.push(&StreamEvent::TextDelta("hi".into()));
let (_, _, usage) = acc.finish_with_usage();
assert!(usage.is_none());
}
}