rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::builder::FinalCompletionResponse;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30/// Control for pausing and resuming a streaming response
31pub struct PauseControl {
32    pub(crate) paused_tx: watch::Sender<bool>,
33    pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37    pub fn new() -> Self {
38        let (paused_tx, paused_rx) = watch::channel(false);
39        Self {
40            paused_tx,
41            paused_rx,
42        }
43    }
44
45    pub fn pause(&self) {
46        self.paused_tx.send(true).unwrap();
47    }
48
49    pub fn resume(&self) {
50        self.paused_tx.send(false).unwrap();
51    }
52
53    pub fn is_paused(&self) -> bool {
54        *self.paused_rx.borrow()
55    }
56}
57
58impl Default for PauseControl {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// Enum representing a streaming chunk from the model
65#[derive(Debug, Clone)]
66pub enum RawStreamingChoice<R>
67where
68    R: Clone,
69{
70    /// A text chunk from a message response
71    Message(String),
72
73    /// A tool call response chunk
74    ToolCall {
75        id: String,
76        call_id: Option<String>,
77        name: String,
78        arguments: serde_json::Value,
79    },
80    /// A reasoning chunk
81    Reasoning {
82        id: Option<String>,
83        reasoning: String,
84    },
85
86    /// The final response object, must be yielded if you want the
87    /// `response` field to be populated on the `StreamingCompletionResponse`
88    FinalResponse(R),
89}
90
91#[cfg(not(target_arch = "wasm32"))]
92pub type StreamingResult<R> =
93    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
94
95#[cfg(target_arch = "wasm32")]
96pub type StreamingResult<R> =
97    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
98
99/// The response from a streaming completion request;
100/// message and response are populated at the end of the
101/// `inner` stream.
102pub struct StreamingCompletionResponse<R>
103where
104    R: Clone + Unpin + GetTokenUsage,
105{
106    pub(crate) inner: Abortable<StreamingResult<R>>,
107    pub(crate) abort_handle: AbortHandle,
108    pub(crate) pause_control: PauseControl,
109    text: String,
110    reasoning: String,
111    tool_calls: Vec<ToolCall>,
112    /// The final aggregated message from the stream
113    /// contains all text and tool calls generated
114    pub choice: OneOrMany<AssistantContent>,
115    /// The final response from the stream, may be `None`
116    /// if the provider didn't yield it during the stream
117    pub response: Option<R>,
118    pub final_response_yielded: AtomicBool,
119}
120
121impl<R> StreamingCompletionResponse<R>
122where
123    R: Clone + Unpin + GetTokenUsage,
124{
125    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
126        let (abort_handle, abort_registration) = AbortHandle::new_pair();
127        let abortable_stream = Abortable::new(inner, abort_registration);
128        let pause_control = PauseControl::new();
129        Self {
130            inner: abortable_stream,
131            abort_handle,
132            pause_control,
133            reasoning: String::new(),
134            text: "".to_string(),
135            tool_calls: vec![],
136            choice: OneOrMany::one(AssistantContent::text("")),
137            response: None,
138            final_response_yielded: AtomicBool::new(false),
139        }
140    }
141
142    pub fn cancel(&self) {
143        self.abort_handle.abort();
144    }
145
146    pub fn pause(&self) {
147        self.pause_control.pause();
148    }
149
150    pub fn resume(&self) {
151        self.pause_control.resume();
152    }
153
154    pub fn is_paused(&self) -> bool {
155        self.pause_control.is_paused()
156    }
157}
158
159impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
160where
161    R: Clone + Unpin + GetTokenUsage,
162{
163    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
164        CompletionResponse {
165            choice: value.choice,
166            usage: Usage::new(), // Usage is not tracked in streaming responses
167            raw_response: value.response,
168        }
169    }
170}
171
172impl<R> Stream for StreamingCompletionResponse<R>
173where
174    R: Clone + Unpin + GetTokenUsage,
175{
176    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
177
178    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179        let stream = self.get_mut();
180
181        if stream.is_paused() {
182            cx.waker().wake_by_ref();
183            return Poll::Pending;
184        }
185
186        match Pin::new(&mut stream.inner).poll_next(cx) {
187            Poll::Pending => Poll::Pending,
188            Poll::Ready(None) => {
189                // This is run at the end of the inner stream to collect all tokens into
190                // a single unified `Message`.
191                let mut choice = vec![];
192
193                stream.tool_calls.iter().for_each(|tc| {
194                    choice.push(AssistantContent::ToolCall(tc.clone()));
195                });
196
197                // This is required to ensure there's always at least one item in the content
198                if choice.is_empty() || !stream.text.is_empty() {
199                    choice.insert(0, AssistantContent::text(stream.text.clone()));
200                }
201
202                stream.choice = OneOrMany::many(choice)
203                    .expect("There should be at least one assistant message");
204
205                Poll::Ready(None)
206            }
207            Poll::Ready(Some(Err(err))) => {
208                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
209                {
210                    return Poll::Ready(None); // Treat cancellation as stream termination
211                }
212                Poll::Ready(Some(Err(err)))
213            }
214            Poll::Ready(Some(Ok(choice))) => match choice {
215                RawStreamingChoice::Message(text) => {
216                    // Forward the streaming tokens to the outer stream
217                    // and concat the text together
218                    stream.text = format!("{}{}", stream.text, text.clone());
219                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
220                }
221                RawStreamingChoice::Reasoning { id, reasoning } => {
222                    // Forward the streaming tokens to the outer stream
223                    // and concat the text together
224                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
225                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
226                        id,
227                        reasoning: vec![stream.reasoning.clone()],
228                    }))))
229                }
230                RawStreamingChoice::ToolCall {
231                    id,
232                    name,
233                    arguments,
234                    call_id,
235                } => {
236                    // Keep track of each tool call to aggregate the final message later
237                    // and pass it to the outer stream
238                    stream.tool_calls.push(ToolCall {
239                        id: id.clone(),
240                        call_id: call_id.clone(),
241                        function: ToolFunction {
242                            name: name.clone(),
243                            arguments: arguments.clone(),
244                        },
245                    });
246                    if let Some(call_id) = call_id {
247                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
248                            id, call_id, name, arguments,
249                        ))))
250                    } else {
251                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
252                            id, name, arguments,
253                        ))))
254                    }
255                }
256                RawStreamingChoice::FinalResponse(response) => {
257                    if stream
258                        .final_response_yielded
259                        .load(std::sync::atomic::Ordering::SeqCst)
260                    {
261                        stream.poll_next_unpin(cx)
262                    } else {
263                        // Set the final response field and return the next item in the stream
264                        stream.response = Some(response.clone());
265                        stream
266                            .final_response_yielded
267                            .store(true, std::sync::atomic::Ordering::SeqCst);
268                        let final_response = StreamedAssistantContent::final_response(response);
269                        Poll::Ready(Some(Ok(final_response)))
270                    }
271                }
272            },
273        }
274    }
275}
276
277/// Trait for high-level streaming prompt interface
278pub trait StreamingPrompt<M, R>
279where
280    M: CompletionModel + 'static,
281    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
282    R: Clone + Unpin + GetTokenUsage,
283{
284    /// Stream a simple prompt to the model
285    fn stream_prompt(
286        &self,
287        prompt: impl Into<Message> + WasmCompatSend,
288    ) -> StreamingPromptRequest<M, ()>;
289}
290
291/// Trait for high-level streaming chat interface
292pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
293where
294    M: CompletionModel + 'static,
295    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
296    R: Clone + Unpin + GetTokenUsage,
297{
298    /// Stream a chat with history to the model
299    fn stream_chat(
300        &self,
301        prompt: impl Into<Message> + WasmCompatSend,
302        chat_history: Vec<Message>,
303    ) -> StreamingPromptRequest<M, ()>;
304}
305
306/// Trait for low-level streaming completion interface
307pub trait StreamingCompletion<M: CompletionModel> {
308    /// Generate a streaming completion from a request
309    fn stream_completion(
310        &self,
311        prompt: impl Into<Message> + WasmCompatSend,
312        chat_history: Vec<Message>,
313    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
314}
315
316pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
317    pub(crate) inner: StreamingResult<R>,
318}
319
320impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
321    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
322
323    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324        let stream = self.get_mut();
325
326        match stream.inner.as_mut().poll_next(cx) {
327            Poll::Pending => Poll::Pending,
328            Poll::Ready(None) => Poll::Ready(None),
329            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
330            Poll::Ready(Some(Ok(chunk))) => match chunk {
331                RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
332                    RawStreamingChoice::FinalResponse(FinalCompletionResponse {
333                        usage: res.token_usage(),
334                    }),
335                ))),
336                RawStreamingChoice::Message(m) => {
337                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
338                }
339                RawStreamingChoice::Reasoning { id, reasoning } => {
340                    Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { id, reasoning })))
341                }
342                RawStreamingChoice::ToolCall {
343                    id,
344                    name,
345                    arguments,
346                    call_id,
347                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
348                    id,
349                    name,
350                    arguments,
351                    call_id,
352                }))),
353            },
354        }
355    }
356}
357
358/// helper function to stream a completion request to stdout
359pub async fn stream_to_stdout<M>(
360    agent: &'static Agent<M>,
361    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
362) -> Result<(), std::io::Error>
363where
364    M: CompletionModel,
365{
366    let mut is_reasoning = false;
367    print!("Response: ");
368    while let Some(chunk) = stream.next().await {
369        match chunk {
370            Ok(StreamedAssistantContent::Text(text)) => {
371                if is_reasoning {
372                    is_reasoning = false;
373                    println!("\n---\n");
374                }
375                print!("{}", text.text);
376                std::io::Write::flush(&mut std::io::stdout())?;
377            }
378            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
379                let res = agent
380                    .tool_server_handle
381                    .call_tool(
382                        &tool_call.function.name,
383                        &tool_call.function.arguments.to_string(),
384                    )
385                    .await
386                    .map_err(|x| std::io::Error::other(x.to_string()))?;
387                println!("\nResult: {res}");
388            }
389            Ok(StreamedAssistantContent::Final(res)) => {
390                let json_res = serde_json::to_string_pretty(&res).unwrap();
391                println!();
392                tracing::info!("Final result: {json_res}");
393            }
394            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
395                if !is_reasoning {
396                    is_reasoning = true;
397                    println!();
398                    println!("Thinking: ");
399                }
400                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
401
402                print!("{reasoning}");
403                std::io::Write::flush(&mut std::io::stdout())?;
404            }
405            Err(e) => {
406                if e.to_string().contains("aborted") {
407                    println!("\nStream cancelled.");
408                    break;
409                }
410                eprintln!("Error: {e}");
411                break;
412            }
413        }
414    }
415
416    println!(); // New line after streaming completes
417
418    Ok(())
419}
420
421// Test module
422#[cfg(test)]
423mod tests {
424    use std::time::Duration;
425
426    use super::*;
427    use async_stream::stream;
428    use tokio::time::sleep;
429
430    #[derive(Debug, Clone)]
431    pub struct MockResponse {
432        #[allow(dead_code)]
433        token_count: u32,
434    }
435
436    impl GetTokenUsage for MockResponse {
437        fn token_usage(&self) -> Option<crate::completion::Usage> {
438            let mut usage = Usage::new();
439            usage.total_tokens = 15;
440            Some(usage)
441        }
442    }
443
444    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
445        let stream = stream! {
446            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
447            sleep(Duration::from_millis(100)).await;
448            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
449            sleep(Duration::from_millis(100)).await;
450            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
451            sleep(Duration::from_millis(100)).await;
452            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
453        };
454
455        #[cfg(not(target_arch = "wasm32"))]
456        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
457        #[cfg(target_arch = "wasm32")]
458        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
459
460        StreamingCompletionResponse::stream(pinned_stream)
461    }
462
463    #[tokio::test]
464    async fn test_stream_cancellation() {
465        let mut stream = create_mock_stream();
466
467        println!("Response: ");
468        let mut chunk_count = 0;
469        while let Some(chunk) = stream.next().await {
470            match chunk {
471                Ok(StreamedAssistantContent::Text(text)) => {
472                    print!("{}", text.text);
473                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
474                    chunk_count += 1;
475                }
476                Ok(StreamedAssistantContent::ToolCall(tc)) => {
477                    println!("\nTool Call: {tc:?}");
478                    chunk_count += 1;
479                }
480                Ok(StreamedAssistantContent::Final(res)) => {
481                    println!("\nFinal response: {res:?}");
482                }
483                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
484                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
485                    print!("{reasoning}");
486                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
487                }
488                Err(e) => {
489                    eprintln!("Error: {e:?}");
490                    break;
491                }
492            }
493
494            if chunk_count >= 2 {
495                println!("\nCancelling stream...");
496                stream.cancel();
497                println!("Stream cancelled.");
498                break;
499            }
500        }
501
502        let next_chunk = stream.next().await;
503        assert!(
504            next_chunk.is_none(),
505            "Expected no further chunks after cancellation, got {next_chunk:?}"
506        );
507    }
508
509    #[tokio::test]
510    async fn test_stream_pause_resume() {
511        let stream = create_mock_stream();
512
513        // Test pause
514        stream.pause();
515        assert!(stream.is_paused());
516
517        // Test resume
518        stream.resume();
519        assert!(!stream.is_paused());
520    }
521}
522
523/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
524#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
525#[serde(untagged)]
526pub enum StreamedAssistantContent<R> {
527    Text(Text),
528    ToolCall(ToolCall),
529    Reasoning(Reasoning),
530    Final(R),
531}
532
533impl<R> StreamedAssistantContent<R>
534where
535    R: Clone + Unpin,
536{
537    pub fn text(text: &str) -> Self {
538        Self::Text(Text {
539            text: text.to_string(),
540        })
541    }
542
543    /// Helper constructor to make creating assistant tool call content easier.
544    pub fn tool_call(
545        id: impl Into<String>,
546        name: impl Into<String>,
547        arguments: serde_json::Value,
548    ) -> Self {
549        Self::ToolCall(ToolCall {
550            id: id.into(),
551            call_id: None,
552            function: ToolFunction {
553                name: name.into(),
554                arguments,
555            },
556        })
557    }
558
559    pub fn tool_call_with_call_id(
560        id: impl Into<String>,
561        call_id: String,
562        name: impl Into<String>,
563        arguments: serde_json::Value,
564    ) -> Self {
565        Self::ToolCall(ToolCall {
566            id: id.into(),
567            call_id: Some(call_id),
568            function: ToolFunction {
569                name: name.into(),
570                arguments,
571            },
572        })
573    }
574
575    pub fn final_response(res: R) -> Self {
576        Self::Final(res)
577    }
578}