rig/agent/
prompt_request.rs

1use std::{future::IntoFuture, marker::PhantomData};
2
3use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
4
5use crate::{
6    OneOrMany,
7    completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
8    message::{AssistantContent, UserContent},
9    tool::ToolSetError,
10};
11
12use super::Agent;
13
14pub trait PromptType {}
15pub struct Standard;
16pub struct Extended;
17
18impl PromptType for Standard {}
19impl PromptType for Extended {}
20
21/// A builder for creating prompt requests with customizable options.
22/// Uses generics to track which options have been set during the build process.
23///
24/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
25/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
26/// attempting to await (which will send the prompt request) can potentially return
27/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
28/// back to back.
29pub struct PromptRequest<'a, S: PromptType, M: CompletionModel> {
30    /// The prompt message to send to the model
31    prompt: Message,
32    /// Optional chat history to include with the prompt
33    /// Note: chat history needs to outlive the agent as it might be used with other agents
34    chat_history: Option<&'a mut Vec<Message>>,
35    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
36    max_depth: usize,
37    /// The agent to use for execution
38    agent: &'a Agent<M>,
39    /// Phantom data to track the type of the request
40    state: PhantomData<S>,
41}
42
43impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
44    /// Create a new PromptRequest with the given prompt and model
45    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
46        Self {
47            prompt: prompt.into(),
48            chat_history: None,
49            max_depth: 0,
50            agent,
51            state: PhantomData,
52        }
53    }
54
55    /// Enable returning extended details for responses (includes aggregated token usage)
56    ///
57    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
58    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
59    /// of conversation.
60    pub fn extended_details(self) -> PromptRequest<'a, Extended, M> {
61        PromptRequest {
62            prompt: self.prompt,
63            chat_history: self.chat_history,
64            max_depth: self.max_depth,
65            agent: self.agent,
66            state: PhantomData,
67        }
68    }
69}
70
71impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
72    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
73    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
74    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M> {
75        PromptRequest {
76            prompt: self.prompt,
77            chat_history: self.chat_history,
78            max_depth: depth,
79            agent: self.agent,
80            state: PhantomData,
81        }
82    }
83
84    /// Add chat history to the prompt request
85    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M> {
86        PromptRequest {
87            prompt: self.prompt,
88            chat_history: Some(history),
89            max_depth: self.max_depth,
90            agent: self.agent,
91            state: PhantomData,
92        }
93    }
94}
95
96/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
97///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
98///  directly via the associated type.
99impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Standard, M> {
100    type Output = Result<String, PromptError>;
101    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
102
103    fn into_future(self) -> Self::IntoFuture {
104        self.send().boxed()
105    }
106}
107
108impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Extended, M> {
109    type Output = Result<PromptResponse, PromptError>;
110    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
111
112    fn into_future(self) -> Self::IntoFuture {
113        self.send().boxed()
114    }
115}
116
117impl<M: CompletionModel> PromptRequest<'_, Standard, M> {
118    async fn send(self) -> Result<String, PromptError> {
119        self.extended_details().send().await.map(|resp| resp.output)
120    }
121}
122
123#[derive(Debug, Clone)]
124pub struct PromptResponse {
125    pub output: String,
126    pub total_usage: Usage,
127}
128
129impl PromptResponse {
130    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
131        Self {
132            output: output.into(),
133            total_usage,
134        }
135    }
136}
137
138impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
139    async fn send(self) -> Result<PromptResponse, PromptError> {
140        let agent = self.agent;
141        let chat_history = if let Some(history) = self.chat_history {
142            history.push(self.prompt);
143            history
144        } else {
145            &mut vec![self.prompt]
146        };
147
148        let mut current_max_depth = 0;
149        let mut usage = Usage::new();
150
151        // We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
152        let last_prompt = loop {
153            let prompt = chat_history
154                .last()
155                .cloned()
156                .expect("there should always be at least one message in the chat history");
157
158            if current_max_depth > self.max_depth + 1 {
159                break prompt;
160            }
161
162            current_max_depth += 1;
163
164            if self.max_depth > 1 {
165                tracing::info!(
166                    "Current conversation depth: {}/{}",
167                    current_max_depth,
168                    self.max_depth
169                );
170            }
171
172            let resp = agent
173                .completion(prompt, chat_history[..chat_history.len() - 1].to_vec())
174                .await?
175                .send()
176                .await?;
177
178            usage += resp.usage;
179
180            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
181                .choice
182                .iter()
183                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
184
185            chat_history.push(Message::Assistant {
186                id: None,
187                content: resp.choice.clone(),
188            });
189
190            if tool_calls.is_empty() {
191                let merged_texts = texts
192                    .into_iter()
193                    .filter_map(|content| {
194                        if let AssistantContent::Text(text) = content {
195                            Some(text.text.clone())
196                        } else {
197                            None
198                        }
199                    })
200                    .collect::<Vec<_>>()
201                    .join("\n");
202
203                if self.max_depth > 1 {
204                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
205                }
206
207                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
208                return Ok(PromptResponse::new(merged_texts, usage));
209            }
210
211            let tool_content = stream::iter(tool_calls)
212                .then(|choice| async move {
213                    if let AssistantContent::ToolCall(tool_call) = choice {
214                        let output = agent
215                            .tools
216                            .call(
217                                &tool_call.function.name,
218                                tool_call.function.arguments.to_string(),
219                            )
220                            .await?;
221                        if let Some(call_id) = tool_call.call_id.clone() {
222                            Ok(UserContent::tool_result_with_call_id(
223                                tool_call.id.clone(),
224                                call_id,
225                                OneOrMany::one(output.into()),
226                            ))
227                        } else {
228                            Ok(UserContent::tool_result(
229                                tool_call.id.clone(),
230                                OneOrMany::one(output.into()),
231                            ))
232                        }
233                    } else {
234                        unreachable!(
235                            "This should never happen as we already filtered for `ToolCall`"
236                        )
237                    }
238                })
239                .collect::<Vec<Result<UserContent, ToolSetError>>>()
240                .await
241                .into_iter()
242                .collect::<Result<Vec<_>, _>>()
243                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
244
245            chat_history.push(Message::User {
246                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
247            });
248        };
249
250        // If we reach here, we never resolved the final tool call. We need to do ... something.
251        Err(PromptError::MaxDepthError {
252            max_depth: self.max_depth,
253            chat_history: chat_history.clone(),
254            prompt: last_prompt,
255        })
256    }
257}