rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::builder::FinalCompletionResponse;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30/// Control for pausing and resuming a streaming response
31pub struct PauseControl {
32    pub(crate) paused_tx: watch::Sender<bool>,
33    pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37    pub fn new() -> Self {
38        let (paused_tx, paused_rx) = watch::channel(false);
39        Self {
40            paused_tx,
41            paused_rx,
42        }
43    }
44
45    pub fn pause(&self) {
46        self.paused_tx.send(true).unwrap();
47    }
48
49    pub fn resume(&self) {
50        self.paused_tx.send(false).unwrap();
51    }
52
53    pub fn is_paused(&self) -> bool {
54        *self.paused_rx.borrow()
55    }
56}
57
58impl Default for PauseControl {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// Enum representing a streaming chunk from the model
65#[derive(Debug, Clone)]
66pub enum RawStreamingChoice<R>
67where
68    R: Clone,
69{
70    /// A text chunk from a message response
71    Message(String),
72
73    /// A tool call response (in its entirety)
74    ToolCall {
75        id: String,
76        call_id: Option<String>,
77        name: String,
78        arguments: serde_json::Value,
79    },
80    /// A tool call partial/delta
81    ToolCallDelta { id: String, delta: String },
82    /// A reasoning chunk
83    Reasoning {
84        id: Option<String>,
85        reasoning: String,
86        signature: Option<String>,
87    },
88
89    /// The final response object, must be yielded if you want the
90    /// `response` field to be populated on the `StreamingCompletionResponse`
91    FinalResponse(R),
92}
93
94#[cfg(not(target_arch = "wasm32"))]
95pub type StreamingResult<R> =
96    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
97
98#[cfg(target_arch = "wasm32")]
99pub type StreamingResult<R> =
100    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
101
102/// The response from a streaming completion request;
103/// message and response are populated at the end of the
104/// `inner` stream.
105pub struct StreamingCompletionResponse<R>
106where
107    R: Clone + Unpin + GetTokenUsage,
108{
109    pub(crate) inner: Abortable<StreamingResult<R>>,
110    pub(crate) abort_handle: AbortHandle,
111    pub(crate) pause_control: PauseControl,
112    text: String,
113    reasoning: String,
114    tool_calls: Vec<ToolCall>,
115    /// The final aggregated message from the stream
116    /// contains all text and tool calls generated
117    pub choice: OneOrMany<AssistantContent>,
118    /// The final response from the stream, may be `None`
119    /// if the provider didn't yield it during the stream
120    pub response: Option<R>,
121    pub final_response_yielded: AtomicBool,
122}
123
124impl<R> StreamingCompletionResponse<R>
125where
126    R: Clone + Unpin + GetTokenUsage,
127{
128    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
129        let (abort_handle, abort_registration) = AbortHandle::new_pair();
130        let abortable_stream = Abortable::new(inner, abort_registration);
131        let pause_control = PauseControl::new();
132        Self {
133            inner: abortable_stream,
134            abort_handle,
135            pause_control,
136            reasoning: String::new(),
137            text: "".to_string(),
138            tool_calls: vec![],
139            choice: OneOrMany::one(AssistantContent::text("")),
140            response: None,
141            final_response_yielded: AtomicBool::new(false),
142        }
143    }
144
145    pub fn cancel(&self) {
146        self.abort_handle.abort();
147    }
148
149    pub fn pause(&self) {
150        self.pause_control.pause();
151    }
152
153    pub fn resume(&self) {
154        self.pause_control.resume();
155    }
156
157    pub fn is_paused(&self) -> bool {
158        self.pause_control.is_paused()
159    }
160}
161
162impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
163where
164    R: Clone + Unpin + GetTokenUsage,
165{
166    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
167        CompletionResponse {
168            choice: value.choice,
169            usage: Usage::new(), // Usage is not tracked in streaming responses
170            raw_response: value.response,
171        }
172    }
173}
174
175impl<R> Stream for StreamingCompletionResponse<R>
176where
177    R: Clone + Unpin + GetTokenUsage,
178{
179    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
180
181    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182        let stream = self.get_mut();
183
184        if stream.is_paused() {
185            cx.waker().wake_by_ref();
186            return Poll::Pending;
187        }
188
189        match Pin::new(&mut stream.inner).poll_next(cx) {
190            Poll::Pending => Poll::Pending,
191            Poll::Ready(None) => {
192                // This is run at the end of the inner stream to collect all tokens into
193                // a single unified `Message`.
194                let mut choice = vec![];
195
196                stream.tool_calls.iter().for_each(|tc| {
197                    choice.push(AssistantContent::ToolCall(tc.clone()));
198                });
199
200                // This is required to ensure there's always at least one item in the content
201                if choice.is_empty() || !stream.text.is_empty() {
202                    choice.insert(0, AssistantContent::text(stream.text.clone()));
203                }
204
205                stream.choice = OneOrMany::many(choice)
206                    .expect("There should be at least one assistant message");
207
208                Poll::Ready(None)
209            }
210            Poll::Ready(Some(Err(err))) => {
211                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
212                {
213                    return Poll::Ready(None); // Treat cancellation as stream termination
214                }
215                Poll::Ready(Some(Err(err)))
216            }
217            Poll::Ready(Some(Ok(choice))) => match choice {
218                RawStreamingChoice::Message(text) => {
219                    // Forward the streaming tokens to the outer stream
220                    // and concat the text together
221                    stream.text = format!("{}{}", stream.text, text);
222                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
223                }
224                RawStreamingChoice::ToolCallDelta { id, delta } => {
225                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
226                        id,
227                        delta,
228                    })))
229                }
230                RawStreamingChoice::Reasoning {
231                    id,
232                    reasoning,
233                    signature,
234                } => {
235                    // Forward the streaming tokens to the outer stream
236                    // and concat the text together
237                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
238                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
239                        id,
240                        reasoning: vec![reasoning],
241                        signature,
242                    }))))
243                }
244                RawStreamingChoice::ToolCall {
245                    id,
246                    name,
247                    arguments,
248                    call_id,
249                } => {
250                    // Keep track of each tool call to aggregate the final message later
251                    // and pass it to the outer stream
252                    stream.tool_calls.push(ToolCall {
253                        id: id.clone(),
254                        call_id: call_id.clone(),
255                        function: ToolFunction {
256                            name: name.clone(),
257                            arguments: arguments.clone(),
258                        },
259                    });
260                    if let Some(call_id) = call_id {
261                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
262                            id, call_id, name, arguments,
263                        ))))
264                    } else {
265                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
266                            id, name, arguments,
267                        ))))
268                    }
269                }
270                RawStreamingChoice::FinalResponse(response) => {
271                    if stream
272                        .final_response_yielded
273                        .load(std::sync::atomic::Ordering::SeqCst)
274                    {
275                        stream.poll_next_unpin(cx)
276                    } else {
277                        // Set the final response field and return the next item in the stream
278                        stream.response = Some(response.clone());
279                        stream
280                            .final_response_yielded
281                            .store(true, std::sync::atomic::Ordering::SeqCst);
282                        let final_response = StreamedAssistantContent::final_response(response);
283                        Poll::Ready(Some(Ok(final_response)))
284                    }
285                }
286            },
287        }
288    }
289}
290
291/// Trait for high-level streaming prompt interface
292pub trait StreamingPrompt<M, R>
293where
294    M: CompletionModel + 'static,
295    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
296    R: Clone + Unpin + GetTokenUsage,
297{
298    /// Stream a simple prompt to the model
299    fn stream_prompt(
300        &self,
301        prompt: impl Into<Message> + WasmCompatSend,
302    ) -> StreamingPromptRequest<M, ()>;
303}
304
305/// Trait for high-level streaming chat interface
306pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
307where
308    M: CompletionModel + 'static,
309    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
310    R: Clone + Unpin + GetTokenUsage,
311{
312    /// Stream a chat with history to the model
313    fn stream_chat(
314        &self,
315        prompt: impl Into<Message> + WasmCompatSend,
316        chat_history: Vec<Message>,
317    ) -> StreamingPromptRequest<M, ()>;
318}
319
320/// Trait for low-level streaming completion interface
321pub trait StreamingCompletion<M: CompletionModel> {
322    /// Generate a streaming completion from a request
323    fn stream_completion(
324        &self,
325        prompt: impl Into<Message> + WasmCompatSend,
326        chat_history: Vec<Message>,
327    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
328}
329
330pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
331    pub(crate) inner: StreamingResult<R>,
332}
333
334impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
335    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
336
337    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
338        let stream = self.get_mut();
339
340        match stream.inner.as_mut().poll_next(cx) {
341            Poll::Pending => Poll::Pending,
342            Poll::Ready(None) => Poll::Ready(None),
343            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
344            Poll::Ready(Some(Ok(chunk))) => match chunk {
345                RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
346                    RawStreamingChoice::FinalResponse(FinalCompletionResponse {
347                        usage: res.token_usage(),
348                    }),
349                ))),
350                RawStreamingChoice::Message(m) => {
351                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
352                }
353                RawStreamingChoice::ToolCallDelta { id, delta } => {
354                    Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta { id, delta })))
355                }
356                RawStreamingChoice::Reasoning {
357                    id,
358                    reasoning,
359                    signature,
360                } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
361                    id,
362                    reasoning,
363                    signature,
364                }))),
365                RawStreamingChoice::ToolCall {
366                    id,
367                    name,
368                    arguments,
369                    call_id,
370                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
371                    id,
372                    name,
373                    arguments,
374                    call_id,
375                }))),
376            },
377        }
378    }
379}
380
381/// A helper function to stream a completion request to stdout.
382/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
383pub async fn stream_to_stdout<M>(
384    agent: &'static Agent<M>,
385    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
386) -> Result<(), std::io::Error>
387where
388    M: CompletionModel,
389{
390    let mut is_reasoning = false;
391    print!("Response: ");
392    while let Some(chunk) = stream.next().await {
393        match chunk {
394            Ok(StreamedAssistantContent::Text(text)) => {
395                if is_reasoning {
396                    is_reasoning = false;
397                    println!("\n---\n");
398                }
399                print!("{}", text.text);
400                std::io::Write::flush(&mut std::io::stdout())?;
401            }
402            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
403                let res = agent
404                    .tool_server_handle
405                    .call_tool(
406                        &tool_call.function.name,
407                        &tool_call.function.arguments.to_string(),
408                    )
409                    .await
410                    .map_err(|x| std::io::Error::other(x.to_string()))?;
411                println!("\nResult: {res}");
412            }
413            Ok(StreamedAssistantContent::Final(res)) => {
414                let json_res = serde_json::to_string_pretty(&res).unwrap();
415                println!();
416                tracing::info!("Final result: {json_res}");
417            }
418            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
419                if !is_reasoning {
420                    is_reasoning = true;
421                    println!();
422                    println!("Thinking: ");
423                }
424                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
425
426                print!("{reasoning}");
427                std::io::Write::flush(&mut std::io::stdout())?;
428            }
429            Err(e) => {
430                if e.to_string().contains("aborted") {
431                    println!("\nStream cancelled.");
432                    break;
433                }
434                eprintln!("Error: {e}");
435                break;
436            }
437            _ => {}
438        }
439    }
440
441    println!(); // New line after streaming completes
442
443    Ok(())
444}
445
446// Test module
447#[cfg(test)]
448mod tests {
449    use std::time::Duration;
450
451    use super::*;
452    use async_stream::stream;
453    use tokio::time::sleep;
454
455    #[derive(Debug, Clone)]
456    pub struct MockResponse {
457        #[allow(dead_code)]
458        token_count: u32,
459    }
460
461    impl GetTokenUsage for MockResponse {
462        fn token_usage(&self) -> Option<crate::completion::Usage> {
463            let mut usage = Usage::new();
464            usage.total_tokens = 15;
465            Some(usage)
466        }
467    }
468
469    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
470        let stream = stream! {
471            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
472            sleep(Duration::from_millis(100)).await;
473            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
474            sleep(Duration::from_millis(100)).await;
475            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
476            sleep(Duration::from_millis(100)).await;
477            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
478        };
479
480        #[cfg(not(target_arch = "wasm32"))]
481        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
482        #[cfg(target_arch = "wasm32")]
483        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
484
485        StreamingCompletionResponse::stream(pinned_stream)
486    }
487
488    #[tokio::test]
489    async fn test_stream_cancellation() {
490        let mut stream = create_mock_stream();
491
492        println!("Response: ");
493        let mut chunk_count = 0;
494        while let Some(chunk) = stream.next().await {
495            match chunk {
496                Ok(StreamedAssistantContent::Text(text)) => {
497                    print!("{}", text.text);
498                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
499                    chunk_count += 1;
500                }
501                Ok(StreamedAssistantContent::ToolCall(tc)) => {
502                    println!("\nTool Call: {tc:?}");
503                    chunk_count += 1;
504                }
505                Ok(StreamedAssistantContent::ToolCallDelta { delta, .. }) => {
506                    println!("\nTool Call delta: {delta:?}");
507                    chunk_count += 1;
508                }
509                Ok(StreamedAssistantContent::Final(res)) => {
510                    println!("\nFinal response: {res:?}");
511                }
512                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
513                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
514                    print!("{reasoning}");
515                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
516                }
517                Err(e) => {
518                    eprintln!("Error: {e:?}");
519                    break;
520                }
521            }
522
523            if chunk_count >= 2 {
524                println!("\nCancelling stream...");
525                stream.cancel();
526                println!("Stream cancelled.");
527                break;
528            }
529        }
530
531        let next_chunk = stream.next().await;
532        assert!(
533            next_chunk.is_none(),
534            "Expected no further chunks after cancellation, got {next_chunk:?}"
535        );
536    }
537
538    #[tokio::test]
539    async fn test_stream_pause_resume() {
540        let stream = create_mock_stream();
541
542        // Test pause
543        stream.pause();
544        assert!(stream.is_paused());
545
546        // Test resume
547        stream.resume();
548        assert!(!stream.is_paused());
549    }
550}
551
552/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
553#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
554#[serde(untagged)]
555pub enum StreamedAssistantContent<R> {
556    Text(Text),
557    ToolCall(ToolCall),
558    ToolCallDelta { id: String, delta: String },
559    Reasoning(Reasoning),
560    Final(R),
561}
562
563impl<R> StreamedAssistantContent<R>
564where
565    R: Clone + Unpin,
566{
567    pub fn text(text: &str) -> Self {
568        Self::Text(Text {
569            text: text.to_string(),
570        })
571    }
572
573    /// Helper constructor to make creating assistant tool call content easier.
574    pub fn tool_call(
575        id: impl Into<String>,
576        name: impl Into<String>,
577        arguments: serde_json::Value,
578    ) -> Self {
579        Self::ToolCall(ToolCall {
580            id: id.into(),
581            call_id: None,
582            function: ToolFunction {
583                name: name.into(),
584                arguments,
585            },
586        })
587    }
588
589    pub fn tool_call_with_call_id(
590        id: impl Into<String>,
591        call_id: String,
592        name: impl Into<String>,
593        arguments: serde_json::Value,
594    ) -> Self {
595        Self::ToolCall(ToolCall {
596            id: id.into(),
597            call_id: Some(call_id),
598            function: ToolFunction {
599                name: name.into(),
600                arguments,
601            },
602        })
603    }
604
605    pub fn final_response(res: R) -> Self {
606        Self::Final(res)
607    }
608}
609
610/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
611#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
612#[serde(untagged)]
613pub enum StreamedUserContent {
614    ToolResult(ToolResult),
615}
616
617impl StreamedUserContent {
618    pub fn tool_result(tool_result: ToolResult) -> Self {
619        Self::ToolResult(tool_result)
620    }
621}