use async_trait::async_trait;
use futures::{Stream, StreamExt, stream};
use crate::{
Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError, request::ToolCall,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TurnStreamEvent {
TextDelta(String),
ToolStarted {
id: String,
name: String,
},
}
#[derive(Debug, Default, Clone)]
pub struct TurnOutput {
pub text: String,
pub tool_calls: Vec<ToolCall>,
pub usage: Usage,
pub stop: Option<StopReason>,
}
pub async fn collect_turn<S, E>(stream: S) -> Result<TurnOutput, E>
where
S: Stream<Item = Result<Chunk, E>> + Unpin,
{
collect_turn_observed(stream, |_| {}).await
}
pub async fn collect_turn_observed<S, E, F>(mut stream: S, mut on_event: F) -> Result<TurnOutput, E>
where
S: Stream<Item = Result<Chunk, E>> + Unpin,
F: FnMut(TurnStreamEvent),
{
let mut out = TurnOutput::default();
let mut pending: Vec<ToolCall> = Vec::new();
while let Some(item) = stream.next().await {
match item? {
Chunk::TextDelta(s) => {
on_event(TurnStreamEvent::TextDelta(s.clone()));
out.text.push_str(&s);
}
Chunk::ToolCallStart {
id,
name,
signature,
} => {
on_event(TurnStreamEvent::ToolStarted {
id: id.clone(),
name: name.clone(),
});
pending.push(ToolCall {
id,
name,
args_json: String::new(),
signature,
});
}
Chunk::ToolCallArgsDelta {
id,
args_json_delta,
} => {
if let Some(tc) = pending.iter_mut().find(|tc| tc.id == id) {
tc.args_json.push_str(&args_json_delta);
}
}
Chunk::ToolCallEnd { id } => {
if let Some(pos) = pending.iter().position(|tc| tc.id == id) {
out.tool_calls.push(pending.remove(pos));
}
}
Chunk::Usage(u) => out.usage = u,
Chunk::Stop(r) => {
let keep_tool_use =
out.stop == Some(StopReason::ToolUse) && matches!(r, StopReason::EndTurn);
if !keep_tool_use {
out.stop = Some(r);
}
}
}
}
out.tool_calls.append(&mut pending);
Ok(out)
}
pub const STUB_TOOL_CALL_ENV: &str = "POLYCHROME_STUB_TOOL_CALL";
fn stub_tool_name() -> Option<String> {
std::env::var(STUB_TOOL_CALL_ENV)
.ok()
.filter(|s| !s.is_empty())
}
#[derive(Clone, Copy, Default)]
pub struct StubProvider;
#[async_trait]
impl LlmProvider for StubProvider {
type Error = DummyError;
async fn complete(
&self,
req: CompletionRequest,
) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
if let Some(tool_name) = stub_tool_name() {
let saw_result = req.messages.iter().any(|m| {
m.content
.iter()
.any(|c| matches!(c, crate::Content::ToolResult(_)))
});
if !saw_result {
let chunks = vec![
Ok(Chunk::tool_call_start("stub-call-1", &tool_name)),
Ok(Chunk::tool_call_args_delta("stub-call-1", "{}")),
Ok(Chunk::tool_call_end("stub-call-1")),
Ok(Chunk::Stop(StopReason::ToolUse)),
];
return Ok(stream::iter(chunks).boxed());
}
}
let chunks = vec![
Ok(Chunk::text_delta("Hello from the ")),
Ok(Chunk::text_delta("stub provider.")),
Ok(Chunk::Usage(Usage {
input_tokens: 5,
output_tokens: 4,
})),
Ok(Chunk::Stop(StopReason::EndTurn)),
];
Ok(stream::iter(chunks).boxed())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::pedantic, clippy::nursery, missing_docs)]
use super::*;
#[tokio::test]
async fn stub_provider_collects_into_text() {
let stream = StubProvider
.complete(CompletionRequest::new("stub"))
.await
.expect("stream opens");
let out = collect_turn(stream).await.expect("collect");
assert_eq!(out.text, "Hello from the stub provider.");
assert!(out.tool_calls.is_empty());
assert_eq!(out.usage.output_tokens, 4);
assert_eq!(out.stop, Some(StopReason::EndTurn));
}
#[tokio::test]
async fn collect_assembles_tool_call_from_deltas() {
let chunks: Vec<Result<Chunk, DummyError>> = vec![
Ok(Chunk::text_delta("calling ")),
Ok(Chunk::tool_call_start("c1", "search")),
Ok(Chunk::tool_call_args_delta("c1", r#"{"q":"#)),
Ok(Chunk::tool_call_args_delta("c1", r#""rust"}"#)),
Ok(Chunk::tool_call_end("c1")),
Ok(Chunk::Stop(StopReason::ToolUse)),
];
let out = collect_turn(stream::iter(chunks)).await.expect("collect");
assert_eq!(out.text, "calling ");
assert_eq!(out.tool_calls.len(), 1);
assert_eq!(out.tool_calls[0].name, "search");
assert_eq!(out.tool_calls[0].args_json, r#"{"q":"rust"}"#);
assert_eq!(out.stop, Some(StopReason::ToolUse));
}
#[tokio::test]
async fn collect_keeps_parallel_tool_calls_with_deferred_ends() {
let chunks: Vec<Result<Chunk, DummyError>> = vec![
Ok(Chunk::tool_call_start("c0", "search")),
Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
Ok(Chunk::tool_call_start("c1", "fetch")),
Ok(Chunk::tool_call_args_delta("c1", r#"{"u":"b"}"#)),
Ok(Chunk::tool_call_end("c0")),
Ok(Chunk::tool_call_end("c1")),
Ok(Chunk::Stop(StopReason::ToolUse)),
];
let out = collect_turn(stream::iter(chunks)).await.expect("collect");
assert_eq!(out.tool_calls.len(), 2, "both parallel calls preserved");
assert_eq!(out.tool_calls[0].id, "c0");
assert_eq!(out.tool_calls[0].name, "search");
assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
assert_eq!(out.tool_calls[1].id, "c1");
assert_eq!(out.tool_calls[1].name, "fetch");
assert_eq!(out.tool_calls[1].args_json, r#"{"u":"b"}"#);
assert_eq!(out.stop, Some(StopReason::ToolUse));
}
#[tokio::test]
async fn collect_flushes_a_call_left_open_at_eof() {
let chunks: Vec<Result<Chunk, DummyError>> = vec![
Ok(Chunk::tool_call_start("c0", "search")),
Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
Ok(Chunk::Stop(StopReason::ToolUse)),
];
let out = collect_turn(stream::iter(chunks)).await.expect("collect");
assert_eq!(out.tool_calls.len(), 1);
assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
}
#[tokio::test]
async fn tool_use_stop_is_sticky_against_later_end_turn() {
let chunks: Vec<Result<Chunk, DummyError>> = vec![
Ok(Chunk::tool_call_start("c1", "search")),
Ok(Chunk::tool_call_end("c1")),
Ok(Chunk::Stop(StopReason::ToolUse)),
Ok(Chunk::Stop(StopReason::EndTurn)),
];
let out = collect_turn(stream::iter(chunks)).await.expect("collect");
assert_eq!(out.stop, Some(StopReason::ToolUse));
}
#[tokio::test]
async fn hard_stop_wins_over_earlier_tool_use() {
let chunks: Vec<Result<Chunk, DummyError>> = vec![
Ok(Chunk::tool_call_start("c1", "search")),
Ok(Chunk::tool_call_end("c1")),
Ok(Chunk::Stop(StopReason::ToolUse)),
Ok(Chunk::Stop(StopReason::MaxTokens)),
];
let out = collect_turn(stream::iter(chunks)).await.expect("collect");
assert_eq!(out.stop, Some(StopReason::MaxTokens));
}
#[tokio::test]
async fn collect_propagates_error() {
let chunks: Vec<Result<Chunk, DummyError>> = vec![
Ok(Chunk::text_delta("partial")),
Err(DummyError::Other("mid-stream fault".to_owned())),
];
let res = collect_turn(stream::iter(chunks)).await;
assert!(res.is_err());
}
}