rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::agent::Agent;
12use crate::completion::{
13    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message,
14};
15use crate::message::{AssistantContent, ToolCall, ToolFunction};
16use crate::OneOrMany;
17use futures::{Stream, StreamExt};
18use std::boxed::Box;
19use std::future::Future;
20use std::pin::Pin;
21use std::task::{Context, Poll};
22
23/// Enum representing a streaming chunk from the model
24#[derive(Debug, Clone)]
25pub enum RawStreamingChoice<R: Clone> {
26    /// A text chunk from a message response
27    Message(String),
28
29    /// A tool call response chunk
30    ToolCall {
31        id: String,
32        name: String,
33        arguments: serde_json::Value,
34    },
35
36    /// The final response object, must be yielded if you want the
37    /// `response` field to be populated on the `StreamingCompletionResponse`
38    FinalResponse(R),
39}
40
41#[cfg(not(target_arch = "wasm32"))]
42pub type StreamingResult<R> =
43    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
44
45#[cfg(target_arch = "wasm32")]
46pub type StreamingResult<R> =
47    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
48
49/// The response from a streaming completion request;
50/// message and response are populated at the end of the
51/// `inner` stream.
52pub struct StreamingCompletionResponse<R: Clone + Unpin> {
53    pub(crate) inner: StreamingResult<R>,
54    text: String,
55    tool_calls: Vec<ToolCall>,
56    /// The final aggregated message from the stream
57    /// contains all text and tool calls generated
58    pub choice: OneOrMany<AssistantContent>,
59    /// The final response from the stream, may be `None`
60    /// if the provider didn't yield it during the stream
61    pub response: Option<R>,
62}
63
64impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
65    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
66        Self {
67            inner,
68            text: "".to_string(),
69            tool_calls: vec![],
70            choice: OneOrMany::one(AssistantContent::text("")),
71            response: None,
72        }
73    }
74}
75
76impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
77    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
78        CompletionResponse {
79            choice: value.choice,
80            raw_response: value.response,
81        }
82    }
83}
84
85impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
86    type Item = Result<AssistantContent, CompletionError>;
87
88    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89        let stream = self.get_mut();
90
91        match stream.inner.as_mut().poll_next(cx) {
92            Poll::Pending => Poll::Pending,
93            Poll::Ready(None) => {
94                // This is run at the end of the inner stream to collect all tokens into
95                // a single unified `Message`.
96                let mut choice = vec![];
97
98                stream.tool_calls.iter().for_each(|tc| {
99                    choice.push(AssistantContent::ToolCall(tc.clone()));
100                });
101
102                // This is required to ensure there's always at least one item in the content
103                if choice.is_empty() || !stream.text.is_empty() {
104                    choice.insert(0, AssistantContent::text(stream.text.clone()));
105                }
106
107                stream.choice = OneOrMany::many(choice)
108                    .expect("There should be at least one assistant message");
109
110                Poll::Ready(None)
111            }
112            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
113            Poll::Ready(Some(Ok(choice))) => match choice {
114                RawStreamingChoice::Message(text) => {
115                    // Forward the streaming tokens to the outer stream
116                    // and concat the text together
117                    stream.text = format!("{}{}", stream.text, text.clone());
118                    Poll::Ready(Some(Ok(AssistantContent::text(text))))
119                }
120                RawStreamingChoice::ToolCall {
121                    id,
122                    name,
123                    arguments,
124                } => {
125                    // Keep track of each tool call to aggregate the final message later
126                    // and pass it to the outer stream
127                    stream.tool_calls.push(ToolCall {
128                        id: id.clone(),
129                        function: ToolFunction {
130                            name: name.clone(),
131                            arguments: arguments.clone(),
132                        },
133                    });
134                    Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, arguments))))
135                }
136                RawStreamingChoice::FinalResponse(response) => {
137                    // Set the final response field and return the next item in the stream
138                    stream.response = Some(response);
139
140                    stream.poll_next_unpin(cx)
141                }
142            },
143        }
144    }
145}
146
147/// Trait for high-level streaming prompt interface
148pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
149    /// Stream a simple prompt to the model
150    fn stream_prompt(
151        &self,
152        prompt: impl Into<Message> + Send,
153    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
154}
155
156/// Trait for high-level streaming chat interface
157pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
158    /// Stream a chat with history to the model
159    fn stream_chat(
160        &self,
161        prompt: impl Into<Message> + Send,
162        chat_history: Vec<Message>,
163    ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
164}
165
166/// Trait for low-level streaming completion interface
167pub trait StreamingCompletion<M: CompletionModel> {
168    /// Generate a streaming completion from a request
169    fn stream_completion(
170        &self,
171        prompt: impl Into<Message> + Send,
172        chat_history: Vec<Message>,
173    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
174}
175
176pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
177    pub(crate) inner: StreamingResult<R>,
178}
179
180impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
181    type Item = Result<RawStreamingChoice<()>, CompletionError>;
182
183    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
184        let stream = self.get_mut();
185
186        match stream.inner.as_mut().poll_next(cx) {
187            Poll::Pending => Poll::Pending,
188            Poll::Ready(None) => Poll::Ready(None),
189            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
190            Poll::Ready(Some(Ok(chunk))) => match chunk {
191                RawStreamingChoice::FinalResponse(_) => {
192                    Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
193                }
194                RawStreamingChoice::Message(m) => {
195                    Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
196                }
197                RawStreamingChoice::ToolCall {
198                    id,
199                    name,
200                    arguments,
201                } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
202                    id,
203                    name,
204                    arguments,
205                }))),
206            },
207        }
208    }
209}
210
211/// helper function to stream a completion request to stdout
212pub async fn stream_to_stdout<M: CompletionModel>(
213    agent: &Agent<M>,
214    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
215) -> Result<(), std::io::Error> {
216    print!("Response: ");
217    while let Some(chunk) = stream.next().await {
218        match chunk {
219            Ok(AssistantContent::Text(text)) => {
220                print!("{}", text.text);
221                std::io::Write::flush(&mut std::io::stdout())?;
222            }
223            Ok(AssistantContent::ToolCall(tool_call)) => {
224                let res = agent
225                    .tools
226                    .call(
227                        &tool_call.function.name,
228                        tool_call.function.arguments.to_string(),
229                    )
230                    .await
231                    .map_err(|e| std::io::Error::other(e.to_string()))?;
232                println!("\nResult: {res}");
233            }
234            Err(e) => {
235                eprintln!("Error: {e}");
236                break;
237            }
238        }
239    }
240
241    println!(); // New line after streaming completes
242
243    Ok(())
244}