1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use std::{collections::HashMap, sync::Arc};
15
16const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
17
18pub type DynamicContextStore = Arc<
19 Vec<(
20 usize,
21 Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
22 )>,
23>;
24
25#[allow(clippy::too_many_arguments)]
28pub(crate) async fn build_completion_request<M: CompletionModel>(
29 model: &Arc<M>,
30 prompt: Message,
31 chat_history: &[Message],
32 preamble: Option<&str>,
33 static_context: &[Document],
34 temperature: Option<f64>,
35 max_tokens: Option<u64>,
36 additional_params: Option<&serde_json::Value>,
37 tool_choice: Option<&ToolChoice>,
38 tool_server_handle: &ToolServerHandle,
39 dynamic_context: &DynamicContextStore,
40 output_schema: Option<&schemars::Schema>,
41) -> Result<CompletionRequestBuilder<M>, CompletionError> {
42 let rag_text = prompt.rag_text();
44 let rag_text = rag_text.or_else(|| {
45 chat_history
46 .iter()
47 .rev()
48 .find_map(|message| message.rag_text())
49 });
50
51 let chat_history: Vec<Message> = if let Some(preamble) = preamble {
53 std::iter::once(Message::system(preamble.to_owned()))
54 .chain(chat_history.iter().cloned())
55 .collect()
56 } else {
57 chat_history.to_vec()
58 };
59
60 let completion_request = model
61 .completion_request(prompt)
62 .messages(chat_history)
63 .temperature_opt(temperature)
64 .max_tokens_opt(max_tokens)
65 .additional_params_opt(additional_params.cloned())
66 .output_schema_opt(output_schema.cloned())
67 .documents(static_context.to_vec());
68
69 let completion_request = if let Some(tool_choice) = tool_choice {
70 completion_request.tool_choice(tool_choice.clone())
71 } else {
72 completion_request
73 };
74
75 let result = match &rag_text {
77 Some(text) => {
78 let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
80 let text = text.clone();
82 let num_sample = *num_sample;
83 let index = index.clone();
84
85 async move {
86 let req = VectorSearchRequest::builder()
87 .query(text)
88 .samples(num_sample as u64)
89 .build();
90
91 let docs = index
92 .top_n(req)
93 .await?
94 .into_iter()
95 .map(|(_, id, doc)| {
96 let text = serde_json::to_string_pretty(&doc)
98 .unwrap_or_else(|_| doc.to_string());
99
100 Document {
101 id,
102 text,
103 additional_props: HashMap::new(),
104 }
105 })
106 .collect::<Vec<_>>();
107
108 Ok::<_, VectorStoreError>(docs)
109 }
110 });
111
112 let fetched_context: Vec<Document> = futures::future::try_join_all(search_futures)
114 .await
115 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
116 .into_iter()
117 .flatten() .collect();
119
120 let tooldefs = tool_server_handle
121 .get_tool_defs(Some(text.to_string()))
122 .await
123 .map_err(|_| {
124 CompletionError::RequestError("Failed to get tool definitions".into())
125 })?;
126
127 completion_request
128 .documents(fetched_context)
129 .tools(tooldefs)
130 }
131 None => {
132 let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
133 CompletionError::RequestError("Failed to get tool definitions".into())
134 })?;
135
136 completion_request.tools(tooldefs)
137 }
138 };
139
140 Ok(result)
141}
142
143#[derive(Clone)]
172#[non_exhaustive]
173pub struct Agent<M, P = ()>
174where
175 M: CompletionModel,
176 P: PromptHook<M>,
177{
178 pub name: Option<String>,
180 pub description: Option<String>,
182 pub model: Arc<M>,
184 pub preamble: Option<String>,
186 pub static_context: Vec<Document>,
188 pub temperature: Option<f64>,
190 pub max_tokens: Option<u64>,
192 pub additional_params: Option<serde_json::Value>,
194 pub tool_server_handle: ToolServerHandle,
195 pub dynamic_context: DynamicContextStore,
197 pub tool_choice: Option<ToolChoice>,
199 pub default_max_turns: Option<usize>,
201 pub hook: Option<P>,
203 pub output_schema: Option<schemars::Schema>,
206 pub memory: Option<Arc<dyn crate::memory::ConversationMemory>>,
208 pub default_conversation_id: Option<String>,
210}
211
212impl<M, P> Agent<M, P>
213where
214 M: CompletionModel,
215 P: PromptHook<M>,
216{
217 pub(crate) fn name(&self) -> &str {
219 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
220 }
221}
222
223impl<M, P> Completion<M> for Agent<M, P>
224where
225 M: CompletionModel,
226 P: PromptHook<M>,
227{
228 async fn completion<I, T>(
229 &self,
230 prompt: impl Into<Message> + WasmCompatSend,
231 chat_history: I,
232 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
233 where
234 I: IntoIterator<Item = T>,
235 T: Into<Message>,
236 {
237 let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
238 build_completion_request(
239 &self.model,
240 prompt.into(),
241 &history,
242 self.preamble.as_deref(),
243 &self.static_context,
244 self.temperature,
245 self.max_tokens,
246 self.additional_params.as_ref(),
247 self.tool_choice.as_ref(),
248 &self.tool_server_handle,
249 &self.dynamic_context,
250 self.output_schema.as_ref(),
251 )
252 .await
253 }
254}
255
256#[allow(refining_impl_trait)]
264impl<M, P> Prompt for Agent<M, P>
265where
266 M: CompletionModel + 'static,
267 P: PromptHook<M> + 'static,
268{
269 fn prompt(
270 &self,
271 prompt: impl Into<Message> + WasmCompatSend,
272 ) -> PromptRequest<prompt_request::Standard, M, P> {
273 PromptRequest::from_agent(self, prompt)
274 }
275}
276
277#[allow(refining_impl_trait)]
278impl<M, P> Prompt for &Agent<M, P>
279where
280 M: CompletionModel + 'static,
281 P: PromptHook<M> + 'static,
282{
283 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
284 fn prompt(
285 &self,
286 prompt: impl Into<Message> + WasmCompatSend,
287 ) -> PromptRequest<prompt_request::Standard, M, P> {
288 PromptRequest::from_agent(*self, prompt)
289 }
290}
291
292#[allow(refining_impl_trait)]
293impl<M, P> Chat for Agent<M, P>
294where
295 M: CompletionModel + 'static,
296 P: PromptHook<M> + 'static,
297{
298 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
299 async fn chat(
300 &self,
301 prompt: impl Into<Message> + WasmCompatSend,
302 chat_history: &mut Vec<Message>,
303 ) -> Result<String, PromptError> {
304 let response = PromptRequest::from_agent(self, prompt)
305 .with_history(chat_history.clone())
306 .extended_details()
307 .await?;
308
309 if let Some(messages) = response.messages {
310 chat_history.extend(messages);
311 }
312
313 Ok(response.output)
314 }
315}
316
317impl<M, P> StreamingCompletion<M> for Agent<M, P>
318where
319 M: CompletionModel,
320 P: PromptHook<M>,
321{
322 async fn stream_completion<I, T>(
323 &self,
324 prompt: impl Into<Message> + WasmCompatSend,
325 chat_history: I,
326 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
327 where
328 I: IntoIterator<Item = T> + WasmCompatSend,
329 T: Into<Message>,
330 {
331 self.completion(prompt, chat_history).await
334 }
335}
336
337impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
338where
339 M: CompletionModel + 'static,
340 M::StreamingResponse: GetTokenUsage,
341 P: PromptHook<M> + 'static,
342{
343 type Hook = P;
344
345 fn stream_prompt(
346 &self,
347 prompt: impl Into<Message> + WasmCompatSend,
348 ) -> StreamingPromptRequest<M, P> {
349 StreamingPromptRequest::<M, P>::from_agent(self, prompt)
350 }
351}
352
353impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
354where
355 M: CompletionModel + 'static,
356 M::StreamingResponse: GetTokenUsage,
357 P: PromptHook<M> + 'static,
358{
359 type Hook = P;
360
361 fn stream_chat<I, T>(
362 &self,
363 prompt: impl Into<Message> + WasmCompatSend,
364 chat_history: I,
365 ) -> StreamingPromptRequest<M, P>
366 where
367 I: IntoIterator<Item = T>,
368 T: Into<Message>,
369 {
370 StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
371 }
372}
373
374use crate::agent::prompt_request::TypedPromptRequest;
375use schemars::JsonSchema;
376use serde::de::DeserializeOwned;
377
378#[allow(refining_impl_trait)]
379impl<M, P> TypedPrompt for Agent<M, P>
380where
381 M: CompletionModel + 'static,
382 P: PromptHook<M> + 'static,
383{
384 type TypedRequest<T>
385 = TypedPromptRequest<T, prompt_request::Standard, M, P>
386 where
387 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
388
389 fn prompt_typed<T>(
422 &self,
423 prompt: impl Into<Message> + WasmCompatSend,
424 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
425 where
426 T: JsonSchema + DeserializeOwned + WasmCompatSend,
427 {
428 TypedPromptRequest::from_agent(self, prompt)
429 }
430}
431
432#[allow(refining_impl_trait)]
433impl<M, P> TypedPrompt for &Agent<M, P>
434where
435 M: CompletionModel + 'static,
436 P: PromptHook<M> + 'static,
437{
438 type TypedRequest<T>
439 = TypedPromptRequest<T, prompt_request::Standard, M, P>
440 where
441 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
442
443 fn prompt_typed<T>(
444 &self,
445 prompt: impl Into<Message> + WasmCompatSend,
446 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
447 where
448 T: JsonSchema + DeserializeOwned + WasmCompatSend,
449 {
450 TypedPromptRequest::from_agent(*self, prompt)
451 }
452}