1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
9 tool::ToolSet,
10 vector_store::{VectorStoreError, request::VectorSearchRequest},
11};
12use futures::{StreamExt, TryStreamExt, stream};
13use std::collections::HashMap;
14
15const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
16
17#[non_exhaustive]
38pub struct Agent<M: CompletionModel> {
39 pub name: Option<String>,
41 pub model: M,
43 pub preamble: String,
45 pub static_context: Vec<Document>,
47 pub static_tools: Vec<String>,
49 pub temperature: Option<f64>,
51 pub max_tokens: Option<u64>,
53 pub additional_params: Option<serde_json::Value>,
55 pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
57 pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
59 pub tools: ToolSet,
61}
62
63impl<M: CompletionModel> Completion<M> for Agent<M> {
64 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
65 async fn completion(
66 &self,
67 prompt: impl Into<Message> + Send,
68 chat_history: Vec<Message>,
69 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
70 let prompt = prompt.into();
71
72 let rag_text = prompt.rag_text();
74 let rag_text = rag_text.or_else(|| {
75 chat_history
76 .iter()
77 .rev()
78 .find_map(|message| message.rag_text())
79 });
80
81 let completion_request = self
82 .model
83 .completion_request(prompt)
84 .preamble(self.preamble.clone())
85 .messages(chat_history)
86 .temperature_opt(self.temperature)
87 .max_tokens_opt(self.max_tokens)
88 .additional_params_opt(self.additional_params.clone())
89 .documents(self.static_context.clone());
90
91 let agent = match &rag_text {
93 Some(text) => {
94 let dynamic_context = stream::iter(self.dynamic_context.iter())
95 .then(|(num_sample, index)| async {
96 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
97 Ok::<_, VectorStoreError>(
98 index
99 .top_n(req)
100 .await?
101 .into_iter()
102 .map(|(_, id, doc)| {
103 let text = serde_json::to_string_pretty(&doc)
105 .unwrap_or_else(|_| doc.to_string());
106
107 Document {
108 id,
109 text,
110 additional_props: HashMap::new(),
111 }
112 })
113 .collect::<Vec<_>>(),
114 )
115 })
116 .try_fold(vec![], |mut acc, docs| async {
117 acc.extend(docs);
118 Ok(acc)
119 })
120 .await
121 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
122
123 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
124 .then(|(num_sample, index)| async {
125 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
126 Ok::<_, VectorStoreError>(
127 index
128 .top_n_ids(req)
129 .await?
130 .into_iter()
131 .map(|(_, id)| id)
132 .collect::<Vec<_>>(),
133 )
134 })
135 .try_fold(vec![], |mut acc, docs| async {
136 for doc in docs {
137 if let Some(tool) = self.tools.get(&doc) {
138 acc.push(tool.definition(text.into()).await)
139 } else {
140 tracing::warn!("Tool implementation not found in toolset: {}", doc);
141 }
142 }
143 Ok(acc)
144 })
145 .await
146 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
147
148 let static_tools = stream::iter(self.static_tools.iter())
149 .filter_map(|toolname| async move {
150 if let Some(tool) = self.tools.get(toolname) {
151 Some(tool.definition(text.into()).await)
152 } else {
153 tracing::warn!(
154 "Tool implementation not found in toolset: {}",
155 toolname
156 );
157 None
158 }
159 })
160 .collect::<Vec<_>>()
161 .await;
162
163 completion_request
164 .documents(dynamic_context)
165 .tools([static_tools.clone(), dynamic_tools].concat())
166 }
167 None => {
168 let static_tools = stream::iter(self.static_tools.iter())
169 .filter_map(|toolname| async move {
170 if let Some(tool) = self.tools.get(toolname) {
171 Some(tool.definition("".into()).await)
173 } else {
174 tracing::warn!(
175 "Tool implementation not found in toolset: {}",
176 toolname
177 );
178 None
179 }
180 })
181 .collect::<Vec<_>>()
182 .await;
183
184 completion_request.tools(static_tools)
185 }
186 };
187
188 Ok(agent)
189 }
190}
191
192#[allow(refining_impl_trait)]
200impl<M: CompletionModel> Prompt for Agent<M> {
201 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
202 fn prompt(
203 &self,
204 prompt: impl Into<Message> + Send,
205 ) -> PromptRequest<'_, prompt_request::Standard, M> {
206 PromptRequest::new(self, prompt)
207 }
208}
209
210#[allow(refining_impl_trait)]
211impl<M: CompletionModel> Prompt for &Agent<M> {
212 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
213 fn prompt(
214 &self,
215 prompt: impl Into<Message> + Send,
216 ) -> PromptRequest<'_, prompt_request::Standard, M> {
217 PromptRequest::new(*self, prompt)
218 }
219}
220
221#[allow(refining_impl_trait)]
222impl<M: CompletionModel> Chat for Agent<M> {
223 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
224 async fn chat(
225 &self,
226 prompt: impl Into<Message> + Send,
227 chat_history: Vec<Message>,
228 ) -> Result<String, PromptError> {
229 let mut cloned_history = chat_history.clone();
230 PromptRequest::new(self, prompt)
231 .with_history(&mut cloned_history)
232 .await
233 }
234}
235
236impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
237 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
238 async fn stream_completion(
239 &self,
240 prompt: impl Into<Message> + Send,
241 chat_history: Vec<Message>,
242 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
243 self.completion(prompt, chat_history).await
246 }
247}
248
249impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
250where
251 M: CompletionModel + 'static,
252 M::StreamingResponse: GetTokenUsage,
253{
254 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
255 fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<'_, M> {
256 StreamingPromptRequest::new(self, prompt)
257 }
258}
259
260impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
261where
262 M: CompletionModel + 'static,
263 M::StreamingResponse: GetTokenUsage,
264{
265 fn stream_chat(
266 &self,
267 prompt: impl Into<Message> + Send,
268 chat_history: Vec<Message>,
269 ) -> StreamingPromptRequest<'_, M> {
270 StreamingPromptRequest::new(self, prompt).with_history(chat_history)
271 }
272}
273
274impl<M: CompletionModel> Agent<M> {
275 pub(crate) fn name(&self) -> &str {
277 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
278 }
279
280 pub(crate) fn name_owned(&self) -> String {
283 self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string())
284 }
285}