1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
9 tool::ToolSet,
10 vector_store::{VectorStoreError, request::VectorSearchRequest},
11};
12use futures::{StreamExt, TryStreamExt, stream};
13use std::{collections::HashMap, sync::Arc};
14
15const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
16
17#[derive(Clone)]
38#[non_exhaustive]
39pub struct Agent<M>
40where
41 M: CompletionModel,
42{
43 pub name: Option<String>,
45 pub model: Arc<M>,
47 pub preamble: String,
49 pub static_context: Vec<Document>,
51 pub static_tools: Vec<String>,
53 pub temperature: Option<f64>,
55 pub max_tokens: Option<u64>,
57 pub additional_params: Option<serde_json::Value>,
59 pub dynamic_context: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
61 pub dynamic_tools: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
63 pub tools: Arc<ToolSet>,
65}
66
67impl<M> Agent<M>
68where
69 M: CompletionModel,
70{
71 pub(crate) fn name(&self) -> &str {
73 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
74 }
75
76 pub(crate) fn name_owned(&self) -> String {
79 self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string())
80 }
81}
82
83impl<M> Completion<M> for Agent<M>
84where
85 M: CompletionModel,
86{
87 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
88 async fn completion(
89 &self,
90 prompt: impl Into<Message> + Send,
91 chat_history: Vec<Message>,
92 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
93 let prompt = prompt.into();
94
95 let rag_text = prompt.rag_text();
97 let rag_text = rag_text.or_else(|| {
98 chat_history
99 .iter()
100 .rev()
101 .find_map(|message| message.rag_text())
102 });
103
104 let completion_request = self
105 .model
106 .completion_request(prompt)
107 .preamble(self.preamble.clone())
108 .messages(chat_history)
109 .temperature_opt(self.temperature)
110 .max_tokens_opt(self.max_tokens)
111 .additional_params_opt(self.additional_params.clone())
112 .documents(self.static_context.clone());
113
114 let agent = match &rag_text {
116 Some(text) => {
117 let dynamic_context = stream::iter(self.dynamic_context.iter())
118 .then(|(num_sample, index)| async {
119 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
120 Ok::<_, VectorStoreError>(
121 index
122 .top_n(req)
123 .await?
124 .into_iter()
125 .map(|(_, id, doc)| {
126 let text = serde_json::to_string_pretty(&doc)
128 .unwrap_or_else(|_| doc.to_string());
129
130 Document {
131 id,
132 text,
133 additional_props: HashMap::new(),
134 }
135 })
136 .collect::<Vec<_>>(),
137 )
138 })
139 .try_fold(vec![], |mut acc, docs| async {
140 acc.extend(docs);
141 Ok(acc)
142 })
143 .await
144 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
145
146 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
147 .then(|(num_sample, index)| async {
148 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
149 Ok::<_, VectorStoreError>(
150 index
151 .top_n_ids(req)
152 .await?
153 .into_iter()
154 .map(|(_, id)| id)
155 .collect::<Vec<_>>(),
156 )
157 })
158 .try_fold(vec![], |mut acc, docs| async {
159 for doc in docs {
160 if let Some(tool) = self.tools.get(&doc) {
161 acc.push(tool.definition(text.into()).await)
162 } else {
163 tracing::warn!("Tool implementation not found in toolset: {}", doc);
164 }
165 }
166 Ok(acc)
167 })
168 .await
169 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
170
171 let static_tools = stream::iter(self.static_tools.iter())
172 .filter_map(|toolname| async move {
173 if let Some(tool) = self.tools.get(toolname) {
174 Some(tool.definition(text.into()).await)
175 } else {
176 tracing::warn!(
177 "Tool implementation not found in toolset: {}",
178 toolname
179 );
180 None
181 }
182 })
183 .collect::<Vec<_>>()
184 .await;
185
186 completion_request
187 .documents(dynamic_context)
188 .tools([static_tools.clone(), dynamic_tools].concat())
189 }
190 None => {
191 let static_tools = stream::iter(self.static_tools.iter())
192 .filter_map(|toolname| async move {
193 if let Some(tool) = self.tools.get(toolname) {
194 Some(tool.definition("".into()).await)
196 } else {
197 tracing::warn!(
198 "Tool implementation not found in toolset: {}",
199 toolname
200 );
201 None
202 }
203 })
204 .collect::<Vec<_>>()
205 .await;
206
207 completion_request.tools(static_tools)
208 }
209 };
210
211 Ok(agent)
212 }
213}
214
215#[allow(refining_impl_trait)]
223impl<M> Prompt for Agent<M>
224where
225 M: CompletionModel,
226{
227 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
228 fn prompt(
229 &self,
230 prompt: impl Into<Message> + Send,
231 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
232 PromptRequest::new(self, prompt)
233 }
234}
235
236#[allow(refining_impl_trait)]
237impl<M> Prompt for &Agent<M>
238where
239 M: CompletionModel,
240{
241 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
242 fn prompt(
243 &self,
244 prompt: impl Into<Message> + Send,
245 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
246 PromptRequest::new(*self, prompt)
247 }
248}
249
250#[allow(refining_impl_trait)]
251impl<M> Chat for Agent<M>
252where
253 M: CompletionModel,
254{
255 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
256 async fn chat(
257 &self,
258 prompt: impl Into<Message> + Send,
259 chat_history: Vec<Message>,
260 ) -> Result<String, PromptError> {
261 let mut cloned_history = chat_history.clone();
262 PromptRequest::new(self, prompt)
263 .with_history(&mut cloned_history)
264 .await
265 }
266}
267
268impl<M> StreamingCompletion<M> for Agent<M>
269where
270 M: CompletionModel,
271{
272 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
273 async fn stream_completion(
274 &self,
275 prompt: impl Into<Message> + Send,
276 chat_history: Vec<Message>,
277 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
278 self.completion(prompt, chat_history).await
281 }
282}
283
284impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
285where
286 M: CompletionModel + 'static,
287 M::StreamingResponse: GetTokenUsage,
288{
289 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
290 fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()> {
291 let arc = Arc::new(self.clone());
292 StreamingPromptRequest::new(arc, prompt)
293 }
294}
295
296impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
297where
298 M: CompletionModel + 'static,
299 M::StreamingResponse: GetTokenUsage,
300{
301 fn stream_chat(
302 &self,
303 prompt: impl Into<Message> + Send,
304 chat_history: Vec<Message>,
305 ) -> StreamingPromptRequest<M, ()> {
306 let arc = Arc::new(self.clone());
307 StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
308 }
309}