rig/agent/prompt_request/
mod.rs

1pub(crate) mod streaming;
2
3use std::{future::IntoFuture, marker::PhantomData};
4
5use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
6
7use crate::{
8    OneOrMany,
9    completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
10    message::{AssistantContent, UserContent},
11    tool::ToolSetError,
12};
13
14use super::Agent;
15
16pub trait PromptType {}
17pub struct Standard;
18pub struct Extended;
19
20impl PromptType for Standard {}
21impl PromptType for Extended {}
22
23/// A builder for creating prompt requests with customizable options.
24/// Uses generics to track which options have been set during the build process.
25///
26/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
27/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
28/// attempting to await (which will send the prompt request) can potentially return
29/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
30/// back to back.
31pub struct PromptRequest<'a, S: PromptType, M: CompletionModel> {
32    /// The prompt message to send to the model
33    prompt: Message,
34    /// Optional chat history to include with the prompt
35    /// Note: chat history needs to outlive the agent as it might be used with other agents
36    chat_history: Option<&'a mut Vec<Message>>,
37    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
38    max_depth: usize,
39    /// The agent to use for execution
40    agent: &'a Agent<M>,
41    /// Phantom data to track the type of the request
42    state: PhantomData<S>,
43    #[cfg(feature = "hooks")]
44    /// Optional per-request hook for events
45    hook: Option<&'a dyn crate::agent::PromptHook<M>>,
46}
47
48impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
49    /// Create a new PromptRequest with the given prompt and model
50    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
51        Self {
52            prompt: prompt.into(),
53            chat_history: None,
54            max_depth: 0,
55            agent,
56            state: PhantomData,
57            #[cfg(feature = "hooks")]
58            hook: None,
59        }
60    }
61
62    /// Enable returning extended details for responses (includes aggregated token usage)
63    ///
64    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
65    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
66    /// of conversation.
67    pub fn extended_details(self) -> PromptRequest<'a, Extended, M> {
68        PromptRequest {
69            prompt: self.prompt,
70            chat_history: self.chat_history,
71            max_depth: self.max_depth,
72            agent: self.agent,
73            state: PhantomData,
74            #[cfg(feature = "hooks")]
75            hook: self.hook,
76        }
77    }
78}
79
80impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
81    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
82    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
83    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M> {
84        PromptRequest {
85            prompt: self.prompt,
86            chat_history: self.chat_history,
87            max_depth: depth,
88            agent: self.agent,
89            state: PhantomData,
90            #[cfg(feature = "hooks")]
91            hook: self.hook,
92        }
93    }
94
95    /// Add chat history to the prompt request
96    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M> {
97        PromptRequest {
98            prompt: self.prompt,
99            chat_history: Some(history),
100            max_depth: self.max_depth,
101            agent: self.agent,
102            state: PhantomData,
103            #[cfg(feature = "hooks")]
104            hook: self.hook,
105        }
106    }
107
108    #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
109    #[cfg(feature = "hooks")]
110    /// Attach a per-request hook for tool call events
111    pub fn with_hook(self, hook: &'a dyn crate::agent::PromptHook<M>) -> PromptRequest<'a, S, M> {
112        PromptRequest {
113            prompt: self.prompt,
114            chat_history: self.chat_history,
115            max_depth: self.max_depth,
116            agent: self.agent,
117            state: PhantomData,
118            #[cfg(feature = "hooks")]
119            hook: Some(hook),
120        }
121    }
122}
123
124// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
125/// Trait for per-request hooks to observe tool call events.
126/// Usage:
127/// ```rust
128///
129/// use std::env;
130///
131/// use rig::agent::PromptHook;
132/// use rig::client::CompletionClient;
133/// use rig::completion::{CompletionModel, CompletionResponse, Message, Prompt};
134/// use rig::message::{AssistantContent, UserContent};
135/// use rig::providers;
136///
137/// struct SessionIdHook<'a> {
138///     session_id: &'a str,
139/// }
140///
141/// #[async_trait::async_trait]
142/// impl<'a, M: CompletionModel> PromptHook<M> for SessionIdHook<'a> {
143///     async fn on_tool_call(&self, tool_name: &str, args: &str) {
144///         println!(
145///             "[Session {}] Calling tool: {} with args: {}",
146///             self.session_id, tool_name, args
147///         );
148///     }
149///     async fn on_tool_result(&self, tool_name: &str, args: &str, result: &str) {
150///         println!(
151///             "[Session {}] Tool result for {} (args: {}): {}",
152///             self.session_id, tool_name, args, result
153///         );
154///     }
155///
156///     async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) {
157///         println!(
158///             "[Session {}] Sending prompt: {}",
159///             self.session_id,
160///             match prompt {
161///                 Message::User { content } => content
162///                     .iter()
163///                     .filter_map(|c| {
164///                         if let UserContent::Text(text_content) = c {
165///                             Some(text_content.text.clone())
166///                         } else {
167///                             None
168///                         }
169///                     })
170///                     .collect::<Vec<_>>()
171///                     .join("\n"),
172///                 Message::Assistant { content, .. } => content
173///                     .iter()
174///                     .filter_map(|c| if let AssistantContent::Text(text_content) = c {
175///                         Some(text_content.text.clone())
176///                     } else {
177///                         None
178///                     })
179///                     .collect::<Vec<_>>()
180///                     .join("\n"),
181///             }
182///         );
183///     }
184///
185///     async fn on_completion_response(
186///         &self,
187///         _prompt: &Message,
188///         response: &CompletionResponse<M::Response>,
189///     ) {
190///         if let Ok(resp) = serde_json::to_string(&response.raw_response) {
191///             println!("[Session {}] Received response: {}", self.session_id, resp);
192///         } else {
193///             println!(
194///                 "[Session {}] Received response: <non-serializable>",
195///                 self.session_id
196///             );
197///         }
198///     }
199/// }
200///
201/// // Example main function (pseudo-code, as actual Agent/CompletionModel setup is project-specific)
202/// #[tokio::main]
203/// async fn main() -> Result<(), anyhow::Error> {
204///     let client = providers::openai::Client::new(
205///         &env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
206///     );
207///
208///     // Create agent with a single context prompt
209///     let comedian_agent = client
210///         .agent("gpt-4o")
211///         .preamble("You are a comedian here to entertain the user using humour and jokes.")
212///         .build();
213///
214///     let session_id = "abc123";
215///     let hook = SessionIdHook { session_id };
216///
217///     // Prompt the agent and print the response
218///     comedian_agent
219///         .prompt("Entertain me!")
220///         .with_hook(&hook)
221///         .await?;
222///
223///     Ok(())
224/// }
225/// ```
226#[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
227#[cfg(feature = "hooks")]
228#[async_trait::async_trait]
229pub trait PromptHook<M: CompletionModel>: Send + Sync {
230    #[allow(unused_variables)]
231    /// Called before the prompt is sent to the model
232    async fn on_completion_call(&self, prompt: &Message, history: &[Message]) {}
233
234    #[allow(unused_variables)]
235    /// Called after the prompt is sent to the model and a response is received.
236    /// This function is for non-streamed responses. Please refer to `on_stream_completion_response_finish` for streamed responses.
237    async fn on_completion_response(
238        &self,
239        prompt: &Message,
240        response: &crate::completion::CompletionResponse<M::Response>,
241    ) {
242    }
243
244    #[allow(unused_variables)]
245    /// Called after the model provider has finished streaming a text response from their completion API to the client.
246    async fn on_stream_completion_response_finish(
247        &self,
248        prompt: &Message,
249        response: &<M as CompletionModel>::StreamingResponse,
250    ) {
251    }
252
253    #[allow(unused_variables)]
254    /// Called before a tool is invoked.
255    async fn on_tool_call(&self, tool_name: &str, args: &str) {}
256
257    #[allow(unused_variables)]
258    /// Called after a tool is invoked (and a result has been returned).
259    async fn on_tool_result(&self, tool_name: &str, args: &str, result: &str) {}
260}
261
262/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
263///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
264///  directly via the associated type.
265impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Standard, M> {
266    type Output = Result<String, PromptError>;
267    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
268
269    fn into_future(self) -> Self::IntoFuture {
270        self.send().boxed()
271    }
272}
273
274impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Extended, M> {
275    type Output = Result<PromptResponse, PromptError>;
276    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
277
278    fn into_future(self) -> Self::IntoFuture {
279        self.send().boxed()
280    }
281}
282
283impl<M: CompletionModel> PromptRequest<'_, Standard, M> {
284    async fn send(self) -> Result<String, PromptError> {
285        self.extended_details().send().await.map(|resp| resp.output)
286    }
287}
288
289#[derive(Debug, Clone)]
290pub struct PromptResponse {
291    pub output: String,
292    pub total_usage: Usage,
293}
294
295impl PromptResponse {
296    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
297        Self {
298            output: output.into(),
299            total_usage,
300        }
301    }
302}
303
304impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
305    #[tracing::instrument(skip(self), fields(agent_name = self.agent.name()))]
306    async fn send(self) -> Result<PromptResponse, PromptError> {
307        let agent = self.agent;
308        let chat_history = if let Some(history) = self.chat_history {
309            history.push(self.prompt);
310            history
311        } else {
312            &mut vec![self.prompt]
313        };
314
315        let mut current_max_depth = 0;
316        let mut usage = Usage::new();
317
318        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
319        let last_prompt = loop {
320            let prompt = chat_history
321                .last()
322                .cloned()
323                .expect("there should always be at least one message in the chat history");
324
325            if current_max_depth > self.max_depth + 1 {
326                break prompt;
327            }
328
329            current_max_depth += 1;
330
331            if self.max_depth > 1 {
332                tracing::info!(
333                    "Current conversation depth: {}/{}",
334                    current_max_depth,
335                    self.max_depth
336                );
337            }
338
339            #[cfg(feature = "hooks")]
340            if let Some(hook) = self.hook.as_ref() {
341                hook.on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
342                    .await;
343            }
344
345            let resp = agent
346                .completion(
347                    prompt.clone(),
348                    chat_history[..chat_history.len() - 1].to_vec(),
349                )
350                .await?
351                .send()
352                .await?;
353
354            usage += resp.usage;
355
356            #[cfg(feature = "hooks")]
357            if let Some(hook) = self.hook.as_ref() {
358                hook.on_completion_response(&prompt, &resp).await;
359            }
360
361            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
362                .choice
363                .iter()
364                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
365
366            chat_history.push(Message::Assistant {
367                id: None,
368                content: resp.choice.clone(),
369            });
370
371            if tool_calls.is_empty() {
372                let merged_texts = texts
373                    .into_iter()
374                    .filter_map(|content| {
375                        if let AssistantContent::Text(text) = content {
376                            Some(text.text.clone())
377                        } else {
378                            None
379                        }
380                    })
381                    .collect::<Vec<_>>()
382                    .join("\n");
383
384                if self.max_depth > 1 {
385                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
386                }
387
388                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
389                return Ok(PromptResponse::new(merged_texts, usage));
390            }
391
392            let tool_content = stream::iter(tool_calls)
393                .then(|choice| async move {
394                    if let AssistantContent::ToolCall(tool_call) = choice {
395                        let tool_name = &tool_call.function.name;
396                        let args = tool_call.function.arguments.to_string();
397                        #[cfg(feature = "hooks")]
398                        if let Some(hook) = self.hook.as_ref() {
399                            hook.on_tool_call(tool_name, &args).await;
400                        }
401                        let output = agent.tools.call(tool_name, args.clone()).await?;
402                        #[cfg(feature = "hooks")]
403                        if let Some(hook) = self.hook.as_ref() {
404                            hook.on_tool_result(tool_name, &args, &output.to_string())
405                                .await;
406                        }
407                        if let Some(call_id) = tool_call.call_id.clone() {
408                            Ok(UserContent::tool_result_with_call_id(
409                                tool_call.id.clone(),
410                                call_id,
411                                OneOrMany::one(output.into()),
412                            ))
413                        } else {
414                            Ok(UserContent::tool_result(
415                                tool_call.id.clone(),
416                                OneOrMany::one(output.into()),
417                            ))
418                        }
419                    } else {
420                        unreachable!(
421                            "This should never happen as we already filtered for `ToolCall`"
422                        )
423                    }
424                })
425                .collect::<Vec<Result<UserContent, ToolSetError>>>()
426                .await
427                .into_iter()
428                .collect::<Result<Vec<_>, _>>()
429                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
430
431            chat_history.push(Message::User {
432                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
433            });
434        };
435
436        // If we reach here, we never resolved the final tool call. We need to do ... something.
437        Err(PromptError::MaxDepthError {
438            max_depth: self.max_depth,
439            chat_history: chat_history.clone(),
440            prompt: last_prompt,
441        })
442    }
443}