Skip to main content

rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::FinalCompletionResponse;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30/// Control for pausing and resuming a streaming response
31pub struct PauseControl {
32    pub(crate) paused_tx: watch::Sender<bool>,
33    pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37    pub fn new() -> Self {
38        let (paused_tx, paused_rx) = watch::channel(false);
39        Self {
40            paused_tx,
41            paused_rx,
42        }
43    }
44
45    pub fn pause(&self) {
46        self.paused_tx.send(true).unwrap();
47    }
48
49    pub fn resume(&self) {
50        self.paused_tx.send(false).unwrap();
51    }
52
53    pub fn is_paused(&self) -> bool {
54        *self.paused_rx.borrow()
55    }
56}
57
58impl Default for PauseControl {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// The content of a tool call delta - either the tool name or argument data
65#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
66pub enum ToolCallDeltaContent {
67    Name(String),
68    Delta(String),
69}
70
71/// Enum representing a streaming chunk from the model
72#[derive(Debug, Clone)]
73pub enum RawStreamingChoice<R>
74where
75    R: Clone,
76{
77    /// A text chunk from a message response
78    Message(String),
79
80    /// A tool call response (in its entirety)
81    ToolCall(RawStreamingToolCall),
82    /// A tool call partial/delta
83    ToolCallDelta {
84        /// Provider-supplied tool call ID.
85        id: String,
86        /// Rig-generated unique identifier for this tool call.
87        internal_call_id: String,
88        content: ToolCallDeltaContent,
89    },
90    /// A reasoning (in its entirety)
91    Reasoning {
92        id: Option<String>,
93        reasoning: String,
94        signature: Option<String>,
95    },
96    /// A reasoning partial/delta
97    ReasoningDelta {
98        id: Option<String>,
99        reasoning: String,
100    },
101
102    /// The final response object, must be yielded if you want the
103    /// `response` field to be populated on the `StreamingCompletionResponse`
104    FinalResponse(R),
105}
106
107/// Describes a streaming tool call response (in its entirety)
108#[derive(Debug, Clone)]
109pub struct RawStreamingToolCall {
110    /// Provider-supplied tool call ID.
111    pub id: String,
112    /// Rig-generated unique identifier for this tool call.
113    pub internal_call_id: String,
114    pub call_id: Option<String>,
115    pub name: String,
116    pub arguments: serde_json::Value,
117    pub signature: Option<String>,
118    pub additional_params: Option<serde_json::Value>,
119}
120
121impl RawStreamingToolCall {
122    pub fn empty() -> Self {
123        Self {
124            id: String::new(),
125            internal_call_id: nanoid::nanoid!(),
126            call_id: None,
127            name: String::new(),
128            arguments: serde_json::Value::Null,
129            signature: None,
130            additional_params: None,
131        }
132    }
133
134    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
135        Self {
136            id,
137            internal_call_id: nanoid::nanoid!(),
138            call_id: None,
139            name,
140            arguments,
141            signature: None,
142            additional_params: None,
143        }
144    }
145
146    pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
147        self.internal_call_id = internal_call_id;
148        self
149    }
150
151    pub fn with_call_id(mut self, call_id: String) -> Self {
152        self.call_id = Some(call_id);
153        self
154    }
155
156    pub fn with_signature(mut self, signature: Option<String>) -> Self {
157        self.signature = signature;
158        self
159    }
160
161    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
162        self.additional_params = additional_params;
163        self
164    }
165}
166
167impl From<RawStreamingToolCall> for ToolCall {
168    fn from(tool_call: RawStreamingToolCall) -> Self {
169        ToolCall {
170            id: tool_call.id,
171            call_id: tool_call.call_id,
172            function: ToolFunction {
173                name: tool_call.name,
174                arguments: tool_call.arguments,
175            },
176            signature: tool_call.signature,
177            additional_params: tool_call.additional_params,
178        }
179    }
180}
181
182#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
183pub type StreamingResult<R> =
184    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
185
186#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
187pub type StreamingResult<R> =
188    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
189
190/// The response from a streaming completion request;
191/// message and response are populated at the end of the
192/// `inner` stream.
193pub struct StreamingCompletionResponse<R>
194where
195    R: Clone + Unpin + GetTokenUsage,
196{
197    pub(crate) inner: Abortable<StreamingResult<R>>,
198    pub(crate) abort_handle: AbortHandle,
199    pub(crate) pause_control: PauseControl,
200    text: String,
201    reasoning: String,
202    tool_calls: Vec<ToolCall>,
203    /// The final aggregated message from the stream
204    /// contains all text and tool calls generated
205    pub choice: OneOrMany<AssistantContent>,
206    /// The final response from the stream, may be `None`
207    /// if the provider didn't yield it during the stream
208    pub response: Option<R>,
209    pub final_response_yielded: AtomicBool,
210}
211
212impl<R> StreamingCompletionResponse<R>
213where
214    R: Clone + Unpin + GetTokenUsage,
215{
216    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
217        let (abort_handle, abort_registration) = AbortHandle::new_pair();
218        let abortable_stream = Abortable::new(inner, abort_registration);
219        let pause_control = PauseControl::new();
220        Self {
221            inner: abortable_stream,
222            abort_handle,
223            pause_control,
224            reasoning: String::new(),
225            text: "".to_string(),
226            tool_calls: vec![],
227            choice: OneOrMany::one(AssistantContent::text("")),
228            response: None,
229            final_response_yielded: AtomicBool::new(false),
230        }
231    }
232
233    pub fn cancel(&self) {
234        self.abort_handle.abort();
235    }
236
237    pub fn pause(&self) {
238        self.pause_control.pause();
239    }
240
241    pub fn resume(&self) {
242        self.pause_control.resume();
243    }
244
245    pub fn is_paused(&self) -> bool {
246        self.pause_control.is_paused()
247    }
248}
249
250impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
251where
252    R: Clone + Unpin + GetTokenUsage,
253{
254    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
255        CompletionResponse {
256            choice: value.choice,
257            usage: Usage::new(), // Usage is not tracked in streaming responses
258            raw_response: value.response,
259        }
260    }
261}
262
263impl<R> Stream for StreamingCompletionResponse<R>
264where
265    R: Clone + Unpin + GetTokenUsage,
266{
267    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
268
269    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270        let stream = self.get_mut();
271
272        if stream.is_paused() {
273            cx.waker().wake_by_ref();
274            return Poll::Pending;
275        }
276
277        match Pin::new(&mut stream.inner).poll_next(cx) {
278            Poll::Pending => Poll::Pending,
279            Poll::Ready(None) => {
280                // This is run at the end of the inner stream to collect all tokens into
281                // a single unified `Message`.
282                let mut choice = vec![];
283
284                stream.tool_calls.iter().for_each(|tc| {
285                    choice.push(AssistantContent::ToolCall(tc.clone()));
286                });
287
288                // This is required to ensure there's always at least one item in the content
289                if choice.is_empty() || !stream.text.is_empty() {
290                    choice.insert(0, AssistantContent::text(stream.text.clone()));
291                }
292
293                stream.choice = OneOrMany::many(choice)
294                    .expect("There should be at least one assistant message");
295
296                Poll::Ready(None)
297            }
298            Poll::Ready(Some(Err(err))) => {
299                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
300                {
301                    return Poll::Ready(None); // Treat cancellation as stream termination
302                }
303                Poll::Ready(Some(Err(err)))
304            }
305            Poll::Ready(Some(Ok(choice))) => match choice {
306                RawStreamingChoice::Message(text) => {
307                    // Forward the streaming tokens to the outer stream
308                    // and concat the text together
309                    stream.text = format!("{}{}", stream.text, text);
310                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
311                }
312                RawStreamingChoice::ToolCallDelta {
313                    id,
314                    internal_call_id,
315                    content,
316                } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
317                    id,
318                    internal_call_id,
319                    content,
320                }))),
321                RawStreamingChoice::Reasoning {
322                    id,
323                    reasoning,
324                    signature,
325                } => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
326                    id,
327                    reasoning: vec![reasoning],
328                    signature,
329                })))),
330                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
331                    // Forward the streaming tokens to the outer stream
332                    // and concat the text together
333                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
334                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
335                        id,
336                        reasoning,
337                    })))
338                }
339                RawStreamingChoice::ToolCall(raw_tool_call) => {
340                    // Keep track of each tool call to aggregate the final message later
341                    // and pass it to the outer stream
342                    let internal_call_id = raw_tool_call.internal_call_id.clone();
343                    let tool_call: ToolCall = raw_tool_call.into();
344                    stream.tool_calls.push(tool_call.clone());
345                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
346                        tool_call,
347                        internal_call_id,
348                    })))
349                }
350                RawStreamingChoice::FinalResponse(response) => {
351                    if stream
352                        .final_response_yielded
353                        .load(std::sync::atomic::Ordering::SeqCst)
354                    {
355                        stream.poll_next_unpin(cx)
356                    } else {
357                        // Set the final response field and return the next item in the stream
358                        stream.response = Some(response.clone());
359                        stream
360                            .final_response_yielded
361                            .store(true, std::sync::atomic::Ordering::SeqCst);
362                        let final_response = StreamedAssistantContent::final_response(response);
363                        Poll::Ready(Some(Ok(final_response)))
364                    }
365                }
366            },
367        }
368    }
369}
370
371/// Trait for high-level streaming prompt interface
372pub trait StreamingPrompt<M, R>
373where
374    M: CompletionModel + 'static,
375    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
376    R: Clone + Unpin + GetTokenUsage,
377{
378    /// Stream a simple prompt to the model
379    fn stream_prompt(
380        &self,
381        prompt: impl Into<Message> + WasmCompatSend,
382    ) -> StreamingPromptRequest<M, ()>;
383}
384
385/// Trait for high-level streaming chat interface
386pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
387where
388    M: CompletionModel + 'static,
389    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
390    R: Clone + Unpin + GetTokenUsage,
391{
392    /// Stream a chat with history to the model
393    fn stream_chat(
394        &self,
395        prompt: impl Into<Message> + WasmCompatSend,
396        chat_history: Vec<Message>,
397    ) -> StreamingPromptRequest<M, ()>;
398}
399
400/// Trait for low-level streaming completion interface
401pub trait StreamingCompletion<M: CompletionModel> {
402    /// Generate a streaming completion from a request
403    fn stream_completion(
404        &self,
405        prompt: impl Into<Message> + WasmCompatSend,
406        chat_history: Vec<Message>,
407    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
408}
409
410pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
411    pub(crate) inner: StreamingResult<R>,
412}
413
414impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
415    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
416
417    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418        let stream = self.get_mut();
419
420        match stream.inner.as_mut().poll_next(cx) {
421            Poll::Pending => Poll::Pending,
422            Poll::Ready(None) => Poll::Ready(None),
423            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
424            Poll::Ready(Some(Ok(chunk))) => match chunk {
425                RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
426                    RawStreamingChoice::FinalResponse(FinalCompletionResponse {
427                        usage: res.token_usage(),
428                    }),
429                ))),
430                RawStreamingChoice::Message(m) => {
431                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
432                }
433                RawStreamingChoice::ToolCallDelta {
434                    id,
435                    internal_call_id,
436                    content,
437                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta {
438                    id,
439                    internal_call_id,
440                    content,
441                }))),
442                RawStreamingChoice::Reasoning {
443                    id,
444                    reasoning,
445                    signature,
446                } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
447                    id,
448                    reasoning,
449                    signature,
450                }))),
451                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
452                    Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
453                        id,
454                        reasoning,
455                    })))
456                }
457                RawStreamingChoice::ToolCall(tool_call) => {
458                    Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall(tool_call))))
459                }
460            },
461        }
462    }
463}
464
465/// A helper function to stream a completion request to stdout.
466/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
467pub async fn stream_to_stdout<M>(
468    agent: &'static Agent<M>,
469    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
470) -> Result<(), std::io::Error>
471where
472    M: CompletionModel,
473{
474    let mut is_reasoning = false;
475    print!("Response: ");
476    while let Some(chunk) = stream.next().await {
477        match chunk {
478            Ok(StreamedAssistantContent::Text(text)) => {
479                if is_reasoning {
480                    is_reasoning = false;
481                    println!("\n---\n");
482                }
483                print!("{}", text.text);
484                std::io::Write::flush(&mut std::io::stdout())?;
485            }
486            Ok(StreamedAssistantContent::ToolCall {
487                tool_call,
488                internal_call_id: _,
489            }) => {
490                let res = agent
491                    .tool_server_handle
492                    .call_tool(
493                        &tool_call.function.name,
494                        &tool_call.function.arguments.to_string(),
495                    )
496                    .await
497                    .map_err(|x| std::io::Error::other(x.to_string()))?;
498                println!("\nResult: {res}");
499            }
500            Ok(StreamedAssistantContent::Final(res)) => {
501                let json_res = serde_json::to_string_pretty(&res).unwrap();
502                println!();
503                tracing::info!("Final result: {json_res}");
504            }
505            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
506                if !is_reasoning {
507                    is_reasoning = true;
508                    println!();
509                    println!("Thinking: ");
510                }
511                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
512
513                print!("{reasoning}");
514                std::io::Write::flush(&mut std::io::stdout())?;
515            }
516            Err(e) => {
517                if e.to_string().contains("aborted") {
518                    println!("\nStream cancelled.");
519                    break;
520                }
521                eprintln!("Error: {e}");
522                break;
523            }
524            _ => {}
525        }
526    }
527
528    println!(); // New line after streaming completes
529
530    Ok(())
531}
532
533// Test module
534#[cfg(test)]
535mod tests {
536    use std::time::Duration;
537
538    use super::*;
539    use async_stream::stream;
540    use tokio::time::sleep;
541
542    #[derive(Debug, Clone)]
543    pub struct MockResponse {
544        #[allow(dead_code)]
545        token_count: u32,
546    }
547
548    impl GetTokenUsage for MockResponse {
549        fn token_usage(&self) -> Option<crate::completion::Usage> {
550            let mut usage = Usage::new();
551            usage.total_tokens = 15;
552            Some(usage)
553        }
554    }
555
556    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
557        let stream = stream! {
558            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
559            sleep(Duration::from_millis(100)).await;
560            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
561            sleep(Duration::from_millis(100)).await;
562            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
563            sleep(Duration::from_millis(100)).await;
564            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
565        };
566
567        #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
568        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
569        #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
570        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
571
572        StreamingCompletionResponse::stream(pinned_stream)
573    }
574
575    #[tokio::test]
576    async fn test_stream_cancellation() {
577        let mut stream = create_mock_stream();
578
579        println!("Response: ");
580        let mut chunk_count = 0;
581        while let Some(chunk) = stream.next().await {
582            match chunk {
583                Ok(StreamedAssistantContent::Text(text)) => {
584                    print!("{}", text.text);
585                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
586                    chunk_count += 1;
587                }
588                Ok(StreamedAssistantContent::ToolCall {
589                    tool_call,
590                    internal_call_id,
591                }) => {
592                    println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
593                    chunk_count += 1;
594                }
595                Ok(StreamedAssistantContent::ToolCallDelta {
596                    id,
597                    internal_call_id,
598                    content,
599                }) => {
600                    println!(
601                        "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
602                    );
603                    chunk_count += 1;
604                }
605                Ok(StreamedAssistantContent::Final(res)) => {
606                    println!("\nFinal response: {res:?}");
607                }
608                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
609                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
610                    print!("{reasoning}");
611                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
612                }
613                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
614                    println!("Reasoning delta: {reasoning}");
615                    chunk_count += 1;
616                }
617                Err(e) => {
618                    eprintln!("Error: {e:?}");
619                    break;
620                }
621            }
622
623            if chunk_count >= 2 {
624                println!("\nCancelling stream...");
625                stream.cancel();
626                println!("Stream cancelled.");
627                break;
628            }
629        }
630
631        let next_chunk = stream.next().await;
632        assert!(
633            next_chunk.is_none(),
634            "Expected no further chunks after cancellation, got {next_chunk:?}"
635        );
636    }
637
638    #[tokio::test]
639    async fn test_stream_pause_resume() {
640        let stream = create_mock_stream();
641
642        // Test pause
643        stream.pause();
644        assert!(stream.is_paused());
645
646        // Test resume
647        stream.resume();
648        assert!(!stream.is_paused());
649    }
650}
651
652/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
653#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
654#[serde(untagged)]
655pub enum StreamedAssistantContent<R> {
656    Text(Text),
657    ToolCall {
658        tool_call: ToolCall,
659        /// Rig-generated unique identifier for this tool call.
660        /// Use this to correlate with ToolCallDelta events.
661        internal_call_id: String,
662    },
663    ToolCallDelta {
664        /// Provider-supplied tool call ID.
665        id: String,
666        /// Rig-generated unique identifier for this tool call.
667        internal_call_id: String,
668        content: ToolCallDeltaContent,
669    },
670    Reasoning(Reasoning),
671    ReasoningDelta {
672        id: Option<String>,
673        reasoning: String,
674    },
675    Final(R),
676}
677
678impl<R> StreamedAssistantContent<R>
679where
680    R: Clone + Unpin,
681{
682    pub fn text(text: &str) -> Self {
683        Self::Text(Text {
684            text: text.to_string(),
685        })
686    }
687
688    pub fn final_response(res: R) -> Self {
689        Self::Final(res)
690    }
691}
692
693/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
694#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
695#[serde(untagged)]
696pub enum StreamedUserContent {
697    ToolResult {
698        tool_result: ToolResult,
699        /// Rig-generated unique identifier for the tool call this result
700        /// belongs to. Use this to correlate with the originating
701        /// [`StreamedAssistantContent::ToolCall::internal_call_id`].
702        internal_call_id: String,
703    },
704}
705
706impl StreamedUserContent {
707    pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
708        Self::ToolResult {
709            tool_result,
710            internal_call_id,
711        }
712    }
713}