rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::completion::{
15    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
16    Message, Usage,
17};
18use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
19use futures::stream::{AbortHandle, Abortable};
20use futures::{Stream, StreamExt};
21use serde::{Deserialize, Serialize};
22use std::boxed::Box;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::atomic::AtomicBool;
26use std::task::{Context, Poll};
27use tokio::sync::watch;
28
29/// Control for pausing and resuming a streaming response
30pub struct PauseControl {
31    pub(crate) paused_tx: watch::Sender<bool>,
32    pub(crate) paused_rx: watch::Receiver<bool>,
33}
34
35impl PauseControl {
36    pub fn new() -> Self {
37        let (paused_tx, paused_rx) = watch::channel(false);
38        Self {
39            paused_tx,
40            paused_rx,
41        }
42    }
43
44    pub fn pause(&self) {
45        self.paused_tx.send(true).unwrap();
46    }
47
48    pub fn resume(&self) {
49        self.paused_tx.send(false).unwrap();
50    }
51
52    pub fn is_paused(&self) -> bool {
53        *self.paused_rx.borrow()
54    }
55}
56
57impl Default for PauseControl {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63/// Enum representing a streaming chunk from the model
64#[derive(Debug, Clone)]
65pub enum RawStreamingChoice<R>
66where
67    R: Clone,
68{
69    /// A text chunk from a message response
70    Message(String),
71
72    /// A tool call response chunk
73    ToolCall {
74        id: String,
75        call_id: Option<String>,
76        name: String,
77        arguments: serde_json::Value,
78    },
79    /// A reasoning chunk
80    Reasoning {
81        id: Option<String>,
82        reasoning: String,
83    },
84
85    /// The final response object, must be yielded if you want the
86    /// `response` field to be populated on the `StreamingCompletionResponse`
87    FinalResponse(R),
88}
89
90#[cfg(not(target_arch = "wasm32"))]
91pub type StreamingResult<R> =
92    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
93
94#[cfg(target_arch = "wasm32")]
95pub type StreamingResult<R> =
96    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
97
98/// The response from a streaming completion request;
99/// message and response are populated at the end of the
100/// `inner` stream.
101pub struct StreamingCompletionResponse<R>
102where
103    R: Clone + Unpin + GetTokenUsage,
104{
105    pub(crate) inner: Abortable<StreamingResult<R>>,
106    pub(crate) abort_handle: AbortHandle,
107    pub(crate) pause_control: PauseControl,
108    text: String,
109    reasoning: String,
110    tool_calls: Vec<ToolCall>,
111    /// The final aggregated message from the stream
112    /// contains all text and tool calls generated
113    pub choice: OneOrMany<AssistantContent>,
114    /// The final response from the stream, may be `None`
115    /// if the provider didn't yield it during the stream
116    pub response: Option<R>,
117    pub final_response_yielded: AtomicBool,
118}
119
120impl<R> StreamingCompletionResponse<R>
121where
122    R: Clone + Unpin + GetTokenUsage,
123{
124    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
125        let (abort_handle, abort_registration) = AbortHandle::new_pair();
126        let abortable_stream = Abortable::new(inner, abort_registration);
127        let pause_control = PauseControl::new();
128        Self {
129            inner: abortable_stream,
130            abort_handle,
131            pause_control,
132            reasoning: String::new(),
133            text: "".to_string(),
134            tool_calls: vec![],
135            choice: OneOrMany::one(AssistantContent::text("")),
136            response: None,
137            final_response_yielded: AtomicBool::new(false),
138        }
139    }
140
141    pub fn cancel(&self) {
142        self.abort_handle.abort();
143    }
144
145    pub fn pause(&self) {
146        self.pause_control.pause();
147    }
148
149    pub fn resume(&self) {
150        self.pause_control.resume();
151    }
152
153    pub fn is_paused(&self) -> bool {
154        self.pause_control.is_paused()
155    }
156}
157
158impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
159where
160    R: Clone + Unpin + GetTokenUsage,
161{
162    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
163        CompletionResponse {
164            choice: value.choice,
165            usage: Usage::new(), // Usage is not tracked in streaming responses
166            raw_response: value.response,
167        }
168    }
169}
170
171impl<R> Stream for StreamingCompletionResponse<R>
172where
173    R: Clone + Unpin + GetTokenUsage,
174{
175    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
176
177    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        let stream = self.get_mut();
179
180        if stream.is_paused() {
181            cx.waker().wake_by_ref();
182            return Poll::Pending;
183        }
184
185        match Pin::new(&mut stream.inner).poll_next(cx) {
186            Poll::Pending => Poll::Pending,
187            Poll::Ready(None) => {
188                // This is run at the end of the inner stream to collect all tokens into
189                // a single unified `Message`.
190                let mut choice = vec![];
191
192                stream.tool_calls.iter().for_each(|tc| {
193                    choice.push(AssistantContent::ToolCall(tc.clone()));
194                });
195
196                // This is required to ensure there's always at least one item in the content
197                if choice.is_empty() || !stream.text.is_empty() {
198                    choice.insert(0, AssistantContent::text(stream.text.clone()));
199                }
200
201                stream.choice = OneOrMany::many(choice)
202                    .expect("There should be at least one assistant message");
203
204                Poll::Ready(None)
205            }
206            Poll::Ready(Some(Err(err))) => {
207                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
208                {
209                    return Poll::Ready(None); // Treat cancellation as stream termination
210                }
211                Poll::Ready(Some(Err(err)))
212            }
213            Poll::Ready(Some(Ok(choice))) => match choice {
214                RawStreamingChoice::Message(text) => {
215                    // Forward the streaming tokens to the outer stream
216                    // and concat the text together
217                    stream.text = format!("{}{}", stream.text, text.clone());
218                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
219                }
220                RawStreamingChoice::Reasoning { id, reasoning } => {
221                    // Forward the streaming tokens to the outer stream
222                    // and concat the text together
223                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
224                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
225                        id,
226                        reasoning: vec![stream.reasoning.clone()],
227                    }))))
228                }
229                RawStreamingChoice::ToolCall {
230                    id,
231                    name,
232                    arguments,
233                    call_id,
234                } => {
235                    // Keep track of each tool call to aggregate the final message later
236                    // and pass it to the outer stream
237                    stream.tool_calls.push(ToolCall {
238                        id: id.clone(),
239                        call_id: call_id.clone(),
240                        function: ToolFunction {
241                            name: name.clone(),
242                            arguments: arguments.clone(),
243                        },
244                    });
245                    if let Some(call_id) = call_id {
246                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
247                            id, call_id, name, arguments,
248                        ))))
249                    } else {
250                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
251                            id, name, arguments,
252                        ))))
253                    }
254                }
255                RawStreamingChoice::FinalResponse(response) => {
256                    if stream
257                        .final_response_yielded
258                        .load(std::sync::atomic::Ordering::SeqCst)
259                    {
260                        stream.poll_next_unpin(cx)
261                    } else {
262                        // Set the final response field and return the next item in the stream
263                        stream.response = Some(response.clone());
264                        stream
265                            .final_response_yielded
266                            .store(true, std::sync::atomic::Ordering::SeqCst);
267                        let final_response = StreamedAssistantContent::final_response(response);
268                        Poll::Ready(Some(Ok(final_response)))
269                    }
270                }
271            },
272        }
273    }
274}
275
276/// Trait for high-level streaming prompt interface
277pub trait StreamingPrompt<M, R>
278where
279    M: CompletionModel + 'static,
280    <M as CompletionModel>::StreamingResponse: Send,
281    R: Clone + Unpin + GetTokenUsage,
282{
283    /// Stream a simple prompt to the model
284    fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()>;
285}
286
287/// Trait for high-level streaming chat interface
288pub trait StreamingChat<M, R>: Send + Sync
289where
290    M: CompletionModel + 'static,
291    <M as CompletionModel>::StreamingResponse: Send,
292    R: Clone + Unpin + GetTokenUsage,
293{
294    /// Stream a chat with history to the model
295    fn stream_chat(
296        &self,
297        prompt: impl Into<Message> + Send,
298        chat_history: Vec<Message>,
299    ) -> StreamingPromptRequest<M, ()>;
300}
301
302/// Trait for low-level streaming completion interface
303pub trait StreamingCompletion<M: CompletionModel> {
304    /// Generate a streaming completion from a request
305    fn stream_completion(
306        &self,
307        prompt: impl Into<Message> + Send,
308        chat_history: Vec<Message>,
309    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
310}
311
312pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
313    pub(crate) inner: StreamingResult<R>,
314}
315
316impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
317    type Item = Result<RawStreamingChoice<()>, CompletionError>;
318
319    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320        let stream = self.get_mut();
321
322        match stream.inner.as_mut().poll_next(cx) {
323            Poll::Pending => Poll::Pending,
324            Poll::Ready(None) => Poll::Ready(None),
325            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
326            Poll::Ready(Some(Ok(chunk))) => match chunk {
327                RawStreamingChoice::FinalResponse(_) => {
328                    Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
329                }
330                RawStreamingChoice::Message(m) => {
331                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
332                }
333                RawStreamingChoice::Reasoning { id, reasoning } => {
334                    Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { id, reasoning })))
335                }
336                RawStreamingChoice::ToolCall {
337                    id,
338                    name,
339                    arguments,
340                    call_id,
341                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
342                    id,
343                    name,
344                    arguments,
345                    call_id,
346                }))),
347            },
348        }
349    }
350}
351
352/// helper function to stream a completion request to stdout
353pub async fn stream_to_stdout<M>(
354    agent: &Agent<M>,
355    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
356) -> Result<(), std::io::Error>
357where
358    M: CompletionModel,
359{
360    let mut is_reasoning = false;
361    print!("Response: ");
362    while let Some(chunk) = stream.next().await {
363        match chunk {
364            Ok(StreamedAssistantContent::Text(text)) => {
365                if is_reasoning {
366                    is_reasoning = false;
367                    println!("\n---\n");
368                }
369                print!("{}", text.text);
370                std::io::Write::flush(&mut std::io::stdout())?;
371            }
372            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
373                let res = agent
374                    .tools
375                    .call(
376                        &tool_call.function.name,
377                        tool_call.function.arguments.to_string(),
378                    )
379                    .await
380                    .map_err(|e| std::io::Error::other(e.to_string()))?;
381                println!("\nResult: {res}");
382            }
383            Ok(StreamedAssistantContent::Final(res)) => {
384                let json_res = serde_json::to_string_pretty(&res).unwrap();
385                println!();
386                tracing::info!("Final result: {json_res}");
387            }
388            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
389                if !is_reasoning {
390                    is_reasoning = true;
391                    println!();
392                    println!("Thinking: ");
393                }
394                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
395
396                print!("{reasoning}");
397                std::io::Write::flush(&mut std::io::stdout())?;
398            }
399            Err(e) => {
400                if e.to_string().contains("aborted") {
401                    println!("\nStream cancelled.");
402                    break;
403                }
404                eprintln!("Error: {e}");
405                break;
406            }
407        }
408    }
409
410    println!(); // New line after streaming completes
411
412    Ok(())
413}
414
415// Test module
416#[cfg(test)]
417mod tests {
418    use std::time::Duration;
419
420    use super::*;
421    use async_stream::stream;
422    use tokio::time::sleep;
423
424    #[derive(Debug, Clone)]
425    pub struct MockResponse {
426        #[allow(dead_code)]
427        token_count: u32,
428    }
429
430    impl GetTokenUsage for MockResponse {
431        fn token_usage(&self) -> Option<crate::completion::Usage> {
432            let mut usage = Usage::new();
433            usage.total_tokens = 15;
434            Some(usage)
435        }
436    }
437
438    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
439        let stream = stream! {
440            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
441            sleep(Duration::from_millis(100)).await;
442            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
443            sleep(Duration::from_millis(100)).await;
444            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
445            sleep(Duration::from_millis(100)).await;
446            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
447        };
448
449        #[cfg(not(target_arch = "wasm32"))]
450        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
451        #[cfg(target_arch = "wasm32")]
452        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
453
454        StreamingCompletionResponse::stream(pinned_stream)
455    }
456
457    #[tokio::test]
458    async fn test_stream_cancellation() {
459        let mut stream = create_mock_stream();
460
461        println!("Response: ");
462        let mut chunk_count = 0;
463        while let Some(chunk) = stream.next().await {
464            match chunk {
465                Ok(StreamedAssistantContent::Text(text)) => {
466                    print!("{}", text.text);
467                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
468                    chunk_count += 1;
469                }
470                Ok(StreamedAssistantContent::ToolCall(tc)) => {
471                    println!("\nTool Call: {tc:?}");
472                    chunk_count += 1;
473                }
474                Ok(StreamedAssistantContent::Final(res)) => {
475                    println!("\nFinal response: {res:?}");
476                }
477                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
478                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
479                    print!("{reasoning}");
480                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
481                }
482                Err(e) => {
483                    eprintln!("Error: {e:?}");
484                    break;
485                }
486            }
487
488            if chunk_count >= 2 {
489                println!("\nCancelling stream...");
490                stream.cancel();
491                println!("Stream cancelled.");
492                break;
493            }
494        }
495
496        let next_chunk = stream.next().await;
497        assert!(
498            next_chunk.is_none(),
499            "Expected no further chunks after cancellation, got {next_chunk:?}"
500        );
501    }
502
503    #[tokio::test]
504    async fn test_stream_pause_resume() {
505        let stream = create_mock_stream();
506
507        // Test pause
508        stream.pause();
509        assert!(stream.is_paused());
510
511        // Test resume
512        stream.resume();
513        assert!(!stream.is_paused());
514    }
515}
516
517/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
518#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
519#[serde(untagged)]
520pub enum StreamedAssistantContent<R> {
521    Text(Text),
522    ToolCall(ToolCall),
523    Reasoning(Reasoning),
524    Final(R),
525}
526
527impl<R> StreamedAssistantContent<R>
528where
529    R: Clone + Unpin,
530{
531    pub fn text(text: &str) -> Self {
532        Self::Text(Text {
533            text: text.to_string(),
534        })
535    }
536
537    /// Helper constructor to make creating assistant tool call content easier.
538    pub fn tool_call(
539        id: impl Into<String>,
540        name: impl Into<String>,
541        arguments: serde_json::Value,
542    ) -> Self {
543        Self::ToolCall(ToolCall {
544            id: id.into(),
545            call_id: None,
546            function: ToolFunction {
547                name: name.into(),
548                arguments,
549            },
550        })
551    }
552
553    pub fn tool_call_with_call_id(
554        id: impl Into<String>,
555        call_id: String,
556        name: impl Into<String>,
557        arguments: serde_json::Value,
558    ) -> Self {
559        Self::ToolCall(ToolCall {
560            id: id.into(),
561            call_id: Some(call_id),
562            function: ToolFunction {
563                name: name.into(),
564                arguments,
565            },
566        })
567    }
568
569    pub fn final_response(res: R) -> Self {
570        Self::Final(res)
571    }
572}