1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::completion::{
14 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message, Usage,
15};
16use crate::message::{AssistantContent, ToolCall, ToolFunction};
17use futures::stream::{AbortHandle, Abortable};
18use futures::{Stream, StreamExt};
19use std::boxed::Box;
20use std::future::Future;
21use std::pin::Pin;
22use std::task::{Context, Poll};
23
24#[derive(Debug, Clone)]
26pub enum RawStreamingChoice<R: Clone> {
27 Message(String),
29
30 ToolCall {
32 id: String,
33 call_id: Option<String>,
34 name: String,
35 arguments: serde_json::Value,
36 },
37
38 FinalResponse(R),
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub type StreamingResult<R> =
45 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
46
47#[cfg(target_arch = "wasm32")]
48pub type StreamingResult<R> =
49 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
50
51pub struct StreamingCompletionResponse<R: Clone + Unpin> {
55 pub(crate) inner: Abortable<StreamingResult<R>>,
56 pub(crate) abort_handle: AbortHandle,
57 text: String,
58 tool_calls: Vec<ToolCall>,
59 pub choice: OneOrMany<AssistantContent>,
62 pub response: Option<R>,
65}
66
67impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
68 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
69 let (abort_handle, abort_registration) = AbortHandle::new_pair();
70 let abortable_stream = Abortable::new(inner, abort_registration);
71 Self {
72 inner: abortable_stream,
73 abort_handle,
74 text: "".to_string(),
75 tool_calls: vec![],
76 choice: OneOrMany::one(AssistantContent::text("")),
77 response: None,
78 }
79 }
80
81 pub fn cancel(&self) {
82 self.abort_handle.abort();
83 }
84}
85
86impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
87 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
88 CompletionResponse {
89 choice: value.choice,
90 usage: Usage::new(), raw_response: value.response,
92 }
93 }
94}
95
96impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
97 type Item = Result<AssistantContent, CompletionError>;
98
99 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 let stream = self.get_mut();
101
102 match Pin::new(&mut stream.inner).poll_next(cx) {
103 Poll::Pending => Poll::Pending,
104 Poll::Ready(None) => {
105 let mut choice = vec![];
108
109 stream.tool_calls.iter().for_each(|tc| {
110 choice.push(AssistantContent::ToolCall(tc.clone()));
111 });
112
113 if choice.is_empty() || !stream.text.is_empty() {
115 choice.insert(0, AssistantContent::text(stream.text.clone()));
116 }
117
118 stream.choice = OneOrMany::many(choice)
119 .expect("There should be at least one assistant message");
120
121 Poll::Ready(None)
122 }
123 Poll::Ready(Some(Err(err))) => {
124 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
125 {
126 return Poll::Ready(None); }
128 Poll::Ready(Some(Err(err)))
129 }
130 Poll::Ready(Some(Ok(choice))) => match choice {
131 RawStreamingChoice::Message(text) => {
132 stream.text = format!("{}{}", stream.text, text.clone());
135 Poll::Ready(Some(Ok(AssistantContent::text(text))))
136 }
137 RawStreamingChoice::ToolCall {
138 id,
139 name,
140 arguments,
141 call_id,
142 } => {
143 stream.tool_calls.push(ToolCall {
146 id: id.clone(),
147 call_id,
148 function: ToolFunction {
149 name: name.clone(),
150 arguments: arguments.clone(),
151 },
152 });
153 Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, arguments))))
154 }
155 RawStreamingChoice::FinalResponse(response) => {
156 stream.response = Some(response);
158
159 stream.poll_next_unpin(cx)
160 }
161 },
162 }
163 }
164}
165
166pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
168 fn stream_prompt(
170 &self,
171 prompt: impl Into<Message> + Send,
172 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
173}
174
175pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
177 fn stream_chat(
179 &self,
180 prompt: impl Into<Message> + Send,
181 chat_history: Vec<Message>,
182 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
183}
184
185pub trait StreamingCompletion<M: CompletionModel> {
187 fn stream_completion(
189 &self,
190 prompt: impl Into<Message> + Send,
191 chat_history: Vec<Message>,
192 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
193}
194
195pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
196 pub(crate) inner: StreamingResult<R>,
197}
198
199impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
200 type Item = Result<RawStreamingChoice<()>, CompletionError>;
201
202 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
203 let stream = self.get_mut();
204
205 match stream.inner.as_mut().poll_next(cx) {
206 Poll::Pending => Poll::Pending,
207 Poll::Ready(None) => Poll::Ready(None),
208 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
209 Poll::Ready(Some(Ok(chunk))) => match chunk {
210 RawStreamingChoice::FinalResponse(_) => {
211 Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
212 }
213 RawStreamingChoice::Message(m) => {
214 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
215 }
216 RawStreamingChoice::ToolCall {
217 id,
218 name,
219 arguments,
220 call_id,
221 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
222 id,
223 name,
224 arguments,
225 call_id,
226 }))),
227 },
228 }
229 }
230}
231
232pub async fn stream_to_stdout<M: CompletionModel>(
234 agent: &Agent<M>,
235 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
236) -> Result<(), std::io::Error> {
237 print!("Response: ");
238 while let Some(chunk) = stream.next().await {
239 match chunk {
240 Ok(AssistantContent::Text(text)) => {
241 print!("{}", text.text);
242 std::io::Write::flush(&mut std::io::stdout())?;
243 }
244 Ok(AssistantContent::ToolCall(tool_call)) => {
245 let res = agent
246 .tools
247 .call(
248 &tool_call.function.name,
249 tool_call.function.arguments.to_string(),
250 )
251 .await
252 .map_err(|e| std::io::Error::other(e.to_string()))?;
253 println!("\nResult: {res}");
254 }
255 Err(e) => {
256 if e.to_string().contains("aborted") {
257 println!("\nStream cancelled.");
258 break;
259 }
260 eprintln!("Error: {e}");
261 break;
262 }
263 }
264 }
265
266 println!(); Ok(())
269}
270
271#[cfg(test)]
273mod tests {
274 use std::time::Duration;
275
276 use super::*;
277 use async_stream::stream;
278 use tokio::time::sleep;
279
280 #[derive(Debug, Clone)]
281 pub struct MockResponse {
282 #[allow(dead_code)]
283 token_count: u32,
284 }
285
286 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
287 let stream = stream! {
288 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
289 sleep(Duration::from_millis(100)).await;
290 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
291 sleep(Duration::from_millis(100)).await;
292 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
293 sleep(Duration::from_millis(100)).await;
294 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
295 };
296
297 #[cfg(not(target_arch = "wasm32"))]
298 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
299 #[cfg(target_arch = "wasm32")]
300 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
301
302 StreamingCompletionResponse::stream(pinned_stream)
303 }
304
305 #[tokio::test]
306 async fn test_stream_cancellation() {
307 let mut stream = create_mock_stream();
308
309 println!("Response: ");
310 let mut chunk_count = 0;
311 while let Some(chunk) = stream.next().await {
312 match chunk {
313 Ok(AssistantContent::Text(text)) => {
314 print!("{}", text.text);
315 std::io::Write::flush(&mut std::io::stdout()).unwrap();
316 chunk_count += 1;
317 }
318 Ok(AssistantContent::ToolCall(tc)) => {
319 println!("\nTool Call: {tc:?}");
320 chunk_count += 1;
321 }
322 Err(e) => {
323 eprintln!("Error: {e:?}");
324 break;
325 }
326 }
327
328 if chunk_count >= 2 {
329 println!("\nCancelling stream...");
330 stream.cancel();
331 println!("Stream cancelled.");
332 break;
333 }
334 }
335
336 let next_chunk = stream.next().await;
337 assert!(
338 next_chunk.is_none(),
339 "Expected no further chunks after cancellation, got {next_chunk:?}"
340 );
341 }
342}