use std::collections::{HashMap, VecDeque};
use std::time::Instant;
use genai::chat::{ChatStreamEvent, StreamEnd};
use serde_json::Value;
use awaken_contract::contract::executor::InterruptSnapshot;
use awaken_contract::contract::inference::{StreamResult, TokenUsage};
use awaken_contract::contract::message::ToolCall;
use super::convert::{map_stop_reason, map_usage};
pub enum StreamOutput {
TextDelta(String),
ReasoningDelta(String),
ToolCallStart { id: String, name: String },
ToolCallDelta { id: String, args_delta: String },
None,
}
pub struct StreamCollector {
text: String,
tool_calls: HashMap<String, PartialToolCall>,
tool_call_order: Vec<String>,
usage: Option<TokenUsage>,
stop_reason: Option<awaken_contract::contract::inference::StopReason>,
pending_outputs: VecDeque<StreamOutput>,
end_seen: bool,
last_delta_at: Option<Instant>,
bytes_received: usize,
}
struct PartialToolCall {
id: String,
name: String,
arguments: String,
start_emitted: bool,
}
impl Default for StreamCollector {
fn default() -> Self {
Self::new()
}
}
impl StreamCollector {
pub fn new() -> Self {
Self {
text: String::new(),
end_seen: false,
tool_calls: HashMap::new(),
tool_call_order: Vec::new(),
usage: None,
stop_reason: None,
pending_outputs: VecDeque::new(),
last_delta_at: None,
bytes_received: 0,
}
}
pub fn last_delta_at(&self) -> Option<Instant> {
self.last_delta_at
}
pub fn text(&self) -> &str {
&self.text
}
pub fn interrupt_snapshot(&self) -> InterruptSnapshot {
let text = (!self.text.is_empty()).then(|| self.text.clone());
let partials = self.tool_call_order.iter().filter_map(|id| {
self.tool_calls
.get(id)
.map(|p| (p.id.clone(), p.name.clone(), p.arguments.clone()))
});
InterruptSnapshot::from_partials(text, partials, self.bytes_received)
}
pub fn take_pending_output(&mut self) -> Option<StreamOutput> {
self.pending_outputs.pop_front()
}
pub fn process(&mut self, event: ChatStreamEvent) -> StreamOutput {
match event {
ChatStreamEvent::Start => StreamOutput::None,
ChatStreamEvent::Chunk(chunk) => {
self.bytes_received += chunk.content.len();
self.last_delta_at = Some(Instant::now());
self.text.push_str(&chunk.content);
StreamOutput::TextDelta(chunk.content)
}
ChatStreamEvent::ReasoningChunk(chunk) => {
self.bytes_received += chunk.content.len();
self.last_delta_at = Some(Instant::now());
StreamOutput::ReasoningDelta(chunk.content)
}
ChatStreamEvent::ThoughtSignatureChunk(_) => StreamOutput::None,
ChatStreamEvent::ToolCallChunk(tool_chunk) => {
self.last_delta_at = Some(Instant::now());
let call = &tool_chunk.tool_call;
let id = call.call_id.clone();
let existing = self.tool_calls.get(&id);
let prev_args_len = existing.map(|e| e.arguments.len()).unwrap_or(0);
let entry = self.tool_calls.entry(id.clone()).or_insert_with(|| {
self.tool_call_order.push(id.clone());
PartialToolCall {
id: id.clone(),
name: String::new(),
arguments: String::new(),
start_emitted: false,
}
});
if !call.fn_name.is_empty() {
entry.name = call.fn_name.clone();
}
let args_str = match &call.fn_arguments {
Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_default(),
};
let delta = if args_str.len() > prev_args_len {
args_str[prev_args_len..].to_string()
} else {
String::new()
};
entry.arguments = args_str;
self.bytes_received += delta.len();
let should_emit_start = !entry.start_emitted && !entry.name.is_empty();
if should_emit_start {
entry.start_emitted = true;
if tool_args_have_content(&call.fn_arguments) {
self.pending_outputs.push_back(StreamOutput::ToolCallDelta {
id: id.clone(),
args_delta: entry.arguments.clone(),
});
}
StreamOutput::ToolCallStart {
id,
name: entry.name.clone(),
}
} else if entry.start_emitted && !delta.is_empty() {
StreamOutput::ToolCallDelta {
id,
args_delta: delta,
}
} else {
StreamOutput::None
}
}
ChatStreamEvent::End(end) => {
self.apply_end(end);
StreamOutput::None
}
}
}
pub fn take_usage(&mut self) -> Option<TokenUsage> {
self.usage.take()
}
pub fn end_seen(&self) -> bool {
self.end_seen
}
fn apply_end(&mut self, end: StreamEnd) {
self.end_seen = true;
if let Some(ref usage) = end.captured_usage {
self.usage = Some(map_usage(usage));
}
if let Some(ref reason) = end.captured_stop_reason {
self.stop_reason = map_stop_reason(reason);
}
if let Some(captured) = end.captured_tool_calls() {
self.tool_calls.clear();
self.tool_call_order.clear();
for call in captured {
let id = call.call_id.clone();
self.tool_call_order.push(id.clone());
self.tool_calls.insert(
id.clone(),
PartialToolCall {
id,
name: call.fn_name.clone(),
arguments: serde_json::to_string(&call.fn_arguments).unwrap_or_default(),
start_emitted: true,
},
);
}
}
if let Some(text) = end.captured_first_text() {
self.text = text.to_string();
}
}
pub fn finish(self) -> StreamResult {
let mut tool_calls: Vec<ToolCall> = Vec::with_capacity(self.tool_call_order.len());
let mut has_incomplete_tool_calls = false;
let mut remaining = self.tool_calls;
for call_id in &self.tool_call_order {
let Some(p) = remaining.remove(call_id) else {
continue;
};
if p.name.is_empty() {
has_incomplete_tool_calls = true;
continue;
}
let arguments: Value = serde_json::from_str(&p.arguments).unwrap_or(Value::Null);
if arguments.is_null() && !p.arguments.is_empty() {
has_incomplete_tool_calls = true;
continue;
}
tool_calls.push(ToolCall::new(p.id, p.name, arguments));
}
let content = if self.text.is_empty() {
vec![]
} else {
vec![awaken_contract::contract::content::ContentBlock::text(
self.text,
)]
};
StreamResult {
content,
tool_calls,
usage: self.usage,
stop_reason: self.stop_reason,
has_incomplete_tool_calls,
}
}
}
fn tool_args_have_content(arguments: &Value) -> bool {
match arguments {
Value::Null => false,
Value::String(s) => !s.is_empty(),
Value::Array(items) => !items.is_empty(),
Value::Object(fields) => !fields.is_empty(),
_ => true,
}
}
#[cfg(test)]
mod tests {
use super::*;
use genai::chat::{StreamChunk, ToolChunk};
#[test]
fn text_chunks_accumulate() {
let mut c = StreamCollector::new();
let o1 = c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "Hello ".into(),
}));
assert!(matches!(o1, StreamOutput::TextDelta(ref s) if s == "Hello "));
let o2 = c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "world".into(),
}));
assert!(matches!(o2, StreamOutput::TextDelta(ref s) if s == "world"));
let result = c.finish();
assert_eq!(result.text(), "Hello world");
}
#[test]
fn tool_call_chunks_accumulate() {
use genai::chat::{ToolCall as GToolCall, ToolChunk};
let mut c = StreamCollector::new();
let o1 = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "c1".into(),
fn_name: "search".into(),
fn_arguments: serde_json::json!({}),
thought_signatures: None,
},
}));
assert!(
matches!(o1, StreamOutput::ToolCallStart { ref id, ref name } if id == "c1" && name == "search")
);
let result = c.finish();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "search");
}
#[test]
fn truncated_tool_call_json_dropped() {
let mut c = StreamCollector::new();
c.tool_call_order.push("c1".into());
c.tool_calls.insert(
"c1".into(),
super::PartialToolCall {
id: "c1".into(),
name: "search".into(),
arguments: r#"{"query": "rust"#.into(), start_emitted: true,
},
);
let result = c.finish();
assert!(result.tool_calls.is_empty());
}
#[test]
fn captured_end_overrides_streamed_text() {
use genai::chat::{StreamChunk, StreamEnd};
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "partial".into(),
}));
let end = StreamEnd {
captured_content: Some(genai::chat::MessageContent::from_text("final text")),
..Default::default()
};
c.process(ChatStreamEvent::End(end));
let result = c.finish();
assert_eq!(result.text(), "final text");
}
#[test]
fn usage_mapped_from_end_event() {
use genai::chat::{StreamEnd, Usage};
let mut c = StreamCollector::new();
let end = StreamEnd {
captured_usage: Some(Usage {
prompt_tokens: Some(100),
completion_tokens: Some(50),
total_tokens: Some(150),
..Default::default()
}),
..Default::default()
};
c.process(ChatStreamEvent::End(end));
let result = c.finish();
let usage = result.usage.unwrap();
assert_eq!(usage.prompt_tokens, Some(100));
assert_eq!(usage.completion_tokens, Some(50));
assert_eq!(usage.total_tokens, Some(150));
}
#[test]
fn finish_with_no_events_returns_empty() {
let c = StreamCollector::new();
let result = c.finish();
assert!(result.content.is_empty());
assert!(result.tool_calls.is_empty());
assert!(result.usage.is_none());
}
#[test]
fn collector_accumulates_text_deltas() {
let mut c = StreamCollector::new();
let chunks = ["The ", "quick ", "brown ", "fox"];
for chunk in &chunks {
let out = c.process(ChatStreamEvent::Chunk(StreamChunk {
content: (*chunk).into(),
}));
assert!(matches!(out, StreamOutput::TextDelta(ref s) if s == chunk));
}
let result = c.finish();
assert_eq!(result.text(), "The quick brown fox");
assert_eq!(result.content.len(), 1);
}
#[test]
fn collector_tracks_tool_call_start_and_delta() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
let o1 = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "get_weather".into(),
fn_arguments: serde_json::json!({}),
thought_signatures: None,
},
}));
assert!(
matches!(o1, StreamOutput::ToolCallStart { ref id, ref name }
if id == "tc1" && name == "get_weather")
);
let o2 = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "get_weather".into(),
fn_arguments: serde_json::json!({"city": "London"}),
thought_signatures: None,
},
}));
assert!(matches!(o2, StreamOutput::ToolCallDelta { ref id, .. } if id == "tc1"));
let result = c.finish();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "get_weather");
assert_eq!(result.tool_calls[0].arguments["city"], "London");
}
#[test]
fn collector_handles_usage_event() {
use genai::chat::{StreamEnd, Usage};
let mut c = StreamCollector::new();
let end = StreamEnd {
captured_usage: Some(Usage {
prompt_tokens: Some(200),
completion_tokens: Some(80),
total_tokens: Some(280),
prompt_tokens_details: None,
completion_tokens_details: None,
}),
..Default::default()
};
c.process(ChatStreamEvent::End(end));
let result = c.finish();
let usage = result.usage.expect("usage should be present");
assert_eq!(usage.prompt_tokens, Some(200));
assert_eq!(usage.completion_tokens, Some(80));
assert_eq!(usage.total_tokens, Some(280));
assert!(usage.cache_read_tokens.is_none());
assert!(usage.thinking_tokens.is_none());
}
#[test]
fn collector_handles_stop_event() {
use genai::chat::StreamEnd;
let mut c = StreamCollector::new();
let end = StreamEnd {
captured_stop_reason: Some(genai::chat::StopReason::Completed("stop".into())),
..Default::default()
};
c.process(ChatStreamEvent::End(end));
let result = c.finish();
assert_eq!(
result.stop_reason,
Some(awaken_contract::contract::inference::StopReason::EndTurn)
);
}
#[test]
fn collector_drops_truncated_tool_calls() {
let mut c = StreamCollector::new();
c.tool_call_order.push("valid".into());
c.tool_calls.insert(
"valid".into(),
PartialToolCall {
id: "valid".into(),
name: "search".into(),
arguments: r#"{"q":"hello"}"#.into(),
start_emitted: true,
},
);
c.tool_call_order.push("bad".into());
c.tool_calls.insert(
"bad".into(),
PartialToolCall {
id: "bad".into(),
name: "calc".into(),
arguments: r#"{"expr": "2+"#.into(), start_emitted: true,
},
);
let result = c.finish();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].id, "valid");
}
#[test]
fn collector_preserves_tool_call_order() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
let tool_names = ["alpha", "beta", "gamma"];
for (i, name) in tool_names.iter().enumerate() {
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: format!("c{i}"),
fn_name: (*name).into(),
fn_arguments: serde_json::json!({}),
thought_signatures: None,
},
}));
}
let result = c.finish();
assert_eq!(result.tool_calls.len(), 3);
assert_eq!(result.tool_calls[0].name, "alpha");
assert_eq!(result.tool_calls[1].name, "beta");
assert_eq!(result.tool_calls[2].name, "gamma");
}
#[test]
fn collector_nameless_tool_call_dropped() {
let mut c = StreamCollector::new();
c.tool_call_order.push("c1".into());
c.tool_calls.insert(
"c1".into(),
PartialToolCall {
id: "c1".into(),
name: String::new(), arguments: r#"{"x":1}"#.into(),
start_emitted: false,
},
);
let result = c.finish();
assert!(result.tool_calls.is_empty());
}
#[test]
fn collector_reasoning_delta_not_accumulated_in_text() {
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "visible".into(),
}));
let out = c.process(ChatStreamEvent::ReasoningChunk(StreamChunk {
content: "thinking...".into(),
}));
assert!(matches!(out, StreamOutput::ReasoningDelta(ref s) if s == "thinking..."));
let result = c.finish();
assert_eq!(result.text(), "visible");
}
#[test]
fn collector_start_event_emits_none() {
let mut c = StreamCollector::new();
let out = c.process(ChatStreamEvent::Start);
assert!(matches!(out, StreamOutput::None));
}
#[test]
fn collector_end_captured_tool_calls_override_streamed() {
use genai::chat::{MessageContent, StreamEnd, ToolCall as GToolCall};
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "streamed".into(),
fn_name: "old_tool".into(),
fn_arguments: serde_json::json!({}),
thought_signatures: None,
},
}));
let captured_call = GToolCall {
call_id: "captured".into(),
fn_name: "new_tool".into(),
fn_arguments: serde_json::json!({"a": 1}),
thought_signatures: None,
};
let end = StreamEnd {
captured_content: Some(MessageContent::from_tool_calls(vec![captured_call])),
..Default::default()
};
c.process(ChatStreamEvent::End(end));
let result = c.finish();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "new_tool");
assert_eq!(result.tool_calls[0].id, "captured");
}
#[test]
fn collector_handles_string_valued_accumulated_tool_args() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
let o1 = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "calculator".into(),
fn_arguments: Value::String(String::new()),
thought_signatures: None,
},
}));
assert!(matches!(o1, StreamOutput::ToolCallStart { ref name, .. } if name == "calculator"));
let o2 = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "calculator".into(),
fn_arguments: Value::String(r#"{"ex"#.into()),
thought_signatures: None,
},
}));
assert!(matches!(o2, StreamOutput::ToolCallDelta { ref id, .. } if id == "tc1"));
let o3 = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "calculator".into(),
fn_arguments: Value::String(r#"{"expression":"137*42"}"#.into()),
thought_signatures: None,
},
}));
assert!(matches!(o3, StreamOutput::ToolCallDelta { ref id, .. } if id == "tc1"));
let result = c.finish();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "calculator");
assert_eq!(result.tool_calls[0].arguments["expression"], "137*42");
}
#[test]
fn collector_emits_initial_object_args_after_tool_call_start() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
let start = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "render_openui_ui".into(),
fn_arguments: serde_json::json!({"prompt": "build a dashboard"}),
thought_signatures: None,
},
}));
assert!(matches!(
start,
StreamOutput::ToolCallStart { ref id, ref name }
if id == "tc1" && name == "render_openui_ui"
));
let delta = c
.take_pending_output()
.expect("initial tool arguments should be preserved");
assert!(matches!(
delta,
StreamOutput::ToolCallDelta {
ref id,
ref args_delta
} if id == "tc1" && args_delta == r#"{"prompt":"build a dashboard"}"#
));
assert!(c.take_pending_output().is_none());
let result = c.finish();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "render_openui_ui");
assert_eq!(
result.tool_calls[0].arguments["prompt"],
"build a dashboard"
);
}
#[test]
fn interrupt_snapshot_empty_when_nothing_received() {
let c = StreamCollector::new();
let snap = c.interrupt_snapshot();
assert!(snap.text.is_none());
assert!(snap.completed_tool_calls.is_empty());
assert!(snap.in_flight_tool.is_none());
assert_eq!(snap.bytes_received, 0);
}
#[test]
fn interrupt_snapshot_captures_text_only() {
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "hello ".into(),
}));
c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "world".into(),
}));
let snap = c.interrupt_snapshot();
assert_eq!(snap.text.as_deref(), Some("hello world"));
assert!(snap.completed_tool_calls.is_empty());
assert!(snap.in_flight_tool.is_none());
assert_eq!(snap.bytes_received, "hello world".len());
assert!(c.last_delta_at().is_some());
}
#[test]
fn interrupt_snapshot_captures_completed_and_in_flight_tool() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "c1".into(),
fn_name: "search".into(),
fn_arguments: serde_json::json!({"q": "rust"}),
thought_signatures: None,
},
}));
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "c2".into(),
fn_name: "fetch".into(),
fn_arguments: Value::String(r#"{"url":"#.into()),
thought_signatures: None,
},
}));
let snap = c.interrupt_snapshot();
assert_eq!(snap.completed_tool_calls.len(), 1);
assert_eq!(snap.completed_tool_calls[0].name, "search");
let p = snap.in_flight_tool.expect("expected in-flight tool");
assert_eq!(p.id, "c2");
assert_eq!(p.name, "fetch");
assert_eq!(p.partial_args, r#"{"url":"#);
}
#[test]
fn duplicate_tool_use_id_is_deduplicated_to_single_entry() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "dup-1".into(),
fn_name: "search".into(),
fn_arguments: serde_json::json!({"q": "a"}),
thought_signatures: None,
},
}));
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "dup-1".into(),
fn_name: "search".into(),
fn_arguments: serde_json::json!({"q": "abc"}),
thought_signatures: None,
},
}));
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "dup-1".into(),
fn_name: "search".into(),
fn_arguments: serde_json::json!({"q": "abcdef"}),
thought_signatures: None,
},
}));
let result = c.finish();
assert_eq!(
result.tool_calls.len(),
1,
"duplicate ids must collapse to one entry, got: {:?}",
result.tool_calls
);
assert_eq!(result.tool_calls[0].id, "dup-1");
assert_eq!(result.tool_calls[0].arguments["q"], "abcdef");
}
#[test]
fn interrupt_snapshot_truncated_only_becomes_in_flight_with_empty_completed() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "c1".into(),
fn_name: "calc".into(),
fn_arguments: Value::String(r#"{"expr":"2+"#.into()),
thought_signatures: None,
},
}));
let snap = c.interrupt_snapshot();
assert!(snap.completed_tool_calls.is_empty());
let p = snap.in_flight_tool.expect("expected in-flight");
assert_eq!(p.name, "calc");
}
#[test]
fn last_delta_at_only_updates_on_deltas_not_start_or_end() {
use genai::chat::StreamEnd;
let mut c = StreamCollector::new();
c.process(ChatStreamEvent::Start);
assert!(c.last_delta_at().is_none());
c.process(ChatStreamEvent::Chunk(StreamChunk {
content: "x".into(),
}));
let t1 = c.last_delta_at().expect("should have timestamp");
c.process(ChatStreamEvent::End(StreamEnd::default()));
let t2 = c.last_delta_at().expect("preserved across end");
assert_eq!(t1, t2);
}
#[test]
fn collector_replays_buffered_args_when_tool_name_arrives_late() {
use genai::chat::ToolCall as GToolCall;
let mut c = StreamCollector::new();
let first = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: String::new(),
fn_arguments: Value::String(r#"{"prompt":"late"}"#.into()),
thought_signatures: None,
},
}));
assert!(matches!(first, StreamOutput::None));
assert!(c.take_pending_output().is_none());
let second = c.process(ChatStreamEvent::ToolCallChunk(ToolChunk {
tool_call: GToolCall {
call_id: "tc1".into(),
fn_name: "render_openui_ui".into(),
fn_arguments: Value::String(r#"{"prompt":"late"}"#.into()),
thought_signatures: None,
},
}));
assert!(matches!(
second,
StreamOutput::ToolCallStart { ref id, ref name }
if id == "tc1" && name == "render_openui_ui"
));
let delta = c
.take_pending_output()
.expect("buffered arguments should be replayed after ToolCallStart");
assert!(matches!(
delta,
StreamOutput::ToolCallDelta {
ref id,
ref args_delta
} if id == "tc1" && args_delta == r#"{"prompt":"late"}"#
));
}
}