1pub(crate) mod streaming;
2
3use std::{future::IntoFuture, marker::PhantomData};
4
5use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
6
7use crate::{
8 OneOrMany,
9 completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
10 message::{AssistantContent, UserContent},
11 tool::ToolSetError,
12};
13
14use super::Agent;
15
16pub trait PromptType {}
17pub struct Standard;
18pub struct Extended;
19
20impl PromptType for Standard {}
21impl PromptType for Extended {}
22
23pub struct PromptRequest<'a, S, M, P>
32where
33 S: PromptType,
34 M: CompletionModel,
35 P: PromptHook<M>,
36{
37 prompt: Message,
39 chat_history: Option<&'a mut Vec<Message>>,
42 max_depth: usize,
44 agent: &'a Agent<M>,
46 state: PhantomData<S>,
48 hook: Option<P>,
50}
51
52impl<'a, M> PromptRequest<'a, Standard, M, ()>
53where
54 M: CompletionModel,
55{
56 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
58 Self {
59 prompt: prompt.into(),
60 chat_history: None,
61 max_depth: 0,
62 agent,
63 state: PhantomData,
64 hook: None,
65 }
66 }
67}
68
69impl<'a, S, M, P> PromptRequest<'a, S, M, P>
70where
71 S: PromptType,
72 M: CompletionModel,
73 P: PromptHook<M>,
74{
75 pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
81 PromptRequest {
82 prompt: self.prompt,
83 chat_history: self.chat_history,
84 max_depth: self.max_depth,
85 agent: self.agent,
86 state: PhantomData,
87 hook: self.hook,
88 }
89 }
90 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
93 PromptRequest {
94 prompt: self.prompt,
95 chat_history: self.chat_history,
96 max_depth: depth,
97 agent: self.agent,
98 state: PhantomData,
99 hook: self.hook,
100 }
101 }
102
103 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
105 PromptRequest {
106 prompt: self.prompt,
107 chat_history: Some(history),
108 max_depth: self.max_depth,
109 agent: self.agent,
110 state: PhantomData,
111 hook: self.hook,
112 }
113 }
114
115 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
117 where
118 P2: PromptHook<M>,
119 {
120 PromptRequest {
121 prompt: self.prompt,
122 chat_history: self.chat_history,
123 max_depth: self.max_depth,
124 agent: self.agent,
125 state: PhantomData,
126 hook: Some(hook),
127 }
128 }
129}
130
131pub trait PromptHook<M>: Clone + Send + Sync
134where
135 M: CompletionModel,
136{
137 #[allow(unused_variables)]
138 fn on_completion_call(
140 &self,
141 prompt: &Message,
142 history: &[Message],
143 ) -> impl Future<Output = ()> + Send {
144 async {}
145 }
146
147 #[allow(unused_variables)]
148 fn on_completion_response(
151 &self,
152 prompt: &Message,
153 response: &crate::completion::CompletionResponse<M::Response>,
154 ) -> impl Future<Output = ()> + Send {
155 async {}
156 }
157
158 #[allow(unused_variables)]
159 fn on_stream_completion_response_finish(
161 &self,
162 prompt: &Message,
163 response: &<M as CompletionModel>::StreamingResponse,
164 ) -> impl Future<Output = ()> + Send {
165 async {}
166 }
167
168 #[allow(unused_variables)]
169 fn on_tool_call(&self, tool_name: &str, args: &str) -> impl Future<Output = ()> + Send {
171 async {}
172 }
173
174 #[allow(unused_variables)]
175 fn on_tool_result(
177 &self,
178 tool_name: &str,
179 args: &str,
180 result: &str,
181 ) -> impl Future<Output = ()> + Send {
182 async {}
183 }
184}
185
186impl<M> PromptHook<M> for () where M: CompletionModel {}
187
188impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
192where
193 M: CompletionModel,
194 P: PromptHook<M> + 'static,
195{
196 type Output = Result<String, PromptError>;
197 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
200 self.send().boxed()
201 }
202}
203
204impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
205where
206 M: CompletionModel,
207 P: PromptHook<M> + 'static,
208{
209 type Output = Result<PromptResponse, PromptError>;
210 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
213 self.send().boxed()
214 }
215}
216
217impl<M, P> PromptRequest<'_, Standard, M, P>
218where
219 M: CompletionModel,
220 P: PromptHook<M>,
221{
222 async fn send(self) -> Result<String, PromptError> {
223 self.extended_details().send().await.map(|resp| resp.output)
224 }
225}
226
227#[derive(Debug, Clone)]
228pub struct PromptResponse {
229 pub output: String,
230 pub total_usage: Usage,
231}
232
233impl PromptResponse {
234 pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
235 Self {
236 output: output.into(),
237 total_usage,
238 }
239 }
240}
241
242impl<M, P> PromptRequest<'_, Extended, M, P>
243where
244 M: CompletionModel,
245 P: PromptHook<M>,
246{
247 #[tracing::instrument(skip(self), fields(agent_name = self.agent.name()))]
248 async fn send(self) -> Result<PromptResponse, PromptError> {
249 let agent = self.agent;
250 let chat_history = if let Some(history) = self.chat_history {
251 history.push(self.prompt);
252 history
253 } else {
254 &mut vec![self.prompt]
255 };
256
257 let mut current_max_depth = 0;
258 let mut usage = Usage::new();
259
260 let last_prompt = loop {
262 let prompt = chat_history
263 .last()
264 .cloned()
265 .expect("there should always be at least one message in the chat history");
266
267 if current_max_depth > self.max_depth + 1 {
268 break prompt;
269 }
270
271 current_max_depth += 1;
272
273 if self.max_depth > 1 {
274 tracing::info!(
275 "Current conversation depth: {}/{}",
276 current_max_depth,
277 self.max_depth
278 );
279 }
280
281 if let Some(ref hook) = self.hook {
282 hook.on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
283 .await;
284 }
285
286 let resp = agent
287 .completion(
288 prompt.clone(),
289 chat_history[..chat_history.len() - 1].to_vec(),
290 )
291 .await?
292 .send()
293 .await?;
294
295 usage += resp.usage;
296
297 if let Some(ref hook) = self.hook {
298 hook.on_completion_response(&prompt, &resp).await;
299 }
300
301 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
302 .choice
303 .iter()
304 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
305
306 chat_history.push(Message::Assistant {
307 id: None,
308 content: resp.choice.clone(),
309 });
310
311 if tool_calls.is_empty() {
312 let merged_texts = texts
313 .into_iter()
314 .filter_map(|content| {
315 if let AssistantContent::Text(text) = content {
316 Some(text.text.clone())
317 } else {
318 None
319 }
320 })
321 .collect::<Vec<_>>()
322 .join("\n");
323
324 if self.max_depth > 1 {
325 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
326 }
327
328 return Ok(PromptResponse::new(merged_texts, usage));
330 }
331
332 let hook = self.hook.clone();
333 let tool_content = stream::iter(tool_calls)
334 .then(|choice| {
335 let hook1 = hook.clone();
336 let hook2 = hook.clone();
337 async move {
338 if let AssistantContent::ToolCall(tool_call) = choice {
339 let tool_name = &tool_call.function.name;
340 let args = tool_call.function.arguments.to_string();
341 if let Some(hook) = hook1 {
342 hook.on_tool_call(tool_name, &args).await;
343 }
344 let output = agent.tools.call(tool_name, args.clone()).await?;
345 if let Some(hook) = hook2 {
346 hook.on_tool_result(tool_name, &args, &output.to_string())
347 .await;
348 }
349 if let Some(call_id) = tool_call.call_id.clone() {
350 Ok(UserContent::tool_result_with_call_id(
351 tool_call.id.clone(),
352 call_id,
353 OneOrMany::one(output.into()),
354 ))
355 } else {
356 Ok(UserContent::tool_result(
357 tool_call.id.clone(),
358 OneOrMany::one(output.into()),
359 ))
360 }
361 } else {
362 unreachable!(
363 "This should never happen as we already filtered for `ToolCall`"
364 )
365 }
366 }
367 })
368 .collect::<Vec<Result<UserContent, ToolSetError>>>()
369 .await
370 .into_iter()
371 .collect::<Result<Vec<_>, _>>()
372 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
373
374 chat_history.push(Message::User {
375 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
376 });
377 };
378
379 Err(PromptError::MaxDepthError {
381 max_depth: self.max_depth,
382 chat_history: Box::new(chat_history.clone()),
383 prompt: last_prompt,
384 })
385 }
386}