rig/agent/
prompt_request.rs

1use std::future::IntoFuture;
2
3use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
4
5use crate::{
6    OneOrMany,
7    completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
8    message::{AssistantContent, UserContent},
9    tool::ToolSetError,
10};
11
12use super::Agent;
13
14/// A builder for creating prompt requests with customizable options.
15/// Uses generics to track which options have been set during the build process.
16/// If you're using tools, you will want to ensure you use `.multi_turn()` to add more turns as by default it is 0 (meaning no tool usage).
17/// Otherwise, attempting to await (which will send the prompt request) returns [`crate::completion::request::PromptError::MaxDepthError`].
18pub struct PromptRequest<'a, M: CompletionModel> {
19    /// The prompt message to send to the model
20    prompt: Message,
21    /// Optional chat history to include with the prompt
22    /// Note: chat history needs to outlive the agent as it might be used with other agents
23    chat_history: Option<&'a mut Vec<Message>>,
24    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
25    max_depth: usize,
26    /// The agent to use for execution
27    agent: &'a Agent<M>,
28}
29
30impl<'a, M: CompletionModel> PromptRequest<'a, M> {
31    /// Create a new PromptRequest with the given prompt and model
32    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
33        Self {
34            prompt: prompt.into(),
35            chat_history: None,
36            max_depth: 0,
37            agent,
38        }
39    }
40}
41
42impl<'a, M: CompletionModel> PromptRequest<'a, M> {
43    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
44    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
45    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
46        PromptRequest {
47            prompt: self.prompt,
48            chat_history: self.chat_history,
49            max_depth: depth,
50            agent: self.agent,
51        }
52    }
53
54    /// Add chat history to the prompt request.
55    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
56        PromptRequest {
57            prompt: self.prompt,
58            chat_history: Some(history),
59            max_depth: self.max_depth,
60            agent: self.agent,
61        }
62    }
63}
64
65/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
66///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
67///  directly via the associated type.
68impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
69    type Output = Result<String, PromptError>;
70    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
71
72    fn into_future(self) -> Self::IntoFuture {
73        self.send().boxed()
74    }
75}
76
77impl<M: CompletionModel> PromptRequest<'_, M> {
78    async fn send(self) -> Result<String, PromptError> {
79        let agent = self.agent;
80        let chat_history = if let Some(history) = self.chat_history {
81            history.push(self.prompt);
82            history
83        } else {
84            &mut vec![self.prompt]
85        };
86
87        let mut current_max_depth = 0;
88        // We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
89        let last_prompt = loop {
90            let prompt = chat_history
91                .last()
92                .cloned()
93                .expect("there should always be at least one message in the chat history");
94
95            if current_max_depth > self.max_depth + 1 {
96                break prompt;
97            }
98
99            current_max_depth += 1;
100
101            if self.max_depth > 1 {
102                tracing::info!(
103                    "Current conversation depth: {}/{}",
104                    current_max_depth,
105                    self.max_depth
106                );
107            }
108
109            let resp = agent
110                .completion(prompt, chat_history[..chat_history.len() - 1].to_vec())
111                .await?
112                .send()
113                .await?;
114
115            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
116                .choice
117                .iter()
118                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
119
120            chat_history.push(Message::Assistant {
121                id: None,
122                content: resp.choice.clone(),
123            });
124
125            if tool_calls.is_empty() {
126                let merged_texts = texts
127                    .into_iter()
128                    .filter_map(|content| {
129                        if let AssistantContent::Text(text) = content {
130                            Some(text.text.clone())
131                        } else {
132                            None
133                        }
134                    })
135                    .collect::<Vec<_>>()
136                    .join("\n");
137
138                if self.max_depth > 1 {
139                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
140                }
141
142                // If there are no tool calls, depth is not relevant, we can just return the merged text.
143                return Ok(merged_texts);
144            }
145
146            let tool_content = stream::iter(tool_calls)
147                .then(|choice| async move {
148                    if let AssistantContent::ToolCall(tool_call) = choice {
149                        let output = agent
150                            .tools
151                            .call(
152                                &tool_call.function.name,
153                                tool_call.function.arguments.to_string(),
154                            )
155                            .await?;
156                        if let Some(call_id) = tool_call.call_id.clone() {
157                            Ok(UserContent::tool_result_with_call_id(
158                                tool_call.id.clone(),
159                                call_id,
160                                OneOrMany::one(output.into()),
161                            ))
162                        } else {
163                            Ok(UserContent::tool_result(
164                                tool_call.id.clone(),
165                                OneOrMany::one(output.into()),
166                            ))
167                        }
168                    } else {
169                        unreachable!(
170                            "This should never happen as we already filtered for `ToolCall`"
171                        )
172                    }
173                })
174                .collect::<Vec<Result<UserContent, ToolSetError>>>()
175                .await
176                .into_iter()
177                .collect::<Result<Vec<_>, _>>()
178                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
179
180            chat_history.push(Message::User {
181                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
182            });
183        };
184
185        // If we reach here, we never resolved the final tool call. We need to do ... something.
186        Err(PromptError::MaxDepthError {
187            max_depth: self.max_depth,
188            chat_history: chat_history.clone(),
189            prompt: last_prompt,
190        })
191    }
192}