1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock as TokioRwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21 TokioRwLock<
22 Vec<(
23 usize,
24 Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25 )>,
26 >,
27>;
28
29#[allow(clippy::too_many_arguments)]
32pub(crate) async fn build_completion_request<M: CompletionModel>(
33 model: &Arc<M>,
34 prompt: Message,
35 chat_history: Vec<Message>,
36 preamble: Option<&str>,
37 static_context: &[Document],
38 temperature: Option<f64>,
39 max_tokens: Option<u64>,
40 additional_params: Option<&serde_json::Value>,
41 tool_choice: Option<&ToolChoice>,
42 tool_server_handle: &ToolServerHandle,
43 dynamic_context: &DynamicContextStore,
44 output_schema: Option<&schemars::Schema>,
45) -> Result<CompletionRequestBuilder<M>, CompletionError> {
46 let rag_text = prompt.rag_text();
48 let rag_text = rag_text.or_else(|| {
49 chat_history
50 .iter()
51 .rev()
52 .find_map(|message| message.rag_text())
53 });
54
55 let completion_request = model
56 .completion_request(prompt)
57 .messages(chat_history)
58 .temperature_opt(temperature)
59 .max_tokens_opt(max_tokens)
60 .additional_params_opt(additional_params.cloned())
61 .output_schema_opt(output_schema.cloned())
62 .documents(static_context.to_vec());
63
64 let completion_request = if let Some(preamble) = preamble {
65 completion_request.preamble(preamble.to_owned())
66 } else {
67 completion_request
68 };
69
70 let completion_request = if let Some(tool_choice) = tool_choice {
71 completion_request.tool_choice(tool_choice.clone())
72 } else {
73 completion_request
74 };
75
76 let result = match &rag_text {
78 Some(text) => {
79 let fetched_context = stream::iter(dynamic_context.read().await.iter())
80 .then(|(num_sample, index)| async {
81 let req = VectorSearchRequest::builder()
82 .query(text)
83 .samples(*num_sample as u64)
84 .build()
85 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
86 Ok::<_, VectorStoreError>(
87 index
88 .top_n(req)
89 .await?
90 .into_iter()
91 .map(|(_, id, doc)| {
92 let text = serde_json::to_string_pretty(&doc)
94 .unwrap_or_else(|_| doc.to_string());
95
96 Document {
97 id,
98 text,
99 additional_props: HashMap::new(),
100 }
101 })
102 .collect::<Vec<_>>(),
103 )
104 })
105 .try_fold(vec![], |mut acc, docs| async {
106 acc.extend(docs);
107 Ok(acc)
108 })
109 .await
110 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
111
112 let tooldefs = tool_server_handle
113 .get_tool_defs(Some(text.to_string()))
114 .await
115 .map_err(|_| {
116 CompletionError::RequestError("Failed to get tool definitions".into())
117 })?;
118
119 completion_request
120 .documents(fetched_context)
121 .tools(tooldefs)
122 }
123 None => {
124 let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
125 CompletionError::RequestError("Failed to get tool definitions".into())
126 })?;
127
128 completion_request.tools(tooldefs)
129 }
130 };
131
132 Ok(result)
133}
134
135#[derive(Clone)]
159#[non_exhaustive]
160pub struct Agent<M, P = ()>
161where
162 M: CompletionModel,
163 P: PromptHook<M>,
164{
165 pub name: Option<String>,
167 pub description: Option<String>,
169 pub model: Arc<M>,
171 pub preamble: Option<String>,
173 pub static_context: Vec<Document>,
175 pub temperature: Option<f64>,
177 pub max_tokens: Option<u64>,
179 pub additional_params: Option<serde_json::Value>,
181 pub tool_server_handle: ToolServerHandle,
182 pub dynamic_context: DynamicContextStore,
184 pub tool_choice: Option<ToolChoice>,
186 pub default_max_turns: Option<usize>,
188 pub hook: Option<P>,
190 pub output_schema: Option<schemars::Schema>,
193}
194
195impl<M, P> Agent<M, P>
196where
197 M: CompletionModel,
198 P: PromptHook<M>,
199{
200 pub(crate) fn name(&self) -> &str {
202 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
203 }
204}
205
206impl<M, P> Completion<M> for Agent<M, P>
207where
208 M: CompletionModel,
209 P: PromptHook<M>,
210{
211 async fn completion(
212 &self,
213 prompt: impl Into<Message> + WasmCompatSend,
214 chat_history: Vec<Message>,
215 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
216 build_completion_request(
217 &self.model,
218 prompt.into(),
219 chat_history,
220 self.preamble.as_deref(),
221 &self.static_context,
222 self.temperature,
223 self.max_tokens,
224 self.additional_params.as_ref(),
225 self.tool_choice.as_ref(),
226 &self.tool_server_handle,
227 &self.dynamic_context,
228 self.output_schema.as_ref(),
229 )
230 .await
231 }
232}
233
234#[allow(refining_impl_trait)]
242impl<M, P> Prompt for Agent<M, P>
243where
244 M: CompletionModel,
245 P: PromptHook<M> + 'static,
246{
247 fn prompt(
248 &self,
249 prompt: impl Into<Message> + WasmCompatSend,
250 ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
251 PromptRequest::from_agent(self, prompt)
252 }
253}
254
255#[allow(refining_impl_trait)]
256impl<M, P> Prompt for &Agent<M, P>
257where
258 M: CompletionModel,
259 P: PromptHook<M> + 'static,
260{
261 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
262 fn prompt(
263 &self,
264 prompt: impl Into<Message> + WasmCompatSend,
265 ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
266 PromptRequest::from_agent(*self, prompt)
267 }
268}
269
270#[allow(refining_impl_trait)]
271impl<M, P> Chat for Agent<M, P>
272where
273 M: CompletionModel,
274 P: PromptHook<M> + 'static,
275{
276 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
277 async fn chat(
278 &self,
279 prompt: impl Into<Message> + WasmCompatSend,
280 mut chat_history: Vec<Message>,
281 ) -> Result<String, PromptError> {
282 PromptRequest::from_agent(self, prompt)
283 .with_history(&mut chat_history)
284 .await
285 }
286}
287
288impl<M, P> StreamingCompletion<M> for Agent<M, P>
289where
290 M: CompletionModel,
291 P: PromptHook<M>,
292{
293 async fn stream_completion(
294 &self,
295 prompt: impl Into<Message> + WasmCompatSend,
296 chat_history: Vec<Message>,
297 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
298 self.completion(prompt, chat_history).await
301 }
302}
303
304impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
305where
306 M: CompletionModel + 'static,
307 M::StreamingResponse: GetTokenUsage,
308 P: PromptHook<M> + 'static,
309{
310 type Hook = P;
311
312 fn stream_prompt(
313 &self,
314 prompt: impl Into<Message> + WasmCompatSend,
315 ) -> StreamingPromptRequest<M, P> {
316 StreamingPromptRequest::<M, P>::from_agent(self, prompt)
317 }
318}
319
320impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
321where
322 M: CompletionModel + 'static,
323 M::StreamingResponse: GetTokenUsage,
324 P: PromptHook<M> + 'static,
325{
326 type Hook = P;
327
328 fn stream_chat(
329 &self,
330 prompt: impl Into<Message> + WasmCompatSend,
331 chat_history: Vec<Message>,
332 ) -> StreamingPromptRequest<M, P> {
333 StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
334 }
335}
336
337use crate::{agent::prompt_request::TypedPromptRequest, completion::TypedPrompt};
338use schemars::JsonSchema;
339use serde::de::DeserializeOwned;
340
341#[allow(refining_impl_trait)]
342impl<M, P> TypedPrompt for Agent<M, P>
343where
344 M: CompletionModel,
345 P: PromptHook<M> + 'static,
346{
347 type TypedRequest<'a, T>
348 = TypedPromptRequest<'a, T, M, P>
349 where
350 Self: 'a,
351 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
352
353 fn prompt_typed<T>(
386 &self,
387 prompt: impl Into<Message> + WasmCompatSend,
388 ) -> TypedPromptRequest<'_, T, M, P>
389 where
390 T: JsonSchema + DeserializeOwned + WasmCompatSend,
391 {
392 TypedPromptRequest::from_agent(self, prompt)
393 }
394}
395
396#[allow(refining_impl_trait)]
397impl<M, P> TypedPrompt for &Agent<M, P>
398where
399 M: CompletionModel,
400 P: PromptHook<M> + 'static,
401{
402 type TypedRequest<'a, T>
403 = TypedPromptRequest<'a, T, M, P>
404 where
405 Self: 'a,
406 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
407
408 fn prompt_typed<T>(
409 &self,
410 prompt: impl Into<Message> + WasmCompatSend,
411 ) -> TypedPromptRequest<'_, T, M, P>
412 where
413 T: JsonSchema + DeserializeOwned + WasmCompatSend,
414 {
415 TypedPromptRequest::from_agent(*self, prompt)
416 }
417}