1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 completion::{
4 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
5 Message, Prompt, PromptError,
6 },
7 streaming::{StreamingChat, StreamingCompletion, StreamingCompletionResponse, StreamingPrompt},
8 tool::ToolSet,
9 vector_store::{VectorStoreError, request::VectorSearchRequest},
10};
11use futures::{StreamExt, TryStreamExt, stream};
12use std::collections::HashMap;
13
14pub struct Agent<M: CompletionModel> {
35 pub model: M,
37 pub preamble: String,
39 pub static_context: Vec<Document>,
41 pub static_tools: Vec<String>,
43 pub temperature: Option<f64>,
45 pub max_tokens: Option<u64>,
47 pub additional_params: Option<serde_json::Value>,
49 pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
51 pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
53 pub tools: ToolSet,
55}
56
57impl<M: CompletionModel> Completion<M> for Agent<M> {
58 async fn completion(
59 &self,
60 prompt: impl Into<Message> + Send,
61 chat_history: Vec<Message>,
62 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
63 let prompt = prompt.into();
64
65 let rag_text = prompt.rag_text();
67 let rag_text = rag_text.or_else(|| {
68 chat_history
69 .iter()
70 .rev()
71 .find_map(|message| message.rag_text())
72 });
73
74 let completion_request = self
75 .model
76 .completion_request(prompt)
77 .preamble(self.preamble.clone())
78 .messages(chat_history)
79 .temperature_opt(self.temperature)
80 .max_tokens_opt(self.max_tokens)
81 .additional_params_opt(self.additional_params.clone())
82 .documents(self.static_context.clone());
83
84 let agent = match &rag_text {
86 Some(text) => {
87 let dynamic_context = stream::iter(self.dynamic_context.iter())
88 .then(|(num_sample, index)| async {
89 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
90 Ok::<_, VectorStoreError>(
91 index
92 .top_n(req)
93 .await?
94 .into_iter()
95 .map(|(_, id, doc)| {
96 let text = serde_json::to_string_pretty(&doc)
98 .unwrap_or_else(|_| doc.to_string());
99
100 Document {
101 id,
102 text,
103 additional_props: HashMap::new(),
104 }
105 })
106 .collect::<Vec<_>>(),
107 )
108 })
109 .try_fold(vec![], |mut acc, docs| async {
110 acc.extend(docs);
111 Ok(acc)
112 })
113 .await
114 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
115
116 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
117 .then(|(num_sample, index)| async {
118 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
119 Ok::<_, VectorStoreError>(
120 index
121 .top_n_ids(req)
122 .await?
123 .into_iter()
124 .map(|(_, id)| id)
125 .collect::<Vec<_>>(),
126 )
127 })
128 .try_fold(vec![], |mut acc, docs| async {
129 for doc in docs {
130 if let Some(tool) = self.tools.get(&doc) {
131 acc.push(tool.definition(text.into()).await)
132 } else {
133 tracing::warn!("Tool implementation not found in toolset: {}", doc);
134 }
135 }
136 Ok(acc)
137 })
138 .await
139 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
140
141 let static_tools = stream::iter(self.static_tools.iter())
142 .filter_map(|toolname| async move {
143 if let Some(tool) = self.tools.get(toolname) {
144 Some(tool.definition(text.into()).await)
145 } else {
146 tracing::warn!(
147 "Tool implementation not found in toolset: {}",
148 toolname
149 );
150 None
151 }
152 })
153 .collect::<Vec<_>>()
154 .await;
155
156 completion_request
157 .documents(dynamic_context)
158 .tools([static_tools.clone(), dynamic_tools].concat())
159 }
160 None => {
161 let static_tools = stream::iter(self.static_tools.iter())
162 .filter_map(|toolname| async move {
163 if let Some(tool) = self.tools.get(toolname) {
164 Some(tool.definition("".into()).await)
166 } else {
167 tracing::warn!(
168 "Tool implementation not found in toolset: {}",
169 toolname
170 );
171 None
172 }
173 })
174 .collect::<Vec<_>>()
175 .await;
176
177 completion_request.tools(static_tools)
178 }
179 };
180
181 Ok(agent)
182 }
183}
184
185#[allow(refining_impl_trait)]
193impl<M: CompletionModel> Prompt for Agent<M> {
194 fn prompt(
195 &self,
196 prompt: impl Into<Message> + Send,
197 ) -> PromptRequest<prompt_request::Standard, M> {
198 PromptRequest::new(self, prompt)
199 }
200}
201
202#[allow(refining_impl_trait)]
203impl<M: CompletionModel> Prompt for &Agent<M> {
204 fn prompt(
205 &self,
206 prompt: impl Into<Message> + Send,
207 ) -> PromptRequest<prompt_request::Standard, M> {
208 PromptRequest::new(*self, prompt)
209 }
210}
211
212#[allow(refining_impl_trait)]
213impl<M: CompletionModel> Chat for Agent<M> {
214 async fn chat(
215 &self,
216 prompt: impl Into<Message> + Send,
217 chat_history: Vec<Message>,
218 ) -> Result<String, PromptError> {
219 let mut cloned_history = chat_history.clone();
220 PromptRequest::new(self, prompt)
221 .with_history(&mut cloned_history)
222 .await
223 }
224}
225
226impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
227 async fn stream_completion(
228 &self,
229 prompt: impl Into<Message> + Send,
230 chat_history: Vec<Message>,
231 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
232 self.completion(prompt, chat_history).await
235 }
236}
237
238impl<M: CompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
239 async fn stream_prompt(
240 &self,
241 prompt: impl Into<Message> + Send,
242 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
243 self.stream_chat(prompt, vec![]).await
244 }
245}
246
247impl<M: CompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
248 async fn stream_chat(
249 &self,
250 prompt: impl Into<Message> + Send,
251 chat_history: Vec<Message>,
252 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
253 self.stream_completion(prompt, chat_history)
254 .await?
255 .stream()
256 .await
257 }
258}