use crate::error::StreamResult;
use crate::events::AgentStreamEvent;
use crate::partial_response::{PartialResponse, ResponseDelta};
use futures::{Stream, StreamExt};
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use serdes_ai_core::{ModelResponse, RequestUsage};
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamState {
Pending,
Streaming,
ProcessingTools,
Retrying,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub emit_partial_outputs: bool,
pub partial_output_interval_ms: u64,
pub emit_thinking: bool,
pub buffer_tool_args: bool,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
emit_partial_outputs: true,
partial_output_interval_ms: 100,
emit_thinking: true,
buffer_tool_args: false,
}
}
}
pin_project! {
pub struct AgentStream<S, Output> {
#[pin]
inner: S,
run_id: String,
step: u32,
state: StreamState,
config: StreamConfig,
partial_response: PartialResponse,
pending_events: VecDeque<AgentStreamEvent<Output>>,
accumulated_usage: RequestUsage,
_output: std::marker::PhantomData<Output>,
}
}
impl<S, Output> AgentStream<S, Output>
where
S: Stream<Item = StreamResult<ResponseDelta>>,
Output: DeserializeOwned,
{
pub fn new(inner: S, run_id: impl Into<String>) -> Self {
let run_id = run_id.into();
Self {
inner,
run_id: run_id.clone(),
step: 0,
state: StreamState::Pending,
config: StreamConfig::default(),
partial_response: PartialResponse::new(),
pending_events: VecDeque::new(),
accumulated_usage: RequestUsage::new(),
_output: std::marker::PhantomData,
}
}
pub fn with_config(mut self, config: StreamConfig) -> Self {
self.config = config;
self
}
pub fn run_id(&self) -> &str {
&self.run_id
}
pub fn step(&self) -> u32 {
self.step
}
pub fn state(&self) -> StreamState {
self.state
}
pub fn partial_response(&self) -> &PartialResponse {
&self.partial_response
}
pub fn response_snapshot(&self) -> ModelResponse {
self.partial_response.as_response()
}
pub fn text_content(&self) -> String {
self.partial_response.text_content()
}
pub fn usage(&self) -> &RequestUsage {
&self.accumulated_usage
}
pub fn is_complete(&self) -> bool {
matches!(self.state, StreamState::Completed | StreamState::Failed)
}
#[allow(dead_code)]
fn process_delta(&mut self, delta: ResponseDelta) {
match &delta {
ResponseDelta::Text { index, content } => {
self.pending_events.push_back(AgentStreamEvent::TextDelta {
content: content.clone(),
part_index: *index,
});
}
ResponseDelta::ToolCall {
index,
name,
args,
id,
} => {
if let Some(name) = name {
self.pending_events
.push_back(AgentStreamEvent::ToolCallStart {
name: name.clone(),
tool_call_id: id.clone(),
index: *index,
});
}
if let Some(args) = args {
if !self.config.buffer_tool_args {
self.pending_events
.push_back(AgentStreamEvent::ToolCallDelta {
args_delta: args.clone(),
index: *index,
});
}
}
}
ResponseDelta::Thinking { index, content, .. } => {
if self.config.emit_thinking {
self.pending_events
.push_back(AgentStreamEvent::ThinkingDelta {
content: content.clone(),
index: *index,
});
}
}
ResponseDelta::Finish { .. } => {
self.state = StreamState::Completed;
}
ResponseDelta::Usage { usage } => {
self.accumulated_usage = self.accumulated_usage.clone() + usage.clone();
self.pending_events
.push_back(AgentStreamEvent::UsageUpdate {
usage: self.accumulated_usage.clone(),
});
}
}
self.partial_response.apply_delta(&delta);
}
}
impl<S, Output> Stream for AgentStream<S, Output>
where
S: Stream<Item = StreamResult<ResponseDelta>> + Unpin,
Output: DeserializeOwned + Clone,
{
type Item = AgentStreamEvent<Output>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if let Some(event) = this.pending_events.pop_front() {
return Poll::Ready(Some(event));
}
if matches!(this.state, StreamState::Completed | StreamState::Failed) {
return Poll::Ready(None);
}
if *this.state == StreamState::Pending {
*this.state = StreamState::Streaming;
*this.step += 1;
return Poll::Ready(Some(AgentStreamEvent::RunStart {
run_id: this.run_id.clone(),
step: *this.step,
}));
}
match this.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(delta))) => {
match &delta {
ResponseDelta::Text { index, content } => {
this.pending_events.push_back(AgentStreamEvent::TextDelta {
content: content.clone(),
part_index: *index,
});
}
ResponseDelta::ToolCall {
index,
name,
args,
id,
} => {
if let Some(name) = name {
this.pending_events
.push_back(AgentStreamEvent::ToolCallStart {
name: name.clone(),
tool_call_id: id.clone(),
index: *index,
});
}
if let Some(args) = args {
if !this.config.buffer_tool_args {
this.pending_events
.push_back(AgentStreamEvent::ToolCallDelta {
args_delta: args.clone(),
index: *index,
});
}
}
}
ResponseDelta::Thinking { index, content, .. } => {
if this.config.emit_thinking {
this.pending_events
.push_back(AgentStreamEvent::ThinkingDelta {
content: content.clone(),
index: *index,
});
}
}
ResponseDelta::Finish { .. } => {
*this.state = StreamState::Completed;
this.pending_events
.push_back(AgentStreamEvent::ResponseComplete {
response: this.partial_response.as_response(),
});
this.pending_events
.push_back(AgentStreamEvent::RunComplete {
run_id: this.run_id.clone(),
total_steps: *this.step,
});
}
ResponseDelta::Usage { usage } => {
*this.accumulated_usage = this.accumulated_usage.clone() + usage.clone();
this.pending_events
.push_back(AgentStreamEvent::UsageUpdate {
usage: this.accumulated_usage.clone(),
});
}
}
this.partial_response.apply_delta(&delta);
if let Some(event) = this.pending_events.pop_front() {
Poll::Ready(Some(event))
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
Poll::Ready(Some(Err(e))) => {
*this.state = StreamState::Failed;
Poll::Ready(Some(AgentStreamEvent::Error {
message: e.to_string(),
recoverable: e.is_recoverable(),
}))
}
Poll::Ready(None) => {
if *this.state == StreamState::Streaming {
*this.state = StreamState::Completed;
this.pending_events
.push_back(AgentStreamEvent::ResponseComplete {
response: this.partial_response.as_response(),
});
this.pending_events
.push_back(AgentStreamEvent::RunComplete {
run_id: this.run_id.clone(),
total_steps: *this.step,
});
if let Some(event) = this.pending_events.pop_front() {
return Poll::Ready(Some(event));
}
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
pub trait AgentStreamExt<Output>: Stream<Item = AgentStreamEvent<Output>> + Sized {
fn text_deltas(self) -> TextDeltaStream<Self> {
TextDeltaStream {
inner: self,
accumulated: String::new(),
emit_accumulated: false,
}
}
fn text_accumulated(self) -> TextDeltaStream<Self> {
TextDeltaStream {
inner: self,
accumulated: String::new(),
emit_accumulated: true,
}
}
fn outputs(self) -> OutputStream<Self, Output> {
OutputStream {
inner: self,
_output: std::marker::PhantomData,
}
}
fn responses(self) -> ResponseStream<Self> {
ResponseStream { inner: self }
}
}
impl<S, Output> AgentStreamExt<Output> for S where S: Stream<Item = AgentStreamEvent<Output>> {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TextDelta {
pub content: String,
pub position: usize,
pub total_length: usize,
}
impl TextDelta {
pub fn new(content: String, position: usize, total_length: usize) -> Self {
Self {
content,
position,
total_length,
}
}
}
pin_project! {
pub struct TextDeltaStream<S> {
#[pin]
inner: S,
accumulated: String,
emit_accumulated: bool,
}
}
impl<S> TextDeltaStream<S> {
pub fn accumulated_text(&self) -> &str {
&self.accumulated
}
pub fn into_accumulated(self) -> String {
self.accumulated
}
}
impl<S, Output> Stream for TextDeltaStream<S>
where
S: Stream<Item = AgentStreamEvent<Output>>,
{
type Item = TextDelta;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(event)) => match event {
AgentStreamEvent::TextDelta { content, .. } => {
let position = this.accumulated.len();
this.accumulated.push_str(&content);
let total_length = this.accumulated.len();
return Poll::Ready(Some(TextDelta::new(content, position, total_length)));
}
AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
return Poll::Ready(None);
}
_ => continue, },
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
pin_project! {
pub struct OutputStream<S, Output> {
#[pin]
inner: S,
_output: std::marker::PhantomData<Output>,
}
}
impl<S, Output> Stream for OutputStream<S, Output>
where
S: Stream<Item = AgentStreamEvent<Output>>,
{
type Item = Output;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(event)) => match event {
AgentStreamEvent::FinalOutput { output } => {
return Poll::Ready(Some(output));
}
AgentStreamEvent::PartialOutput { output } => {
return Poll::Ready(Some(output));
}
AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
return Poll::Ready(None);
}
_ => continue,
},
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
pin_project! {
pub struct ResponseStream<S> {
#[pin]
inner: S,
}
}
impl<S, Output> Stream for ResponseStream<S>
where
S: Stream<Item = AgentStreamEvent<Output>>,
{
type Item = ModelResponse;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(event)) => match event {
AgentStreamEvent::ResponseComplete { response } => {
return Poll::Ready(Some(response));
}
AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
return Poll::Ready(None);
}
_ => continue,
},
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
#[tokio::test]
async fn test_agent_stream_basic() {
let deltas = vec![
Ok(ResponseDelta::Text {
index: 0,
content: "Hello".to_string(),
}),
Ok(ResponseDelta::Text {
index: 0,
content: ", world!".to_string(),
}),
Ok(ResponseDelta::Finish {
reason: serdes_ai_core::FinishReason::Stop,
}),
];
let inner = stream::iter(deltas);
let mut agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
let mut events = Vec::new();
while let Some(event) = agent_stream.next().await {
events.push(event);
}
assert!(events.len() >= 4);
assert!(matches!(events[0], AgentStreamEvent::RunStart { .. }));
}
#[tokio::test]
async fn test_text_deltas() {
let deltas = vec![
Ok(ResponseDelta::Text {
index: 0,
content: "Hello".to_string(),
}),
Ok(ResponseDelta::Text {
index: 0,
content: " world".to_string(),
}),
Ok(ResponseDelta::Finish {
reason: serdes_ai_core::FinishReason::Stop,
}),
];
let inner = stream::iter(deltas);
let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
let text_deltas: Vec<TextDelta> = agent_stream.text_deltas().collect().await;
assert_eq!(text_deltas.len(), 2);
assert_eq!(text_deltas[0].content, "Hello");
assert_eq!(text_deltas[0].position, 0);
assert_eq!(text_deltas[0].total_length, 5);
assert_eq!(text_deltas[1].content, " world");
assert_eq!(text_deltas[1].position, 5);
assert_eq!(text_deltas[1].total_length, 11);
}
#[tokio::test]
async fn test_text_accumulated() {
let deltas = vec![
Ok(ResponseDelta::Text {
index: 0,
content: "Hello".to_string(),
}),
Ok(ResponseDelta::Text {
index: 0,
content: " world".to_string(),
}),
Ok(ResponseDelta::Finish {
reason: serdes_ai_core::FinishReason::Stop,
}),
];
let inner = stream::iter(deltas);
let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
let mut stream = agent_stream.text_accumulated();
let text_deltas: Vec<TextDelta> = (&mut stream).collect().await;
assert_eq!(text_deltas.len(), 2);
assert_eq!(text_deltas[0].content, "Hello");
assert_eq!(text_deltas[1].content, " world");
assert_eq!(stream.accumulated_text(), "Hello world");
}
#[tokio::test]
async fn test_stream_state() {
let deltas = vec![Ok(ResponseDelta::Text {
index: 0,
content: "Test".to_string(),
})];
let inner = stream::iter(deltas);
let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
assert_eq!(agent_stream.state(), StreamState::Pending);
assert!(!agent_stream.is_complete());
}
#[test]
fn test_stream_config_default() {
let config = StreamConfig::default();
assert!(config.emit_partial_outputs);
assert!(config.emit_thinking);
assert!(!config.buffer_tool_args);
}
}