rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::prompt_request::PromptHook,
4    completion::GetTokenUsage,
5    message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6    streaming::{StreamedAssistantContent, StreamingCompletion},
7};
8use futures::{Stream, StreamExt};
9use serde::{Deserialize, Serialize};
10use std::{pin::Pin, sync::Arc};
11use tokio::sync::RwLock;
12
13use crate::{
14    agent::Agent,
15    completion::{CompletionError, CompletionModel, PromptError},
16    message::{Message, Text},
17    tool::ToolSetError,
18};
19
20#[cfg(not(target_arch = "wasm32"))]
21type StreamingResult =
22    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>> + Send>>;
23
24#[cfg(target_arch = "wasm32")]
25type StreamingResult = Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>;
26
27#[derive(Deserialize, Serialize, Debug, Clone)]
28#[serde(tag = "type", rename_all = "camelCase")]
29pub enum MultiTurnStreamItem {
30    Text(Text),
31    FinalResponse(FinalResponse),
32}
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(rename_all = "camelCase")]
36pub struct FinalResponse {
37    response: String,
38    aggregated_usage: crate::completion::Usage,
39}
40
41impl FinalResponse {
42    pub fn empty() -> Self {
43        Self {
44            response: String::new(),
45            aggregated_usage: crate::completion::Usage::new(),
46        }
47    }
48
49    pub fn response(&self) -> &str {
50        &self.response
51    }
52
53    pub fn usage(&self) -> crate::completion::Usage {
54        self.aggregated_usage
55    }
56}
57
58impl MultiTurnStreamItem {
59    pub(crate) fn text(text: &str) -> Self {
60        Self::Text(Text {
61            text: text.to_string(),
62        })
63    }
64
65    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
66        Self::FinalResponse(FinalResponse {
67            response: response.to_string(),
68            aggregated_usage,
69        })
70    }
71}
72
73#[derive(Debug, thiserror::Error)]
74pub enum StreamingError {
75    #[error("CompletionError: {0}")]
76    Completion(#[from] CompletionError),
77    #[error("PromptError: {0}")]
78    Prompt(#[from] PromptError),
79    #[error("ToolSetError: {0}")]
80    Tool(#[from] ToolSetError),
81}
82
83/// A builder for creating prompt requests with customizable options.
84/// Uses generics to track which options have been set during the build process.
85///
86/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
87/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
88/// attempting to await (which will send the prompt request) can potentially return
89/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
90/// back to back.
91pub struct StreamingPromptRequest<M, P>
92where
93    M: CompletionModel,
94    P: PromptHook<M> + 'static,
95{
96    /// The prompt message to send to the model
97    prompt: Message,
98    /// Optional chat history to include with the prompt
99    /// Note: chat history needs to outlive the agent as it might be used with other agents
100    chat_history: Option<Vec<Message>>,
101    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
102    max_depth: usize,
103    /// The agent to use for execution
104    agent: Arc<Agent<M>>,
105    /// Optional per-request hook for events
106    hook: Option<P>,
107}
108
109impl<M, P> StreamingPromptRequest<M, P>
110where
111    M: CompletionModel + 'static,
112    <M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
113    P: PromptHook<M>,
114{
115    /// Create a new PromptRequest with the given prompt and model
116    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
117        Self {
118            prompt: prompt.into(),
119            chat_history: None,
120            max_depth: 0,
121            agent,
122            hook: None,
123        }
124    }
125
126    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
127    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
128    pub fn multi_turn(mut self, depth: usize) -> Self {
129        self.max_depth = depth;
130        self
131    }
132
133    /// Add chat history to the prompt request
134    pub fn with_history(mut self, history: Vec<Message>) -> Self {
135        self.chat_history = Some(history);
136        self
137    }
138
139    /// Attach a per-request hook for tool call events
140    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
141    where
142        P2: PromptHook<M>,
143    {
144        StreamingPromptRequest {
145            prompt: self.prompt,
146            chat_history: self.chat_history,
147            max_depth: self.max_depth,
148            agent: self.agent,
149            hook: Some(hook),
150        }
151    }
152
153    #[cfg_attr(feature = "worker", worker::send)]
154    async fn send(self) -> StreamingResult {
155        let agent_name = self.agent.name_owned();
156
157        #[tracing::instrument(skip_all, fields(agent_name = agent_name))]
158        fn inner<M, P>(req: StreamingPromptRequest<M, P>, agent_name: String) -> StreamingResult
159        where
160            M: CompletionModel + 'static,
161            <M as CompletionModel>::StreamingResponse: Send,
162            P: PromptHook<M> + 'static,
163        {
164            let prompt = req.prompt;
165            let agent = req.agent;
166
167            let chat_history = if let Some(mut history) = req.chat_history {
168                history.push(prompt.clone());
169                Arc::new(RwLock::new(history))
170            } else {
171                Arc::new(RwLock::new(vec![prompt.clone()]))
172            };
173
174            let mut current_max_depth = 0;
175            let mut last_prompt_error = String::new();
176
177            let mut last_text_response = String::new();
178            let mut is_text_response = false;
179            let mut max_depth_reached = false;
180
181            let mut aggregated_usage = crate::completion::Usage::new();
182
183            Box::pin(async_stream::stream! {
184                let mut current_prompt = prompt.clone();
185                let mut did_call_tool = false;
186
187                'outer: loop {
188                    if current_max_depth > req.max_depth + 1 {
189                        last_prompt_error = current_prompt.rag_text().unwrap_or_default();
190                        max_depth_reached = true;
191                        break;
192                    }
193
194                    current_max_depth += 1;
195
196                    if req.max_depth > 1 {
197                        tracing::info!(
198                            "Current conversation depth: {}/{}",
199                            current_max_depth,
200                            req.max_depth
201                        );
202                    }
203
204                    if let Some(ref hook) = req.hook {
205                        let reader = chat_history.read().await;
206                        let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
207                        let chat_history_except_last = reader[..reader.len() - 1].to_vec();
208
209                        hook.on_completion_call(&prompt, &chat_history_except_last)
210                            .await;
211                    }
212
213
214                    let mut stream = agent
215                        .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
216                        .await?
217                        .stream()
218                        .await?;
219
220                    chat_history.write().await.push(current_prompt.clone());
221
222                    let mut tool_calls = vec![];
223                    let mut tool_results = vec![];
224
225                    while let Some(content) = stream.next().await {
226                        match content {
227                            Ok(StreamedAssistantContent::Text(text)) => {
228                                if !is_text_response {
229                                    last_text_response = String::new();
230                                    is_text_response = true;
231                                }
232                                last_text_response.push_str(&text.text);
233                                yield Ok(MultiTurnStreamItem::text(&text.text));
234                                did_call_tool = false;
235                            },
236                            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
237                                if let Some(ref hook) = req.hook {
238                                    hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await;
239                                }
240                                let tool_result =
241                                    agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
242
243                                if let Some(ref hook) = req.hook {
244                                    hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string())
245                                        .await;
246                                }
247                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
248
249                                tool_calls.push(tool_call_msg);
250                                tool_results.push((tool_call.id, tool_call.call_id, tool_result));
251
252                                did_call_tool = true;
253                                // break;
254                            },
255                            Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => {
256                                chat_history.write().await.push(rig::message::Message::Assistant {
257                                    id: None,
258                                    content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
259                                        reasoning: reasoning.clone(), id
260                                    }))
261                                });
262                                let text = reasoning.into_iter().collect::<Vec<String>>().join("");
263                                yield Ok(MultiTurnStreamItem::text(&text));
264                                did_call_tool = false;
265                            },
266                            Ok(StreamedAssistantContent::Final(final_resp)) => {
267                                if is_text_response {
268                                    if let Some(ref hook) = req.hook {
269                                        hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
270                                    }
271                                    yield Ok(MultiTurnStreamItem::text("\n"));
272                                    is_text_response = false;
273                                }
274                                if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
275                                // Do nothing here, since at the moment the final generic is actually unreachable.
276                                // We need to implement a trait that aggregates token usage.
277                                // TODO: Add a way to aggregate token responses from the generic variant
278                            }
279                            Err(e) => {
280                                yield Err(e.into());
281                                break 'outer;
282                            }
283                        }
284                    }
285
286                    // Add (parallel) tool calls to chat history
287                    if !tool_calls.is_empty() {
288                        chat_history.write().await.push(Message::Assistant {
289                            id: None,
290                            content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
291                        });
292                    }
293
294                    // Add tool results to chat history
295                    for (id, call_id, tool_result) in tool_results {
296                        if let Some(call_id) = call_id {
297                            chat_history.write().await.push(Message::User {
298                                content: OneOrMany::one(UserContent::tool_result_with_call_id(
299                                    &id,
300                                    call_id.clone(),
301                                    OneOrMany::one(ToolResultContent::text(&tool_result)),
302                                )),
303                            });
304                        } else {
305                            chat_history.write().await.push(Message::User {
306                                content: OneOrMany::one(UserContent::tool_result(
307                                    &id,
308                                    OneOrMany::one(ToolResultContent::text(&tool_result)),
309                                )),
310                            });
311                        }
312
313                    }
314
315                    // Set the current prompt to the last message in the chat history
316                    current_prompt = match chat_history.write().await.pop() {
317                        Some(prompt) => prompt,
318                        None => unreachable!("Chat history should never be empty at this point"),
319                    };
320
321                    if !did_call_tool {
322                        yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
323                        break;
324                    }
325                }
326
327                    if max_depth_reached {
328                        yield Err(PromptError::MaxDepthError {
329                            max_depth: req.max_depth,
330                            chat_history: (*chat_history.read().await).clone(),
331                            prompt: last_prompt_error.into(),
332                        }.into());
333                    }
334
335            })
336        }
337
338        inner(self, agent_name)
339    }
340}
341
342impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
343where
344    M: CompletionModel + 'static,
345    <M as CompletionModel>::StreamingResponse: Send,
346    P: PromptHook<M> + 'static,
347{
348    type Output = StreamingResult; // what `.await` returns
349    type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + Send>>;
350
351    fn into_future(self) -> Self::IntoFuture {
352        // Wrap send() in a future, because send() returns a stream immediately
353        Box::pin(async move { self.send().await })
354    }
355}
356
357/// helper function to stream a completion request to stdout
358pub async fn stream_to_stdout(
359    stream: &mut StreamingResult,
360) -> Result<FinalResponse, std::io::Error> {
361    let mut final_res = FinalResponse::empty();
362    print!("Response: ");
363    while let Some(content) = stream.next().await {
364        match content {
365            Ok(MultiTurnStreamItem::Text(Text { text })) => {
366                print!("{text}");
367                std::io::Write::flush(&mut std::io::stdout())?;
368            }
369            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
370                final_res = res;
371            }
372            Err(err) => {
373                eprintln!("Error: {err}");
374            }
375        }
376    }
377
378    Ok(final_res)
379}