rig/agent/
prompt_request.rs1use std::future::IntoFuture;
2
3use futures::{future::BoxFuture, stream, FutureExt, StreamExt};
4
5use crate::{
6 completion::{Completion, CompletionError, CompletionModel, Message, PromptError},
7 message::{AssistantContent, UserContent},
8 tool::ToolSetError,
9 OneOrMany,
10};
11
12use super::Agent;
13
14pub struct PromptRequest<'a, M: CompletionModel> {
17 prompt: Message,
19 chat_history: Option<&'a mut Vec<Message>>,
22 max_depth: usize,
24 agent: &'a Agent<M>,
26}
27
28impl<'a, M: CompletionModel> PromptRequest<'a, M> {
29 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
31 Self {
32 prompt: prompt.into(),
33 chat_history: None,
34 max_depth: 0,
35 agent,
36 }
37 }
38}
39
40impl<'a, M: CompletionModel> PromptRequest<'a, M> {
41 pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, M> {
43 PromptRequest {
44 prompt: self.prompt,
45 chat_history: self.chat_history,
46 max_depth: depth,
47 agent: self.agent,
48 }
49 }
50
51 pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, M> {
53 PromptRequest {
54 prompt: self.prompt,
55 chat_history: Some(history),
56 max_depth: self.max_depth,
57 agent: self.agent,
58 }
59 }
60}
61
62impl<'a, M: CompletionModel> IntoFuture for PromptRequest<'a, M> {
66 type Output = Result<String, PromptError>;
67 type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture {
70 self.send().boxed()
71 }
72}
73
74impl<M: CompletionModel> PromptRequest<'_, M> {
75 async fn send(self) -> Result<String, PromptError> {
76 let agent = self.agent;
77 let mut prompt = self.prompt;
78 let chat_history = if let Some(history) = self.chat_history {
79 history
80 } else {
81 &mut Vec::new()
82 };
83
84 let mut current_max_depth = 0;
85 while current_max_depth <= self.max_depth + 1 {
87 current_max_depth += 1;
88
89 if self.max_depth > 1 {
90 tracing::info!(
91 "Current conversation depth: {}/{}",
92 current_max_depth,
93 self.max_depth
94 );
95 }
96
97 let resp = agent
98 .completion(prompt.clone(), chat_history.to_vec())
99 .await?
100 .send()
101 .await?;
102
103 chat_history.push(prompt);
104
105 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
106 .choice
107 .iter()
108 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
109
110 chat_history.push(Message::Assistant {
111 content: resp.choice.clone(),
112 });
113
114 if tool_calls.is_empty() {
115 let merged_texts = texts
116 .into_iter()
117 .filter_map(|content| {
118 if let AssistantContent::Text(text) = content {
119 Some(text.text.clone())
120 } else {
121 None
122 }
123 })
124 .collect::<Vec<_>>()
125 .join("\n");
126
127 if self.max_depth > 1 {
128 tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
129 }
130
131 return Ok(merged_texts);
133 }
134
135 let tool_content = stream::iter(tool_calls)
136 .then(|choice| async move {
137 if let AssistantContent::ToolCall(tool_call) = choice {
138 let output = agent
139 .tools
140 .call(
141 &tool_call.function.name,
142 tool_call.function.arguments.to_string(),
143 )
144 .await?;
145 Ok(UserContent::tool_result(
146 tool_call.id.clone(),
147 OneOrMany::one(output.into()),
148 ))
149 } else {
150 unreachable!(
151 "This should never happen as we already filtered for `ToolCall`"
152 )
153 }
154 })
155 .collect::<Vec<Result<UserContent, ToolSetError>>>()
156 .await
157 .into_iter()
158 .collect::<Result<Vec<_>, _>>()
159 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
160
161 prompt = Message::User {
162 content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
163 };
164 }
165
166 Err(PromptError::MaxDepthError {
168 max_depth: self.max_depth,
169 chat_history: chat_history.clone(),
170 prompt,
171 })
172 }
173}