pub mod handle;
mod input;
pub use handle::{ChatStream, Input, InputStream, IntoInput, OutputStream, SendError};
use async_stream::try_stream;
use futures::{Stream, StreamExt, channel::mpsc, future::Either, stream::BoxStream};
use tools_rs::FunctionResponse;
use input::{InputSignal, apply_input_to_messages, next_input};
use crate::{
chat::{Chat, ToolCallPass, state::Unstructured},
error::{ChatError, ChatFailure},
traits::StreamProvider,
types::{
messages::{Messages, parts::PartEnum},
metadata::Metadata,
response::{ChatResponse, StreamEvent},
},
};
pub(crate) enum Pump {
Yield(StreamEvent),
Done(ChatResponse),
}
pub(crate) fn pump_event(
item: Result<StreamEvent, ChatError>,
structured_buffer: &mut Vec<serde_json::Value>,
) -> Result<Pump, ChatError> {
match item {
Ok(StreamEvent::Done(response)) => Ok(Pump::Done(response)),
Ok(event) => {
if let StreamEvent::Structured(ref v) = event {
structured_buffer.push(v.clone());
}
Ok(Pump::Yield(event))
}
Err(err) => Err(err),
}
}
impl<CP: StreamProvider, Output> Chat<CP, Output> {
pub(crate) async fn stream_tool_pass(
&self,
messages: &mut Messages,
last_metadata: &Option<Metadata>,
) -> Result<(Vec<FunctionResponse>, ToolCallPass), ChatFailure> {
let pass = match messages.0.last_mut() {
Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
err,
metadata: last_metadata.clone(),
})?,
None => ToolCallPass::default(),
};
Ok((collect_tool_results(&pass, messages), pass))
}
pub(crate) async fn stream_commit(
&mut self,
messages: &mut Messages,
response: &ChatResponse,
last_metadata: &mut Option<Metadata>,
) -> Result<(Vec<FunctionResponse>, ToolCallPass), ChatFailure> {
self.model.on_stream_done(response);
if let Some(metadata) = response.metadata.clone() {
match last_metadata {
Some(existing) => {
existing.extend(&metadata);
}
None => {
*last_metadata = Some(metadata);
}
}
}
messages.push(response.content.clone());
let pass = match messages.0.last_mut() {
Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
err,
metadata: last_metadata.clone(),
})?,
None => ToolCallPass::default(),
};
Ok((collect_tool_results(&pass, messages), pass))
}
pub(crate) fn run_stream<'a>(
&'a mut self,
messages: &'a mut Messages,
mut input: Option<mpsc::UnboundedReceiver<Input>>,
) -> impl Stream<Item = Result<StreamEvent, ChatFailure>> + Send + 'a
where
Output: Send + Sync,
{
try_stream! {
let max_steps = self.max_steps.unwrap_or(1);
let mut last_metadata: Option<Metadata> = None;
let mut input_open = input.is_some();
'step: for _ in 0..max_steps {
let (results, pass) = self.stream_tool_pass(messages, &last_metadata).await?;
for fr in results {
yield StreamEvent::ToolResult(fr);
}
if let Some(reason) = pass.pause {
yield StreamEvent::Paused(reason);
return;
}
let decls = crate::chat::tool_declarations_from(&self.scoped_collections);
let decls_dyn = decls
.as_ref()
.map(|d| d as &dyn crate::types::tools::ToolDeclarations);
'restart: loop {
let mut provider_stream = self
.model
.stream(messages, decls_dyn, self.model_options.as_ref())
.await
.map_err(|err| ChatFailure { err, metadata: last_metadata.clone() })?;
let mut final_response: Option<ChatResponse> = None;
let mut structured_buffer: Vec<serde_json::Value> = Vec::new();
loop {
let item = match input.as_mut().filter(|_| input_open) {
Some(rx) => {
let provider_next = provider_stream.next();
let input_next = next_input(rx);
match futures::future::select(
Box::pin(provider_next),
Box::pin(input_next),
)
.await
{
Either::Left((item, _)) => item,
Either::Right((InputSignal::Apply(batch), _)) => {
for part in batch {
apply_input_to_messages(messages, part);
}
continue 'restart;
}
Either::Right((InputSignal::Cancelled, _)) => return,
Either::Right((InputSignal::Closed, _)) => {
input_open = false;
continue;
}
}
}
None => provider_stream.next().await,
};
match item {
Some(item) => {
match pump_event(item, &mut structured_buffer).map_err(|err| {
ChatFailure { err, metadata: last_metadata.clone() }
})? {
Pump::Done(response) => {
final_response = Some(response);
break;
}
Pump::Yield(event) => yield event,
}
}
None => break,
}
}
if let Some(mut response) = final_response {
for v in structured_buffer.drain(..) {
response.content.parts.push(PartEnum::Structured(v));
}
let (results, pass) =
self.stream_commit(messages, &response, &mut last_metadata).await?;
for fr in results {
yield StreamEvent::ToolResult(fr);
}
if let Some(reason) = pass.pause {
yield StreamEvent::Paused(reason);
return;
}
if pass.executed {
continue 'step;
}
if let Some(strategy) = self.after_strategy.as_mut() {
strategy(messages, last_metadata.as_ref()).await;
}
yield StreamEvent::Done(response);
return;
}
break 'restart;
}
}
}
}
}
fn collect_tool_results(pass: &ToolCallPass, messages: &Messages) -> Vec<FunctionResponse> {
if !pass.executed {
return Vec::new();
}
match messages.0.last() {
Some(last) => last
.parts
.tools()
.filter_map(|t| t.response().cloned())
.collect(),
None => Vec::new(),
}
}
impl<CP: StreamProvider> Chat<CP, Unstructured> {
pub async fn stream<'a>(
&'a mut self,
messages: &'a mut Messages,
) -> Result<BoxStream<'a, Result<StreamEvent, ChatFailure>>, ChatFailure> {
if let Some(strategy) = self.before_strategy.as_mut() {
strategy(messages, None).await;
}
Ok(Box::pin(self.run_stream(messages, None)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
error::ChatError,
types::{
messages::{
Messages,
content::{Content, RoleEnum},
parts::{PartEnum, Parts},
},
options::ChatOptions,
response::ChatResponse,
tools::ToolDeclarations,
},
};
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::marker::PhantomData;
struct MockStreamProvider {
events: Vec<Result<StreamEvent, ChatError>>,
}
#[async_trait]
impl StreamProvider for MockStreamProvider {
async fn stream(
&mut self,
_messages: &mut Messages,
_tool_declarations: Option<&dyn ToolDeclarations>,
_options: Option<&ChatOptions>,
) -> Result<futures::stream::BoxStream<'static, Result<StreamEvent, ChatError>>, ChatError>
{
let events = std::mem::take(&mut self.events);
Ok(Box::pin(futures::stream::iter(events)))
}
}
fn chat_with(
events: Vec<Result<StreamEvent, ChatError>>,
) -> Chat<MockStreamProvider, Unstructured> {
Chat {
model: MockStreamProvider { events },
output_shape: None,
model_options: None,
max_steps: Some(1),
max_retries: None,
retry_strategy: None,
before_strategy: None,
after_strategy: None,
scoped_collections: Vec::new(),
routing: HashMap::new(),
_output: PhantomData,
}
}
fn done_event() -> StreamEvent {
StreamEvent::Done(ChatResponse {
content: Content {
role: RoleEnum::Model,
parts: Parts::default(),
complete_reason: Default::default(),
},
metadata: None,
})
}
async fn collect_stream(
chat: &mut Chat<MockStreamProvider, Unstructured>,
messages: &mut Messages,
) -> Vec<StreamEvent> {
let mut s = chat.stream(messages).await.expect("stream open");
let mut out = Vec::new();
while let Some(ev) = s.next().await {
out.push(ev.expect("event ok"));
}
out
}
#[tokio::test]
async fn structured_events_flow_to_consumer_and_into_final_response() {
let mut chat = chat_with(vec![
Ok(StreamEvent::Structured(json!({"step": 1}))),
Ok(StreamEvent::Structured(json!({"step": 2}))),
Ok(done_event()),
]);
let mut messages = Messages::default();
let events = collect_stream(&mut chat, &mut messages).await;
assert_eq!(events.len(), 3);
assert!(matches!(events[0], StreamEvent::Structured(_)));
assert!(matches!(events[1], StreamEvent::Structured(_)));
let StreamEvent::Done(response) = &events[2] else {
panic!("expected Done event");
};
let structured: Vec<&serde_json::Value> = response
.content
.parts
.0
.iter()
.filter_map(|p| match p {
PartEnum::Structured(v) => Some(v),
_ => None,
})
.collect();
assert_eq!(structured.len(), 2);
assert_eq!(structured[0], &json!({"step": 1}));
assert_eq!(structured[1], &json!({"step": 2}));
}
#[tokio::test]
async fn structured_interleaved_with_text_preserves_event_order() {
let mut chat = chat_with(vec![
Ok(StreamEvent::TextChunk("hello ".into())),
Ok(StreamEvent::Structured(json!({"step": 1}))),
Ok(StreamEvent::TextChunk("world".into())),
Ok(StreamEvent::Structured(json!({"step": 2}))),
Ok(done_event()),
]);
let mut messages = Messages::default();
let events = collect_stream(&mut chat, &mut messages).await;
assert_eq!(events.len(), 5);
assert!(matches!(events[0], StreamEvent::TextChunk(ref t) if t == "hello "));
assert!(matches!(events[1], StreamEvent::Structured(_)));
assert!(matches!(events[2], StreamEvent::TextChunk(ref t) if t == "world"));
assert!(matches!(events[3], StreamEvent::Structured(_)));
let StreamEvent::Done(response) = &events[4] else {
panic!("expected Done event");
};
let parts: Vec<&PartEnum> = response.content.parts.0.iter().collect();
assert_eq!(parts.len(), 2);
assert!(matches!(parts[0], PartEnum::Structured(v) if v == &json!({"step": 1})));
assert!(matches!(parts[1], PartEnum::Structured(v) if v == &json!({"step": 2})));
}
#[tokio::test]
async fn no_structured_events_leaves_final_response_untouched() {
let mut chat = chat_with(vec![
Ok(StreamEvent::TextChunk("just text".into())),
Ok(done_event()),
]);
let mut messages = Messages::default();
let events = collect_stream(&mut chat, &mut messages).await;
assert_eq!(events.len(), 2);
let StreamEvent::Done(response) = &events[1] else {
panic!("expected Done event");
};
assert!(response.content.parts.0.is_empty());
}
}