rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::completion::{
14    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message, Usage,
15};
16use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
17use futures::stream::{AbortHandle, Abortable};
18use futures::{Stream, StreamExt};
19use serde::{Deserialize, Serialize};
20use std::boxed::Box;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::atomic::AtomicBool;
24use std::task::{Context, Poll};
25
26/// Enum representing a streaming chunk from the model
27#[derive(Debug, Clone)]
28pub enum RawStreamingChoice<R: Clone> {
29    /// A text chunk from a message response
30    Message(String),
31
32    /// A tool call response chunk
33    ToolCall {
34        id: String,
35        call_id: Option<String>,
36        name: String,
37        arguments: serde_json::Value,
38    },
39    /// A reasoning chunk
40    Reasoning { reasoning: String },
41
42    /// The final response object, must be yielded if you want the
43    /// `response` field to be populated on the `StreamingCompletionResponse`
44    FinalResponse(R),
45}
46
47#[cfg(not(target_arch = "wasm32"))]
48pub type StreamingResult<R> =
49    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
50
51#[cfg(target_arch = "wasm32")]
52pub type StreamingResult<R> =
53    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
54
55/// The response from a streaming completion request;
56/// message and response are populated at the end of the
57/// `inner` stream.
58pub struct StreamingCompletionResponse<R: Clone + Unpin> {
59    pub(crate) inner: Abortable<StreamingResult<R>>,
60    pub(crate) abort_handle: AbortHandle,
61    text: String,
62    reasoning: String,
63    tool_calls: Vec<ToolCall>,
64    /// The final aggregated message from the stream
65    /// contains all text and tool calls generated
66    pub choice: OneOrMany<AssistantContent>,
67    /// The final response from the stream, may be `None`
68    /// if the provider didn't yield it during the stream
69    pub response: Option<R>,
70    pub final_response_yielded: AtomicBool,
71}
72
73impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
74    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
75        let (abort_handle, abort_registration) = AbortHandle::new_pair();
76        let abortable_stream = Abortable::new(inner, abort_registration);
77        Self {
78            inner: abortable_stream,
79            abort_handle,
80            reasoning: String::new(),
81            text: "".to_string(),
82            tool_calls: vec![],
83            choice: OneOrMany::one(AssistantContent::text("")),
84            response: None,
85            final_response_yielded: AtomicBool::new(false),
86        }
87    }
88
89    pub fn cancel(&self) {
90        self.abort_handle.abort();
91    }
92}
93
94impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
95    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
96        CompletionResponse {
97            choice: value.choice,
98            usage: Usage::new(), // Usage is not tracked in streaming responses
99            raw_response: value.response,
100        }
101    }
102}
103
104impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
105    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
106
107    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        let stream = self.get_mut();
109
110        match Pin::new(&mut stream.inner).poll_next(cx) {
111            Poll::Pending => Poll::Pending,
112            Poll::Ready(None) => {
113                // This is run at the end of the inner stream to collect all tokens into
114                // a single unified `Message`.
115                let mut choice = vec![];
116
117                stream.tool_calls.iter().for_each(|tc| {
118                    choice.push(AssistantContent::ToolCall(tc.clone()));
119                });
120
121                // This is required to ensure there's always at least one item in the content
122                if choice.is_empty() || !stream.text.is_empty() {
123                    choice.insert(0, AssistantContent::text(stream.text.clone()));
124                }
125
126                stream.choice = OneOrMany::many(choice)
127                    .expect("There should be at least one assistant message");
128
129                Poll::Ready(None)
130            }
131            Poll::Ready(Some(Err(err))) => {
132                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
133                {
134                    return Poll::Ready(None); // Treat cancellation as stream termination
135                }
136                Poll::Ready(Some(Err(err)))
137            }
138            Poll::Ready(Some(Ok(choice))) => match choice {
139                RawStreamingChoice::Message(text) => {
140                    // Forward the streaming tokens to the outer stream
141                    // and concat the text together
142                    stream.text = format!("{}{}", stream.text, text.clone());
143                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
144                }
145                RawStreamingChoice::Reasoning { reasoning } => {
146                    // Forward the streaming tokens to the outer stream
147                    // and concat the text together
148                    stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
149                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
150                        reasoning,
151                    }))))
152                }
153                RawStreamingChoice::ToolCall {
154                    id,
155                    name,
156                    arguments,
157                    call_id,
158                } => {
159                    // Keep track of each tool call to aggregate the final message later
160                    // and pass it to the outer stream
161                    stream.tool_calls.push(ToolCall {
162                        id: id.clone(),
163                        call_id: call_id.clone(),
164                        function: ToolFunction {
165                            name: name.clone(),
166                            arguments: arguments.clone(),
167                        },
168                    });
169                    if let Some(call_id) = call_id {
170                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
171                            id, call_id, name, arguments,
172                        ))))
173                    } else {
174                        Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
175                            id, name, arguments,
176                        ))))
177                    }
178                }
179                RawStreamingChoice::FinalResponse(response) => {
180                    if stream
181                        .final_response_yielded
182                        .load(std::sync::atomic::Ordering::SeqCst)
183                    {
184                        stream.poll_next_unpin(cx)
185                    } else {
186                        // Set the final response field and return the next item in the stream
187                        stream.response = Some(response.clone());
188                        stream
189                            .final_response_yielded
190                            .store(true, std::sync::atomic::Ordering::SeqCst);
191                        let final_response = StreamedAssistantContent::final_response(response);
192                        Poll::Ready(Some(Ok(final_response)))
193                    }
194                }
195            },
196        }
197    }
198}
199
200/// Trait for high-level streaming prompt interface
201pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
202    /// Stream a simple prompt to the model
203    fn stream_prompt(
204        &self,
205        prompt: impl Into<Message> + Send,
206    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
207}
208
209/// Trait for high-level streaming chat interface
210pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
211    /// Stream a chat with history to the model
212    fn stream_chat(
213        &self,
214        prompt: impl Into<Message> + Send,
215        chat_history: Vec<Message>,
216    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
217}
218
219/// Trait for low-level streaming completion interface
220pub trait StreamingCompletion<M: CompletionModel> {
221    /// Generate a streaming completion from a request
222    fn stream_completion(
223        &self,
224        prompt: impl Into<Message> + Send,
225        chat_history: Vec<Message>,
226    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
227}
228
229pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
230    pub(crate) inner: StreamingResult<R>,
231}
232
233impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
234    type Item = Result<RawStreamingChoice<()>, CompletionError>;
235
236    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237        let stream = self.get_mut();
238
239        match stream.inner.as_mut().poll_next(cx) {
240            Poll::Pending => Poll::Pending,
241            Poll::Ready(None) => Poll::Ready(None),
242            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
243            Poll::Ready(Some(Ok(chunk))) => match chunk {
244                RawStreamingChoice::FinalResponse(_) => {
245                    Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
246                }
247                RawStreamingChoice::Message(m) => {
248                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
249                }
250                RawStreamingChoice::Reasoning { reasoning } => {
251                    Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { reasoning })))
252                }
253                RawStreamingChoice::ToolCall {
254                    id,
255                    name,
256                    arguments,
257                    call_id,
258                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
259                    id,
260                    name,
261                    arguments,
262                    call_id,
263                }))),
264            },
265        }
266    }
267}
268
269/// helper function to stream a completion request to stdout
270pub async fn stream_to_stdout<M: CompletionModel>(
271    agent: &Agent<M>,
272    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
273) -> Result<(), std::io::Error> {
274    let mut is_reasoning = false;
275    print!("Response: ");
276    while let Some(chunk) = stream.next().await {
277        match chunk {
278            Ok(StreamedAssistantContent::Text(text)) => {
279                if is_reasoning {
280                    is_reasoning = false;
281                    println!("\n---\n");
282                }
283                print!("{}", text.text);
284                std::io::Write::flush(&mut std::io::stdout())?;
285            }
286            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
287                let res = agent
288                    .tools
289                    .call(
290                        &tool_call.function.name,
291                        tool_call.function.arguments.to_string(),
292                    )
293                    .await
294                    .map_err(|e| std::io::Error::other(e.to_string()))?;
295                println!("\nResult: {res}");
296            }
297            Ok(StreamedAssistantContent::Final(res)) => {
298                let json_res = serde_json::to_string_pretty(&res).unwrap();
299                println!();
300                tracing::info!("Final result: {json_res}");
301            }
302            Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning })) => {
303                if !is_reasoning {
304                    is_reasoning = true;
305                    println!();
306                    println!("Thinking: ");
307                }
308                print!("{reasoning}");
309                std::io::Write::flush(&mut std::io::stdout())?;
310            }
311            Err(e) => {
312                if e.to_string().contains("aborted") {
313                    println!("\nStream cancelled.");
314                    break;
315                }
316                eprintln!("Error: {e}");
317                break;
318            }
319        }
320    }
321
322    println!(); // New line after streaming completes
323
324    Ok(())
325}
326
327// Test module
328#[cfg(test)]
329mod tests {
330    use std::time::Duration;
331
332    use super::*;
333    use async_stream::stream;
334    use tokio::time::sleep;
335
336    #[derive(Debug, Clone)]
337    pub struct MockResponse {
338        #[allow(dead_code)]
339        token_count: u32,
340    }
341
342    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
343        let stream = stream! {
344            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
345            sleep(Duration::from_millis(100)).await;
346            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
347            sleep(Duration::from_millis(100)).await;
348            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
349            sleep(Duration::from_millis(100)).await;
350            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
351        };
352
353        #[cfg(not(target_arch = "wasm32"))]
354        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
355        #[cfg(target_arch = "wasm32")]
356        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
357
358        StreamingCompletionResponse::stream(pinned_stream)
359    }
360
361    #[tokio::test]
362    async fn test_stream_cancellation() {
363        let mut stream = create_mock_stream();
364
365        println!("Response: ");
366        let mut chunk_count = 0;
367        while let Some(chunk) = stream.next().await {
368            match chunk {
369                Ok(StreamedAssistantContent::Text(text)) => {
370                    print!("{}", text.text);
371                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
372                    chunk_count += 1;
373                }
374                Ok(StreamedAssistantContent::ToolCall(tc)) => {
375                    println!("\nTool Call: {tc:?}");
376                    chunk_count += 1;
377                }
378                Ok(StreamedAssistantContent::Final(res)) => {
379                    println!("\nFinal response: {res:?}");
380                }
381                Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning })) => {
382                    print!("{reasoning}");
383                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
384                }
385                Err(e) => {
386                    eprintln!("Error: {e:?}");
387                    break;
388                }
389            }
390
391            if chunk_count >= 2 {
392                println!("\nCancelling stream...");
393                stream.cancel();
394                println!("Stream cancelled.");
395                break;
396            }
397        }
398
399        let next_chunk = stream.next().await;
400        assert!(
401            next_chunk.is_none(),
402            "Expected no further chunks after cancellation, got {next_chunk:?}"
403        );
404    }
405}
406
407/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
408#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
409#[serde(untagged)]
410pub enum StreamedAssistantContent<R> {
411    Text(Text),
412    ToolCall(ToolCall),
413    Reasoning(Reasoning),
414    Final(R),
415}
416
417impl<R> StreamedAssistantContent<R>
418where
419    R: Clone + Unpin,
420{
421    pub fn text(text: &str) -> Self {
422        Self::Text(Text {
423            text: text.to_string(),
424        })
425    }
426
427    /// Helper constructor to make creating assistant tool call content easier.
428    pub fn tool_call(
429        id: impl Into<String>,
430        name: impl Into<String>,
431        arguments: serde_json::Value,
432    ) -> Self {
433        Self::ToolCall(ToolCall {
434            id: id.into(),
435            call_id: None,
436            function: ToolFunction {
437                name: name.into(),
438                arguments,
439            },
440        })
441    }
442
443    pub fn tool_call_with_call_id(
444        id: impl Into<String>,
445        call_id: String,
446        name: impl Into<String>,
447        arguments: serde_json::Value,
448    ) -> Self {
449        Self::ToolCall(ToolCall {
450            id: id.into(),
451            call_id: Some(call_id),
452            function: ToolFunction {
453                name: name.into(),
454                arguments,
455            },
456        })
457    }
458
459    pub fn final_response(res: R) -> Self {
460        Self::Final(res)
461    }
462}