rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::FinalCompletionResponse;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30/// Control for pausing and resuming a streaming response
31pub struct PauseControl {
32    pub(crate) paused_tx: watch::Sender<bool>,
33    pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37    pub fn new() -> Self {
38        let (paused_tx, paused_rx) = watch::channel(false);
39        Self {
40            paused_tx,
41            paused_rx,
42        }
43    }
44
45    pub fn pause(&self) {
46        self.paused_tx.send(true).unwrap();
47    }
48
49    pub fn resume(&self) {
50        self.paused_tx.send(false).unwrap();
51    }
52
53    pub fn is_paused(&self) -> bool {
54        *self.paused_rx.borrow()
55    }
56}
57
58impl Default for PauseControl {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// Enum representing a streaming chunk from the model
65#[derive(Debug, Clone)]
66pub enum RawStreamingChoice<R>
67where
68    R: Clone,
69{
70    /// A text chunk from a message response
71    Message(String),
72
73    /// A tool call response (in its entirety)
74    ToolCall(RawStreamingToolCall),
75    /// A tool call partial/delta
76    ToolCallDelta { id: String, delta: String },
77    /// A reasoning (in its entirety)
78    Reasoning {
79        id: Option<String>,
80        reasoning: String,
81        signature: Option<String>,
82    },
83    /// A reasoning partial/delta
84    ReasoningDelta {
85        id: Option<String>,
86        reasoning: String,
87    },
88
89    /// The final response object, must be yielded if you want the
90    /// `response` field to be populated on the `StreamingCompletionResponse`
91    FinalResponse(R),
92}
93
94/// Describes a streaming tool call response (in its entirety)
95#[derive(Debug, Clone)]
96pub struct RawStreamingToolCall {
97    pub id: String,
98    pub call_id: Option<String>,
99    pub name: String,
100    pub arguments: serde_json::Value,
101    pub signature: Option<String>,
102    pub additional_params: Option<serde_json::Value>,
103}
104
105impl RawStreamingToolCall {
106    pub fn empty() -> Self {
107        Self {
108            id: String::new(),
109            call_id: None,
110            name: String::new(),
111            arguments: serde_json::Value::Null,
112            signature: None,
113            additional_params: None,
114        }
115    }
116
117    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
118        Self {
119            id,
120            call_id: None,
121            name,
122            arguments,
123            signature: None,
124            additional_params: None,
125        }
126    }
127
128    pub fn with_call_id(mut self, call_id: String) -> Self {
129        self.call_id = Some(call_id);
130        self
131    }
132
133    pub fn with_signature(mut self, signature: Option<String>) -> Self {
134        self.signature = signature;
135        self
136    }
137
138    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
139        self.additional_params = additional_params;
140        self
141    }
142}
143
144impl From<RawStreamingToolCall> for ToolCall {
145    fn from(tool_call: RawStreamingToolCall) -> Self {
146        ToolCall {
147            id: tool_call.id,
148            call_id: tool_call.call_id,
149            function: ToolFunction {
150                name: tool_call.name,
151                arguments: tool_call.arguments,
152            },
153            signature: tool_call.signature,
154            additional_params: tool_call.additional_params,
155        }
156    }
157}
158
159#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
160pub type StreamingResult<R> =
161    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
162
163#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
164pub type StreamingResult<R> =
165    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
166
167/// The response from a streaming completion request;
168/// message and response are populated at the end of the
169/// `inner` stream.
170pub struct StreamingCompletionResponse<R>
171where
172    R: Clone + Unpin + GetTokenUsage,
173{
174    pub(crate) inner: Abortable<StreamingResult<R>>,
175    pub(crate) abort_handle: AbortHandle,
176    pub(crate) pause_control: PauseControl,
177    text: String,
178    reasoning: String,
179    tool_calls: Vec<ToolCall>,
180    /// The final aggregated message from the stream
181    /// contains all text and tool calls generated
182    pub choice: OneOrMany<AssistantContent>,
183    /// The final response from the stream, may be `None`
184    /// if the provider didn't yield it during the stream
185    pub response: Option<R>,
186    pub final_response_yielded: AtomicBool,
187}
188
189impl<R> StreamingCompletionResponse<R>
190where
191    R: Clone + Unpin + GetTokenUsage,
192{
193    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
194        let (abort_handle, abort_registration) = AbortHandle::new_pair();
195        let abortable_stream = Abortable::new(inner, abort_registration);
196        let pause_control = PauseControl::new();
197        Self {
198            inner: abortable_stream,
199            abort_handle,
200            pause_control,
201            reasoning: String::new(),
202            text: "".to_string(),
203            tool_calls: vec![],
204            choice: OneOrMany::one(AssistantContent::text("")),
205            response: None,
206            final_response_yielded: AtomicBool::new(false),
207        }
208    }
209
210    pub fn cancel(&self) {
211        self.abort_handle.abort();
212    }
213
214    pub fn pause(&self) {
215        self.pause_control.pause();
216    }
217
218    pub fn resume(&self) {
219        self.pause_control.resume();
220    }
221
222    pub fn is_paused(&self) -> bool {
223        self.pause_control.is_paused()
224    }
225}
226
227impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
228where
229    R: Clone + Unpin + GetTokenUsage,
230{
231    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
232        CompletionResponse {
233            choice: value.choice,
234            usage: Usage::new(), // Usage is not tracked in streaming responses
235            raw_response: value.response,
236        }
237    }
238}
239
240impl<R> Stream for StreamingCompletionResponse<R>
241where
242    R: Clone + Unpin + GetTokenUsage,
243{
244    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
245
246    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
247        let stream = self.get_mut();
248
249        if stream.is_paused() {
250            cx.waker().wake_by_ref();
251            return Poll::Pending;
252        }
253
254        match Pin::new(&mut stream.inner).poll_next(cx) {
255            Poll::Pending => Poll::Pending,
256            Poll::Ready(None) => {
257                // This is run at the end of the inner stream to collect all tokens into
258                // a single unified `Message`.
259                let mut choice = vec![];
260
261                stream.tool_calls.iter().for_each(|tc| {
262                    choice.push(AssistantContent::ToolCall(tc.clone()));
263                });
264
265                // This is required to ensure there's always at least one item in the content
266                if choice.is_empty() || !stream.text.is_empty() {
267                    choice.insert(0, AssistantContent::text(stream.text.clone()));
268                }
269
270                stream.choice = OneOrMany::many(choice)
271                    .expect("There should be at least one assistant message");
272
273                Poll::Ready(None)
274            }
275            Poll::Ready(Some(Err(err))) => {
276                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
277                {
278                    return Poll::Ready(None); // Treat cancellation as stream termination
279                }
280                Poll::Ready(Some(Err(err)))
281            }
282            Poll::Ready(Some(Ok(choice))) => match choice {
283                RawStreamingChoice::Message(text) => {
284                    // Forward the streaming tokens to the outer stream
285                    // and concat the text together
286                    stream.text = format!("{}{}", stream.text, text);
287                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
288                }
289                RawStreamingChoice::ToolCallDelta { id, delta } => {
290                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
291                        id,
292                        delta,
293                    })))
294                }
295                RawStreamingChoice::Reasoning {
296                    id,
297                    reasoning,
298                    signature,
299                } => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
300                    id,
301                    reasoning: vec![reasoning],
302                    signature,
303                })))),
304                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
305                    // Forward the streaming tokens to the outer stream
306                    // and concat the text together
307                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
308                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
309                        id,
310                        reasoning,
311                    })))
312                }
313                RawStreamingChoice::ToolCall(tool_call) => {
314                    // Keep track of each tool call to aggregate the final message later
315                    // and pass it to the outer stream
316                    let tool_call: ToolCall = tool_call.into();
317                    stream.tool_calls.push(tool_call.clone());
318                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall(tool_call))))
319                }
320                RawStreamingChoice::FinalResponse(response) => {
321                    if stream
322                        .final_response_yielded
323                        .load(std::sync::atomic::Ordering::SeqCst)
324                    {
325                        stream.poll_next_unpin(cx)
326                    } else {
327                        // Set the final response field and return the next item in the stream
328                        stream.response = Some(response.clone());
329                        stream
330                            .final_response_yielded
331                            .store(true, std::sync::atomic::Ordering::SeqCst);
332                        let final_response = StreamedAssistantContent::final_response(response);
333                        Poll::Ready(Some(Ok(final_response)))
334                    }
335                }
336            },
337        }
338    }
339}
340
341/// Trait for high-level streaming prompt interface
342pub trait StreamingPrompt<M, R>
343where
344    M: CompletionModel + 'static,
345    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
346    R: Clone + Unpin + GetTokenUsage,
347{
348    /// Stream a simple prompt to the model
349    fn stream_prompt(
350        &self,
351        prompt: impl Into<Message> + WasmCompatSend,
352    ) -> StreamingPromptRequest<M, ()>;
353}
354
355/// Trait for high-level streaming chat interface
356pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
357where
358    M: CompletionModel + 'static,
359    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
360    R: Clone + Unpin + GetTokenUsage,
361{
362    /// Stream a chat with history to the model
363    fn stream_chat(
364        &self,
365        prompt: impl Into<Message> + WasmCompatSend,
366        chat_history: Vec<Message>,
367    ) -> StreamingPromptRequest<M, ()>;
368}
369
370/// Trait for low-level streaming completion interface
371pub trait StreamingCompletion<M: CompletionModel> {
372    /// Generate a streaming completion from a request
373    fn stream_completion(
374        &self,
375        prompt: impl Into<Message> + WasmCompatSend,
376        chat_history: Vec<Message>,
377    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
378}
379
380pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
381    pub(crate) inner: StreamingResult<R>,
382}
383
384impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
385    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
386
387    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
388        let stream = self.get_mut();
389
390        match stream.inner.as_mut().poll_next(cx) {
391            Poll::Pending => Poll::Pending,
392            Poll::Ready(None) => Poll::Ready(None),
393            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
394            Poll::Ready(Some(Ok(chunk))) => match chunk {
395                RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
396                    RawStreamingChoice::FinalResponse(FinalCompletionResponse {
397                        usage: res.token_usage(),
398                    }),
399                ))),
400                RawStreamingChoice::Message(m) => {
401                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
402                }
403                RawStreamingChoice::ToolCallDelta { id, delta } => {
404                    Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta { id, delta })))
405                }
406                RawStreamingChoice::Reasoning {
407                    id,
408                    reasoning,
409                    signature,
410                } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
411                    id,
412                    reasoning,
413                    signature,
414                }))),
415                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
416                    Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
417                        id,
418                        reasoning,
419                    })))
420                }
421                RawStreamingChoice::ToolCall(tool_call) => {
422                    Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall(tool_call))))
423                }
424            },
425        }
426    }
427}
428
429/// A helper function to stream a completion request to stdout.
430/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
431pub async fn stream_to_stdout<M>(
432    agent: &'static Agent<M>,
433    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
434) -> Result<(), std::io::Error>
435where
436    M: CompletionModel,
437{
438    let mut is_reasoning = false;
439    print!("Response: ");
440    while let Some(chunk) = stream.next().await {
441        match chunk {
442            Ok(StreamedAssistantContent::Text(text)) => {
443                if is_reasoning {
444                    is_reasoning = false;
445                    println!("\n---\n");
446                }
447                print!("{}", text.text);
448                std::io::Write::flush(&mut std::io::stdout())?;
449            }
450            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
451                let res = agent
452                    .tool_server_handle
453                    .call_tool(
454                        &tool_call.function.name,
455                        &tool_call.function.arguments.to_string(),
456                    )
457                    .await
458                    .map_err(|x| std::io::Error::other(x.to_string()))?;
459                println!("\nResult: {res}");
460            }
461            Ok(StreamedAssistantContent::Final(res)) => {
462                let json_res = serde_json::to_string_pretty(&res).unwrap();
463                println!();
464                tracing::info!("Final result: {json_res}");
465            }
466            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
467                if !is_reasoning {
468                    is_reasoning = true;
469                    println!();
470                    println!("Thinking: ");
471                }
472                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
473
474                print!("{reasoning}");
475                std::io::Write::flush(&mut std::io::stdout())?;
476            }
477            Err(e) => {
478                if e.to_string().contains("aborted") {
479                    println!("\nStream cancelled.");
480                    break;
481                }
482                eprintln!("Error: {e}");
483                break;
484            }
485            _ => {}
486        }
487    }
488
489    println!(); // New line after streaming completes
490
491    Ok(())
492}
493
494// Test module
495#[cfg(test)]
496mod tests {
497    use std::time::Duration;
498
499    use super::*;
500    use async_stream::stream;
501    use tokio::time::sleep;
502
503    #[derive(Debug, Clone)]
504    pub struct MockResponse {
505        #[allow(dead_code)]
506        token_count: u32,
507    }
508
509    impl GetTokenUsage for MockResponse {
510        fn token_usage(&self) -> Option<crate::completion::Usage> {
511            let mut usage = Usage::new();
512            usage.total_tokens = 15;
513            Some(usage)
514        }
515    }
516
517    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
518        let stream = stream! {
519            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
520            sleep(Duration::from_millis(100)).await;
521            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
522            sleep(Duration::from_millis(100)).await;
523            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
524            sleep(Duration::from_millis(100)).await;
525            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
526        };
527
528        #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
529        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
530        #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
531        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
532
533        StreamingCompletionResponse::stream(pinned_stream)
534    }
535
536    #[tokio::test]
537    async fn test_stream_cancellation() {
538        let mut stream = create_mock_stream();
539
540        println!("Response: ");
541        let mut chunk_count = 0;
542        while let Some(chunk) = stream.next().await {
543            match chunk {
544                Ok(StreamedAssistantContent::Text(text)) => {
545                    print!("{}", text.text);
546                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
547                    chunk_count += 1;
548                }
549                Ok(StreamedAssistantContent::ToolCall(tc)) => {
550                    println!("\nTool Call: {tc:?}");
551                    chunk_count += 1;
552                }
553                Ok(StreamedAssistantContent::ToolCallDelta { delta, .. }) => {
554                    println!("\nTool Call delta: {delta:?}");
555                    chunk_count += 1;
556                }
557                Ok(StreamedAssistantContent::Final(res)) => {
558                    println!("\nFinal response: {res:?}");
559                }
560                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
561                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
562                    print!("{reasoning}");
563                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
564                }
565                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
566                    println!("Reasoning delta: {reasoning}");
567                    chunk_count += 1;
568                }
569                Err(e) => {
570                    eprintln!("Error: {e:?}");
571                    break;
572                }
573            }
574
575            if chunk_count >= 2 {
576                println!("\nCancelling stream...");
577                stream.cancel();
578                println!("Stream cancelled.");
579                break;
580            }
581        }
582
583        let next_chunk = stream.next().await;
584        assert!(
585            next_chunk.is_none(),
586            "Expected no further chunks after cancellation, got {next_chunk:?}"
587        );
588    }
589
590    #[tokio::test]
591    async fn test_stream_pause_resume() {
592        let stream = create_mock_stream();
593
594        // Test pause
595        stream.pause();
596        assert!(stream.is_paused());
597
598        // Test resume
599        stream.resume();
600        assert!(!stream.is_paused());
601    }
602}
603
604/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
605#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
606#[serde(untagged)]
607pub enum StreamedAssistantContent<R> {
608    Text(Text),
609    ToolCall(ToolCall),
610    ToolCallDelta {
611        id: String,
612        delta: String,
613    },
614    Reasoning(Reasoning),
615    ReasoningDelta {
616        id: Option<String>,
617        reasoning: String,
618    },
619    Final(R),
620}
621
622impl<R> StreamedAssistantContent<R>
623where
624    R: Clone + Unpin,
625{
626    pub fn text(text: &str) -> Self {
627        Self::Text(Text {
628            text: text.to_string(),
629        })
630    }
631
632    pub fn final_response(res: R) -> Self {
633        Self::Final(res)
634    }
635}
636
637/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
638#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
639#[serde(untagged)]
640pub enum StreamedUserContent {
641    ToolResult(ToolResult),
642}
643
644impl StreamedUserContent {
645    pub fn tool_result(tool_result: ToolResult) -> Self {
646        Self::ToolResult(tool_result)
647    }
648}