rig/agent/
prompt_request.rs

1use std::future::IntoFuture;
2
3use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
4
5use crate::{
6    completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
7    message::{AssistantContent, UserContent},
8    tool::ToolSetError,
9    OneOrMany,
10};
11
12use super::Agent;
13
14/// A builder for creating prompt requests with customizable options.
15/// Uses generics to track which options have been set during the build process.
16pub struct PromptRequest<'a, M: CompletionModel> {
17    /// The prompt message to send to the model
18    prompt: Message,
19    /// Optional chat history to include with the prompt
20    /// Note: chat history needs to outlive the agent as it might be used with other agents
21    chat_history: Option<&'a mut Vec<Message>>,
22    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
23    max_depth: usize,
24    /// The agent to use for execution
25    agent: &'a Agent<M>,
26}
27
28impl<'a, M: CompletionModel> PromptRequest<'a, M> {
29    /// Create a new PromptRequest with the given prompt and model
30    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
31        Self {
32            prompt: prompt.into(),
33            chat_history: None,
34            max_depth: 0,
35            agent,
36        }
37    }
38}
39
40impl<'a, M: CompletionModel> PromptRequest<'a, M> {
41    /// Set the maximum depth for multi-turn conversations
42    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
43        PromptRequest {
44            prompt: self.prompt,
45            chat_history: self.chat_history,
46            max_depth: depth,
47            agent: self.agent,
48        }
49    }
50
51    /// Add chat history to the prompt request
52    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
53        PromptRequest {
54            prompt: self.prompt,
55            chat_history: Some(history),
56            max_depth: self.max_depth,
57            agent: self.agent,
58        }
59    }
60}
61
62/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
63///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
64///  directly via the associated type.
65impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
66    type Output = Result<String, PromptError>;
67    type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
68
69    fn into_future(self) -> Self::IntoFuture {
70        self.send().boxed()
71    }
72}
73
74impl<M: CompletionModel> PromptRequest<'_, M> {
75    async fn send(self) -> Result<String, PromptError> {
76        let agent = self.agent;
77        let mut prompt = self.prompt;
78        let chat_history = if let Some(history) = self.chat_history {
79            history
80        } else {
81            &mut Vec::new()
82        };
83
84        let mut current_max_depth = 0;
85        // We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
86        while current_max_depth <= self.max_depth + 1 {
87            current_max_depth += 1;
88
89            if self.max_depth > 1 {
90                tracing::info!(
91                    "Current conversation depth: {}/{}",
92                    current_max_depth,
93                    self.max_depth
94                );
95            }
96
97            let resp = agent
98                .completion(prompt.clone(), chat_history.to_vec())
99                .await?
100                .send()
101                .await?;
102
103            chat_history.push(prompt);
104
105            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
106                .choice
107                .iter()
108                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
109
110            chat_history.push(Message::Assistant {
111                content: resp.choice.clone(),
112            });
113
114            if tool_calls.is_empty() {
115                let merged_texts = texts
116                    .into_iter()
117                    .filter_map(|content| {
118                        if let AssistantContent::Text(text) = content {
119                            Some(text.text.clone())
120                        } else {
121                            None
122                        }
123                    })
124                    .collect::<Vec<_>>()
125                    .join("\n");
126
127                if self.max_depth > 1 {
128                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
129                }
130
131                // If there are no tool calls, depth is not relevant, we can just return the merged text.
132                return Ok(merged_texts);
133            }
134
135            let tool_content = stream::iter(tool_calls)
136                .then(|choice| async move {
137                    if let AssistantContent::ToolCall(tool_call) = choice {
138                        let output = agent
139                            .tools
140                            .call(
141                                &tool_call.function.name,
142                                tool_call.function.arguments.to_string(),
143                            )
144                            .await?;
145                        Ok(UserContent::tool_result(
146                            tool_call.id.clone(),
147                            OneOrMany::one(output.into()),
148                        ))
149                    } else {
150                        unreachable!(
151                            "This should never happen as we already filtered for `ToolCall`"
152                        )
153                    }
154                })
155                .collect::<Vec<Result<UserContent, ToolSetError>>>()
156                .await
157                .into_iter()
158                .collect::<Result<Vec<_>, _>>()
159                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
160
161            prompt = Message::User {
162                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
163            };
164        }
165
166        // If we reach here, we never resolved the final tool call. We need to do ... something.
167        Err(PromptError::MaxDepthError {
168            max_depth: self.max_depth,
169            chat_history: chat_history.clone(),
170            prompt,
171        })
172    }
173}