1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock as TokioRwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21 TokioRwLock<
22 Vec<(
23 usize,
24 Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25 )>,
26 >,
27>;
28
29#[allow(clippy::too_many_arguments)]
32pub(crate) async fn build_completion_request<M: CompletionModel>(
33 model: &Arc<M>,
34 prompt: Message,
35 chat_history: &[Message],
36 preamble: Option<&str>,
37 static_context: &[Document],
38 temperature: Option<f64>,
39 max_tokens: Option<u64>,
40 additional_params: Option<&serde_json::Value>,
41 tool_choice: Option<&ToolChoice>,
42 tool_server_handle: &ToolServerHandle,
43 dynamic_context: &DynamicContextStore,
44 output_schema: Option<&schemars::Schema>,
45) -> Result<CompletionRequestBuilder<M>, CompletionError> {
46 let rag_text = prompt.rag_text();
48 let rag_text = rag_text.or_else(|| {
49 chat_history
50 .iter()
51 .rev()
52 .find_map(|message| message.rag_text())
53 });
54
55 let chat_history: Vec<Message> = if let Some(preamble) = preamble {
57 std::iter::once(Message::system(preamble.to_owned()))
58 .chain(chat_history.iter().cloned())
59 .collect()
60 } else {
61 chat_history.to_vec()
62 };
63
64 let completion_request = model
65 .completion_request(prompt)
66 .messages(chat_history)
67 .temperature_opt(temperature)
68 .max_tokens_opt(max_tokens)
69 .additional_params_opt(additional_params.cloned())
70 .output_schema_opt(output_schema.cloned())
71 .documents(static_context.to_vec());
72
73 let completion_request = if let Some(tool_choice) = tool_choice {
74 completion_request.tool_choice(tool_choice.clone())
75 } else {
76 completion_request
77 };
78
79 let result = match &rag_text {
81 Some(text) => {
82 let fetched_context = stream::iter(dynamic_context.read().await.iter())
83 .then(|(num_sample, index)| async {
84 let req = VectorSearchRequest::builder()
85 .query(text)
86 .samples(*num_sample as u64)
87 .build()
88 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
89 Ok::<_, VectorStoreError>(
90 index
91 .top_n(req)
92 .await?
93 .into_iter()
94 .map(|(_, id, doc)| {
95 let text = serde_json::to_string_pretty(&doc)
97 .unwrap_or_else(|_| doc.to_string());
98
99 Document {
100 id,
101 text,
102 additional_props: HashMap::new(),
103 }
104 })
105 .collect::<Vec<_>>(),
106 )
107 })
108 .try_fold(vec![], |mut acc, docs| async {
109 acc.extend(docs);
110 Ok(acc)
111 })
112 .await
113 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115 let tooldefs = tool_server_handle
116 .get_tool_defs(Some(text.to_string()))
117 .await
118 .map_err(|_| {
119 CompletionError::RequestError("Failed to get tool definitions".into())
120 })?;
121
122 completion_request
123 .documents(fetched_context)
124 .tools(tooldefs)
125 }
126 None => {
127 let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
128 CompletionError::RequestError("Failed to get tool definitions".into())
129 })?;
130
131 completion_request.tools(tooldefs)
132 }
133 };
134
135 Ok(result)
136}
137
138#[derive(Clone)]
162#[non_exhaustive]
163pub struct Agent<M, P = ()>
164where
165 M: CompletionModel,
166 P: PromptHook<M>,
167{
168 pub name: Option<String>,
170 pub description: Option<String>,
172 pub model: Arc<M>,
174 pub preamble: Option<String>,
176 pub static_context: Vec<Document>,
178 pub temperature: Option<f64>,
180 pub max_tokens: Option<u64>,
182 pub additional_params: Option<serde_json::Value>,
184 pub tool_server_handle: ToolServerHandle,
185 pub dynamic_context: DynamicContextStore,
187 pub tool_choice: Option<ToolChoice>,
189 pub default_max_turns: Option<usize>,
191 pub hook: Option<P>,
193 pub output_schema: Option<schemars::Schema>,
196}
197
198impl<M, P> Agent<M, P>
199where
200 M: CompletionModel,
201 P: PromptHook<M>,
202{
203 pub(crate) fn name(&self) -> &str {
205 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
206 }
207}
208
209impl<M, P> Completion<M> for Agent<M, P>
210where
211 M: CompletionModel,
212 P: PromptHook<M>,
213{
214 async fn completion<I, T>(
215 &self,
216 prompt: impl Into<Message> + WasmCompatSend,
217 chat_history: I,
218 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
219 where
220 I: IntoIterator<Item = T>,
221 T: Into<Message>,
222 {
223 let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
224 build_completion_request(
225 &self.model,
226 prompt.into(),
227 &history,
228 self.preamble.as_deref(),
229 &self.static_context,
230 self.temperature,
231 self.max_tokens,
232 self.additional_params.as_ref(),
233 self.tool_choice.as_ref(),
234 &self.tool_server_handle,
235 &self.dynamic_context,
236 self.output_schema.as_ref(),
237 )
238 .await
239 }
240}
241
242#[allow(refining_impl_trait)]
250impl<M, P> Prompt for Agent<M, P>
251where
252 M: CompletionModel + 'static,
253 P: PromptHook<M> + 'static,
254{
255 fn prompt(
256 &self,
257 prompt: impl Into<Message> + WasmCompatSend,
258 ) -> PromptRequest<prompt_request::Standard, M, P> {
259 PromptRequest::from_agent(self, prompt)
260 }
261}
262
263#[allow(refining_impl_trait)]
264impl<M, P> Prompt for &Agent<M, P>
265where
266 M: CompletionModel + 'static,
267 P: PromptHook<M> + 'static,
268{
269 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
270 fn prompt(
271 &self,
272 prompt: impl Into<Message> + WasmCompatSend,
273 ) -> PromptRequest<prompt_request::Standard, M, P> {
274 PromptRequest::from_agent(*self, prompt)
275 }
276}
277
278#[allow(refining_impl_trait)]
279impl<M, P> Chat for Agent<M, P>
280where
281 M: CompletionModel + 'static,
282 P: PromptHook<M> + 'static,
283{
284 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
285 async fn chat<I, T>(
286 &self,
287 prompt: impl Into<Message> + WasmCompatSend,
288 chat_history: I,
289 ) -> Result<String, PromptError>
290 where
291 I: IntoIterator<Item = T>,
292 T: Into<Message>,
293 {
294 PromptRequest::from_agent(self, prompt)
295 .with_history(chat_history)
296 .await
297 }
298}
299
300impl<M, P> StreamingCompletion<M> for Agent<M, P>
301where
302 M: CompletionModel,
303 P: PromptHook<M>,
304{
305 async fn stream_completion<I, T>(
306 &self,
307 prompt: impl Into<Message> + WasmCompatSend,
308 chat_history: I,
309 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
310 where
311 I: IntoIterator<Item = T> + WasmCompatSend,
312 T: Into<Message>,
313 {
314 self.completion(prompt, chat_history).await
317 }
318}
319
320impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
321where
322 M: CompletionModel + 'static,
323 M::StreamingResponse: GetTokenUsage,
324 P: PromptHook<M> + 'static,
325{
326 type Hook = P;
327
328 fn stream_prompt(
329 &self,
330 prompt: impl Into<Message> + WasmCompatSend,
331 ) -> StreamingPromptRequest<M, P> {
332 StreamingPromptRequest::<M, P>::from_agent(self, prompt)
333 }
334}
335
336impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
337where
338 M: CompletionModel + 'static,
339 M::StreamingResponse: GetTokenUsage,
340 P: PromptHook<M> + 'static,
341{
342 type Hook = P;
343
344 fn stream_chat<I, T>(
345 &self,
346 prompt: impl Into<Message> + WasmCompatSend,
347 chat_history: I,
348 ) -> StreamingPromptRequest<M, P>
349 where
350 I: IntoIterator<Item = T>,
351 T: Into<Message>,
352 {
353 StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
354 }
355}
356
357use crate::agent::prompt_request::TypedPromptRequest;
358use schemars::JsonSchema;
359use serde::de::DeserializeOwned;
360
361#[allow(refining_impl_trait)]
362impl<M, P> TypedPrompt for Agent<M, P>
363where
364 M: CompletionModel + 'static,
365 P: PromptHook<M> + 'static,
366{
367 type TypedRequest<T>
368 = TypedPromptRequest<T, prompt_request::Standard, M, P>
369 where
370 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
371
372 fn prompt_typed<T>(
405 &self,
406 prompt: impl Into<Message> + WasmCompatSend,
407 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
408 where
409 T: JsonSchema + DeserializeOwned + WasmCompatSend,
410 {
411 TypedPromptRequest::from_agent(self, prompt)
412 }
413}
414
415#[allow(refining_impl_trait)]
416impl<M, P> TypedPrompt for &Agent<M, P>
417where
418 M: CompletionModel + 'static,
419 P: PromptHook<M> + 'static,
420{
421 type TypedRequest<T>
422 = TypedPromptRequest<T, prompt_request::Standard, M, P>
423 where
424 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
425
426 fn prompt_typed<T>(
427 &self,
428 prompt: impl Into<Message> + WasmCompatSend,
429 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
430 where
431 T: JsonSchema + DeserializeOwned + WasmCompatSend,
432 {
433 TypedPromptRequest::from_agent(*self, prompt)
434 }
435}