1use std::pin::Pin;
42
43use futures::Stream;
44use serde::{Deserialize, Serialize};
45
46use crate::chat::{StopReason, ToolCall};
47use crate::error::LlmError;
48use crate::usage::Usage;
49
50pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>;
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58#[non_exhaustive]
59pub enum StreamEvent {
60 TextDelta(String),
62 ReasoningDelta(String),
64 ToolCallStart {
66 index: u32,
69 id: String,
71 name: String,
73 },
74 ToolCallDelta {
76 index: u32,
78 json_chunk: String,
80 },
81 ToolCallComplete {
83 index: u32,
85 call: ToolCall,
87 },
88 Usage(Usage),
90 Done {
92 stop_reason: StopReason,
94 },
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use futures::StreamExt;
101
102 #[test]
103 fn test_stream_event_text_delta_eq() {
104 let a = StreamEvent::TextDelta("hello".into());
105 let b = a.clone();
106 assert_eq!(a, b);
107 }
108
109 #[test]
110 fn test_stream_event_reasoning_delta_eq() {
111 let a = StreamEvent::ReasoningDelta("step 1".into());
112 assert_eq!(a, a.clone());
113 }
114
115 #[test]
116 fn test_stream_event_tool_call_start() {
117 let e = StreamEvent::ToolCallStart {
118 index: 0,
119 id: "tc_1".into(),
120 name: "search".into(),
121 };
122 assert!(matches!(
123 &e,
124 StreamEvent::ToolCallStart { index: 0, id, name }
125 if id == "tc_1" && name == "search"
126 ));
127 }
128
129 #[test]
130 fn test_stream_event_tool_call_delta() {
131 let e = StreamEvent::ToolCallDelta {
132 index: 0,
133 json_chunk: r#"{"q":"#.into(),
134 };
135 assert_eq!(e, e.clone());
136 }
137
138 #[test]
139 fn test_stream_event_tool_call_complete() {
140 let call = ToolCall {
141 id: "tc_1".into(),
142 name: "search".into(),
143 arguments: serde_json::json!({"q": "rust"}),
144 };
145 let e = StreamEvent::ToolCallComplete {
146 index: 0,
147 call: call.clone(),
148 };
149 assert!(matches!(
150 &e,
151 StreamEvent::ToolCallComplete { call: c, .. } if *c == call
152 ));
153 }
154
155 #[test]
156 fn test_stream_event_usage() {
157 let e = StreamEvent::Usage(Usage {
158 input_tokens: 100,
159 output_tokens: 50,
160 ..Usage::default()
161 });
162 assert_eq!(e, e.clone());
163 }
164
165 #[test]
166 fn test_stream_event_done() {
167 let e = StreamEvent::Done {
168 stop_reason: StopReason::EndTurn,
169 };
170 assert!(matches!(
171 &e,
172 StreamEvent::Done { stop_reason } if *stop_reason == StopReason::EndTurn
173 ));
174 }
175
176 #[tokio::test]
177 async fn test_chat_stream_collect() {
178 let events = vec![
179 Ok(StreamEvent::TextDelta("hello ".into())),
180 Ok(StreamEvent::TextDelta("world".into())),
181 Ok(StreamEvent::Done {
182 stop_reason: StopReason::EndTurn,
183 }),
184 ];
185 let stream: ChatStream = Box::pin(futures::stream::iter(events));
186 let collected: Vec<_> = stream.collect().await;
187 assert_eq!(collected.len(), 3);
188 assert!(collected.iter().all(Result::is_ok));
189 }
190
191 #[tokio::test]
192 async fn test_chat_stream_error_mid_stream() {
193 let events = vec![
194 Ok(StreamEvent::TextDelta("hello".into())),
195 Ok(StreamEvent::TextDelta(" world".into())),
196 Err(LlmError::Http {
197 status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
198 message: "server error".into(),
199 retryable: true,
200 }),
201 ];
202 let stream: ChatStream = Box::pin(futures::stream::iter(events));
203 let collected: Vec<_> = stream.collect().await;
204 assert_eq!(collected.len(), 3);
205 assert!(collected[0].is_ok());
206 assert!(collected[1].is_ok());
207 assert!(collected[2].is_err());
208 }
209
210 #[test]
211 fn test_chat_stream_is_send() {
212 fn assert_send<T: Send>() {}
213 assert_send::<ChatStream>();
214 }
215}