1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21 RwLock<
22 Vec<(
23 usize,
24 Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25 )>,
26 >,
27>;
28
29#[derive(Clone)]
50#[non_exhaustive]
51pub struct Agent<M>
52where
53 M: CompletionModel,
54{
55 pub name: Option<String>,
57 pub description: Option<String>,
59 pub model: Arc<M>,
61 pub preamble: Option<String>,
63 pub static_context: Vec<Document>,
65 pub temperature: Option<f64>,
67 pub max_tokens: Option<u64>,
69 pub additional_params: Option<serde_json::Value>,
71 pub tool_server_handle: ToolServerHandle,
72 pub dynamic_context: DynamicContextStore,
74 pub tool_choice: Option<ToolChoice>,
76 pub default_max_depth: Option<usize>,
78}
79
80impl<M> Agent<M>
81where
82 M: CompletionModel,
83{
84 pub(crate) fn name(&self) -> &str {
86 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
87 }
88}
89
90impl<M> Completion<M> for Agent<M>
91where
92 M: CompletionModel,
93{
94 async fn completion(
95 &self,
96 prompt: impl Into<Message> + WasmCompatSend,
97 chat_history: Vec<Message>,
98 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
99 let prompt = prompt.into();
100
101 let rag_text = prompt.rag_text();
103 let rag_text = rag_text.or_else(|| {
104 chat_history
105 .iter()
106 .rev()
107 .find_map(|message| message.rag_text())
108 });
109
110 let completion_request = self
111 .model
112 .completion_request(prompt)
113 .messages(chat_history)
114 .temperature_opt(self.temperature)
115 .max_tokens_opt(self.max_tokens)
116 .additional_params_opt(self.additional_params.clone())
117 .documents(self.static_context.clone());
118 let completion_request = if let Some(preamble) = &self.preamble {
119 completion_request.preamble(preamble.to_owned())
120 } else {
121 completion_request
122 };
123 let completion_request = if let Some(tool_choice) = &self.tool_choice {
124 completion_request.tool_choice(tool_choice.clone())
125 } else {
126 completion_request
127 };
128
129 let agent = match &rag_text {
131 Some(text) => {
132 let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
133 .then(|(num_sample, index)| async {
134 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
135 Ok::<_, VectorStoreError>(
136 index
137 .top_n(req)
138 .await?
139 .into_iter()
140 .map(|(_, id, doc)| {
141 let text = serde_json::to_string_pretty(&doc)
143 .unwrap_or_else(|_| doc.to_string());
144
145 Document {
146 id,
147 text,
148 additional_props: HashMap::new(),
149 }
150 })
151 .collect::<Vec<_>>(),
152 )
153 })
154 .try_fold(vec![], |mut acc, docs| async {
155 acc.extend(docs);
156 Ok(acc)
157 })
158 .await
159 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
160
161 let tooldefs = self
162 .tool_server_handle
163 .get_tool_defs(Some(text.to_string()))
164 .await
165 .map_err(|_| {
166 CompletionError::RequestError("Failed to get tool definitions".into())
167 })?;
168
169 completion_request
170 .documents(dynamic_context)
171 .tools(tooldefs)
172 }
173 None => {
174 let tooldefs = self
175 .tool_server_handle
176 .get_tool_defs(None)
177 .await
178 .map_err(|_| {
179 CompletionError::RequestError("Failed to get tool definitions".into())
180 })?;
181
182 completion_request.tools(tooldefs)
183 }
184 };
185
186 Ok(agent)
187 }
188}
189
190#[allow(refining_impl_trait)]
198impl<M> Prompt for Agent<M>
199where
200 M: CompletionModel,
201{
202 fn prompt(
203 &self,
204 prompt: impl Into<Message> + WasmCompatSend,
205 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
206 PromptRequest::new(self, prompt)
207 }
208}
209
210#[allow(refining_impl_trait)]
211impl<M> Prompt for &Agent<M>
212where
213 M: CompletionModel,
214{
215 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
216 fn prompt(
217 &self,
218 prompt: impl Into<Message> + WasmCompatSend,
219 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
220 PromptRequest::new(*self, prompt)
221 }
222}
223
224#[allow(refining_impl_trait)]
225impl<M> Chat for Agent<M>
226where
227 M: CompletionModel,
228{
229 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
230 async fn chat(
231 &self,
232 prompt: impl Into<Message> + WasmCompatSend,
233 mut chat_history: Vec<Message>,
234 ) -> Result<String, PromptError> {
235 PromptRequest::new(self, prompt)
236 .with_history(&mut chat_history)
237 .await
238 }
239}
240
241impl<M> StreamingCompletion<M> for Agent<M>
242where
243 M: CompletionModel,
244{
245 async fn stream_completion(
246 &self,
247 prompt: impl Into<Message> + WasmCompatSend,
248 chat_history: Vec<Message>,
249 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
250 self.completion(prompt, chat_history).await
253 }
254}
255
256impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
257where
258 M: CompletionModel + 'static,
259 M::StreamingResponse: GetTokenUsage,
260{
261 fn stream_prompt(
262 &self,
263 prompt: impl Into<Message> + WasmCompatSend,
264 ) -> StreamingPromptRequest<M, ()> {
265 let arc = Arc::new(self.clone());
266 StreamingPromptRequest::new(arc, prompt)
267 }
268}
269
270impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
271where
272 M: CompletionModel + 'static,
273 M::StreamingResponse: GetTokenUsage,
274{
275 fn stream_chat(
276 &self,
277 prompt: impl Into<Message> + WasmCompatSend,
278 chat_history: Vec<Message>,
279 ) -> StreamingPromptRequest<M, ()> {
280 let arc = Arc::new(self.clone());
281 StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
282 }
283}