1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::completion::{
15 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
16 Message, Usage,
17};
18use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
19use futures::stream::{AbortHandle, Abortable};
20use futures::{Stream, StreamExt};
21use serde::{Deserialize, Serialize};
22use std::boxed::Box;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::atomic::AtomicBool;
26use std::task::{Context, Poll};
27
28#[derive(Debug, Clone)]
30pub enum RawStreamingChoice<R: Clone> {
31 Message(String),
33
34 ToolCall {
36 id: String,
37 call_id: Option<String>,
38 name: String,
39 arguments: serde_json::Value,
40 },
41 Reasoning {
43 id: Option<String>,
44 reasoning: String,
45 },
46
47 FinalResponse(R),
50}
51
52#[cfg(not(target_arch = "wasm32"))]
53pub type StreamingResult<R> =
54 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
55
56#[cfg(target_arch = "wasm32")]
57pub type StreamingResult<R> =
58 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
59
60pub struct StreamingCompletionResponse<R: Clone + Unpin + GetTokenUsage> {
64 pub(crate) inner: Abortable<StreamingResult<R>>,
65 pub(crate) abort_handle: AbortHandle,
66 text: String,
67 reasoning: String,
68 tool_calls: Vec<ToolCall>,
69 pub choice: OneOrMany<AssistantContent>,
72 pub response: Option<R>,
75 pub final_response_yielded: AtomicBool,
76}
77
78impl<R> StreamingCompletionResponse<R>
79where
80 R: Clone + Unpin + GetTokenUsage,
81{
82 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
83 let (abort_handle, abort_registration) = AbortHandle::new_pair();
84 let abortable_stream = Abortable::new(inner, abort_registration);
85 Self {
86 inner: abortable_stream,
87 abort_handle,
88 reasoning: String::new(),
89 text: "".to_string(),
90 tool_calls: vec![],
91 choice: OneOrMany::one(AssistantContent::text("")),
92 response: None,
93 final_response_yielded: AtomicBool::new(false),
94 }
95 }
96
97 pub fn cancel(&self) {
98 self.abort_handle.abort();
99 }
100}
101
102impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
103where
104 R: Clone + Unpin + GetTokenUsage,
105{
106 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
107 CompletionResponse {
108 choice: value.choice,
109 usage: Usage::new(), raw_response: value.response,
111 }
112 }
113}
114
115impl<R> Stream for StreamingCompletionResponse<R>
116where
117 R: Clone + Unpin + GetTokenUsage,
118{
119 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
120
121 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 let stream = self.get_mut();
123
124 match Pin::new(&mut stream.inner).poll_next(cx) {
125 Poll::Pending => Poll::Pending,
126 Poll::Ready(None) => {
127 let mut choice = vec![];
130
131 stream.tool_calls.iter().for_each(|tc| {
132 choice.push(AssistantContent::ToolCall(tc.clone()));
133 });
134
135 if choice.is_empty() || !stream.text.is_empty() {
137 choice.insert(0, AssistantContent::text(stream.text.clone()));
138 }
139
140 stream.choice = OneOrMany::many(choice)
141 .expect("There should be at least one assistant message");
142
143 Poll::Ready(None)
144 }
145 Poll::Ready(Some(Err(err))) => {
146 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
147 {
148 return Poll::Ready(None); }
150 Poll::Ready(Some(Err(err)))
151 }
152 Poll::Ready(Some(Ok(choice))) => match choice {
153 RawStreamingChoice::Message(text) => {
154 stream.text = format!("{}{}", stream.text, text.clone());
157 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
158 }
159 RawStreamingChoice::Reasoning { id, reasoning } => {
160 stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
163 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
164 id,
165 reasoning: vec![stream.reasoning.clone()],
166 }))))
167 }
168 RawStreamingChoice::ToolCall {
169 id,
170 name,
171 arguments,
172 call_id,
173 } => {
174 stream.tool_calls.push(ToolCall {
177 id: id.clone(),
178 call_id: call_id.clone(),
179 function: ToolFunction {
180 name: name.clone(),
181 arguments: arguments.clone(),
182 },
183 });
184 if let Some(call_id) = call_id {
185 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
186 id, call_id, name, arguments,
187 ))))
188 } else {
189 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
190 id, name, arguments,
191 ))))
192 }
193 }
194 RawStreamingChoice::FinalResponse(response) => {
195 if stream
196 .final_response_yielded
197 .load(std::sync::atomic::Ordering::SeqCst)
198 {
199 stream.poll_next_unpin(cx)
200 } else {
201 stream.response = Some(response.clone());
203 stream
204 .final_response_yielded
205 .store(true, std::sync::atomic::Ordering::SeqCst);
206 let final_response = StreamedAssistantContent::final_response(response);
207 Poll::Ready(Some(Ok(final_response)))
208 }
209 }
210 },
211 }
212 }
213}
214
215pub trait StreamingPrompt<M, R>
217where
218 M: CompletionModel + 'static,
219 <M as CompletionModel>::StreamingResponse: Send,
220 R: Clone + Unpin + GetTokenUsage,
221{
222 fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<'_, M>;
224}
225
226pub trait StreamingChat<M, R>: Send + Sync
228where
229 M: CompletionModel + 'static,
230 <M as CompletionModel>::StreamingResponse: Send,
231 R: Clone + Unpin + GetTokenUsage,
232{
233 fn stream_chat(
235 &self,
236 prompt: impl Into<Message> + Send,
237 chat_history: Vec<Message>,
238 ) -> StreamingPromptRequest<'_, M>;
239}
240
241pub trait StreamingCompletion<M: CompletionModel> {
243 fn stream_completion(
245 &self,
246 prompt: impl Into<Message> + Send,
247 chat_history: Vec<Message>,
248 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
249}
250
251pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
252 pub(crate) inner: StreamingResult<R>,
253}
254
255impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
256 type Item = Result<RawStreamingChoice<()>, CompletionError>;
257
258 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
259 let stream = self.get_mut();
260
261 match stream.inner.as_mut().poll_next(cx) {
262 Poll::Pending => Poll::Pending,
263 Poll::Ready(None) => Poll::Ready(None),
264 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
265 Poll::Ready(Some(Ok(chunk))) => match chunk {
266 RawStreamingChoice::FinalResponse(_) => {
267 Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
268 }
269 RawStreamingChoice::Message(m) => {
270 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
271 }
272 RawStreamingChoice::Reasoning { id, reasoning } => {
273 Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { id, reasoning })))
274 }
275 RawStreamingChoice::ToolCall {
276 id,
277 name,
278 arguments,
279 call_id,
280 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
281 id,
282 name,
283 arguments,
284 call_id,
285 }))),
286 },
287 }
288 }
289}
290
291pub async fn stream_to_stdout<M: CompletionModel>(
293 agent: &Agent<M>,
294 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
295) -> Result<(), std::io::Error> {
296 let mut is_reasoning = false;
297 print!("Response: ");
298 while let Some(chunk) = stream.next().await {
299 match chunk {
300 Ok(StreamedAssistantContent::Text(text)) => {
301 if is_reasoning {
302 is_reasoning = false;
303 println!("\n---\n");
304 }
305 print!("{}", text.text);
306 std::io::Write::flush(&mut std::io::stdout())?;
307 }
308 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
309 let res = agent
310 .tools
311 .call(
312 &tool_call.function.name,
313 tool_call.function.arguments.to_string(),
314 )
315 .await
316 .map_err(|e| std::io::Error::other(e.to_string()))?;
317 println!("\nResult: {res}");
318 }
319 Ok(StreamedAssistantContent::Final(res)) => {
320 let json_res = serde_json::to_string_pretty(&res).unwrap();
321 println!();
322 tracing::info!("Final result: {json_res}");
323 }
324 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
325 if !is_reasoning {
326 is_reasoning = true;
327 println!();
328 println!("Thinking: ");
329 }
330 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
331
332 print!("{reasoning}");
333 std::io::Write::flush(&mut std::io::stdout())?;
334 }
335 Err(e) => {
336 if e.to_string().contains("aborted") {
337 println!("\nStream cancelled.");
338 break;
339 }
340 eprintln!("Error: {e}");
341 break;
342 }
343 }
344 }
345
346 println!(); Ok(())
349}
350
351#[cfg(test)]
353mod tests {
354 use std::time::Duration;
355
356 use super::*;
357 use async_stream::stream;
358 use tokio::time::sleep;
359
360 #[derive(Debug, Clone)]
361 pub struct MockResponse {
362 #[allow(dead_code)]
363 token_count: u32,
364 }
365
366 impl GetTokenUsage for MockResponse {
367 fn token_usage(&self) -> Option<crate::completion::Usage> {
368 let mut usage = Usage::new();
369 usage.total_tokens = 15;
370 Some(usage)
371 }
372 }
373
374 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
375 let stream = stream! {
376 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
377 sleep(Duration::from_millis(100)).await;
378 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
379 sleep(Duration::from_millis(100)).await;
380 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
381 sleep(Duration::from_millis(100)).await;
382 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
383 };
384
385 #[cfg(not(target_arch = "wasm32"))]
386 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
387 #[cfg(target_arch = "wasm32")]
388 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
389
390 StreamingCompletionResponse::stream(pinned_stream)
391 }
392
393 #[tokio::test]
394 async fn test_stream_cancellation() {
395 let mut stream = create_mock_stream();
396
397 println!("Response: ");
398 let mut chunk_count = 0;
399 while let Some(chunk) = stream.next().await {
400 match chunk {
401 Ok(StreamedAssistantContent::Text(text)) => {
402 print!("{}", text.text);
403 std::io::Write::flush(&mut std::io::stdout()).unwrap();
404 chunk_count += 1;
405 }
406 Ok(StreamedAssistantContent::ToolCall(tc)) => {
407 println!("\nTool Call: {tc:?}");
408 chunk_count += 1;
409 }
410 Ok(StreamedAssistantContent::Final(res)) => {
411 println!("\nFinal response: {res:?}");
412 }
413 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
414 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
415 print!("{reasoning}");
416 std::io::Write::flush(&mut std::io::stdout()).unwrap();
417 }
418 Err(e) => {
419 eprintln!("Error: {e:?}");
420 break;
421 }
422 }
423
424 if chunk_count >= 2 {
425 println!("\nCancelling stream...");
426 stream.cancel();
427 println!("Stream cancelled.");
428 break;
429 }
430 }
431
432 let next_chunk = stream.next().await;
433 assert!(
434 next_chunk.is_none(),
435 "Expected no further chunks after cancellation, got {next_chunk:?}"
436 );
437 }
438}
439
440#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
442#[serde(untagged)]
443pub enum StreamedAssistantContent<R> {
444 Text(Text),
445 ToolCall(ToolCall),
446 Reasoning(Reasoning),
447 Final(R),
448}
449
450impl<R> StreamedAssistantContent<R>
451where
452 R: Clone + Unpin,
453{
454 pub fn text(text: &str) -> Self {
455 Self::Text(Text {
456 text: text.to_string(),
457 })
458 }
459
460 pub fn tool_call(
462 id: impl Into<String>,
463 name: impl Into<String>,
464 arguments: serde_json::Value,
465 ) -> Self {
466 Self::ToolCall(ToolCall {
467 id: id.into(),
468 call_id: None,
469 function: ToolFunction {
470 name: name.into(),
471 arguments,
472 },
473 })
474 }
475
476 pub fn tool_call_with_call_id(
477 id: impl Into<String>,
478 call_id: String,
479 name: impl Into<String>,
480 arguments: serde_json::Value,
481 ) -> Self {
482 Self::ToolCall(ToolCall {
483 id: id.into(),
484 call_id: Some(call_id),
485 function: ToolFunction {
486 name: name.into(),
487 arguments,
488 },
489 })
490 }
491
492 pub fn final_response(res: R) -> Self {
493 Self::Final(res)
494 }
495}