rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::completion::{
14    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message,
15};
16use crate::message::{AssistantContent, ToolCall, ToolFunction};
17use futures::stream::{AbortHandle, Abortable};
18use futures::{Stream, StreamExt};
19use std::boxed::Box;
20use std::future::Future;
21use std::pin::Pin;
22use std::task::{Context, Poll};
23
24/// Enum representing a streaming chunk from the model
25#[derive(Debug, Clone)]
26pub enum RawStreamingChoice<R: Clone> {
27    /// A text chunk from a message response
28    Message(String),
29
30    /// A tool call response chunk
31    ToolCall {
32        id: String,
33        call_id: Option<String>,
34        name: String,
35        arguments: serde_json::Value,
36    },
37
38    /// The final response object, must be yielded if you want the
39    /// `response` field to be populated on the `StreamingCompletionResponse`
40    FinalResponse(R),
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub type StreamingResult<R> =
45    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
46
47#[cfg(target_arch = "wasm32")]
48pub type StreamingResult<R> =
49    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
50
51/// The response from a streaming completion request;
52/// message and response are populated at the end of the
53/// `inner` stream.
54pub struct StreamingCompletionResponse<R: Clone + Unpin> {
55    pub(crate) inner: Abortable<StreamingResult<R>>,
56    pub(crate) abort_handle: AbortHandle,
57    text: String,
58    tool_calls: Vec<ToolCall>,
59    /// The final aggregated message from the stream
60    /// contains all text and tool calls generated
61    pub choice: OneOrMany<AssistantContent>,
62    /// The final response from the stream, may be `None`
63    /// if the provider didn't yield it during the stream
64    pub response: Option<R>,
65}
66
67impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
68    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
69        let (abort_handle, abort_registration) = AbortHandle::new_pair();
70        let abortable_stream = Abortable::new(inner, abort_registration);
71        Self {
72            inner: abortable_stream,
73            abort_handle,
74            text: "".to_string(),
75            tool_calls: vec![],
76            choice: OneOrMany::one(AssistantContent::text("")),
77            response: None,
78        }
79    }
80
81    pub fn cancel(&self) {
82        self.abort_handle.abort();
83    }
84}
85
86impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
87    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
88        CompletionResponse {
89            choice: value.choice,
90            raw_response: value.response,
91        }
92    }
93}
94
95impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
96    type Item = Result<AssistantContent, CompletionError>;
97
98    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99        let stream = self.get_mut();
100
101        match Pin::new(&mut stream.inner).poll_next(cx) {
102            Poll::Pending => Poll::Pending,
103            Poll::Ready(None) => {
104                // This is run at the end of the inner stream to collect all tokens into
105                // a single unified `Message`.
106                let mut choice = vec![];
107
108                stream.tool_calls.iter().for_each(|tc| {
109                    choice.push(AssistantContent::ToolCall(tc.clone()));
110                });
111
112                // This is required to ensure there's always at least one item in the content
113                if choice.is_empty() || !stream.text.is_empty() {
114                    choice.insert(0, AssistantContent::text(stream.text.clone()));
115                }
116
117                stream.choice = OneOrMany::many(choice)
118                    .expect("There should be at least one assistant message");
119
120                Poll::Ready(None)
121            }
122            Poll::Ready(Some(Err(err))) => {
123                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
124                {
125                    return Poll::Ready(None); // Treat cancellation as stream termination
126                }
127                Poll::Ready(Some(Err(err)))
128            }
129            Poll::Ready(Some(Ok(choice))) => match choice {
130                RawStreamingChoice::Message(text) => {
131                    // Forward the streaming tokens to the outer stream
132                    // and concat the text together
133                    stream.text = format!("{}{}", stream.text, text.clone());
134                    Poll::Ready(Some(Ok(AssistantContent::text(text))))
135                }
136                RawStreamingChoice::ToolCall {
137                    id,
138                    name,
139                    arguments,
140                    call_id,
141                } => {
142                    // Keep track of each tool call to aggregate the final message later
143                    // and pass it to the outer stream
144                    stream.tool_calls.push(ToolCall {
145                        id: id.clone(),
146                        call_id,
147                        function: ToolFunction {
148                            name: name.clone(),
149                            arguments: arguments.clone(),
150                        },
151                    });
152                    Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, arguments))))
153                }
154                RawStreamingChoice::FinalResponse(response) => {
155                    // Set the final response field and return the next item in the stream
156                    stream.response = Some(response);
157
158                    stream.poll_next_unpin(cx)
159                }
160            },
161        }
162    }
163}
164
165/// Trait for high-level streaming prompt interface
166pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
167    /// Stream a simple prompt to the model
168    fn stream_prompt(
169        &self,
170        prompt: impl Into<Message> + Send,
171    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
172}
173
174/// Trait for high-level streaming chat interface
175pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
176    /// Stream a chat with history to the model
177    fn stream_chat(
178        &self,
179        prompt: impl Into<Message> + Send,
180        chat_history: Vec<Message>,
181    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
182}
183
184/// Trait for low-level streaming completion interface
185pub trait StreamingCompletion<M: CompletionModel> {
186    /// Generate a streaming completion from a request
187    fn stream_completion(
188        &self,
189        prompt: impl Into<Message> + Send,
190        chat_history: Vec<Message>,
191    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
192}
193
194pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
195    pub(crate) inner: StreamingResult<R>,
196}
197
198impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
199    type Item = Result<RawStreamingChoice<()>, CompletionError>;
200
201    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
202        let stream = self.get_mut();
203
204        match stream.inner.as_mut().poll_next(cx) {
205            Poll::Pending => Poll::Pending,
206            Poll::Ready(None) => Poll::Ready(None),
207            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
208            Poll::Ready(Some(Ok(chunk))) => match chunk {
209                RawStreamingChoice::FinalResponse(_) => {
210                    Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
211                }
212                RawStreamingChoice::Message(m) => {
213                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
214                }
215                RawStreamingChoice::ToolCall {
216                    id,
217                    name,
218                    arguments,
219                    call_id,
220                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
221                    id,
222                    name,
223                    arguments,
224                    call_id,
225                }))),
226            },
227        }
228    }
229}
230
231/// helper function to stream a completion request to stdout
232pub async fn stream_to_stdout<M: CompletionModel>(
233    agent: &Agent<M>,
234    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
235) -> Result<(), std::io::Error> {
236    print!("Response: ");
237    while let Some(chunk) = stream.next().await {
238        match chunk {
239            Ok(AssistantContent::Text(text)) => {
240                print!("{}", text.text);
241                std::io::Write::flush(&mut std::io::stdout())?;
242            }
243            Ok(AssistantContent::ToolCall(tool_call)) => {
244                let res = agent
245                    .tools
246                    .call(
247                        &tool_call.function.name,
248                        tool_call.function.arguments.to_string(),
249                    )
250                    .await
251                    .map_err(|e| std::io::Error::other(e.to_string()))?;
252                println!("\nResult: {res}");
253            }
254            Err(e) => {
255                if e.to_string().contains("aborted") {
256                    println!("\nStream cancelled.");
257                    break;
258                }
259                eprintln!("Error: {e}");
260                break;
261            }
262        }
263    }
264
265    println!(); // New line after streaming completes
266
267    Ok(())
268}
269
270// Test module
271#[cfg(test)]
272mod tests {
273    use std::time::Duration;
274
275    use super::*;
276    use async_stream::stream;
277    use tokio::time::sleep;
278
279    #[derive(Debug, Clone)]
280    pub struct MockResponse {
281        #[allow(dead_code)]
282        token_count: u32,
283    }
284
285    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
286        let stream = stream! {
287            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
288            sleep(Duration::from_millis(100)).await;
289            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
290            sleep(Duration::from_millis(100)).await;
291            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
292            sleep(Duration::from_millis(100)).await;
293            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
294        };
295
296        #[cfg(not(target_arch = "wasm32"))]
297        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
298        #[cfg(target_arch = "wasm32")]
299        let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
300
301        StreamingCompletionResponse::stream(pinned_stream)
302    }
303
304    #[tokio::test]
305    async fn test_stream_cancellation() {
306        let mut stream = create_mock_stream();
307
308        println!("Response: ");
309        let mut chunk_count = 0;
310        while let Some(chunk) = stream.next().await {
311            match chunk {
312                Ok(AssistantContent::Text(text)) => {
313                    print!("{}", text.text);
314                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
315                    chunk_count += 1;
316                }
317                Ok(AssistantContent::ToolCall(tc)) => {
318                    println!("\nTool Call: {tc:?}");
319                    chunk_count += 1;
320                }
321                Err(e) => {
322                    eprintln!("Error: {e:?}");
323                    break;
324                }
325            }
326
327            if chunk_count >= 2 {
328                println!("\nCancelling stream...");
329                stream.cancel();
330                println!("Stream cancelled.");
331                break;
332            }
333        }
334
335        let next_chunk = stream.next().await;
336        assert!(
337            next_chunk.is_none(),
338            "Expected no further chunks after cancellation, got {next_chunk:?}"
339        );
340    }
341}