use std::pin::Pin;
use futures::Stream;
use serde::{Deserialize, Serialize};
use crate::chat::{StopReason, ToolCall};
use crate::error::LlmError;
use crate::usage::Usage;
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum StreamEvent {
TextDelta(String),
ReasoningDelta(String),
ToolCallStart {
index: u32,
id: String,
name: String,
},
ToolCallDelta {
index: u32,
json_chunk: String,
},
ToolCallComplete {
index: u32,
call: ToolCall,
},
Usage(Usage),
Done {
stop_reason: StopReason,
},
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[test]
fn test_stream_event_text_delta_eq() {
let a = StreamEvent::TextDelta("hello".into());
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn test_stream_event_reasoning_delta_eq() {
let a = StreamEvent::ReasoningDelta("step 1".into());
assert_eq!(a, a.clone());
}
#[test]
fn test_stream_event_tool_call_start() {
let e = StreamEvent::ToolCallStart {
index: 0,
id: "tc_1".into(),
name: "search".into(),
};
assert!(matches!(
&e,
StreamEvent::ToolCallStart { index: 0, id, name }
if id == "tc_1" && name == "search"
));
}
#[test]
fn test_stream_event_tool_call_delta() {
let e = StreamEvent::ToolCallDelta {
index: 0,
json_chunk: r#"{"q":"#.into(),
};
assert_eq!(e, e.clone());
}
#[test]
fn test_stream_event_tool_call_complete() {
let call = ToolCall {
id: "tc_1".into(),
name: "search".into(),
arguments: serde_json::json!({"q": "rust"}),
};
let e = StreamEvent::ToolCallComplete {
index: 0,
call: call.clone(),
};
assert!(matches!(
&e,
StreamEvent::ToolCallComplete { call: c, .. } if *c == call
));
}
#[test]
fn test_stream_event_usage() {
let e = StreamEvent::Usage(Usage {
input_tokens: 100,
output_tokens: 50,
..Usage::default()
});
assert_eq!(e, e.clone());
}
#[test]
fn test_stream_event_done() {
let e = StreamEvent::Done {
stop_reason: StopReason::EndTurn,
};
assert!(matches!(
&e,
StreamEvent::Done { stop_reason } if *stop_reason == StopReason::EndTurn
));
}
#[tokio::test]
async fn test_chat_stream_collect() {
let events = vec![
Ok(StreamEvent::TextDelta("hello ".into())),
Ok(StreamEvent::TextDelta("world".into())),
Ok(StreamEvent::Done {
stop_reason: StopReason::EndTurn,
}),
];
let stream: ChatStream = Box::pin(futures::stream::iter(events));
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 3);
assert!(collected.iter().all(Result::is_ok));
}
#[tokio::test]
async fn test_chat_stream_error_mid_stream() {
let events = vec![
Ok(StreamEvent::TextDelta("hello".into())),
Ok(StreamEvent::TextDelta(" world".into())),
Err(LlmError::Http {
status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
message: "server error".into(),
retryable: true,
}),
];
let stream: ChatStream = Box::pin(futures::stream::iter(events));
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 3);
assert!(collected[0].is_ok());
assert!(collected[1].is_ok());
assert!(collected[2].is_err());
}
#[test]
fn test_chat_stream_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ChatStream>();
}
}