1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock as TokioRwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21 TokioRwLock<
22 Vec<(
23 usize,
24 Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25 )>,
26 >,
27>;
28
29#[allow(clippy::too_many_arguments)]
32pub(crate) async fn build_completion_request<M: CompletionModel>(
33 model: &Arc<M>,
34 prompt: Message,
35 chat_history: Vec<Message>,
36 preamble: Option<&str>,
37 static_context: &[Document],
38 temperature: Option<f64>,
39 max_tokens: Option<u64>,
40 additional_params: Option<&serde_json::Value>,
41 tool_choice: Option<&ToolChoice>,
42 tool_server_handle: &ToolServerHandle,
43 dynamic_context: &DynamicContextStore,
44 output_schema: Option<&schemars::Schema>,
45) -> Result<CompletionRequestBuilder<M>, CompletionError> {
46 let rag_text = prompt.rag_text();
48 let rag_text = rag_text.or_else(|| {
49 chat_history
50 .iter()
51 .rev()
52 .find_map(|message| message.rag_text())
53 });
54
55 let chat_history = if let Some(preamble) = preamble {
56 let mut with_system = Vec::with_capacity(chat_history.len() + 1);
57 with_system.push(Message::system(preamble.to_owned()));
58 with_system.extend(chat_history);
59 with_system
60 } else {
61 chat_history
62 };
63
64 let completion_request = model
65 .completion_request(prompt)
66 .messages(chat_history)
67 .temperature_opt(temperature)
68 .max_tokens_opt(max_tokens)
69 .additional_params_opt(additional_params.cloned())
70 .output_schema_opt(output_schema.cloned())
71 .documents(static_context.to_vec());
72
73 let completion_request = if let Some(tool_choice) = tool_choice {
74 completion_request.tool_choice(tool_choice.clone())
75 } else {
76 completion_request
77 };
78
79 let result = match &rag_text {
81 Some(text) => {
82 let fetched_context = stream::iter(dynamic_context.read().await.iter())
83 .then(|(num_sample, index)| async {
84 let req = VectorSearchRequest::builder()
85 .query(text)
86 .samples(*num_sample as u64)
87 .build()
88 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
89 Ok::<_, VectorStoreError>(
90 index
91 .top_n(req)
92 .await?
93 .into_iter()
94 .map(|(_, id, doc)| {
95 let text = serde_json::to_string_pretty(&doc)
97 .unwrap_or_else(|_| doc.to_string());
98
99 Document {
100 id,
101 text,
102 additional_props: HashMap::new(),
103 }
104 })
105 .collect::<Vec<_>>(),
106 )
107 })
108 .try_fold(vec![], |mut acc, docs| async {
109 acc.extend(docs);
110 Ok(acc)
111 })
112 .await
113 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115 let tooldefs = tool_server_handle
116 .get_tool_defs(Some(text.to_string()))
117 .await
118 .map_err(|_| {
119 CompletionError::RequestError("Failed to get tool definitions".into())
120 })?;
121
122 completion_request
123 .documents(fetched_context)
124 .tools(tooldefs)
125 }
126 None => {
127 let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
128 CompletionError::RequestError("Failed to get tool definitions".into())
129 })?;
130
131 completion_request.tools(tooldefs)
132 }
133 };
134
135 Ok(result)
136}
137
138#[derive(Clone)]
162#[non_exhaustive]
163pub struct Agent<M, P = ()>
164where
165 M: CompletionModel,
166 P: PromptHook<M>,
167{
168 pub name: Option<String>,
170 pub description: Option<String>,
172 pub model: Arc<M>,
174 pub preamble: Option<String>,
176 pub static_context: Vec<Document>,
178 pub temperature: Option<f64>,
180 pub max_tokens: Option<u64>,
182 pub additional_params: Option<serde_json::Value>,
184 pub tool_server_handle: ToolServerHandle,
185 pub dynamic_context: DynamicContextStore,
187 pub tool_choice: Option<ToolChoice>,
189 pub default_max_turns: Option<usize>,
191 pub hook: Option<P>,
193 pub output_schema: Option<schemars::Schema>,
196}
197
198impl<M, P> Agent<M, P>
199where
200 M: CompletionModel,
201 P: PromptHook<M>,
202{
203 pub(crate) fn name(&self) -> &str {
205 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
206 }
207}
208
209impl<M, P> Completion<M> for Agent<M, P>
210where
211 M: CompletionModel,
212 P: PromptHook<M>,
213{
214 async fn completion(
215 &self,
216 prompt: impl Into<Message> + WasmCompatSend,
217 chat_history: Vec<Message>,
218 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
219 build_completion_request(
220 &self.model,
221 prompt.into(),
222 chat_history,
223 self.preamble.as_deref(),
224 &self.static_context,
225 self.temperature,
226 self.max_tokens,
227 self.additional_params.as_ref(),
228 self.tool_choice.as_ref(),
229 &self.tool_server_handle,
230 &self.dynamic_context,
231 self.output_schema.as_ref(),
232 )
233 .await
234 }
235}
236
237#[allow(refining_impl_trait)]
245impl<M, P> Prompt for Agent<M, P>
246where
247 M: CompletionModel,
248 P: PromptHook<M> + 'static,
249{
250 fn prompt(
251 &self,
252 prompt: impl Into<Message> + WasmCompatSend,
253 ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
254 PromptRequest::from_agent(self, prompt)
255 }
256}
257
258#[allow(refining_impl_trait)]
259impl<M, P> Prompt for &Agent<M, P>
260where
261 M: CompletionModel,
262 P: PromptHook<M> + 'static,
263{
264 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
265 fn prompt(
266 &self,
267 prompt: impl Into<Message> + WasmCompatSend,
268 ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
269 PromptRequest::from_agent(*self, prompt)
270 }
271}
272
273#[allow(refining_impl_trait)]
274impl<M, P> Chat for Agent<M, P>
275where
276 M: CompletionModel,
277 P: PromptHook<M> + 'static,
278{
279 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
280 async fn chat(
281 &self,
282 prompt: impl Into<Message> + WasmCompatSend,
283 mut chat_history: Vec<Message>,
284 ) -> Result<String, PromptError> {
285 PromptRequest::from_agent(self, prompt)
286 .with_history(&mut chat_history)
287 .await
288 }
289}
290
291impl<M, P> StreamingCompletion<M> for Agent<M, P>
292where
293 M: CompletionModel,
294 P: PromptHook<M>,
295{
296 async fn stream_completion(
297 &self,
298 prompt: impl Into<Message> + WasmCompatSend,
299 chat_history: Vec<Message>,
300 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
301 self.completion(prompt, chat_history).await
304 }
305}
306
307impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
308where
309 M: CompletionModel + 'static,
310 M::StreamingResponse: GetTokenUsage,
311 P: PromptHook<M> + 'static,
312{
313 type Hook = P;
314
315 fn stream_prompt(
316 &self,
317 prompt: impl Into<Message> + WasmCompatSend,
318 ) -> StreamingPromptRequest<M, P> {
319 StreamingPromptRequest::<M, P>::from_agent(self, prompt)
320 }
321}
322
323impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
324where
325 M: CompletionModel + 'static,
326 M::StreamingResponse: GetTokenUsage,
327 P: PromptHook<M> + 'static,
328{
329 type Hook = P;
330
331 fn stream_chat(
332 &self,
333 prompt: impl Into<Message> + WasmCompatSend,
334 chat_history: Vec<Message>,
335 ) -> StreamingPromptRequest<M, P> {
336 StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
337 }
338}
339
340use crate::agent::prompt_request::TypedPromptRequest;
341use schemars::JsonSchema;
342use serde::de::DeserializeOwned;
343
344#[allow(refining_impl_trait)]
345impl<M, P> TypedPrompt for Agent<M, P>
346where
347 M: CompletionModel,
348 P: PromptHook<M> + 'static,
349{
350 type TypedRequest<'a, T>
351 = TypedPromptRequest<'a, T, prompt_request::Standard, M, P>
352 where
353 Self: 'a,
354 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
355
356 fn prompt_typed<T>(
389 &self,
390 prompt: impl Into<Message> + WasmCompatSend,
391 ) -> TypedPromptRequest<'_, T, prompt_request::Standard, M, P>
392 where
393 T: JsonSchema + DeserializeOwned + WasmCompatSend,
394 {
395 TypedPromptRequest::from_agent(self, prompt)
396 }
397}
398
399#[allow(refining_impl_trait)]
400impl<M, P> TypedPrompt for &Agent<M, P>
401where
402 M: CompletionModel,
403 P: PromptHook<M> + 'static,
404{
405 type TypedRequest<'a, T>
406 = TypedPromptRequest<'a, T, prompt_request::Standard, M, P>
407 where
408 Self: 'a,
409 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
410
411 fn prompt_typed<T>(
412 &self,
413 prompt: impl Into<Message> + WasmCompatSend,
414 ) -> TypedPromptRequest<'_, T, prompt_request::Standard, M, P>
415 where
416 T: JsonSchema + DeserializeOwned + WasmCompatSend,
417 {
418 TypedPromptRequest::from_agent(*self, prompt)
419 }
420}