1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::completion::{
15 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
16 Message, Usage,
17};
18use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
19use futures::stream::{AbortHandle, Abortable};
20use futures::{Stream, StreamExt};
21use serde::{Deserialize, Serialize};
22use std::boxed::Box;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::atomic::AtomicBool;
26use std::task::{Context, Poll};
27use tokio::sync::watch;
28
29pub struct PauseControl {
31 pub(crate) paused_tx: watch::Sender<bool>,
32 pub(crate) paused_rx: watch::Receiver<bool>,
33}
34
35impl PauseControl {
36 pub fn new() -> Self {
37 let (paused_tx, paused_rx) = watch::channel(false);
38 Self {
39 paused_tx,
40 paused_rx,
41 }
42 }
43
44 pub fn pause(&self) {
45 self.paused_tx.send(true).unwrap();
46 }
47
48 pub fn resume(&self) {
49 self.paused_tx.send(false).unwrap();
50 }
51
52 pub fn is_paused(&self) -> bool {
53 *self.paused_rx.borrow()
54 }
55}
56
57impl Default for PauseControl {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[derive(Debug, Clone)]
65pub enum RawStreamingChoice<R>
66where
67 R: Clone,
68{
69 Message(String),
71
72 ToolCall {
74 id: String,
75 call_id: Option<String>,
76 name: String,
77 arguments: serde_json::Value,
78 },
79 Reasoning {
81 id: Option<String>,
82 reasoning: String,
83 },
84
85 FinalResponse(R),
88}
89
90#[cfg(not(target_arch = "wasm32"))]
91pub type StreamingResult<R> =
92 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
93
94#[cfg(target_arch = "wasm32")]
95pub type StreamingResult<R> =
96 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
97
98pub struct StreamingCompletionResponse<R>
102where
103 R: Clone + Unpin + GetTokenUsage,
104{
105 pub(crate) inner: Abortable<StreamingResult<R>>,
106 pub(crate) abort_handle: AbortHandle,
107 pub(crate) pause_control: PauseControl,
108 text: String,
109 reasoning: String,
110 tool_calls: Vec<ToolCall>,
111 pub choice: OneOrMany<AssistantContent>,
114 pub response: Option<R>,
117 pub final_response_yielded: AtomicBool,
118}
119
120impl<R> StreamingCompletionResponse<R>
121where
122 R: Clone + Unpin + GetTokenUsage,
123{
124 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
125 let (abort_handle, abort_registration) = AbortHandle::new_pair();
126 let abortable_stream = Abortable::new(inner, abort_registration);
127 let pause_control = PauseControl::new();
128 Self {
129 inner: abortable_stream,
130 abort_handle,
131 pause_control,
132 reasoning: String::new(),
133 text: "".to_string(),
134 tool_calls: vec![],
135 choice: OneOrMany::one(AssistantContent::text("")),
136 response: None,
137 final_response_yielded: AtomicBool::new(false),
138 }
139 }
140
141 pub fn cancel(&self) {
142 self.abort_handle.abort();
143 }
144
145 pub fn pause(&self) {
146 self.pause_control.pause();
147 }
148
149 pub fn resume(&self) {
150 self.pause_control.resume();
151 }
152
153 pub fn is_paused(&self) -> bool {
154 self.pause_control.is_paused()
155 }
156}
157
158impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
159where
160 R: Clone + Unpin + GetTokenUsage,
161{
162 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
163 CompletionResponse {
164 choice: value.choice,
165 usage: Usage::new(), raw_response: value.response,
167 }
168 }
169}
170
171impl<R> Stream for StreamingCompletionResponse<R>
172where
173 R: Clone + Unpin + GetTokenUsage,
174{
175 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
176
177 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 let stream = self.get_mut();
179
180 if stream.is_paused() {
181 cx.waker().wake_by_ref();
182 return Poll::Pending;
183 }
184
185 match Pin::new(&mut stream.inner).poll_next(cx) {
186 Poll::Pending => Poll::Pending,
187 Poll::Ready(None) => {
188 let mut choice = vec![];
191
192 stream.tool_calls.iter().for_each(|tc| {
193 choice.push(AssistantContent::ToolCall(tc.clone()));
194 });
195
196 if choice.is_empty() || !stream.text.is_empty() {
198 choice.insert(0, AssistantContent::text(stream.text.clone()));
199 }
200
201 stream.choice = OneOrMany::many(choice)
202 .expect("There should be at least one assistant message");
203
204 Poll::Ready(None)
205 }
206 Poll::Ready(Some(Err(err))) => {
207 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
208 {
209 return Poll::Ready(None); }
211 Poll::Ready(Some(Err(err)))
212 }
213 Poll::Ready(Some(Ok(choice))) => match choice {
214 RawStreamingChoice::Message(text) => {
215 stream.text = format!("{}{}", stream.text, text.clone());
218 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
219 }
220 RawStreamingChoice::Reasoning { id, reasoning } => {
221 stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
224 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
225 id,
226 reasoning: vec![stream.reasoning.clone()],
227 }))))
228 }
229 RawStreamingChoice::ToolCall {
230 id,
231 name,
232 arguments,
233 call_id,
234 } => {
235 stream.tool_calls.push(ToolCall {
238 id: id.clone(),
239 call_id: call_id.clone(),
240 function: ToolFunction {
241 name: name.clone(),
242 arguments: arguments.clone(),
243 },
244 });
245 if let Some(call_id) = call_id {
246 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
247 id, call_id, name, arguments,
248 ))))
249 } else {
250 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
251 id, name, arguments,
252 ))))
253 }
254 }
255 RawStreamingChoice::FinalResponse(response) => {
256 if stream
257 .final_response_yielded
258 .load(std::sync::atomic::Ordering::SeqCst)
259 {
260 stream.poll_next_unpin(cx)
261 } else {
262 stream.response = Some(response.clone());
264 stream
265 .final_response_yielded
266 .store(true, std::sync::atomic::Ordering::SeqCst);
267 let final_response = StreamedAssistantContent::final_response(response);
268 Poll::Ready(Some(Ok(final_response)))
269 }
270 }
271 },
272 }
273 }
274}
275
276pub trait StreamingPrompt<M, R>
278where
279 M: CompletionModel + 'static,
280 <M as CompletionModel>::StreamingResponse: Send,
281 R: Clone + Unpin + GetTokenUsage,
282{
283 fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()>;
285}
286
287pub trait StreamingChat<M, R>: Send + Sync
289where
290 M: CompletionModel + 'static,
291 <M as CompletionModel>::StreamingResponse: Send,
292 R: Clone + Unpin + GetTokenUsage,
293{
294 fn stream_chat(
296 &self,
297 prompt: impl Into<Message> + Send,
298 chat_history: Vec<Message>,
299 ) -> StreamingPromptRequest<M, ()>;
300}
301
302pub trait StreamingCompletion<M: CompletionModel> {
304 fn stream_completion(
306 &self,
307 prompt: impl Into<Message> + Send,
308 chat_history: Vec<Message>,
309 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
310}
311
312pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
313 pub(crate) inner: StreamingResult<R>,
314}
315
316impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
317 type Item = Result<RawStreamingChoice<()>, CompletionError>;
318
319 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320 let stream = self.get_mut();
321
322 match stream.inner.as_mut().poll_next(cx) {
323 Poll::Pending => Poll::Pending,
324 Poll::Ready(None) => Poll::Ready(None),
325 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
326 Poll::Ready(Some(Ok(chunk))) => match chunk {
327 RawStreamingChoice::FinalResponse(_) => {
328 Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
329 }
330 RawStreamingChoice::Message(m) => {
331 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
332 }
333 RawStreamingChoice::Reasoning { id, reasoning } => {
334 Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { id, reasoning })))
335 }
336 RawStreamingChoice::ToolCall {
337 id,
338 name,
339 arguments,
340 call_id,
341 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
342 id,
343 name,
344 arguments,
345 call_id,
346 }))),
347 },
348 }
349 }
350}
351
352pub async fn stream_to_stdout<M>(
354 agent: &Agent<M>,
355 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
356) -> Result<(), std::io::Error>
357where
358 M: CompletionModel,
359{
360 let mut is_reasoning = false;
361 print!("Response: ");
362 while let Some(chunk) = stream.next().await {
363 match chunk {
364 Ok(StreamedAssistantContent::Text(text)) => {
365 if is_reasoning {
366 is_reasoning = false;
367 println!("\n---\n");
368 }
369 print!("{}", text.text);
370 std::io::Write::flush(&mut std::io::stdout())?;
371 }
372 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
373 let res = agent
374 .tools
375 .call(
376 &tool_call.function.name,
377 tool_call.function.arguments.to_string(),
378 )
379 .await
380 .map_err(|e| std::io::Error::other(e.to_string()))?;
381 println!("\nResult: {res}");
382 }
383 Ok(StreamedAssistantContent::Final(res)) => {
384 let json_res = serde_json::to_string_pretty(&res).unwrap();
385 println!();
386 tracing::info!("Final result: {json_res}");
387 }
388 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
389 if !is_reasoning {
390 is_reasoning = true;
391 println!();
392 println!("Thinking: ");
393 }
394 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
395
396 print!("{reasoning}");
397 std::io::Write::flush(&mut std::io::stdout())?;
398 }
399 Err(e) => {
400 if e.to_string().contains("aborted") {
401 println!("\nStream cancelled.");
402 break;
403 }
404 eprintln!("Error: {e}");
405 break;
406 }
407 }
408 }
409
410 println!(); Ok(())
413}
414
415#[cfg(test)]
417mod tests {
418 use std::time::Duration;
419
420 use super::*;
421 use async_stream::stream;
422 use tokio::time::sleep;
423
424 #[derive(Debug, Clone)]
425 pub struct MockResponse {
426 #[allow(dead_code)]
427 token_count: u32,
428 }
429
430 impl GetTokenUsage for MockResponse {
431 fn token_usage(&self) -> Option<crate::completion::Usage> {
432 let mut usage = Usage::new();
433 usage.total_tokens = 15;
434 Some(usage)
435 }
436 }
437
438 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
439 let stream = stream! {
440 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
441 sleep(Duration::from_millis(100)).await;
442 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
443 sleep(Duration::from_millis(100)).await;
444 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
445 sleep(Duration::from_millis(100)).await;
446 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
447 };
448
449 #[cfg(not(target_arch = "wasm32"))]
450 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
451 #[cfg(target_arch = "wasm32")]
452 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
453
454 StreamingCompletionResponse::stream(pinned_stream)
455 }
456
457 #[tokio::test]
458 async fn test_stream_cancellation() {
459 let mut stream = create_mock_stream();
460
461 println!("Response: ");
462 let mut chunk_count = 0;
463 while let Some(chunk) = stream.next().await {
464 match chunk {
465 Ok(StreamedAssistantContent::Text(text)) => {
466 print!("{}", text.text);
467 std::io::Write::flush(&mut std::io::stdout()).unwrap();
468 chunk_count += 1;
469 }
470 Ok(StreamedAssistantContent::ToolCall(tc)) => {
471 println!("\nTool Call: {tc:?}");
472 chunk_count += 1;
473 }
474 Ok(StreamedAssistantContent::Final(res)) => {
475 println!("\nFinal response: {res:?}");
476 }
477 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
478 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
479 print!("{reasoning}");
480 std::io::Write::flush(&mut std::io::stdout()).unwrap();
481 }
482 Err(e) => {
483 eprintln!("Error: {e:?}");
484 break;
485 }
486 }
487
488 if chunk_count >= 2 {
489 println!("\nCancelling stream...");
490 stream.cancel();
491 println!("Stream cancelled.");
492 break;
493 }
494 }
495
496 let next_chunk = stream.next().await;
497 assert!(
498 next_chunk.is_none(),
499 "Expected no further chunks after cancellation, got {next_chunk:?}"
500 );
501 }
502
503 #[tokio::test]
504 async fn test_stream_pause_resume() {
505 let stream = create_mock_stream();
506
507 stream.pause();
509 assert!(stream.is_paused());
510
511 stream.resume();
513 assert!(!stream.is_paused());
514 }
515}
516
517#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
519#[serde(untagged)]
520pub enum StreamedAssistantContent<R> {
521 Text(Text),
522 ToolCall(ToolCall),
523 Reasoning(Reasoning),
524 Final(R),
525}
526
527impl<R> StreamedAssistantContent<R>
528where
529 R: Clone + Unpin,
530{
531 pub fn text(text: &str) -> Self {
532 Self::Text(Text {
533 text: text.to_string(),
534 })
535 }
536
537 pub fn tool_call(
539 id: impl Into<String>,
540 name: impl Into<String>,
541 arguments: serde_json::Value,
542 ) -> Self {
543 Self::ToolCall(ToolCall {
544 id: id.into(),
545 call_id: None,
546 function: ToolFunction {
547 name: name.into(),
548 arguments,
549 },
550 })
551 }
552
553 pub fn tool_call_with_call_id(
554 id: impl Into<String>,
555 call_id: String,
556 name: impl Into<String>,
557 arguments: serde_json::Value,
558 ) -> Self {
559 Self::ToolCall(ToolCall {
560 id: id.into(),
561 call_id: Some(call_id),
562 function: ToolFunction {
563 name: name.into(),
564 arguments,
565 },
566 })
567 }
568
569 pub fn final_response(res: R) -> Self {
570 Self::Final(res)
571 }
572}