1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
9 tool::ToolSet,
10 vector_store::{VectorStoreError, request::VectorSearchRequest},
11};
12use futures::{StreamExt, TryStreamExt, stream};
13use std::{collections::HashMap, sync::Arc};
14
15const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
16
17#[derive(Clone)]
38#[non_exhaustive]
39pub struct Agent<M>
40where
41 M: CompletionModel,
42{
43 pub name: Option<String>,
45 pub model: Arc<M>,
47 pub preamble: Option<String>,
49 pub static_context: Vec<Document>,
51 pub static_tools: Vec<String>,
53 pub temperature: Option<f64>,
55 pub max_tokens: Option<u64>,
57 pub additional_params: Option<serde_json::Value>,
59 pub dynamic_context: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
61 pub dynamic_tools: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
63 pub tools: Arc<ToolSet>,
65}
66
67impl<M> Agent<M>
68where
69 M: CompletionModel,
70{
71 pub(crate) fn name(&self) -> &str {
73 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
74 }
75
76 pub(crate) fn name_owned(&self) -> String {
79 self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string())
80 }
81}
82
83impl<M> Completion<M> for Agent<M>
84where
85 M: CompletionModel,
86{
87 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
88 async fn completion(
89 &self,
90 prompt: impl Into<Message> + Send,
91 chat_history: Vec<Message>,
92 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
93 let prompt = prompt.into();
94
95 let rag_text = prompt.rag_text();
97 let rag_text = rag_text.or_else(|| {
98 chat_history
99 .iter()
100 .rev()
101 .find_map(|message| message.rag_text())
102 });
103
104 let completion_request = self
105 .model
106 .completion_request(prompt)
107 .messages(chat_history)
108 .temperature_opt(self.temperature)
109 .max_tokens_opt(self.max_tokens)
110 .additional_params_opt(self.additional_params.clone())
111 .documents(self.static_context.clone());
112 let completion_request = if let Some(preamble) = &self.preamble {
113 completion_request.preamble(preamble.to_owned())
114 } else {
115 completion_request
116 };
117
118 let agent = match &rag_text {
120 Some(text) => {
121 let dynamic_context = stream::iter(self.dynamic_context.iter())
122 .then(|(num_sample, index)| async {
123 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
124 Ok::<_, VectorStoreError>(
125 index
126 .top_n(req)
127 .await?
128 .into_iter()
129 .map(|(_, id, doc)| {
130 let text = serde_json::to_string_pretty(&doc)
132 .unwrap_or_else(|_| doc.to_string());
133
134 Document {
135 id,
136 text,
137 additional_props: HashMap::new(),
138 }
139 })
140 .collect::<Vec<_>>(),
141 )
142 })
143 .try_fold(vec![], |mut acc, docs| async {
144 acc.extend(docs);
145 Ok(acc)
146 })
147 .await
148 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
149
150 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
151 .then(|(num_sample, index)| async {
152 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
153 Ok::<_, VectorStoreError>(
154 index
155 .top_n_ids(req)
156 .await?
157 .into_iter()
158 .map(|(_, id)| id)
159 .collect::<Vec<_>>(),
160 )
161 })
162 .try_fold(vec![], |mut acc, docs| async {
163 for doc in docs {
164 if let Some(tool) = self.tools.get(&doc) {
165 acc.push(tool.definition(text.into()).await)
166 } else {
167 tracing::warn!("Tool implementation not found in toolset: {}", doc);
168 }
169 }
170 Ok(acc)
171 })
172 .await
173 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
174
175 let static_tools = stream::iter(self.static_tools.iter())
176 .filter_map(|toolname| async move {
177 if let Some(tool) = self.tools.get(toolname) {
178 Some(tool.definition(text.into()).await)
179 } else {
180 tracing::warn!(
181 "Tool implementation not found in toolset: {}",
182 toolname
183 );
184 None
185 }
186 })
187 .collect::<Vec<_>>()
188 .await;
189
190 completion_request
191 .documents(dynamic_context)
192 .tools([static_tools.clone(), dynamic_tools].concat())
193 }
194 None => {
195 let static_tools = stream::iter(self.static_tools.iter())
196 .filter_map(|toolname| async move {
197 if let Some(tool) = self.tools.get(toolname) {
198 Some(tool.definition("".into()).await)
200 } else {
201 tracing::warn!(
202 "Tool implementation not found in toolset: {}",
203 toolname
204 );
205 None
206 }
207 })
208 .collect::<Vec<_>>()
209 .await;
210
211 completion_request.tools(static_tools)
212 }
213 };
214
215 Ok(agent)
216 }
217}
218
219#[allow(refining_impl_trait)]
227impl<M> Prompt for Agent<M>
228where
229 M: CompletionModel,
230{
231 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
232 fn prompt(
233 &self,
234 prompt: impl Into<Message> + Send,
235 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
236 PromptRequest::new(self, prompt)
237 }
238}
239
240#[allow(refining_impl_trait)]
241impl<M> Prompt for &Agent<M>
242where
243 M: CompletionModel,
244{
245 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
246 fn prompt(
247 &self,
248 prompt: impl Into<Message> + Send,
249 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
250 PromptRequest::new(*self, prompt)
251 }
252}
253
254#[allow(refining_impl_trait)]
255impl<M> Chat for Agent<M>
256where
257 M: CompletionModel,
258{
259 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
260 async fn chat(
261 &self,
262 prompt: impl Into<Message> + Send,
263 mut chat_history: Vec<Message>,
264 ) -> Result<String, PromptError> {
265 PromptRequest::new(self, prompt)
266 .with_history(&mut chat_history)
267 .await
268 }
269}
270
271impl<M> StreamingCompletion<M> for Agent<M>
272where
273 M: CompletionModel,
274{
275 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
276 async fn stream_completion(
277 &self,
278 prompt: impl Into<Message> + Send,
279 chat_history: Vec<Message>,
280 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
281 self.completion(prompt, chat_history).await
284 }
285}
286
287impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
288where
289 M: CompletionModel + 'static,
290 M::StreamingResponse: GetTokenUsage,
291{
292 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
293 fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()> {
294 let arc = Arc::new(self.clone());
295 StreamingPromptRequest::new(arc, prompt)
296 }
297}
298
299impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
300where
301 M: CompletionModel + 'static,
302 M::StreamingResponse: GetTokenUsage,
303{
304 fn stream_chat(
305 &self,
306 prompt: impl Into<Message> + Send,
307 chat_history: Vec<Message>,
308 ) -> StreamingPromptRequest<M, ()> {
309 let arc = Arc::new(self.clone());
310 StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
311 }
312}