1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::completion::{
14 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message, Usage,
15};
16use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
17use futures::stream::{AbortHandle, Abortable};
18use futures::{Stream, StreamExt};
19use serde::{Deserialize, Serialize};
20use std::boxed::Box;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::atomic::AtomicBool;
24use std::task::{Context, Poll};
25
26#[derive(Debug, Clone)]
28pub enum RawStreamingChoice<R: Clone> {
29 Message(String),
31
32 ToolCall {
34 id: String,
35 call_id: Option<String>,
36 name: String,
37 arguments: serde_json::Value,
38 },
39 Reasoning { reasoning: String },
41
42 FinalResponse(R),
45}
46
47#[cfg(not(target_arch = "wasm32"))]
48pub type StreamingResult<R> =
49 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
50
51#[cfg(target_arch = "wasm32")]
52pub type StreamingResult<R> =
53 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
54
55pub struct StreamingCompletionResponse<R: Clone + Unpin> {
59 pub(crate) inner: Abortable<StreamingResult<R>>,
60 pub(crate) abort_handle: AbortHandle,
61 text: String,
62 reasoning: String,
63 tool_calls: Vec<ToolCall>,
64 pub choice: OneOrMany<AssistantContent>,
67 pub response: Option<R>,
70 pub final_response_yielded: AtomicBool,
71}
72
73impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
74 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
75 let (abort_handle, abort_registration) = AbortHandle::new_pair();
76 let abortable_stream = Abortable::new(inner, abort_registration);
77 Self {
78 inner: abortable_stream,
79 abort_handle,
80 reasoning: String::new(),
81 text: "".to_string(),
82 tool_calls: vec![],
83 choice: OneOrMany::one(AssistantContent::text("")),
84 response: None,
85 final_response_yielded: AtomicBool::new(false),
86 }
87 }
88
89 pub fn cancel(&self) {
90 self.abort_handle.abort();
91 }
92}
93
94impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
95 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
96 CompletionResponse {
97 choice: value.choice,
98 usage: Usage::new(), raw_response: value.response,
100 }
101 }
102}
103
104impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
105 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
106
107 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108 let stream = self.get_mut();
109
110 match Pin::new(&mut stream.inner).poll_next(cx) {
111 Poll::Pending => Poll::Pending,
112 Poll::Ready(None) => {
113 let mut choice = vec![];
116
117 stream.tool_calls.iter().for_each(|tc| {
118 choice.push(AssistantContent::ToolCall(tc.clone()));
119 });
120
121 if choice.is_empty() || !stream.text.is_empty() {
123 choice.insert(0, AssistantContent::text(stream.text.clone()));
124 }
125
126 stream.choice = OneOrMany::many(choice)
127 .expect("There should be at least one assistant message");
128
129 Poll::Ready(None)
130 }
131 Poll::Ready(Some(Err(err))) => {
132 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
133 {
134 return Poll::Ready(None); }
136 Poll::Ready(Some(Err(err)))
137 }
138 Poll::Ready(Some(Ok(choice))) => match choice {
139 RawStreamingChoice::Message(text) => {
140 stream.text = format!("{}{}", stream.text, text.clone());
143 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
144 }
145 RawStreamingChoice::Reasoning { reasoning } => {
146 stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
149 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
150 reasoning,
151 }))))
152 }
153 RawStreamingChoice::ToolCall {
154 id,
155 name,
156 arguments,
157 call_id,
158 } => {
159 stream.tool_calls.push(ToolCall {
162 id: id.clone(),
163 call_id: call_id.clone(),
164 function: ToolFunction {
165 name: name.clone(),
166 arguments: arguments.clone(),
167 },
168 });
169 if let Some(call_id) = call_id {
170 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
171 id, call_id, name, arguments,
172 ))))
173 } else {
174 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
175 id, name, arguments,
176 ))))
177 }
178 }
179 RawStreamingChoice::FinalResponse(response) => {
180 if stream
181 .final_response_yielded
182 .load(std::sync::atomic::Ordering::SeqCst)
183 {
184 stream.poll_next_unpin(cx)
185 } else {
186 stream.response = Some(response.clone());
188 stream
189 .final_response_yielded
190 .store(true, std::sync::atomic::Ordering::SeqCst);
191 let final_response = StreamedAssistantContent::final_response(response);
192 Poll::Ready(Some(Ok(final_response)))
193 }
194 }
195 },
196 }
197 }
198}
199
200pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
202 fn stream_prompt(
204 &self,
205 prompt: impl Into<Message> + Send,
206 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
207}
208
209pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
211 fn stream_chat(
213 &self,
214 prompt: impl Into<Message> + Send,
215 chat_history: Vec<Message>,
216 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
217}
218
219pub trait StreamingCompletion<M: CompletionModel> {
221 fn stream_completion(
223 &self,
224 prompt: impl Into<Message> + Send,
225 chat_history: Vec<Message>,
226 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
227}
228
229pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
230 pub(crate) inner: StreamingResult<R>,
231}
232
233impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
234 type Item = Result<RawStreamingChoice<()>, CompletionError>;
235
236 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237 let stream = self.get_mut();
238
239 match stream.inner.as_mut().poll_next(cx) {
240 Poll::Pending => Poll::Pending,
241 Poll::Ready(None) => Poll::Ready(None),
242 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
243 Poll::Ready(Some(Ok(chunk))) => match chunk {
244 RawStreamingChoice::FinalResponse(_) => {
245 Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
246 }
247 RawStreamingChoice::Message(m) => {
248 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
249 }
250 RawStreamingChoice::Reasoning { reasoning } => {
251 Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { reasoning })))
252 }
253 RawStreamingChoice::ToolCall {
254 id,
255 name,
256 arguments,
257 call_id,
258 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
259 id,
260 name,
261 arguments,
262 call_id,
263 }))),
264 },
265 }
266 }
267}
268
269pub async fn stream_to_stdout<M: CompletionModel>(
271 agent: &Agent<M>,
272 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
273) -> Result<(), std::io::Error> {
274 let mut is_reasoning = false;
275 print!("Response: ");
276 while let Some(chunk) = stream.next().await {
277 match chunk {
278 Ok(StreamedAssistantContent::Text(text)) => {
279 if is_reasoning {
280 is_reasoning = false;
281 println!("\n---\n");
282 }
283 print!("{}", text.text);
284 std::io::Write::flush(&mut std::io::stdout())?;
285 }
286 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
287 let res = agent
288 .tools
289 .call(
290 &tool_call.function.name,
291 tool_call.function.arguments.to_string(),
292 )
293 .await
294 .map_err(|e| std::io::Error::other(e.to_string()))?;
295 println!("\nResult: {res}");
296 }
297 Ok(StreamedAssistantContent::Final(res)) => {
298 let json_res = serde_json::to_string_pretty(&res).unwrap();
299 println!();
300 tracing::info!("Final result: {json_res}");
301 }
302 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning })) => {
303 if !is_reasoning {
304 is_reasoning = true;
305 println!();
306 println!("Thinking: ");
307 }
308 print!("{reasoning}");
309 std::io::Write::flush(&mut std::io::stdout())?;
310 }
311 Err(e) => {
312 if e.to_string().contains("aborted") {
313 println!("\nStream cancelled.");
314 break;
315 }
316 eprintln!("Error: {e}");
317 break;
318 }
319 }
320 }
321
322 println!(); Ok(())
325}
326
327#[cfg(test)]
329mod tests {
330 use std::time::Duration;
331
332 use super::*;
333 use async_stream::stream;
334 use tokio::time::sleep;
335
336 #[derive(Debug, Clone)]
337 pub struct MockResponse {
338 #[allow(dead_code)]
339 token_count: u32,
340 }
341
342 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
343 let stream = stream! {
344 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
345 sleep(Duration::from_millis(100)).await;
346 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
347 sleep(Duration::from_millis(100)).await;
348 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
349 sleep(Duration::from_millis(100)).await;
350 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
351 };
352
353 #[cfg(not(target_arch = "wasm32"))]
354 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
355 #[cfg(target_arch = "wasm32")]
356 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
357
358 StreamingCompletionResponse::stream(pinned_stream)
359 }
360
361 #[tokio::test]
362 async fn test_stream_cancellation() {
363 let mut stream = create_mock_stream();
364
365 println!("Response: ");
366 let mut chunk_count = 0;
367 while let Some(chunk) = stream.next().await {
368 match chunk {
369 Ok(StreamedAssistantContent::Text(text)) => {
370 print!("{}", text.text);
371 std::io::Write::flush(&mut std::io::stdout()).unwrap();
372 chunk_count += 1;
373 }
374 Ok(StreamedAssistantContent::ToolCall(tc)) => {
375 println!("\nTool Call: {tc:?}");
376 chunk_count += 1;
377 }
378 Ok(StreamedAssistantContent::Final(res)) => {
379 println!("\nFinal response: {res:?}");
380 }
381 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning })) => {
382 print!("{reasoning}");
383 std::io::Write::flush(&mut std::io::stdout()).unwrap();
384 }
385 Err(e) => {
386 eprintln!("Error: {e:?}");
387 break;
388 }
389 }
390
391 if chunk_count >= 2 {
392 println!("\nCancelling stream...");
393 stream.cancel();
394 println!("Stream cancelled.");
395 break;
396 }
397 }
398
399 let next_chunk = stream.next().await;
400 assert!(
401 next_chunk.is_none(),
402 "Expected no further chunks after cancellation, got {next_chunk:?}"
403 );
404 }
405}
406
407#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
409#[serde(untagged)]
410pub enum StreamedAssistantContent<R> {
411 Text(Text),
412 ToolCall(ToolCall),
413 Reasoning(Reasoning),
414 Final(R),
415}
416
417impl<R> StreamedAssistantContent<R>
418where
419 R: Clone + Unpin,
420{
421 pub fn text(text: &str) -> Self {
422 Self::Text(Text {
423 text: text.to_string(),
424 })
425 }
426
427 pub fn tool_call(
429 id: impl Into<String>,
430 name: impl Into<String>,
431 arguments: serde_json::Value,
432 ) -> Self {
433 Self::ToolCall(ToolCall {
434 id: id.into(),
435 call_id: None,
436 function: ToolFunction {
437 name: name.into(),
438 arguments,
439 },
440 })
441 }
442
443 pub fn tool_call_with_call_id(
444 id: impl Into<String>,
445 call_id: String,
446 name: impl Into<String>,
447 arguments: serde_json::Value,
448 ) -> Self {
449 Self::ToolCall(ToolCall {
450 id: id.into(),
451 call_id: Some(call_id),
452 function: ToolFunction {
453 name: name.into(),
454 arguments,
455 },
456 })
457 }
458
459 pub fn final_response(res: R) -> Self {
460 Self::Final(res)
461 }
462}