rig/agent/
prompt_request.rs1use std::future::IntoFuture;
2
3use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
4
5use crate::{
6 OneOrMany,
7 completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
8 message::{AssistantContent, UserContent},
9 tool::ToolSetError,
10};
11
12use super::Agent;
13
14pub struct PromptRequest<'a, M: CompletionModel> {
19 prompt: Message,
21 chat_history: Option<&'a mut Vec<Message>>,
24 max_depth: usize,
26 agent: &'a Agent<M>,
28}
29
30impl<'a, M: CompletionModel> PromptRequest<'a, M> {
31 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
33 Self {
34 prompt: prompt.into(),
35 chat_history: None,
36 max_depth: 0,
37 agent,
38 }
39 }
40}
41
42impl<'a, M: CompletionModel> PromptRequest<'a, M> {
43 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
46 PromptRequest {
47 prompt: self.prompt,
48 chat_history: self.chat_history,
49 max_depth: depth,
50 agent: self.agent,
51 }
52 }
53
54 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
56 PromptRequest {
57 prompt: self.prompt,
58 chat_history: Some(history),
59 max_depth: self.max_depth,
60 agent: self.agent,
61 }
62 }
63}
64
65impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
69 type Output = Result<String, PromptError>;
70 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
73 self.send().boxed()
74 }
75}
76
77impl<M: CompletionModel> PromptRequest<'_, M> {
78 async fn send(self) -> Result<String, PromptError> {
79 let agent = self.agent;
80 let chat_history = if let Some(history) = self.chat_history {
81 history.push(self.prompt);
82 history
83 } else {
84 &mut vec![self.prompt]
85 };
86
87 let mut current_max_depth = 0;
88 let last_prompt = loop {
90 let prompt = chat_history
91 .last()
92 .cloned()
93 .expect("there should always be at least one message in the chat history");
94
95 if current_max_depth > self.max_depth + 1 {
96 break prompt;
97 }
98
99 current_max_depth += 1;
100
101 if self.max_depth > 1 {
102 tracing::info!(
103 "Current conversation depth: {}/{}",
104 current_max_depth,
105 self.max_depth
106 );
107 }
108
109 let resp = agent
110 .completion(prompt, chat_history[..chat_history.len() - 1].to_vec())
111 .await?
112 .send()
113 .await?;
114
115 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
116 .choice
117 .iter()
118 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
119
120 chat_history.push(Message::Assistant {
121 id: None,
122 content: resp.choice.clone(),
123 });
124
125 if tool_calls.is_empty() {
126 let merged_texts = texts
127 .into_iter()
128 .filter_map(|content| {
129 if let AssistantContent::Text(text) = content {
130 Some(text.text.clone())
131 } else {
132 None
133 }
134 })
135 .collect::<Vec<_>>()
136 .join("\n");
137
138 if self.max_depth > 1 {
139 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
140 }
141
142 return Ok(merged_texts);
144 }
145
146 let tool_content = stream::iter(tool_calls)
147 .then(|choice| async move {
148 if let AssistantContent::ToolCall(tool_call) = choice {
149 let output = agent
150 .tools
151 .call(
152 &tool_call.function.name,
153 tool_call.function.arguments.to_string(),
154 )
155 .await?;
156 if let Some(call_id) = tool_call.call_id.clone() {
157 Ok(UserContent::tool_result_with_call_id(
158 tool_call.id.clone(),
159 call_id,
160 OneOrMany::one(output.into()),
161 ))
162 } else {
163 Ok(UserContent::tool_result(
164 tool_call.id.clone(),
165 OneOrMany::one(output.into()),
166 ))
167 }
168 } else {
169 unreachable!(
170 "This should never happen as we already filtered for `ToolCall`"
171 )
172 }
173 })
174 .collect::<Vec<Result<UserContent, ToolSetError>>>()
175 .await
176 .into_iter()
177 .collect::<Result<Vec<_>, _>>()
178 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
179
180 chat_history.push(Message::User {
181 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
182 });
183 };
184
185 Err(PromptError::MaxDepthError {
187 max_depth: self.max_depth,
188 chat_history: chat_history.clone(),
189 prompt: last_prompt,
190 })
191 }
192}