1use std::{future::IntoFuture, marker::PhantomData};
2
3use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
4
5use crate::{
6 OneOrMany,
7 completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
8 message::{AssistantContent, UserContent},
9 tool::ToolSetError,
10};
11
12use super::Agent;
13
14pub trait PromptType {}
15pub struct Standard;
16pub struct Extended;
17
18impl PromptType for Standard {}
19impl PromptType for Extended {}
20
21pub struct PromptRequest<'a, S: PromptType, M: CompletionModel> {
30 prompt: Message,
32 chat_history: Option<&'a mut Vec<Message>>,
35 max_depth: usize,
37 agent: &'a Agent<M>,
39 state: PhantomData<S>,
41}
42
43impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
44 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
46 Self {
47 prompt: prompt.into(),
48 chat_history: None,
49 max_depth: 0,
50 agent,
51 state: PhantomData,
52 }
53 }
54
55 pub fn extended_details(self) -> PromptRequest<'a, Extended, M> {
61 PromptRequest {
62 prompt: self.prompt,
63 chat_history: self.chat_history,
64 max_depth: self.max_depth,
65 agent: self.agent,
66 state: PhantomData,
67 }
68 }
69}
70
71impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
72 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M> {
75 PromptRequest {
76 prompt: self.prompt,
77 chat_history: self.chat_history,
78 max_depth: depth,
79 agent: self.agent,
80 state: PhantomData,
81 }
82 }
83
84 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M> {
86 PromptRequest {
87 prompt: self.prompt,
88 chat_history: Some(history),
89 max_depth: self.max_depth,
90 agent: self.agent,
91 state: PhantomData,
92 }
93 }
94}
95
96impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Standard, M> {
100 type Output = Result<String, PromptError>;
101 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
104 self.send().boxed()
105 }
106}
107
108impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, Extended, M> {
109 type Output = Result<PromptResponse, PromptError>;
110 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
113 self.send().boxed()
114 }
115}
116
117impl<M: CompletionModel> PromptRequest<'_, Standard, M> {
118 async fn send(self) -> Result<String, PromptError> {
119 self.extended_details().send().await.map(|resp| resp.output)
120 }
121}
122
123#[derive(Debug, Clone)]
124pub struct PromptResponse {
125 pub output: String,
126 pub total_usage: Usage,
127}
128
129impl PromptResponse {
130 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
131 Self {
132 output: output.into(),
133 total_usage,
134 }
135 }
136}
137
138impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
139 #[tracing::instrument(skip(self), fields(agent_name = self.agent.name()))]
140 async fn send(self) -> Result<PromptResponse, PromptError> {
141 let agent = self.agent;
142 let chat_history = if let Some(history) = self.chat_history {
143 history.push(self.prompt);
144 history
145 } else {
146 &mut vec![self.prompt]
147 };
148
149 let mut current_max_depth = 0;
150 let mut usage = Usage::new();
151
152 let last_prompt = loop {
154 let prompt = chat_history
155 .last()
156 .cloned()
157 .expect("there should always be at least one message in the chat history");
158
159 if current_max_depth > self.max_depth + 1 {
160 break prompt;
161 }
162
163 current_max_depth += 1;
164
165 if self.max_depth > 1 {
166 tracing::info!(
167 "Current conversation depth: {}/{}",
168 current_max_depth,
169 self.max_depth
170 );
171 }
172
173 let resp = agent
174 .completion(prompt, chat_history[..chat_history.len() - 1].to_vec())
175 .await?
176 .send()
177 .await?;
178
179 usage += resp.usage;
180
181 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
182 .choice
183 .iter()
184 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
185
186 chat_history.push(Message::Assistant {
187 id: None,
188 content: resp.choice.clone(),
189 });
190
191 if tool_calls.is_empty() {
192 let merged_texts = texts
193 .into_iter()
194 .filter_map(|content| {
195 if let AssistantContent::Text(text) = content {
196 Some(text.text.clone())
197 } else {
198 None
199 }
200 })
201 .collect::<Vec<_>>()
202 .join("\n");
203
204 if self.max_depth > 1 {
205 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
206 }
207
208 return Ok(PromptResponse::new(merged_texts, usage));
210 }
211
212 let tool_content = stream::iter(tool_calls)
213 .then(|choice| async move {
214 if let AssistantContent::ToolCall(tool_call) = choice {
215 let output = agent
216 .tools
217 .call(
218 &tool_call.function.name,
219 tool_call.function.arguments.to_string(),
220 )
221 .await?;
222 if let Some(call_id) = tool_call.call_id.clone() {
223 Ok(UserContent::tool_result_with_call_id(
224 tool_call.id.clone(),
225 call_id,
226 OneOrMany::one(output.into()),
227 ))
228 } else {
229 Ok(UserContent::tool_result(
230 tool_call.id.clone(),
231 OneOrMany::one(output.into()),
232 ))
233 }
234 } else {
235 unreachable!(
236 "This should never happen as we already filtered for `ToolCall`"
237 )
238 }
239 })
240 .collect::<Vec<Result<UserContent, ToolSetError>>>()
241 .await
242 .into_iter()
243 .collect::<Result<Vec<_>, _>>()
244 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
245
246 chat_history.push(Message::User {
247 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
248 });
249 };
250
251 Err(PromptError::MaxDepthError {
253 max_depth: self.max_depth,
254 chat_history: chat_history.clone(),
255 prompt: last_prompt,
256 })
257 }
258}