rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::FinalCompletionResponse;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30/// Control for pausing and resuming a streaming response
31pub struct PauseControl {
32    pub(crate) paused_tx: watch::Sender<bool>,
33    pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37    pub fn new() -> Self {
38        let (paused_tx, paused_rx) = watch::channel(false);
39        Self {
40            paused_tx,
41            paused_rx,
42        }
43    }
44
45    pub fn pause(&self) {
46        self.paused_tx.send(true).unwrap();
47    }
48
49    pub fn resume(&self) {
50        self.paused_tx.send(false).unwrap();
51    }
52
53    pub fn is_paused(&self) -> bool {
54        *self.paused_rx.borrow()
55    }
56}
57
58impl Default for PauseControl {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// The content of a tool call delta - either the tool name or argument data
65#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
66pub enum ToolCallDeltaContent {
67    Name(String),
68    Delta(String),
69}
70
71/// Enum representing a streaming chunk from the model
72#[derive(Debug, Clone)]
73pub enum RawStreamingChoice<R>
74where
75    R: Clone,
76{
77    /// A text chunk from a message response
78    Message(String),
79
80    /// A tool call response (in its entirety)
81    ToolCall(RawStreamingToolCall),
82    /// A tool call partial/delta
83    ToolCallDelta {
84        id: String,
85        content: ToolCallDeltaContent,
86    },
87    /// A reasoning (in its entirety)
88    Reasoning {
89        id: Option<String>,
90        reasoning: String,
91        signature: Option<String>,
92    },
93    /// A reasoning partial/delta
94    ReasoningDelta {
95        id: Option<String>,
96        reasoning: String,
97    },
98
99    /// The final response object, must be yielded if you want the
100    /// `response` field to be populated on the `StreamingCompletionResponse`
101    FinalResponse(R),
102}
103
104/// Describes a streaming tool call response (in its entirety)
105#[derive(Debug, Clone)]
106pub struct RawStreamingToolCall {
107    pub id: String,
108    pub call_id: Option<String>,
109    pub name: String,
110    pub arguments: serde_json::Value,
111    pub signature: Option<String>,
112    pub additional_params: Option<serde_json::Value>,
113}
114
115impl RawStreamingToolCall {
116    pub fn empty() -> Self {
117        Self {
118            id: String::new(),
119            call_id: None,
120            name: String::new(),
121            arguments: serde_json::Value::Null,
122            signature: None,
123            additional_params: None,
124        }
125    }
126
127    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
128        Self {
129            id,
130            call_id: None,
131            name,
132            arguments,
133            signature: None,
134            additional_params: None,
135        }
136    }
137
138    pub fn with_call_id(mut self, call_id: String) -> Self {
139        self.call_id = Some(call_id);
140        self
141    }
142
143    pub fn with_signature(mut self, signature: Option<String>) -> Self {
144        self.signature = signature;
145        self
146    }
147
148    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
149        self.additional_params = additional_params;
150        self
151    }
152}
153
154impl From<RawStreamingToolCall> for ToolCall {
155    fn from(tool_call: RawStreamingToolCall) -> Self {
156        ToolCall {
157            id: tool_call.id,
158            call_id: tool_call.call_id,
159            function: ToolFunction {
160                name: tool_call.name,
161                arguments: tool_call.arguments,
162            },
163            signature: tool_call.signature,
164            additional_params: tool_call.additional_params,
165        }
166    }
167}
168
169#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
170pub type StreamingResult<R> =
171    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
172
173#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
174pub type StreamingResult<R> =
175    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
176
177/// The response from a streaming completion request;
178/// message and response are populated at the end of the
179/// `inner` stream.
180pub struct StreamingCompletionResponse<R>
181where
182    R: Clone + Unpin + GetTokenUsage,
183{
184    pub(crate) inner: Abortable<StreamingResult<R>>,
185    pub(crate) abort_handle: AbortHandle,
186    pub(crate) pause_control: PauseControl,
187    text: String,
188    reasoning: String,
189    tool_calls: Vec<ToolCall>,
190    /// The final aggregated message from the stream
191    /// contains all text and tool calls generated
192    pub choice: OneOrMany<AssistantContent>,
193    /// The final response from the stream, may be `None`
194    /// if the provider didn't yield it during the stream
195    pub response: Option<R>,
196    pub final_response_yielded: AtomicBool,
197}
198
199impl<R> StreamingCompletionResponse<R>
200where
201    R: Clone + Unpin + GetTokenUsage,
202{
203    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
204        let (abort_handle, abort_registration) = AbortHandle::new_pair();
205        let abortable_stream = Abortable::new(inner, abort_registration);
206        let pause_control = PauseControl::new();
207        Self {
208            inner: abortable_stream,
209            abort_handle,
210            pause_control,
211            reasoning: String::new(),
212            text: "".to_string(),
213            tool_calls: vec![],
214            choice: OneOrMany::one(AssistantContent::text("")),
215            response: None,
216            final_response_yielded: AtomicBool::new(false),
217        }
218    }
219
220    pub fn cancel(&self) {
221        self.abort_handle.abort();
222    }
223
224    pub fn pause(&self) {
225        self.pause_control.pause();
226    }
227
228    pub fn resume(&self) {
229        self.pause_control.resume();
230    }
231
232    pub fn is_paused(&self) -> bool {
233        self.pause_control.is_paused()
234    }
235}
236
237impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
238where
239    R: Clone + Unpin + GetTokenUsage,
240{
241    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
242        CompletionResponse {
243            choice: value.choice,
244            usage: Usage::new(), // Usage is not tracked in streaming responses
245            raw_response: value.response,
246        }
247    }
248}
249
250impl<R> Stream for StreamingCompletionResponse<R>
251where
252    R: Clone + Unpin + GetTokenUsage,
253{
254    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
255
256    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
257        let stream = self.get_mut();
258
259        if stream.is_paused() {
260            cx.waker().wake_by_ref();
261            return Poll::Pending;
262        }
263
264        match Pin::new(&mut stream.inner).poll_next(cx) {
265            Poll::Pending => Poll::Pending,
266            Poll::Ready(None) => {
267                // This is run at the end of the inner stream to collect all tokens into
268                // a single unified `Message`.
269                let mut choice = vec![];
270
271                stream.tool_calls.iter().for_each(|tc| {
272                    choice.push(AssistantContent::ToolCall(tc.clone()));
273                });
274
275                // This is required to ensure there's always at least one item in the content
276                if choice.is_empty() || !stream.text.is_empty() {
277                    choice.insert(0, AssistantContent::text(stream.text.clone()));
278                }
279
280                stream.choice = OneOrMany::many(choice)
281                    .expect("There should be at least one assistant message");
282
283                Poll::Ready(None)
284            }
285            Poll::Ready(Some(Err(err))) => {
286                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
287                {
288                    return Poll::Ready(None); // Treat cancellation as stream termination
289                }
290                Poll::Ready(Some(Err(err)))
291            }
292            Poll::Ready(Some(Ok(choice))) => match choice {
293                RawStreamingChoice::Message(text) => {
294                    // Forward the streaming tokens to the outer stream
295                    // and concat the text together
296                    stream.text = format!("{}{}", stream.text, text);
297                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
298                }
299                RawStreamingChoice::ToolCallDelta { id, content } => {
300                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
301                        id,
302                        content,
303                    })))
304                }
305                RawStreamingChoice::Reasoning {
306                    id,
307                    reasoning,
308                    signature,
309                } => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
310                    id,
311                    reasoning: vec![reasoning],
312                    signature,
313                })))),
314                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
315                    // Forward the streaming tokens to the outer stream
316                    // and concat the text together
317                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
318                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
319                        id,
320                        reasoning,
321                    })))
322                }
323                RawStreamingChoice::ToolCall(tool_call) => {
324                    // Keep track of each tool call to aggregate the final message later
325                    // and pass it to the outer stream
326                    let tool_call: ToolCall = tool_call.into();
327                    stream.tool_calls.push(tool_call.clone());
328                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall(tool_call))))
329                }
330                RawStreamingChoice::FinalResponse(response) => {
331                    if stream
332                        .final_response_yielded
333                        .load(std::sync::atomic::Ordering::SeqCst)
334                    {
335                        stream.poll_next_unpin(cx)
336                    } else {
337                        // Set the final response field and return the next item in the stream
338                        stream.response = Some(response.clone());
339                        stream
340                            .final_response_yielded
341                            .store(true, std::sync::atomic::Ordering::SeqCst);
342                        let final_response = StreamedAssistantContent::final_response(response);
343                        Poll::Ready(Some(Ok(final_response)))
344                    }
345                }
346            },
347        }
348    }
349}
350
351/// Trait for high-level streaming prompt interface
352pub trait StreamingPrompt<M, R>
353where
354    M: CompletionModel + 'static,
355    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
356    R: Clone + Unpin + GetTokenUsage,
357{
358    /// Stream a simple prompt to the model
359    fn stream_prompt(
360        &self,
361        prompt: impl Into<Message> + WasmCompatSend,
362    ) -> StreamingPromptRequest<M, ()>;
363}
364
365/// Trait for high-level streaming chat interface
366pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
367where
368    M: CompletionModel + 'static,
369    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
370    R: Clone + Unpin + GetTokenUsage,
371{
372    /// Stream a chat with history to the model
373    fn stream_chat(
374        &self,
375        prompt: impl Into<Message> + WasmCompatSend,
376        chat_history: Vec<Message>,
377    ) -> StreamingPromptRequest<M, ()>;
378}
379
380/// Trait for low-level streaming completion interface
381pub trait StreamingCompletion<M: CompletionModel> {
382    /// Generate a streaming completion from a request
383    fn stream_completion(
384        &self,
385        prompt: impl Into<Message> + WasmCompatSend,
386        chat_history: Vec<Message>,
387    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
388}
389
390pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
391    pub(crate) inner: StreamingResult<R>,
392}
393
394impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
395    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
396
397    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        let stream = self.get_mut();
399
400        match stream.inner.as_mut().poll_next(cx) {
401            Poll::Pending => Poll::Pending,
402            Poll::Ready(None) => Poll::Ready(None),
403            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
404            Poll::Ready(Some(Ok(chunk))) => match chunk {
405                RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
406                    RawStreamingChoice::FinalResponse(FinalCompletionResponse {
407                        usage: res.token_usage(),
408                    }),
409                ))),
410                RawStreamingChoice::Message(m) => {
411                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
412                }
413                RawStreamingChoice::ToolCallDelta { id, content } => {
414                    Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta { id, content })))
415                }
416                RawStreamingChoice::Reasoning {
417                    id,
418                    reasoning,
419                    signature,
420                } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
421                    id,
422                    reasoning,
423                    signature,
424                }))),
425                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
426                    Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
427                        id,
428                        reasoning,
429                    })))
430                }
431                RawStreamingChoice::ToolCall(tool_call) => {
432                    Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall(tool_call))))
433                }
434            },
435        }
436    }
437}
438
439/// A helper function to stream a completion request to stdout.
440/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
441pub async fn stream_to_stdout<M>(
442    agent: &'static Agent<M>,
443    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
444) -> Result<(), std::io::Error>
445where
446    M: CompletionModel,
447{
448    let mut is_reasoning = false;
449    print!("Response: ");
450    while let Some(chunk) = stream.next().await {
451        match chunk {
452            Ok(StreamedAssistantContent::Text(text)) => {
453                if is_reasoning {
454                    is_reasoning = false;
455                    println!("\n---\n");
456                }
457                print!("{}", text.text);
458                std::io::Write::flush(&mut std::io::stdout())?;
459            }
460            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
461                let res = agent
462                    .tool_server_handle
463                    .call_tool(
464                        &tool_call.function.name,
465                        &tool_call.function.arguments.to_string(),
466                    )
467                    .await
468                    .map_err(|x| std::io::Error::other(x.to_string()))?;
469                println!("\nResult: {res}");
470            }
471            Ok(StreamedAssistantContent::Final(res)) => {
472                let json_res = serde_json::to_string_pretty(&res).unwrap();
473                println!();
474                tracing::info!("Final result: {json_res}");
475            }
476            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
477                if !is_reasoning {
478                    is_reasoning = true;
479                    println!();
480                    println!("Thinking: ");
481                }
482                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
483
484                print!("{reasoning}");
485                std::io::Write::flush(&mut std::io::stdout())?;
486            }
487            Err(e) => {
488                if e.to_string().contains("aborted") {
489                    println!("\nStream cancelled.");
490                    break;
491                }
492                eprintln!("Error: {e}");
493                break;
494            }
495            _ => {}
496        }
497    }
498
499    println!(); // New line after streaming completes
500
501    Ok(())
502}
503
504// Test module
505#[cfg(test)]
506mod tests {
507    use std::time::Duration;
508
509    use super::*;
510    use async_stream::stream;
511    use tokio::time::sleep;
512
513    #[derive(Debug, Clone)]
514    pub struct MockResponse {
515        #[allow(dead_code)]
516        token_count: u32,
517    }
518
519    impl GetTokenUsage for MockResponse {
520        fn token_usage(&self) -> Option<crate::completion::Usage> {
521            let mut usage = Usage::new();
522            usage.total_tokens = 15;
523            Some(usage)
524        }
525    }
526
527    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
528        let stream = stream! {
529            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
530            sleep(Duration::from_millis(100)).await;
531            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
532            sleep(Duration::from_millis(100)).await;
533            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
534            sleep(Duration::from_millis(100)).await;
535            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
536        };
537
538        #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
539        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
540        #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
541        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
542
543        StreamingCompletionResponse::stream(pinned_stream)
544    }
545
546    #[tokio::test]
547    async fn test_stream_cancellation() {
548        let mut stream = create_mock_stream();
549
550        println!("Response: ");
551        let mut chunk_count = 0;
552        while let Some(chunk) = stream.next().await {
553            match chunk {
554                Ok(StreamedAssistantContent::Text(text)) => {
555                    print!("{}", text.text);
556                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
557                    chunk_count += 1;
558                }
559                Ok(StreamedAssistantContent::ToolCall(tc)) => {
560                    println!("\nTool Call: {tc:?}");
561                    chunk_count += 1;
562                }
563                Ok(StreamedAssistantContent::ToolCallDelta { id, content }) => {
564                    println!("\nTool Call delta: id={id:?}, content={content:?}");
565                    chunk_count += 1;
566                }
567                Ok(StreamedAssistantContent::Final(res)) => {
568                    println!("\nFinal response: {res:?}");
569                }
570                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
571                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
572                    print!("{reasoning}");
573                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
574                }
575                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
576                    println!("Reasoning delta: {reasoning}");
577                    chunk_count += 1;
578                }
579                Err(e) => {
580                    eprintln!("Error: {e:?}");
581                    break;
582                }
583            }
584
585            if chunk_count >= 2 {
586                println!("\nCancelling stream...");
587                stream.cancel();
588                println!("Stream cancelled.");
589                break;
590            }
591        }
592
593        let next_chunk = stream.next().await;
594        assert!(
595            next_chunk.is_none(),
596            "Expected no further chunks after cancellation, got {next_chunk:?}"
597        );
598    }
599
600    #[tokio::test]
601    async fn test_stream_pause_resume() {
602        let stream = create_mock_stream();
603
604        // Test pause
605        stream.pause();
606        assert!(stream.is_paused());
607
608        // Test resume
609        stream.resume();
610        assert!(!stream.is_paused());
611    }
612}
613
614/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
615#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
616#[serde(untagged)]
617pub enum StreamedAssistantContent<R> {
618    Text(Text),
619    ToolCall(ToolCall),
620    ToolCallDelta {
621        id: String,
622        content: ToolCallDeltaContent,
623    },
624    Reasoning(Reasoning),
625    ReasoningDelta {
626        id: Option<String>,
627        reasoning: String,
628    },
629    Final(R),
630}
631
632impl<R> StreamedAssistantContent<R>
633where
634    R: Clone + Unpin,
635{
636    pub fn text(text: &str) -> Self {
637        Self::Text(Text {
638            text: text.to_string(),
639        })
640    }
641
642    pub fn final_response(res: R) -> Self {
643        Self::Final(res)
644    }
645}
646
647/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
648#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
649#[serde(untagged)]
650pub enum StreamedUserContent {
651    ToolResult(ToolResult),
652}
653
654impl StreamedUserContent {
655    pub fn tool_result(tool_result: ToolResult) -> Self {
656        Self::ToolResult(tool_result)
657    }
658}