rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::completion::{
15    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
16    Message, Usage,
17};
18use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
19use futures::stream::{AbortHandle, Abortable};
20use futures::{Stream, StreamExt};
21use serde::{Deserialize, Serialize};
22use std::boxed::Box;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::atomic::AtomicBool;
26use std::task::{Context, Poll};
27
28/// Enum representing a streaming chunk from the model
29#[derive(Debug, Clone)]
30pub enum RawStreamingChoice<R: Clone> {
31    /// A text chunk from a message response
32    Message(String),
33
34    /// A tool call response chunk
35    ToolCall {
36        id: String,
37        call_id: Option<String>,
38        name: String,
39        arguments: serde_json::Value,
40    },
41    /// A reasoning chunk
42    Reasoning {
43        id: Option<String>,
44        reasoning: String,
45    },
46
47    /// The final response object, must be yielded if you want the
48    /// `response` field to be populated on the `StreamingCompletionResponse`
49    FinalResponse(R),
50}
51
52#[cfg(not(target_arch = "wasm32"))]
53pub type StreamingResult<R> =
54    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
55
56#[cfg(target_arch = "wasm32")]
57pub type StreamingResult<R> =
58    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
59
60/// The response from a streaming completion request;
61/// message and response are populated at the end of the
62/// `inner` stream.
63pub struct StreamingCompletionResponse<R: Clone + Unpin + GetTokenUsage> {
64    pub(crate) inner: Abortable<StreamingResult<R>>,
65    pub(crate) abort_handle: AbortHandle,
66    text: String,
67    reasoning: String,
68    tool_calls: Vec<ToolCall>,
69    /// The final aggregated message from the stream
70    /// contains all text and tool calls generated
71    pub choice: OneOrMany<AssistantContent>,
72    /// The final response from the stream, may be `None`
73    /// if the provider didn't yield it during the stream
74    pub response: Option<R>,
75    pub final_response_yielded: AtomicBool,
76}
77
78impl<R> StreamingCompletionResponse<R>
79where
80    R: Clone + Unpin + GetTokenUsage,
81{
82    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
83        let (abort_handle, abort_registration) = AbortHandle::new_pair();
84        let abortable_stream = Abortable::new(inner, abort_registration);
85        Self {
86            inner: abortable_stream,
87            abort_handle,
88            reasoning: String::new(),
89            text: "".to_string(),
90            tool_calls: vec![],
91            choice: OneOrMany::one(AssistantContent::text("")),
92            response: None,
93            final_response_yielded: AtomicBool::new(false),
94        }
95    }
96
97    pub fn cancel(&self) {
98        self.abort_handle.abort();
99    }
100}
101
102impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
103where
104    R: Clone + Unpin + GetTokenUsage,
105{
106    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
107        CompletionResponse {
108            choice: value.choice,
109            usage: Usage::new(), // Usage is not tracked in streaming responses
110            raw_response: value.response,
111        }
112    }
113}
114
115impl<R> Stream for StreamingCompletionResponse<R>
116where
117    R: Clone + Unpin + GetTokenUsage,
118{
119    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
120
121    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122        let stream = self.get_mut();
123
124        match Pin::new(&mut stream.inner).poll_next(cx) {
125            Poll::Pending => Poll::Pending,
126            Poll::Ready(None) => {
127                // This is run at the end of the inner stream to collect all tokens into
128                // a single unified `Message`.
129                let mut choice = vec![];
130
131                stream.tool_calls.iter().for_each(|tc| {
132                    choice.push(AssistantContent::ToolCall(tc.clone()));
133                });
134
135                // This is required to ensure there's always at least one item in the content
136                if choice.is_empty() || !stream.text.is_empty() {
137                    choice.insert(0, AssistantContent::text(stream.text.clone()));
138                }
139
140                stream.choice = OneOrMany::many(choice)
141                    .expect("There should be at least one assistant message");
142
143                Poll::Ready(None)
144            }
145            Poll::Ready(Some(Err(err))) => {
146                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
147                {
148                    return Poll::Ready(None); // Treat cancellation as stream termination
149                }
150                Poll::Ready(Some(Err(err)))
151            }
152            Poll::Ready(Some(Ok(choice))) => match choice {
153                RawStreamingChoice::Message(text) => {
154                    // Forward the streaming tokens to the outer stream
155                    // and concat the text together
156                    stream.text = format!("{}{}", stream.text, text.clone());
157                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
158                }
159                RawStreamingChoice::Reasoning { id, reasoning } => {
160                    // Forward the streaming tokens to the outer stream
161                    // and concat the text together
162                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
163                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
164                        id,
165                        reasoning: vec![stream.reasoning.clone()],
166                    }))))
167                }
168                RawStreamingChoice::ToolCall {
169                    id,
170                    name,
171                    arguments,
172                    call_id,
173                } => {
174                    // Keep track of each tool call to aggregate the final message later
175                    // and pass it to the outer stream
176                    stream.tool_calls.push(ToolCall {
177                        id: id.clone(),
178                        call_id: call_id.clone(),
179                        function: ToolFunction {
180                            name: name.clone(),
181                            arguments: arguments.clone(),
182                        },
183                    });
184                    if let Some(call_id) = call_id {
185                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
186                            id, call_id, name, arguments,
187                        ))))
188                    } else {
189                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
190                            id, name, arguments,
191                        ))))
192                    }
193                }
194                RawStreamingChoice::FinalResponse(response) => {
195                    if stream
196                        .final_response_yielded
197                        .load(std::sync::atomic::Ordering::SeqCst)
198                    {
199                        stream.poll_next_unpin(cx)
200                    } else {
201                        // Set the final response field and return the next item in the stream
202                        stream.response = Some(response.clone());
203                        stream
204                            .final_response_yielded
205                            .store(true, std::sync::atomic::Ordering::SeqCst);
206                        let final_response = StreamedAssistantContent::final_response(response);
207                        Poll::Ready(Some(Ok(final_response)))
208                    }
209                }
210            },
211        }
212    }
213}
214
215/// Trait for high-level streaming prompt interface
216pub trait StreamingPrompt<M, R>
217where
218    M: CompletionModel + 'static,
219    <M as CompletionModel>::StreamingResponse: Send,
220    R: Clone + Unpin + GetTokenUsage,
221{
222    /// Stream a simple prompt to the model
223    fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<'_, M>;
224}
225
226/// Trait for high-level streaming chat interface
227pub trait StreamingChat<M, R>: Send + Sync
228where
229    M: CompletionModel + 'static,
230    <M as CompletionModel>::StreamingResponse: Send,
231    R: Clone + Unpin + GetTokenUsage,
232{
233    /// Stream a chat with history to the model
234    fn stream_chat(
235        &self,
236        prompt: impl Into<Message> + Send,
237        chat_history: Vec<Message>,
238    ) -> StreamingPromptRequest<'_, M>;
239}
240
241/// Trait for low-level streaming completion interface
242pub trait StreamingCompletion<M: CompletionModel> {
243    /// Generate a streaming completion from a request
244    fn stream_completion(
245        &self,
246        prompt: impl Into<Message> + Send,
247        chat_history: Vec<Message>,
248    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
249}
250
251pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
252    pub(crate) inner: StreamingResult<R>,
253}
254
255impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
256    type Item = Result<RawStreamingChoice<()>, CompletionError>;
257
258    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
259        let stream = self.get_mut();
260
261        match stream.inner.as_mut().poll_next(cx) {
262            Poll::Pending => Poll::Pending,
263            Poll::Ready(None) => Poll::Ready(None),
264            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
265            Poll::Ready(Some(Ok(chunk))) => match chunk {
266                RawStreamingChoice::FinalResponse(_) => {
267                    Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
268                }
269                RawStreamingChoice::Message(m) => {
270                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
271                }
272                RawStreamingChoice::Reasoning { id, reasoning } => {
273                    Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { id, reasoning })))
274                }
275                RawStreamingChoice::ToolCall {
276                    id,
277                    name,
278                    arguments,
279                    call_id,
280                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
281                    id,
282                    name,
283                    arguments,
284                    call_id,
285                }))),
286            },
287        }
288    }
289}
290
291/// helper function to stream a completion request to stdout
292pub async fn stream_to_stdout<M: CompletionModel>(
293    agent: &Agent<M>,
294    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
295) -> Result<(), std::io::Error> {
296    let mut is_reasoning = false;
297    print!("Response: ");
298    while let Some(chunk) = stream.next().await {
299        match chunk {
300            Ok(StreamedAssistantContent::Text(text)) => {
301                if is_reasoning {
302                    is_reasoning = false;
303                    println!("\n---\n");
304                }
305                print!("{}", text.text);
306                std::io::Write::flush(&mut std::io::stdout())?;
307            }
308            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
309                let res = agent
310                    .tools
311                    .call(
312                        &tool_call.function.name,
313                        tool_call.function.arguments.to_string(),
314                    )
315                    .await
316                    .map_err(|e| std::io::Error::other(e.to_string()))?;
317                println!("\nResult: {res}");
318            }
319            Ok(StreamedAssistantContent::Final(res)) => {
320                let json_res = serde_json::to_string_pretty(&res).unwrap();
321                println!();
322                tracing::info!("Final result: {json_res}");
323            }
324            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
325                if !is_reasoning {
326                    is_reasoning = true;
327                    println!();
328                    println!("Thinking: ");
329                }
330                let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
331
332                print!("{reasoning}");
333                std::io::Write::flush(&mut std::io::stdout())?;
334            }
335            Err(e) => {
336                if e.to_string().contains("aborted") {
337                    println!("\nStream cancelled.");
338                    break;
339                }
340                eprintln!("Error: {e}");
341                break;
342            }
343        }
344    }
345
346    println!(); // New line after streaming completes
347
348    Ok(())
349}
350
351// Test module
352#[cfg(test)]
353mod tests {
354    use std::time::Duration;
355
356    use super::*;
357    use async_stream::stream;
358    use tokio::time::sleep;
359
360    #[derive(Debug, Clone)]
361    pub struct MockResponse {
362        #[allow(dead_code)]
363        token_count: u32,
364    }
365
366    impl GetTokenUsage for MockResponse {
367        fn token_usage(&self) -> Option<crate::completion::Usage> {
368            let mut usage = Usage::new();
369            usage.total_tokens = 15;
370            Some(usage)
371        }
372    }
373
374    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
375        let stream = stream! {
376            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
377            sleep(Duration::from_millis(100)).await;
378            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
379            sleep(Duration::from_millis(100)).await;
380            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
381            sleep(Duration::from_millis(100)).await;
382            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
383        };
384
385        #[cfg(not(target_arch = "wasm32"))]
386        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
387        #[cfg(target_arch = "wasm32")]
388        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
389
390        StreamingCompletionResponse::stream(pinned_stream)
391    }
392
393    #[tokio::test]
394    async fn test_stream_cancellation() {
395        let mut stream = create_mock_stream();
396
397        println!("Response: ");
398        let mut chunk_count = 0;
399        while let Some(chunk) = stream.next().await {
400            match chunk {
401                Ok(StreamedAssistantContent::Text(text)) => {
402                    print!("{}", text.text);
403                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
404                    chunk_count += 1;
405                }
406                Ok(StreamedAssistantContent::ToolCall(tc)) => {
407                    println!("\nTool Call: {tc:?}");
408                    chunk_count += 1;
409                }
410                Ok(StreamedAssistantContent::Final(res)) => {
411                    println!("\nFinal response: {res:?}");
412                }
413                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
414                    let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
415                    print!("{reasoning}");
416                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
417                }
418                Err(e) => {
419                    eprintln!("Error: {e:?}");
420                    break;
421                }
422            }
423
424            if chunk_count >= 2 {
425                println!("\nCancelling stream...");
426                stream.cancel();
427                println!("Stream cancelled.");
428                break;
429            }
430        }
431
432        let next_chunk = stream.next().await;
433        assert!(
434            next_chunk.is_none(),
435            "Expected no further chunks after cancellation, got {next_chunk:?}"
436        );
437    }
438}
439
440/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
441#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
442#[serde(untagged)]
443pub enum StreamedAssistantContent<R> {
444    Text(Text),
445    ToolCall(ToolCall),
446    Reasoning(Reasoning),
447    Final(R),
448}
449
450impl<R> StreamedAssistantContent<R>
451where
452    R: Clone + Unpin,
453{
454    pub fn text(text: &str) -> Self {
455        Self::Text(Text {
456            text: text.to_string(),
457        })
458    }
459
460    /// Helper constructor to make creating assistant tool call content easier.
461    pub fn tool_call(
462        id: impl Into<String>,
463        name: impl Into<String>,
464        arguments: serde_json::Value,
465    ) -> Self {
466        Self::ToolCall(ToolCall {
467            id: id.into(),
468            call_id: None,
469            function: ToolFunction {
470                name: name.into(),
471                arguments,
472            },
473        })
474    }
475
476    pub fn tool_call_with_call_id(
477        id: impl Into<String>,
478        call_id: String,
479        name: impl Into<String>,
480        arguments: serde_json::Value,
481    ) -> Self {
482        Self::ToolCall(ToolCall {
483            id: id.into(),
484            call_id: Some(call_id),
485            function: ToolFunction {
486                name: name.into(),
487                arguments,
488            },
489        })
490    }
491
492    pub fn final_response(res: R) -> Self {
493        Self::Final(res)
494    }
495}