1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use std::{collections::HashMap, sync::Arc};
15
16const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
17
18pub type DynamicContextStore = Arc<
19 Vec<(
20 usize,
21 Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
22 )>,
23>;
24
25#[allow(clippy::too_many_arguments)]
28pub(crate) async fn build_completion_request<M: CompletionModel>(
29 model: &Arc<M>,
30 prompt: Message,
31 chat_history: &[Message],
32 preamble: Option<&str>,
33 static_context: &[Document],
34 temperature: Option<f64>,
35 max_tokens: Option<u64>,
36 additional_params: Option<&serde_json::Value>,
37 tool_choice: Option<&ToolChoice>,
38 tool_server_handle: &ToolServerHandle,
39 dynamic_context: &DynamicContextStore,
40 output_schema: Option<&schemars::Schema>,
41) -> Result<CompletionRequestBuilder<M>, CompletionError> {
42 let rag_text = prompt.rag_text();
44 let rag_text = rag_text.or_else(|| {
45 chat_history
46 .iter()
47 .rev()
48 .find_map(|message| message.rag_text())
49 });
50
51 let chat_history: Vec<Message> = if let Some(preamble) = preamble {
53 std::iter::once(Message::system(preamble.to_owned()))
54 .chain(chat_history.iter().cloned())
55 .collect()
56 } else {
57 chat_history.to_vec()
58 };
59
60 let completion_request = model
61 .completion_request(prompt)
62 .messages(chat_history)
63 .temperature_opt(temperature)
64 .max_tokens_opt(max_tokens)
65 .additional_params_opt(additional_params.cloned())
66 .output_schema_opt(output_schema.cloned())
67 .documents(static_context.to_vec());
68
69 let completion_request = if let Some(tool_choice) = tool_choice {
70 completion_request.tool_choice(tool_choice.clone())
71 } else {
72 completion_request
73 };
74
75 let result = match &rag_text {
77 Some(text) => {
78 let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
80 let text = text.clone();
82 let num_sample = *num_sample;
83 let index = index.clone();
84
85 async move {
86 let req = VectorSearchRequest::builder()
87 .query(text)
88 .samples(num_sample as u64)
89 .build();
90
91 let docs = index
92 .top_n(req)
93 .await?
94 .into_iter()
95 .map(|(_, id, doc)| {
96 let text = serde_json::to_string_pretty(&doc)
98 .unwrap_or_else(|_| doc.to_string());
99
100 Document {
101 id,
102 text,
103 additional_props: HashMap::new(),
104 }
105 })
106 .collect::<Vec<_>>();
107
108 Ok::<_, VectorStoreError>(docs)
109 }
110 });
111
112 let fetched_context: Vec<Document> = futures::future::try_join_all(search_futures)
114 .await
115 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
116 .into_iter()
117 .flatten() .collect();
119
120 let tooldefs = tool_server_handle
121 .get_tool_defs(Some(text.to_string()))
122 .await
123 .map_err(|_| {
124 CompletionError::RequestError("Failed to get tool definitions".into())
125 })?;
126
127 completion_request
128 .documents(fetched_context)
129 .tools(tooldefs)
130 }
131 None => {
132 let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
133 CompletionError::RequestError("Failed to get tool definitions".into())
134 })?;
135
136 completion_request.tools(tooldefs)
137 }
138 };
139
140 Ok(result)
141}
142
143#[derive(Clone)]
167#[non_exhaustive]
168pub struct Agent<M, P = ()>
169where
170 M: CompletionModel,
171 P: PromptHook<M>,
172{
173 pub name: Option<String>,
175 pub description: Option<String>,
177 pub model: Arc<M>,
179 pub preamble: Option<String>,
181 pub static_context: Vec<Document>,
183 pub temperature: Option<f64>,
185 pub max_tokens: Option<u64>,
187 pub additional_params: Option<serde_json::Value>,
189 pub tool_server_handle: ToolServerHandle,
190 pub dynamic_context: DynamicContextStore,
192 pub tool_choice: Option<ToolChoice>,
194 pub default_max_turns: Option<usize>,
196 pub hook: Option<P>,
198 pub output_schema: Option<schemars::Schema>,
201}
202
203impl<M, P> Agent<M, P>
204where
205 M: CompletionModel,
206 P: PromptHook<M>,
207{
208 pub(crate) fn name(&self) -> &str {
210 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
211 }
212}
213
214impl<M, P> Completion<M> for Agent<M, P>
215where
216 M: CompletionModel,
217 P: PromptHook<M>,
218{
219 async fn completion<I, T>(
220 &self,
221 prompt: impl Into<Message> + WasmCompatSend,
222 chat_history: I,
223 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
224 where
225 I: IntoIterator<Item = T>,
226 T: Into<Message>,
227 {
228 let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
229 build_completion_request(
230 &self.model,
231 prompt.into(),
232 &history,
233 self.preamble.as_deref(),
234 &self.static_context,
235 self.temperature,
236 self.max_tokens,
237 self.additional_params.as_ref(),
238 self.tool_choice.as_ref(),
239 &self.tool_server_handle,
240 &self.dynamic_context,
241 self.output_schema.as_ref(),
242 )
243 .await
244 }
245}
246
247#[allow(refining_impl_trait)]
255impl<M, P> Prompt for Agent<M, P>
256where
257 M: CompletionModel + 'static,
258 P: PromptHook<M> + 'static,
259{
260 fn prompt(
261 &self,
262 prompt: impl Into<Message> + WasmCompatSend,
263 ) -> PromptRequest<prompt_request::Standard, M, P> {
264 PromptRequest::from_agent(self, prompt)
265 }
266}
267
268#[allow(refining_impl_trait)]
269impl<M, P> Prompt for &Agent<M, P>
270where
271 M: CompletionModel + 'static,
272 P: PromptHook<M> + 'static,
273{
274 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
275 fn prompt(
276 &self,
277 prompt: impl Into<Message> + WasmCompatSend,
278 ) -> PromptRequest<prompt_request::Standard, M, P> {
279 PromptRequest::from_agent(*self, prompt)
280 }
281}
282
283#[allow(refining_impl_trait)]
284impl<M, P> Chat for Agent<M, P>
285where
286 M: CompletionModel + 'static,
287 P: PromptHook<M> + 'static,
288{
289 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
290 async fn chat<I, T>(
291 &self,
292 prompt: impl Into<Message> + WasmCompatSend,
293 chat_history: I,
294 ) -> Result<String, PromptError>
295 where
296 I: IntoIterator<Item = T>,
297 T: Into<Message>,
298 {
299 PromptRequest::from_agent(self, prompt)
300 .with_history(chat_history)
301 .await
302 }
303}
304
305impl<M, P> StreamingCompletion<M> for Agent<M, P>
306where
307 M: CompletionModel,
308 P: PromptHook<M>,
309{
310 async fn stream_completion<I, T>(
311 &self,
312 prompt: impl Into<Message> + WasmCompatSend,
313 chat_history: I,
314 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
315 where
316 I: IntoIterator<Item = T> + WasmCompatSend,
317 T: Into<Message>,
318 {
319 self.completion(prompt, chat_history).await
322 }
323}
324
325impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
326where
327 M: CompletionModel + 'static,
328 M::StreamingResponse: GetTokenUsage,
329 P: PromptHook<M> + 'static,
330{
331 type Hook = P;
332
333 fn stream_prompt(
334 &self,
335 prompt: impl Into<Message> + WasmCompatSend,
336 ) -> StreamingPromptRequest<M, P> {
337 StreamingPromptRequest::<M, P>::from_agent(self, prompt)
338 }
339}
340
341impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
342where
343 M: CompletionModel + 'static,
344 M::StreamingResponse: GetTokenUsage,
345 P: PromptHook<M> + 'static,
346{
347 type Hook = P;
348
349 fn stream_chat<I, T>(
350 &self,
351 prompt: impl Into<Message> + WasmCompatSend,
352 chat_history: I,
353 ) -> StreamingPromptRequest<M, P>
354 where
355 I: IntoIterator<Item = T>,
356 T: Into<Message>,
357 {
358 StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
359 }
360}
361
362use crate::agent::prompt_request::TypedPromptRequest;
363use schemars::JsonSchema;
364use serde::de::DeserializeOwned;
365
366#[allow(refining_impl_trait)]
367impl<M, P> TypedPrompt for Agent<M, P>
368where
369 M: CompletionModel + 'static,
370 P: PromptHook<M> + 'static,
371{
372 type TypedRequest<T>
373 = TypedPromptRequest<T, prompt_request::Standard, M, P>
374 where
375 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
376
377 fn prompt_typed<T>(
410 &self,
411 prompt: impl Into<Message> + WasmCompatSend,
412 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
413 where
414 T: JsonSchema + DeserializeOwned + WasmCompatSend,
415 {
416 TypedPromptRequest::from_agent(self, prompt)
417 }
418}
419
420#[allow(refining_impl_trait)]
421impl<M, P> TypedPrompt for &Agent<M, P>
422where
423 M: CompletionModel + 'static,
424 P: PromptHook<M> + 'static,
425{
426 type TypedRequest<T>
427 = TypedPromptRequest<T, prompt_request::Standard, M, P>
428 where
429 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
430
431 fn prompt_typed<T>(
432 &self,
433 prompt: impl Into<Message> + WasmCompatSend,
434 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
435 where
436 T: JsonSchema + DeserializeOwned + WasmCompatSend,
437 {
438 TypedPromptRequest::from_agent(*self, prompt)
439 }
440}