use crate::error::{Error, LlmError};
use crate::llm::{ChatChunk, ChunkStream, FinishReason, Message, Role, ToolCall};
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
pub(crate) async fn forward_chunks(
mut provider_stream: ChunkStream,
tx: &tokio::sync::mpsc::Sender<Result<ChatChunk, LlmError>>,
cancel: &CancellationToken,
max_bytes: usize,
) -> Result<(Message, FinishReason), Error> {
let mut content = String::with_capacity(1024);
let mut tool_calls: Vec<ToolCall> = Vec::with_capacity(2);
let mut finish_reason: Option<FinishReason> = None;
let mut bytes_seen: usize = 0;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
drop(provider_stream);
return Err(Error::Cancelled);
}
next = provider_stream.next() => {
match next {
Some(Ok(chunk)) => {
let chunk_bytes = chunk.delta.len()
+ chunk
.tool_calls
.iter()
.map(estimate_tool_call_bytes)
.sum::<usize>();
bytes_seen = bytes_seen.saturating_add(chunk_bytes);
if bytes_seen > max_bytes {
drop(provider_stream);
return Err(Error::Llm(LlmError::Server(
"response exceeded max_response_bytes cap".into(),
)));
}
content.push_str(&chunk.delta);
if !chunk.tool_calls.is_empty() {
tool_calls.extend(chunk.tool_calls.clone());
}
if let Some(fr) = chunk.finish_reason {
finish_reason = Some(fr);
}
if tx.send(Ok(chunk)).await.is_err() {
drop(provider_stream);
return Err(Error::Cancelled);
}
}
Some(Err(e)) => {
return Err(Error::Llm(e));
}
None => break,
}
}
}
}
let finish_reason = finish_reason.unwrap_or(FinishReason::Stop);
let msg = Message {
role: Role::Assistant,
content,
tool_calls,
tool_call_id: None,
};
Ok((msg, finish_reason))
}
fn estimate_tool_call_bytes(call: &ToolCall) -> usize {
let args_len = serde_json::to_string(&call.args)
.map(|s| s.len())
.unwrap_or(0);
call.id.len() + call.name.len() + args_len
}