1use std::{future::IntoFuture, marker::PhantomData};
2
3use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
4
5use crate::{
6 OneOrMany,
7 completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
8 message::{AssistantContent, UserContent},
9 tool::ToolSetError,
10};
11
12use super::Agent;
13
14pub trait PromptType {}
15pub struct Standard;
16pub struct Extended;
17
18impl PromptType for Standard {}
19impl PromptType for Extended {}
20
21pub struct PromptRequest<'a, S: PromptType, M: CompletionModel> {
30 prompt: Message,
32 chat_history: Option<&'a mut Vec<Message>>,
35 max_depth: usize,
37 agent: &'a Agent<M>,
39 state: PhantomData<S>,
41}
42
43impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
44 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
46 Self {
47 prompt: prompt.into(),
48 chat_history: None,
49 max_depth: 0,
50 agent,
51 state: PhantomData,
52 }
53 }
54
55 pub fn extended_details(self) -> PromptRequest<'a, Extended, M> {
61 PromptRequest {
62 prompt: self.prompt,
63 chat_history: self.chat_history,
64 max_depth: self.max_depth,
65 agent: self.agent,
66 state: PhantomData,
67 }
68 }
69}
70
71impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
72 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M> {
75 PromptRequest {
76 prompt: self.prompt,
77 chat_history: self.chat_history,
78 max_depth: depth,
79 agent: self.agent,
80 state: PhantomData,
81 }
82 }
83
84 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M> {
86 PromptRequest {
87 prompt: self.prompt,
88 chat_history: Some(history),
89 max_depth: self.max_depth,
90 agent: self.agent,
91 state: PhantomData,
92 }
93 }
94}
95
96impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Standard, M> {
100 type Output = Result<String, PromptError>;
101 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
104 self.send().boxed()
105 }
106}
107
108impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Extended, M> {
109 type Output = Result<PromptResponse, PromptError>;
110 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
113 self.send().boxed()
114 }
115}
116
117impl<M: CompletionModel> PromptRequest<'_, Standard, M> {
118 async fn send(self) -> Result<String, PromptError> {
119 self.extended_details().send().await.map(|resp| resp.output)
120 }
121}
122
123#[derive(Debug, Clone)]
124pub struct PromptResponse {
125 pub output: String,
126 pub total_usage: Usage,
127}
128
129impl PromptResponse {
130 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
131 Self {
132 output: output.into(),
133 total_usage,
134 }
135 }
136}
137
138impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
139 async fn send(self) -> Result<PromptResponse, PromptError> {
140 let agent = self.agent;
141 let chat_history = if let Some(history) = self.chat_history {
142 history.push(self.prompt);
143 history
144 } else {
145 &mut vec![self.prompt]
146 };
147
148 let mut current_max_depth = 0;
149 let mut usage = Usage::new();
150
151 let last_prompt = loop {
153 let prompt = chat_history
154 .last()
155 .cloned()
156 .expect("there should always be at least one message in the chat history");
157
158 if current_max_depth > self.max_depth + 1 {
159 break prompt;
160 }
161
162 current_max_depth += 1;
163
164 if self.max_depth > 1 {
165 tracing::info!(
166 "Current conversation depth: {}/{}",
167 current_max_depth,
168 self.max_depth
169 );
170 }
171
172 let resp = agent
173 .completion(prompt, chat_history[..chat_history.len() - 1].to_vec())
174 .await?
175 .send()
176 .await?;
177
178 usage += resp.usage;
179
180 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
181 .choice
182 .iter()
183 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
184
185 chat_history.push(Message::Assistant {
186 id: None,
187 content: resp.choice.clone(),
188 });
189
190 if tool_calls.is_empty() {
191 let merged_texts = texts
192 .into_iter()
193 .filter_map(|content| {
194 if let AssistantContent::Text(text) = content {
195 Some(text.text.clone())
196 } else {
197 None
198 }
199 })
200 .collect::<Vec<_>>()
201 .join("\n");
202
203 if self.max_depth > 1 {
204 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
205 }
206
207 return Ok(PromptResponse::new(merged_texts, usage));
209 }
210
211 let tool_content = stream::iter(tool_calls)
212 .then(|choice| async move {
213 if let AssistantContent::ToolCall(tool_call) = choice {
214 let output = agent
215 .tools
216 .call(
217 &tool_call.function.name,
218 tool_call.function.arguments.to_string(),
219 )
220 .await?;
221 if let Some(call_id) = tool_call.call_id.clone() {
222 Ok(UserContent::tool_result_with_call_id(
223 tool_call.id.clone(),
224 call_id,
225 OneOrMany::one(output.into()),
226 ))
227 } else {
228 Ok(UserContent::tool_result(
229 tool_call.id.clone(),
230 OneOrMany::one(output.into()),
231 ))
232 }
233 } else {
234 unreachable!(
235 "This should never happen as we already filtered for `ToolCall`"
236 )
237 }
238 })
239 .collect::<Vec<Result<UserContent, ToolSetError>>>()
240 .await
241 .into_iter()
242 .collect::<Result<Vec<_>, _>>()
243 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
244
245 chat_history.push(Message::User {
246 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
247 });
248 };
249
250 Err(PromptError::MaxDepthError {
252 max_depth: self.max_depth,
253 chat_history: chat_history.clone(),
254 prompt: last_prompt,
255 })
256 }
257}