rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    completion::GetTokenUsage,
4    message::{AssistantContent, ToolResultContent, UserContent},
5    streaming::{StreamedAssistantContent, StreamingCompletion},
6};
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::{pin::Pin, sync::Arc};
10use tokio::sync::RwLock;
11
12use crate::{
13    agent::Agent,
14    completion::{CompletionError, CompletionModel, PromptError},
15    message::{Message, Text},
16    tool::ToolSetError,
17};
18
19type StreamingResult<'a> =
20    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>> + 'a>>;
21
22#[derive(Deserialize, Serialize, Debug, Clone)]
23#[serde(tag = "type", rename_all = "camelCase")]
24pub enum MultiTurnStreamItem {
25    Text(Text),
26    FinalResponse(FinalResponse),
27}
28
29#[derive(Deserialize, Serialize, Debug, Clone)]
30#[serde(rename_all = "camelCase")]
31pub struct FinalResponse {
32    response: String,
33    aggregated_usage: crate::completion::Usage,
34}
35
36impl FinalResponse {
37    pub fn empty() -> Self {
38        Self {
39            response: String::new(),
40            aggregated_usage: crate::completion::Usage::new(),
41        }
42    }
43
44    pub fn response(&self) -> &str {
45        &self.response
46    }
47
48    pub fn usage(&self) -> crate::completion::Usage {
49        self.aggregated_usage
50    }
51}
52
53impl MultiTurnStreamItem {
54    pub(crate) fn text(text: &str) -> Self {
55        Self::Text(Text {
56            text: text.to_string(),
57        })
58    }
59
60    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
61        Self::FinalResponse(FinalResponse {
62            response: response.to_string(),
63            aggregated_usage,
64        })
65    }
66}
67
68#[derive(Debug, thiserror::Error)]
69pub enum StreamingError {
70    #[error("CompletionError: {0}")]
71    Completion(#[from] CompletionError),
72    #[error("PromptError: {0}")]
73    Prompt(#[from] PromptError),
74    #[error("ToolSetError: {0}")]
75    Tool(#[from] ToolSetError),
76}
77
78/// A builder for creating prompt requests with customizable options.
79/// Uses generics to track which options have been set during the build process.
80///
81/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
82/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
83/// attempting to await (which will send the prompt request) can potentially return
84/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
85/// back to back.
86pub struct StreamingPromptRequest<'a, M>
87where
88    M: CompletionModel,
89{
90    /// The prompt message to send to the model
91    prompt: Message,
92    /// Optional chat history to include with the prompt
93    /// Note: chat history needs to outlive the agent as it might be used with other agents
94    chat_history: Option<Vec<Message>>,
95    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
96    max_depth: usize,
97    /// The agent to use for execution
98    agent: &'a Agent<M>,
99    #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
100    #[cfg(feature = "hooks")]
101    /// Optional per-request hook for events
102    hook: Option<&'a dyn crate::agent::PromptHook<M>>,
103}
104
105impl<'a, M> StreamingPromptRequest<'a, M>
106where
107    M: CompletionModel + 'static,
108    <M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
109{
110    /// Create a new PromptRequest with the given prompt and model
111    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
112        Self {
113            prompt: prompt.into(),
114            chat_history: None,
115            max_depth: 0,
116            agent,
117            #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
118            #[cfg(feature = "hooks")]
119            hook: None,
120        }
121    }
122
123    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
124    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
125    pub fn multi_turn(mut self, depth: usize) -> Self {
126        self.max_depth = depth;
127        self
128    }
129
130    /// Add chat history to the prompt request
131    pub fn with_history(mut self, history: Vec<Message>) -> Self {
132        self.chat_history = Some(history);
133        self
134    }
135
136    /// Attach a per-request hook for tool call events
137    #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
138    #[cfg(feature = "hooks")]
139    pub fn with_hook(
140        mut self,
141        hook: &'a dyn crate::agent::PromptHook<M>,
142    ) -> StreamingPromptRequest<'a, M> {
143        self.hook = Some(hook);
144        self
145    }
146
147    #[cfg_attr(feature = "worker", worker::send)]
148    async fn send(self) -> StreamingResult<'a> {
149        let agent_name = self.agent.name_owned();
150
151        #[tracing::instrument(skip_all, fields(agent_name = agent_name))]
152        fn inner<'a, M>(
153            req: StreamingPromptRequest<'a, M>,
154            agent_name: String,
155        ) -> StreamingResult<'a>
156        where
157            M: CompletionModel + 'static,
158            <M as CompletionModel>::StreamingResponse: Send,
159        {
160            let prompt = req.prompt;
161            let agent = req.agent;
162
163            let chat_history = if let Some(mut history) = req.chat_history {
164                history.push(prompt.clone());
165                Arc::new(RwLock::new(history))
166            } else {
167                Arc::new(RwLock::new(vec![prompt.clone()]))
168            };
169
170            let mut current_max_depth = 0;
171            let mut last_prompt_error = String::new();
172
173            let mut last_text_response = String::new();
174            let mut is_text_response = false;
175            let mut max_depth_reached = false;
176
177            let mut aggregated_usage = crate::completion::Usage::new();
178
179            Box::pin(async_stream::stream! {
180                let mut current_prompt = prompt.clone();
181                let mut did_call_tool = false;
182
183                'outer: loop {
184                    if current_max_depth > req.max_depth + 1 {
185                        last_prompt_error = current_prompt.rag_text().unwrap_or_default();
186                        max_depth_reached = true;
187                        break;
188                    }
189
190                    current_max_depth += 1;
191
192                    if req.max_depth > 1 {
193                        tracing::info!(
194                            "Current conversation depth: {}/{}",
195                            current_max_depth,
196                            req.max_depth
197                        );
198                    }
199
200                    #[cfg(feature = "hooks")]
201                    if let Some(hook) = req.hook.as_ref() {
202                        let reader = chat_history.read().await;
203                        let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
204                        let chat_history_except_last = reader[..reader.len() - 1].to_vec();
205
206                        hook.on_completion_call(&prompt, &chat_history_except_last)
207                            .await;
208                    }
209
210
211                    let mut stream = agent
212                        .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
213                        .await?
214                        .stream()
215                        .await?;
216
217                    chat_history.write().await.push(current_prompt.clone());
218
219                    let mut tool_calls = vec![];
220                    let mut tool_results = vec![];
221
222                    while let Some(content) = stream.next().await {
223                        match content {
224                            Ok(StreamedAssistantContent::Text(text)) => {
225                                if !is_text_response {
226                                    last_text_response = String::new();
227                                    is_text_response = true;
228                                }
229                                last_text_response.push_str(&text.text);
230                                yield Ok(MultiTurnStreamItem::text(&text.text));
231                                did_call_tool = false;
232                            },
233                            Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
234                                #[cfg(feature = "hooks")]
235                                if let Some(hook) = req.hook.as_ref() {
236                                    hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await;
237                                }
238                                let tool_result =
239                                    agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
240
241                                #[cfg(feature = "hooks")]
242                                if let Some(hook) = req.hook.as_ref() {
243                                    hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string())
244                                        .await;
245                                }
246                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
247
248                                tool_calls.push(tool_call_msg);
249                                tool_results.push((tool_call.id, tool_call.call_id, tool_result));
250
251                                did_call_tool = true;
252                                // break;
253                            },
254                            Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, .. })) => {
255                                let text = reasoning.into_iter().collect::<Vec<String>>().join("");
256                                yield Ok(MultiTurnStreamItem::text(&text));
257                                did_call_tool = false;
258                            },
259                            Ok(StreamedAssistantContent::Final(final_resp)) => {
260                                if is_text_response {
261                                    #[cfg(feature = "hooks")]
262                                    if let Some(hook) = req.hook.as_ref() {
263                                        hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
264                                    }
265                                    yield Ok(MultiTurnStreamItem::text("\n"));
266                                    is_text_response = false;
267                                }
268                                if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
269                                // Do nothing here, since at the moment the final generic is actually unreachable.
270                                // We need to implement a trait that aggregates token usage.
271                                // TODO: Add a way to aggregate token responses from the generic variant
272                            }
273                            Err(e) => {
274                                yield Err(e.into());
275                                break 'outer;
276                            }
277                        }
278                    }
279
280                    // Add (parallel) tool calls to chat history
281                    if !tool_calls.is_empty() {
282                        chat_history.write().await.push(Message::Assistant {
283                            id: None,
284                            content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
285                        });
286                    }
287
288                    // Add tool results to chat history
289                    for (id, call_id, tool_result) in tool_results {
290                        if let Some(call_id) = call_id {
291                            chat_history.write().await.push(Message::User {
292                                content: OneOrMany::one(UserContent::tool_result_with_call_id(
293                                    &id,
294                                    call_id.clone(),
295                                    OneOrMany::one(ToolResultContent::text(&tool_result)),
296                                )),
297                            });
298                        } else {
299                            chat_history.write().await.push(Message::User {
300                                content: OneOrMany::one(UserContent::tool_result(
301                                    &id,
302                                    OneOrMany::one(ToolResultContent::text(&tool_result)),
303                                )),
304                            });
305                        }
306
307                    }
308
309                    // Set the current prompt to the last message in the chat history
310                    current_prompt = match chat_history.write().await.pop() {
311                        Some(prompt) => prompt,
312                        None => unreachable!("Chat history should never be empty at this point"),
313                    };
314
315                    if !did_call_tool {
316                        yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
317                        break;
318                    }
319                }
320
321                    if max_depth_reached {
322                        yield Err(PromptError::MaxDepthError {
323                            max_depth: req.max_depth,
324                            chat_history: (*chat_history.read().await).clone(),
325                            prompt: last_prompt_error.into(),
326                        }.into());
327                    }
328
329            })
330        }
331
332        inner(self, agent_name)
333    }
334}
335
336impl<'a, M> IntoFuture for StreamingPromptRequest<'a, M>
337where
338    M: CompletionModel + 'static,
339    <M as CompletionModel>::StreamingResponse: Send,
340{
341    type Output = StreamingResult<'a>; // what `.await` returns
342    type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + 'a>>;
343
344    fn into_future(self) -> Self::IntoFuture {
345        // Wrap send() in a future, because send() returns a stream immediately
346        Box::pin(async move { self.send().await })
347    }
348}
349
350/// helper function to stream a completion request to stdout
351pub async fn stream_to_stdout(
352    stream: &mut StreamingResult<'_>,
353) -> Result<FinalResponse, std::io::Error> {
354    let mut final_res = FinalResponse::empty();
355    print!("Response: ");
356    while let Some(content) = stream.next().await {
357        match content {
358            Ok(MultiTurnStreamItem::Text(Text { text })) => {
359                print!("{text}");
360                std::io::Write::flush(&mut std::io::stdout())?;
361            }
362            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
363                final_res = res;
364            }
365            Err(err) => {
366                eprintln!("Error: {err}");
367            }
368        }
369    }
370
371    Ok(final_res)
372}