rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::prompt_request::PromptHook,
4    completion::GetTokenUsage,
5    message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6    streaming::{StreamedAssistantContent, StreamingCompletion},
7};
8use futures::{Stream, StreamExt};
9use serde::{Deserialize, Serialize};
10use std::{pin::Pin, sync::Arc};
11use tokio::sync::RwLock;
12
13use crate::{
14    agent::Agent,
15    completion::{CompletionError, CompletionModel, PromptError},
16    message::{Message, Text},
17    tool::ToolSetError,
18};
19
20#[cfg(not(target_arch = "wasm32"))]
21pub type StreamingResult<R> =
22    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
23
24#[cfg(target_arch = "wasm32")]
25pub type StreamingResult<R> =
26    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
27
28#[derive(Deserialize, Serialize, Debug, Clone)]
29#[serde(tag = "type", rename_all = "camelCase")]
30#[non_exhaustive]
31pub enum MultiTurnStreamItem<R> {
32    StreamItem(StreamedAssistantContent<R>),
33    FinalResponse(FinalResponse),
34}
35
36#[derive(Deserialize, Serialize, Debug, Clone)]
37#[serde(rename_all = "camelCase")]
38pub struct FinalResponse {
39    response: String,
40    aggregated_usage: crate::completion::Usage,
41}
42
43impl FinalResponse {
44    pub fn empty() -> Self {
45        Self {
46            response: String::new(),
47            aggregated_usage: crate::completion::Usage::new(),
48        }
49    }
50
51    pub fn response(&self) -> &str {
52        &self.response
53    }
54
55    pub fn usage(&self) -> crate::completion::Usage {
56        self.aggregated_usage
57    }
58}
59
60impl<R> MultiTurnStreamItem<R> {
61    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
62        Self::StreamItem(item)
63    }
64
65    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
66        Self::FinalResponse(FinalResponse {
67            response: response.to_string(),
68            aggregated_usage,
69        })
70    }
71}
72
73#[derive(Debug, thiserror::Error)]
74pub enum StreamingError {
75    #[error("CompletionError: {0}")]
76    Completion(#[from] CompletionError),
77    #[error("PromptError: {0}")]
78    Prompt(#[from] Box<PromptError>),
79    #[error("ToolSetError: {0}")]
80    Tool(#[from] ToolSetError),
81}
82
83/// A builder for creating prompt requests with customizable options.
84/// Uses generics to track which options have been set during the build process.
85///
86/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
87/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
88/// attempting to await (which will send the prompt request) can potentially return
89/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
90/// back to back.
91pub struct StreamingPromptRequest<M, P>
92where
93    M: CompletionModel,
94    P: PromptHook<M> + 'static,
95{
96    /// The prompt message to send to the model
97    prompt: Message,
98    /// Optional chat history to include with the prompt
99    /// Note: chat history needs to outlive the agent as it might be used with other agents
100    chat_history: Option<Vec<Message>>,
101    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
102    max_depth: usize,
103    /// The agent to use for execution
104    agent: Arc<Agent<M>>,
105    /// Optional per-request hook for events
106    hook: Option<P>,
107}
108
109impl<M, P> StreamingPromptRequest<M, P>
110where
111    M: CompletionModel + 'static,
112    <M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
113    P: PromptHook<M>,
114{
115    /// Create a new PromptRequest with the given prompt and model
116    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
117        Self {
118            prompt: prompt.into(),
119            chat_history: None,
120            max_depth: 0,
121            agent,
122            hook: None,
123        }
124    }
125
126    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
127    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
128    pub fn multi_turn(mut self, depth: usize) -> Self {
129        self.max_depth = depth;
130        self
131    }
132
133    /// Add chat history to the prompt request
134    pub fn with_history(mut self, history: Vec<Message>) -> Self {
135        self.chat_history = Some(history);
136        self
137    }
138
139    /// Attach a per-request hook for tool call events
140    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
141    where
142        P2: PromptHook<M>,
143    {
144        StreamingPromptRequest {
145            prompt: self.prompt,
146            chat_history: self.chat_history,
147            max_depth: self.max_depth,
148            agent: self.agent,
149            hook: Some(hook),
150        }
151    }
152
153    #[cfg_attr(feature = "worker", worker::send)]
154    async fn send(self) -> StreamingResult<M::StreamingResponse> {
155        let agent_name = self.agent.name_owned();
156
157        #[tracing::instrument(skip_all, fields(agent_name = agent_name))]
158        fn inner<M, P>(
159            req: StreamingPromptRequest<M, P>,
160            agent_name: String,
161        ) -> StreamingResult<M::StreamingResponse>
162        where
163            M: CompletionModel + 'static,
164            <M as CompletionModel>::StreamingResponse: Send,
165            P: PromptHook<M> + 'static,
166        {
167            let prompt = req.prompt;
168            let agent = req.agent;
169
170            let chat_history = if let Some(mut history) = req.chat_history {
171                history.push(prompt.clone());
172                Arc::new(RwLock::new(history))
173            } else {
174                Arc::new(RwLock::new(vec![prompt.clone()]))
175            };
176
177            let mut current_max_depth = 0;
178            let mut last_prompt_error = String::new();
179
180            let mut last_text_response = String::new();
181            let mut is_text_response = false;
182            let mut max_depth_reached = false;
183
184            let mut aggregated_usage = crate::completion::Usage::new();
185
186            Box::pin(async_stream::stream! {
187                let mut current_prompt = prompt.clone();
188                let mut did_call_tool = false;
189
190                'outer: loop {
191                    if current_max_depth > req.max_depth + 1 {
192                        last_prompt_error = current_prompt.rag_text().unwrap_or_default();
193                        max_depth_reached = true;
194                        break;
195                    }
196
197                    current_max_depth += 1;
198
199                    if req.max_depth > 1 {
200                        tracing::info!(
201                            "Current conversation depth: {}/{}",
202                            current_max_depth,
203                            req.max_depth
204                        );
205                    }
206
207                    if let Some(ref hook) = req.hook {
208                        let reader = chat_history.read().await;
209                        let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
210                        let chat_history_except_last = reader[..reader.len() - 1].to_vec();
211
212                        hook.on_completion_call(&prompt, &chat_history_except_last)
213                            .await;
214                    }
215
216
217                    let mut stream = agent
218                        .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
219                        .await?
220                        .stream()
221                        .await?;
222
223                    chat_history.write().await.push(current_prompt.clone());
224
225                    let mut tool_calls = vec![];
226                    let mut tool_results = vec![];
227
228                    while let Some(content) = stream.next().await {
229                        match content {
230                            Ok(StreamedAssistantContent::Text(text)) => {
231                                if !is_text_response {
232                                    last_text_response = String::new();
233                                    is_text_response = true;
234                                }
235                                last_text_response.push_str(&text.text);
236                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
237                                did_call_tool = false;
238                            },
239                            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
240                                if let Some(ref hook) = req.hook {
241                                    hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await;
242                                }
243                                let tool_result =
244                                    agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
245
246                                if let Some(ref hook) = req.hook {
247                                    hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string())
248                                        .await;
249                                }
250                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
251
252                                tool_calls.push(tool_call_msg);
253                                tool_results.push((tool_call.id, tool_call.call_id, tool_result));
254
255                                did_call_tool = true;
256                                // break;
257                            },
258                            Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => {
259                                chat_history.write().await.push(rig::message::Message::Assistant {
260                                    id: None,
261                                    content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
262                                        reasoning: reasoning.clone(), id: id.clone()
263                                    }))
264                                });
265                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })));
266                                did_call_tool = false;
267                            },
268                            Ok(StreamedAssistantContent::Final(final_resp)) => {
269                                if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
270                                if is_text_response {
271                                    if let Some(ref hook) = req.hook {
272                                        hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
273                                    }
274                                    yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
275                                    is_text_response = false;
276                                }
277                            }
278                            Err(e) => {
279                                yield Err(e.into());
280                                break 'outer;
281                            }
282                        }
283                    }
284
285                    // Add (parallel) tool calls to chat history
286                    if !tool_calls.is_empty() {
287                        chat_history.write().await.push(Message::Assistant {
288                            id: None,
289                            content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
290                        });
291                    }
292
293                    // Add tool results to chat history
294                    for (id, call_id, tool_result) in tool_results {
295                        if let Some(call_id) = call_id {
296                            chat_history.write().await.push(Message::User {
297                                content: OneOrMany::one(UserContent::tool_result_with_call_id(
298                                    &id,
299                                    call_id.clone(),
300                                    OneOrMany::one(ToolResultContent::text(&tool_result)),
301                                )),
302                            });
303                        } else {
304                            chat_history.write().await.push(Message::User {
305                                content: OneOrMany::one(UserContent::tool_result(
306                                    &id,
307                                    OneOrMany::one(ToolResultContent::text(&tool_result)),
308                                )),
309                            });
310                        }
311
312                    }
313
314                    // Set the current prompt to the last message in the chat history
315                    current_prompt = match chat_history.write().await.pop() {
316                        Some(prompt) => prompt,
317                        None => unreachable!("Chat history should never be empty at this point"),
318                    };
319
320                    if !did_call_tool {
321                        yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
322                        break;
323                    }
324                }
325
326                    if max_depth_reached {
327                        yield Err(Box::new(PromptError::MaxDepthError {
328                            max_depth: req.max_depth,
329                            chat_history: Box::new((*chat_history.read().await).clone()),
330                            prompt: last_prompt_error.into(),
331                        }).into());
332                    }
333
334            })
335        }
336
337        inner(self, agent_name)
338    }
339}
340
341impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
342where
343    M: CompletionModel + 'static,
344    <M as CompletionModel>::StreamingResponse: Send,
345    P: PromptHook<M> + 'static,
346{
347    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
348    type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + Send>>;
349
350    fn into_future(self) -> Self::IntoFuture {
351        // Wrap send() in a future, because send() returns a stream immediately
352        Box::pin(async move { self.send().await })
353    }
354}
355
356/// helper function to stream a completion request to stdout
357pub async fn stream_to_stdout<R>(
358    stream: &mut StreamingResult<R>,
359) -> Result<FinalResponse, std::io::Error> {
360    let mut final_res = FinalResponse::empty();
361    print!("Response: ");
362    while let Some(content) = stream.next().await {
363        match content {
364            Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
365                print!("{text}");
366                std::io::Write::flush(&mut std::io::stdout()).unwrap();
367            }
368            Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
369                Reasoning { reasoning, .. },
370            ))) => {
371                let reasoning = reasoning.join("\n");
372                print!("{reasoning}");
373                std::io::Write::flush(&mut std::io::stdout()).unwrap();
374            }
375            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
376                final_res = res;
377            }
378            Err(err) => {
379                eprintln!("Error: {err}");
380            }
381            _ => {}
382        }
383    }
384
385    Ok(final_res)
386}