rig/agent/prompt_request/
mod.rs

1pub(crate) mod streaming;
2
3use std::{future::IntoFuture, marker::PhantomData};
4
5use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
6
7use crate::{
8    OneOrMany,
9    completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
10    message::{AssistantContent, UserContent},
11    tool::ToolSetError,
12};
13
14use super::Agent;
15
16pub trait PromptType {}
17pub struct Standard;
18pub struct Extended;
19
20impl PromptType for Standard {}
21impl PromptType for Extended {}
22
23/// A builder for creating prompt requests with customizable options.
24/// Uses generics to track which options have been set during the build process.
25///
26/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
27/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
28/// attempting to await (which will send the prompt request) can potentially return
29/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
30/// back to back.
31pub struct PromptRequest<'a, S, M, P>
32where
33    S: PromptType,
34    M: CompletionModel,
35    P: PromptHook<M>,
36{
37    /// The prompt message to send to the model
38    prompt: Message,
39    /// Optional chat history to include with the prompt
40    /// Note: chat history needs to outlive the agent as it might be used with other agents
41    chat_history: Option<&'a mut Vec<Message>>,
42    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
43    max_depth: usize,
44    /// The agent to use for execution
45    agent: &'a Agent<M>,
46    /// Phantom data to track the type of the request
47    state: PhantomData<S>,
48    /// Optional per-request hook for events
49    hook: Option<P>,
50}
51
52impl<'a, M> PromptRequest<'a, Standard, M, ()>
53where
54    M: CompletionModel,
55{
56    /// Create a new PromptRequest with the given prompt and model
57    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
58        Self {
59            prompt: prompt.into(),
60            chat_history: None,
61            max_depth: 0,
62            agent,
63            state: PhantomData,
64            hook: None,
65        }
66    }
67}
68
69impl<'a, S, M, P> PromptRequest<'a, S, M, P>
70where
71    S: PromptType,
72    M: CompletionModel,
73    P: PromptHook<M>,
74{
75    /// Enable returning extended details for responses (includes aggregated token usage)
76    ///
77    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
78    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
79    /// of conversation.
80    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
81        PromptRequest {
82            prompt: self.prompt,
83            chat_history: self.chat_history,
84            max_depth: self.max_depth,
85            agent: self.agent,
86            state: PhantomData,
87            hook: self.hook,
88        }
89    }
90    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
91    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
92    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
93        PromptRequest {
94            prompt: self.prompt,
95            chat_history: self.chat_history,
96            max_depth: depth,
97            agent: self.agent,
98            state: PhantomData,
99            hook: self.hook,
100        }
101    }
102
103    /// Add chat history to the prompt request
104    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
105        PromptRequest {
106            prompt: self.prompt,
107            chat_history: Some(history),
108            max_depth: self.max_depth,
109            agent: self.agent,
110            state: PhantomData,
111            hook: self.hook,
112        }
113    }
114
115    /// Attach a per-request hook for tool call events
116    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
117    where
118        P2: PromptHook<M>,
119    {
120        PromptRequest {
121            prompt: self.prompt,
122            chat_history: self.chat_history,
123            max_depth: self.max_depth,
124            agent: self.agent,
125            state: PhantomData,
126            hook: Some(hook),
127        }
128    }
129}
130
131// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
132/// Trait for per-request hooks to observe tool call events.
133pub trait PromptHook<M>: Clone + Send + Sync
134where
135    M: CompletionModel,
136{
137    #[allow(unused_variables)]
138    /// Called before the prompt is sent to the model
139    fn on_completion_call(
140        &self,
141        prompt: &Message,
142        history: &[Message],
143    ) -> impl Future<Output = ()> + Send {
144        async {}
145    }
146
147    #[allow(unused_variables)]
148    /// Called after the prompt is sent to the model and a response is received.
149    /// This function is for non-streamed responses. Please refer to `on_stream_completion_response_finish` for streamed responses.
150    fn on_completion_response(
151        &self,
152        prompt: &Message,
153        response: &crate::completion::CompletionResponse<M::Response>,
154    ) -> impl Future<Output = ()> + Send {
155        async {}
156    }
157
158    #[allow(unused_variables)]
159    /// Called after the model provider has finished streaming a text response from their completion API to the client.
160    fn on_stream_completion_response_finish(
161        &self,
162        prompt: &Message,
163        response: &<M as CompletionModel>::StreamingResponse,
164    ) -> impl Future<Output = ()> + Send {
165        async {}
166    }
167
168    #[allow(unused_variables)]
169    /// Called before a tool is invoked.
170    fn on_tool_call(&self, tool_name: &str, args: &str) -> impl Future<Output = ()> + Send {
171        async {}
172    }
173
174    #[allow(unused_variables)]
175    /// Called after a tool is invoked (and a result has been returned).
176    fn on_tool_result(
177        &self,
178        tool_name: &str,
179        args: &str,
180        result: &str,
181    ) -> impl Future<Output = ()> + Send {
182        async {}
183    }
184}
185
186impl<M> PromptHook<M> for () where M: CompletionModel {}
187
188/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
189///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
190///  directly via the associated type.
191impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
192where
193    M: CompletionModel,
194    P: PromptHook<M> + 'static,
195{
196    type Output = Result<String, PromptError>;
197    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
198
199    fn into_future(self) -> Self::IntoFuture {
200        self.send().boxed()
201    }
202}
203
204impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
205where
206    M: CompletionModel,
207    P: PromptHook<M> + 'static,
208{
209    type Output = Result<PromptResponse, PromptError>;
210    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
211
212    fn into_future(self) -> Self::IntoFuture {
213        self.send().boxed()
214    }
215}
216
217impl<M, P> PromptRequest<'_, Standard, M, P>
218where
219    M: CompletionModel,
220    P: PromptHook<M>,
221{
222    async fn send(self) -> Result<String, PromptError> {
223        self.extended_details().send().await.map(|resp| resp.output)
224    }
225}
226
227#[derive(Debug, Clone)]
228pub struct PromptResponse {
229    pub output: String,
230    pub total_usage: Usage,
231}
232
233impl PromptResponse {
234    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
235        Self {
236            output: output.into(),
237            total_usage,
238        }
239    }
240}
241
242impl<M, P> PromptRequest<'_, Extended, M, P>
243where
244    M: CompletionModel,
245    P: PromptHook<M>,
246{
247    #[tracing::instrument(skip(self), fields(agent_name = self.agent.name()))]
248    async fn send(self) -> Result<PromptResponse, PromptError> {
249        let agent = self.agent;
250        let chat_history = if let Some(history) = self.chat_history {
251            history.push(self.prompt);
252            history
253        } else {
254            &mut vec![self.prompt]
255        };
256
257        let mut current_max_depth = 0;
258        let mut usage = Usage::new();
259
260        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
261        let last_prompt = loop {
262            let prompt = chat_history
263                .last()
264                .cloned()
265                .expect("there should always be at least one message in the chat history");
266
267            if current_max_depth > self.max_depth + 1 {
268                break prompt;
269            }
270
271            current_max_depth += 1;
272
273            if self.max_depth > 1 {
274                tracing::info!(
275                    "Current conversation depth: {}/{}",
276                    current_max_depth,
277                    self.max_depth
278                );
279            }
280
281            if let Some(ref hook) = self.hook {
282                hook.on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
283                    .await;
284            }
285
286            let resp = agent
287                .completion(
288                    prompt.clone(),
289                    chat_history[..chat_history.len() - 1].to_vec(),
290                )
291                .await?
292                .send()
293                .await?;
294
295            usage += resp.usage;
296
297            if let Some(ref hook) = self.hook {
298                hook.on_completion_response(&prompt, &resp).await;
299            }
300
301            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
302                .choice
303                .iter()
304                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
305
306            chat_history.push(Message::Assistant {
307                id: None,
308                content: resp.choice.clone(),
309            });
310
311            if tool_calls.is_empty() {
312                let merged_texts = texts
313                    .into_iter()
314                    .filter_map(|content| {
315                        if let AssistantContent::Text(text) = content {
316                            Some(text.text.clone())
317                        } else {
318                            None
319                        }
320                    })
321                    .collect::<Vec<_>>()
322                    .join("\n");
323
324                if self.max_depth > 1 {
325                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
326                }
327
328                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
329                return Ok(PromptResponse::new(merged_texts, usage));
330            }
331
332            let hook = self.hook.clone();
333            let tool_content = stream::iter(tool_calls)
334                .then(|choice| {
335                    let hook1 = hook.clone();
336                    let hook2 = hook.clone();
337                    async move {
338                        if let AssistantContent::ToolCall(tool_call) = choice {
339                            let tool_name = &tool_call.function.name;
340                            let args = tool_call.function.arguments.to_string();
341                            if let Some(hook) = hook1 {
342                                hook.on_tool_call(tool_name, &args).await;
343                            }
344                            let output = agent.tools.call(tool_name, args.clone()).await?;
345                            if let Some(hook) = hook2 {
346                                hook.on_tool_result(tool_name, &args, &output.to_string())
347                                    .await;
348                            }
349                            if let Some(call_id) = tool_call.call_id.clone() {
350                                Ok(UserContent::tool_result_with_call_id(
351                                    tool_call.id.clone(),
352                                    call_id,
353                                    OneOrMany::one(output.into()),
354                                ))
355                            } else {
356                                Ok(UserContent::tool_result(
357                                    tool_call.id.clone(),
358                                    OneOrMany::one(output.into()),
359                                ))
360                            }
361                        } else {
362                            unreachable!(
363                                "This should never happen as we already filtered for `ToolCall`"
364                            )
365                        }
366                    }
367                })
368                .collect::<Vec<Result<UserContent, ToolSetError>>>()
369                .await
370                .into_iter()
371                .collect::<Result<Vec<_>, _>>()
372                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
373
374            chat_history.push(Message::User {
375                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
376            });
377        };
378
379        // If we reach here, we never resolved the final tool call. We need to do ... something.
380        Err(PromptError::MaxDepthError {
381            max_depth: self.max_depth,
382            chat_history: Box::new(chat_history.clone()),
383            prompt: last_prompt,
384        })
385    }
386}