rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::completion::{
14    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message, Usage,
15};
16use crate::message::{AssistantContent, ToolCall, ToolFunction};
17use futures::stream::{AbortHandle, Abortable};
18use futures::{Stream, StreamExt};
19use std::boxed::Box;
20use std::future::Future;
21use std::pin::Pin;
22use std::task::{Context, Poll};
23
24/// Enum representing a streaming chunk from the model
25#[derive(Debug, Clone)]
26pub enum RawStreamingChoice<R: Clone> {
27    /// A text chunk from a message response
28    Message(String),
29
30    /// A tool call response chunk
31    ToolCall {
32        id: String,
33        call_id: Option<String>,
34        name: String,
35        arguments: serde_json::Value,
36    },
37
38    /// The final response object, must be yielded if you want the
39    /// `response` field to be populated on the `StreamingCompletionResponse`
40    FinalResponse(R),
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub type StreamingResult<R> =
45    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
46
47#[cfg(target_arch = "wasm32")]
48pub type StreamingResult<R> =
49    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
50
51/// The response from a streaming completion request;
52/// message and response are populated at the end of the
53/// `inner` stream.
54pub struct StreamingCompletionResponse<R: Clone + Unpin> {
55    pub(crate) inner: Abortable<StreamingResult<R>>,
56    pub(crate) abort_handle: AbortHandle,
57    text: String,
58    tool_calls: Vec<ToolCall>,
59    /// The final aggregated message from the stream
60    /// contains all text and tool calls generated
61    pub choice: OneOrMany<AssistantContent>,
62    /// The final response from the stream, may be `None`
63    /// if the provider didn't yield it during the stream
64    pub response: Option<R>,
65}
66
67impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
68    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
69        let (abort_handle, abort_registration) = AbortHandle::new_pair();
70        let abortable_stream = Abortable::new(inner, abort_registration);
71        Self {
72            inner: abortable_stream,
73            abort_handle,
74            text: "".to_string(),
75            tool_calls: vec![],
76            choice: OneOrMany::one(AssistantContent::text("")),
77            response: None,
78        }
79    }
80
81    pub fn cancel(&self) {
82        self.abort_handle.abort();
83    }
84}
85
86impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
87    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
88        CompletionResponse {
89            choice: value.choice,
90            usage: Usage::new(), // Usage is not tracked in streaming responses
91            raw_response: value.response,
92        }
93    }
94}
95
96impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
97    type Item = Result<AssistantContent, CompletionError>;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        let stream = self.get_mut();
101
102        match Pin::new(&mut stream.inner).poll_next(cx) {
103            Poll::Pending => Poll::Pending,
104            Poll::Ready(None) => {
105                // This is run at the end of the inner stream to collect all tokens into
106                // a single unified `Message`.
107                let mut choice = vec![];
108
109                stream.tool_calls.iter().for_each(|tc| {
110                    choice.push(AssistantContent::ToolCall(tc.clone()));
111                });
112
113                // This is required to ensure there's always at least one item in the content
114                if choice.is_empty() || !stream.text.is_empty() {
115                    choice.insert(0, AssistantContent::text(stream.text.clone()));
116                }
117
118                stream.choice = OneOrMany::many(choice)
119                    .expect("There should be at least one assistant message");
120
121                Poll::Ready(None)
122            }
123            Poll::Ready(Some(Err(err))) => {
124                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
125                {
126                    return Poll::Ready(None); // Treat cancellation as stream termination
127                }
128                Poll::Ready(Some(Err(err)))
129            }
130            Poll::Ready(Some(Ok(choice))) => match choice {
131                RawStreamingChoice::Message(text) => {
132                    // Forward the streaming tokens to the outer stream
133                    // and concat the text together
134                    stream.text = format!("{}{}", stream.text, text.clone());
135                    Poll::Ready(Some(Ok(AssistantContent::text(text))))
136                }
137                RawStreamingChoice::ToolCall {
138                    id,
139                    name,
140                    arguments,
141                    call_id,
142                } => {
143                    // Keep track of each tool call to aggregate the final message later
144                    // and pass it to the outer stream
145                    stream.tool_calls.push(ToolCall {
146                        id: id.clone(),
147                        call_id,
148                        function: ToolFunction {
149                            name: name.clone(),
150                            arguments: arguments.clone(),
151                        },
152                    });
153                    Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, arguments))))
154                }
155                RawStreamingChoice::FinalResponse(response) => {
156                    // Set the final response field and return the next item in the stream
157                    stream.response = Some(response);
158
159                    stream.poll_next_unpin(cx)
160                }
161            },
162        }
163    }
164}
165
166/// Trait for high-level streaming prompt interface
167pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
168    /// Stream a simple prompt to the model
169    fn stream_prompt(
170        &self,
171        prompt: impl Into<Message> + Send,
172    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
173}
174
175/// Trait for high-level streaming chat interface
176pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
177    /// Stream a chat with history to the model
178    fn stream_chat(
179        &self,
180        prompt: impl Into<Message> + Send,
181        chat_history: Vec<Message>,
182    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
183}
184
185/// Trait for low-level streaming completion interface
186pub trait StreamingCompletion<M: CompletionModel> {
187    /// Generate a streaming completion from a request
188    fn stream_completion(
189        &self,
190        prompt: impl Into<Message> + Send,
191        chat_history: Vec<Message>,
192    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
193}
194
195pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
196    pub(crate) inner: StreamingResult<R>,
197}
198
199impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
200    type Item = Result<RawStreamingChoice<()>, CompletionError>;
201
202    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
203        let stream = self.get_mut();
204
205        match stream.inner.as_mut().poll_next(cx) {
206            Poll::Pending => Poll::Pending,
207            Poll::Ready(None) => Poll::Ready(None),
208            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
209            Poll::Ready(Some(Ok(chunk))) => match chunk {
210                RawStreamingChoice::FinalResponse(_) => {
211                    Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
212                }
213                RawStreamingChoice::Message(m) => {
214                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
215                }
216                RawStreamingChoice::ToolCall {
217                    id,
218                    name,
219                    arguments,
220                    call_id,
221                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
222                    id,
223                    name,
224                    arguments,
225                    call_id,
226                }))),
227            },
228        }
229    }
230}
231
232/// helper function to stream a completion request to stdout
233pub async fn stream_to_stdout<M: CompletionModel>(
234    agent: &Agent<M>,
235    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
236) -> Result<(), std::io::Error> {
237    print!("Response: ");
238    while let Some(chunk) = stream.next().await {
239        match chunk {
240            Ok(AssistantContent::Text(text)) => {
241                print!("{}", text.text);
242                std::io::Write::flush(&mut std::io::stdout())?;
243            }
244            Ok(AssistantContent::ToolCall(tool_call)) => {
245                let res = agent
246                    .tools
247                    .call(
248                        &tool_call.function.name,
249                        tool_call.function.arguments.to_string(),
250                    )
251                    .await
252                    .map_err(|e| std::io::Error::other(e.to_string()))?;
253                println!("\nResult: {res}");
254            }
255            Err(e) => {
256                if e.to_string().contains("aborted") {
257                    println!("\nStream cancelled.");
258                    break;
259                }
260                eprintln!("Error: {e}");
261                break;
262            }
263        }
264    }
265
266    println!(); // New line after streaming completes
267
268    Ok(())
269}
270
271// Test module
272#[cfg(test)]
273mod tests {
274    use std::time::Duration;
275
276    use super::*;
277    use async_stream::stream;
278    use tokio::time::sleep;
279
280    #[derive(Debug, Clone)]
281    pub struct MockResponse {
282        #[allow(dead_code)]
283        token_count: u32,
284    }
285
286    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
287        let stream = stream! {
288            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
289            sleep(Duration::from_millis(100)).await;
290            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
291            sleep(Duration::from_millis(100)).await;
292            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
293            sleep(Duration::from_millis(100)).await;
294            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
295        };
296
297        #[cfg(not(target_arch = "wasm32"))]
298        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
299        #[cfg(target_arch = "wasm32")]
300        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
301
302        StreamingCompletionResponse::stream(pinned_stream)
303    }
304
305    #[tokio::test]
306    async fn test_stream_cancellation() {
307        let mut stream = create_mock_stream();
308
309        println!("Response: ");
310        let mut chunk_count = 0;
311        while let Some(chunk) = stream.next().await {
312            match chunk {
313                Ok(AssistantContent::Text(text)) => {
314                    print!("{}", text.text);
315                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
316                    chunk_count += 1;
317                }
318                Ok(AssistantContent::ToolCall(tc)) => {
319                    println!("\nTool Call: {tc:?}");
320                    chunk_count += 1;
321                }
322                Err(e) => {
323                    eprintln!("Error: {e:?}");
324                    break;
325                }
326            }
327
328            if chunk_count >= 2 {
329                println!("\nCancelling stream...");
330                stream.cancel();
331                println!("Stream cancelled.");
332                break;
333            }
334        }
335
336        let next_chunk = stream.next().await;
337        assert!(
338            next_chunk.is_none(),
339            "Expected no further chunks after cancellation, got {next_chunk:?}"
340        );
341    }
342}