1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21 RwLock<
22 Vec<(
23 usize,
24 Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25 )>,
26 >,
27>;
28
29#[derive(Clone)]
50#[non_exhaustive]
51pub struct Agent<M>
52where
53 M: CompletionModel,
54{
55 pub name: Option<String>,
57 pub description: Option<String>,
59 pub model: Arc<M>,
61 pub preamble: Option<String>,
63 pub static_context: Vec<Document>,
65 pub temperature: Option<f64>,
67 pub max_tokens: Option<u64>,
69 pub additional_params: Option<serde_json::Value>,
71 pub tool_server_handle: ToolServerHandle,
72 pub dynamic_context: DynamicContextStore,
74 pub tool_choice: Option<ToolChoice>,
76}
77
78impl<M> Agent<M>
79where
80 M: CompletionModel,
81{
82 pub(crate) fn name(&self) -> &str {
84 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
85 }
86}
87
88impl<M> Completion<M> for Agent<M>
89where
90 M: CompletionModel,
91{
92 async fn completion(
93 &self,
94 prompt: impl Into<Message> + WasmCompatSend,
95 chat_history: Vec<Message>,
96 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
97 let prompt = prompt.into();
98
99 let rag_text = prompt.rag_text();
101 let rag_text = rag_text.or_else(|| {
102 chat_history
103 .iter()
104 .rev()
105 .find_map(|message| message.rag_text())
106 });
107
108 let completion_request = self
109 .model
110 .completion_request(prompt)
111 .messages(chat_history)
112 .temperature_opt(self.temperature)
113 .max_tokens_opt(self.max_tokens)
114 .additional_params_opt(self.additional_params.clone())
115 .documents(self.static_context.clone());
116 let completion_request = if let Some(preamble) = &self.preamble {
117 completion_request.preamble(preamble.to_owned())
118 } else {
119 completion_request
120 };
121 let completion_request = if let Some(tool_choice) = &self.tool_choice {
122 completion_request.tool_choice(tool_choice.clone())
123 } else {
124 completion_request
125 };
126
127 let agent = match &rag_text {
129 Some(text) => {
130 let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
131 .then(|(num_sample, index)| async {
132 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
133 Ok::<_, VectorStoreError>(
134 index
135 .top_n(req)
136 .await?
137 .into_iter()
138 .map(|(_, id, doc)| {
139 let text = serde_json::to_string_pretty(&doc)
141 .unwrap_or_else(|_| doc.to_string());
142
143 Document {
144 id,
145 text,
146 additional_props: HashMap::new(),
147 }
148 })
149 .collect::<Vec<_>>(),
150 )
151 })
152 .try_fold(vec![], |mut acc, docs| async {
153 acc.extend(docs);
154 Ok(acc)
155 })
156 .await
157 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
158
159 let tooldefs = self
160 .tool_server_handle
161 .get_tool_defs(Some(text.to_string()))
162 .await
163 .map_err(|_| {
164 CompletionError::RequestError("Failed to get tool definitions".into())
165 })?;
166
167 completion_request
168 .documents(dynamic_context)
169 .tools(tooldefs)
170 }
171 None => {
172 let tooldefs = self
173 .tool_server_handle
174 .get_tool_defs(None)
175 .await
176 .map_err(|_| {
177 CompletionError::RequestError("Failed to get tool definitions".into())
178 })?;
179
180 completion_request.tools(tooldefs)
181 }
182 };
183
184 Ok(agent)
185 }
186}
187
188#[allow(refining_impl_trait)]
196impl<M> Prompt for Agent<M>
197where
198 M: CompletionModel,
199{
200 fn prompt(
201 &self,
202 prompt: impl Into<Message> + WasmCompatSend,
203 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
204 PromptRequest::new(self, prompt)
205 }
206}
207
208#[allow(refining_impl_trait)]
209impl<M> Prompt for &Agent<M>
210where
211 M: CompletionModel,
212{
213 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
214 fn prompt(
215 &self,
216 prompt: impl Into<Message> + WasmCompatSend,
217 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
218 PromptRequest::new(*self, prompt)
219 }
220}
221
222#[allow(refining_impl_trait)]
223impl<M> Chat for Agent<M>
224where
225 M: CompletionModel,
226{
227 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
228 async fn chat(
229 &self,
230 prompt: impl Into<Message> + WasmCompatSend,
231 mut chat_history: Vec<Message>,
232 ) -> Result<String, PromptError> {
233 PromptRequest::new(self, prompt)
234 .with_history(&mut chat_history)
235 .await
236 }
237}
238
239impl<M> StreamingCompletion<M> for Agent<M>
240where
241 M: CompletionModel,
242{
243 async fn stream_completion(
244 &self,
245 prompt: impl Into<Message> + WasmCompatSend,
246 chat_history: Vec<Message>,
247 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
248 self.completion(prompt, chat_history).await
251 }
252}
253
254impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
255where
256 M: CompletionModel + 'static,
257 M::StreamingResponse: GetTokenUsage,
258{
259 fn stream_prompt(
260 &self,
261 prompt: impl Into<Message> + WasmCompatSend,
262 ) -> StreamingPromptRequest<M, ()> {
263 let arc = Arc::new(self.clone());
264 StreamingPromptRequest::new(arc, prompt)
265 }
266}
267
268impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
269where
270 M: CompletionModel + 'static,
271 M::StreamingResponse: GetTokenUsage,
272{
273 fn stream_chat(
274 &self,
275 prompt: impl Into<Message> + WasmCompatSend,
276 chat_history: Vec<Message>,
277 ) -> StreamingPromptRequest<M, ()> {
278 let arc = Arc::new(self.clone());
279 StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
280 }
281}