1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use std::{
15 collections::{BTreeSet, HashMap},
16 sync::Arc,
17};
18
19const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
20
21pub type DynamicContextStore = Arc<
22 Vec<(
23 usize,
24 Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25 )>,
26>;
27
28pub(crate) struct PreparedCompletionRequest<M: CompletionModel> {
31 pub(crate) builder: CompletionRequestBuilder<M>,
32 pub(crate) executable_tool_names: BTreeSet<String>,
33 pub(crate) allowed_tool_names: BTreeSet<String>,
34}
35
36pub(crate) fn allowed_tool_names_for_choice(
37 executable_tool_names: &BTreeSet<String>,
38 tool_choice: Option<&ToolChoice>,
39) -> Result<BTreeSet<String>, CompletionError> {
40 let allowed = match tool_choice {
41 None | Some(ToolChoice::Auto | ToolChoice::Required) => executable_tool_names.clone(),
42 Some(ToolChoice::None) => BTreeSet::new(),
43 Some(ToolChoice::Specific { function_names }) => {
44 if function_names.is_empty() {
45 return Err(CompletionError::RequestError(
46 "ToolChoice::Specific requires at least one function name".into(),
47 ));
48 }
49
50 let requested = function_names.iter().cloned().collect::<BTreeSet<String>>();
51 let missing = requested
52 .difference(executable_tool_names)
53 .cloned()
54 .collect::<Vec<_>>();
55
56 if !missing.is_empty() {
57 return Err(CompletionError::RequestError(
58 format!(
59 "ToolChoice::Specific requested unknown tool names: {missing:?}. Available tools: {:?}",
60 executable_tool_names.iter().collect::<Vec<_>>()
61 )
62 .into(),
63 ));
64 }
65
66 requested
67 }
68 };
69
70 Ok(allowed)
71}
72
73#[allow(clippy::too_many_arguments)]
76pub(crate) async fn build_completion_request<M: CompletionModel>(
77 model: &Arc<M>,
78 prompt: Message,
79 chat_history: &[Message],
80 preamble: Option<&str>,
81 static_context: &[Document],
82 temperature: Option<f64>,
83 max_tokens: Option<u64>,
84 additional_params: Option<&serde_json::Value>,
85 tool_choice: Option<&ToolChoice>,
86 tool_server_handle: &ToolServerHandle,
87 dynamic_context: &DynamicContextStore,
88 output_schema: Option<&schemars::Schema>,
89) -> Result<CompletionRequestBuilder<M>, CompletionError> {
90 Ok(build_prepared_completion_request(
91 model,
92 prompt,
93 chat_history,
94 preamble,
95 static_context,
96 temperature,
97 max_tokens,
98 additional_params,
99 tool_choice,
100 tool_server_handle,
101 dynamic_context,
102 output_schema,
103 )
104 .await?
105 .builder)
106}
107
108#[allow(clippy::too_many_arguments)]
111pub(crate) async fn build_prepared_completion_request<M: CompletionModel>(
112 model: &Arc<M>,
113 prompt: Message,
114 chat_history: &[Message],
115 preamble: Option<&str>,
116 static_context: &[Document],
117 temperature: Option<f64>,
118 max_tokens: Option<u64>,
119 additional_params: Option<&serde_json::Value>,
120 tool_choice: Option<&ToolChoice>,
121 tool_server_handle: &ToolServerHandle,
122 dynamic_context: &DynamicContextStore,
123 output_schema: Option<&schemars::Schema>,
124) -> Result<PreparedCompletionRequest<M>, CompletionError> {
125 let rag_text = prompt.rag_text();
127 let rag_text = rag_text.or_else(|| {
128 chat_history
129 .iter()
130 .rev()
131 .find_map(|message| message.rag_text())
132 });
133
134 let chat_history: Vec<Message> = if let Some(preamble) = preamble {
136 std::iter::once(Message::system(preamble.to_owned()))
137 .chain(chat_history.iter().cloned())
138 .collect()
139 } else {
140 chat_history.to_vec()
141 };
142
143 let completion_request = model
144 .completion_request(prompt)
145 .messages(chat_history)
146 .temperature_opt(temperature)
147 .max_tokens_opt(max_tokens)
148 .additional_params_opt(additional_params.cloned())
149 .output_schema_opt(output_schema.cloned())
150 .documents(static_context.to_vec());
151
152 let completion_request = if let Some(tool_choice) = tool_choice {
153 completion_request.tool_choice(tool_choice.clone())
154 } else {
155 completion_request
156 };
157
158 let (builder, executable_tool_names) = match &rag_text {
160 Some(text) => {
161 let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
163 let text = text.clone();
165 let num_sample = *num_sample;
166 let index = index.clone();
167
168 async move {
169 let req = VectorSearchRequest::builder()
170 .query(text)
171 .samples(num_sample as u64)
172 .build();
173
174 let docs = index
175 .top_n(req)
176 .await?
177 .into_iter()
178 .map(|(_, id, doc)| {
179 let text = serde_json::to_string_pretty(&doc)
181 .unwrap_or_else(|_| doc.to_string());
182
183 Document {
184 id,
185 text,
186 additional_props: HashMap::new(),
187 }
188 })
189 .collect::<Vec<_>>();
190
191 Ok::<_, VectorStoreError>(docs)
192 }
193 });
194
195 let fetched_context: Vec<Document> = futures::future::try_join_all(search_futures)
197 .await
198 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
199 .into_iter()
200 .flatten() .collect();
202
203 let tooldefs = tool_server_handle
204 .get_tool_defs(Some(text.to_string()))
205 .await
206 .map_err(|_| {
207 CompletionError::RequestError("Failed to get tool definitions".into())
208 })?;
209 let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect();
210
211 (
212 completion_request
213 .documents(fetched_context)
214 .tools(tooldefs),
215 executable_tool_names,
216 )
217 }
218 None => {
219 let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
220 CompletionError::RequestError("Failed to get tool definitions".into())
221 })?;
222 let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect();
223
224 (completion_request.tools(tooldefs), executable_tool_names)
225 }
226 };
227 let allowed_tool_names = allowed_tool_names_for_choice(&executable_tool_names, tool_choice)?;
228
229 Ok(PreparedCompletionRequest {
230 builder,
231 executable_tool_names,
232 allowed_tool_names,
233 })
234}
235
236#[derive(Clone)]
265#[non_exhaustive]
266pub struct Agent<M, P = ()>
267where
268 M: CompletionModel,
269 P: PromptHook<M>,
270{
271 pub name: Option<String>,
273 pub description: Option<String>,
275 pub model: Arc<M>,
277 pub preamble: Option<String>,
279 pub static_context: Vec<Document>,
281 pub temperature: Option<f64>,
283 pub max_tokens: Option<u64>,
285 pub additional_params: Option<serde_json::Value>,
287 pub tool_server_handle: ToolServerHandle,
288 pub dynamic_context: DynamicContextStore,
290 pub tool_choice: Option<ToolChoice>,
292 pub default_max_turns: Option<usize>,
294 pub hook: Option<P>,
296 pub output_schema: Option<schemars::Schema>,
299 pub memory: Option<Arc<dyn crate::memory::ConversationMemory>>,
301 pub default_conversation_id: Option<String>,
303}
304
305impl<M, P> Agent<M, P>
306where
307 M: CompletionModel,
308 P: PromptHook<M>,
309{
310 pub(crate) fn name(&self) -> &str {
312 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
313 }
314}
315
316impl<M, P> Completion<M> for Agent<M, P>
317where
318 M: CompletionModel,
319 P: PromptHook<M>,
320{
321 async fn completion<I, T>(
322 &self,
323 prompt: impl Into<Message> + WasmCompatSend,
324 chat_history: I,
325 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
326 where
327 I: IntoIterator<Item = T>,
328 T: Into<Message>,
329 {
330 let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
331 build_completion_request(
332 &self.model,
333 prompt.into(),
334 &history,
335 self.preamble.as_deref(),
336 &self.static_context,
337 self.temperature,
338 self.max_tokens,
339 self.additional_params.as_ref(),
340 self.tool_choice.as_ref(),
341 &self.tool_server_handle,
342 &self.dynamic_context,
343 self.output_schema.as_ref(),
344 )
345 .await
346 }
347}
348
349#[allow(refining_impl_trait)]
357impl<M, P> Prompt for Agent<M, P>
358where
359 M: CompletionModel + 'static,
360 P: PromptHook<M> + 'static,
361{
362 fn prompt(
363 &self,
364 prompt: impl Into<Message> + WasmCompatSend,
365 ) -> PromptRequest<prompt_request::Standard, M, P> {
366 PromptRequest::from_agent(self, prompt)
367 }
368}
369
370#[allow(refining_impl_trait)]
371impl<M, P> Prompt for &Agent<M, P>
372where
373 M: CompletionModel + 'static,
374 P: PromptHook<M> + 'static,
375{
376 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
377 fn prompt(
378 &self,
379 prompt: impl Into<Message> + WasmCompatSend,
380 ) -> PromptRequest<prompt_request::Standard, M, P> {
381 PromptRequest::from_agent(*self, prompt)
382 }
383}
384
385#[allow(refining_impl_trait)]
386impl<M, P> Chat for Agent<M, P>
387where
388 M: CompletionModel + 'static,
389 P: PromptHook<M> + 'static,
390{
391 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
392 async fn chat(
393 &self,
394 prompt: impl Into<Message> + WasmCompatSend,
395 chat_history: &mut Vec<Message>,
396 ) -> Result<String, PromptError> {
397 let response = PromptRequest::from_agent(self, prompt)
398 .with_history(chat_history.clone())
399 .extended_details()
400 .await?;
401
402 if let Some(messages) = response.messages {
403 chat_history.extend(messages);
404 }
405
406 Ok(response.output)
407 }
408}
409
410impl<M, P> StreamingCompletion<M> for Agent<M, P>
411where
412 M: CompletionModel,
413 P: PromptHook<M>,
414{
415 async fn stream_completion<I, T>(
416 &self,
417 prompt: impl Into<Message> + WasmCompatSend,
418 chat_history: I,
419 ) -> Result<CompletionRequestBuilder<M>, CompletionError>
420 where
421 I: IntoIterator<Item = T> + WasmCompatSend,
422 T: Into<Message>,
423 {
424 self.completion(prompt, chat_history).await
427 }
428}
429
430impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
431where
432 M: CompletionModel + 'static,
433 M::StreamingResponse: GetTokenUsage,
434 P: PromptHook<M> + 'static,
435{
436 type Hook = P;
437
438 fn stream_prompt(
439 &self,
440 prompt: impl Into<Message> + WasmCompatSend,
441 ) -> StreamingPromptRequest<M, P> {
442 StreamingPromptRequest::<M, P>::from_agent(self, prompt)
443 }
444}
445
446impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
447where
448 M: CompletionModel + 'static,
449 M::StreamingResponse: GetTokenUsage,
450 P: PromptHook<M> + 'static,
451{
452 type Hook = P;
453
454 fn stream_chat<I, T>(
455 &self,
456 prompt: impl Into<Message> + WasmCompatSend,
457 chat_history: I,
458 ) -> StreamingPromptRequest<M, P>
459 where
460 I: IntoIterator<Item = T>,
461 T: Into<Message>,
462 {
463 StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
464 }
465}
466
467use crate::agent::prompt_request::TypedPromptRequest;
468use schemars::JsonSchema;
469use serde::de::DeserializeOwned;
470
471#[allow(refining_impl_trait)]
472impl<M, P> TypedPrompt for Agent<M, P>
473where
474 M: CompletionModel + 'static,
475 P: PromptHook<M> + 'static,
476{
477 type TypedRequest<T>
478 = TypedPromptRequest<T, prompt_request::Standard, M, P>
479 where
480 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
481
482 fn prompt_typed<T>(
515 &self,
516 prompt: impl Into<Message> + WasmCompatSend,
517 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
518 where
519 T: JsonSchema + DeserializeOwned + WasmCompatSend,
520 {
521 TypedPromptRequest::from_agent(self, prompt)
522 }
523}
524
525#[allow(refining_impl_trait)]
526impl<M, P> TypedPrompt for &Agent<M, P>
527where
528 M: CompletionModel + 'static,
529 P: PromptHook<M> + 'static,
530{
531 type TypedRequest<T>
532 = TypedPromptRequest<T, prompt_request::Standard, M, P>
533 where
534 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
535
536 fn prompt_typed<T>(
537 &self,
538 prompt: impl Into<Message> + WasmCompatSend,
539 ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
540 where
541 T: JsonSchema + DeserializeOwned + WasmCompatSend,
542 {
543 TypedPromptRequest::from_agent(*self, prompt)
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 fn tool_names(names: &[&str]) -> BTreeSet<String> {
552 names.iter().map(|name| (*name).to_string()).collect()
553 }
554
555 #[test]
556 fn allowed_tool_names_defaults_to_all_executable_tools() {
557 let executable = tool_names(&["add", "subtract"]);
558
559 assert_eq!(
560 allowed_tool_names_for_choice(&executable, None).unwrap(),
561 executable
562 );
563 }
564
565 #[test]
566 fn allowed_tool_names_auto_and_required_allow_all_executable_tools() {
567 let executable = tool_names(&["add", "subtract"]);
568
569 assert_eq!(
570 allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Auto)).unwrap(),
571 executable
572 );
573 assert_eq!(
574 allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Required)).unwrap(),
575 executable
576 );
577 }
578
579 #[test]
580 fn allowed_tool_names_none_allows_no_tools() {
581 let executable = tool_names(&["add", "subtract"]);
582
583 assert!(
584 allowed_tool_names_for_choice(&executable, Some(&ToolChoice::None))
585 .unwrap()
586 .is_empty()
587 );
588 }
589
590 #[test]
591 fn allowed_tool_names_specific_allows_requested_executable_tools() {
592 let executable = tool_names(&["add", "subtract"]);
593 let choice = ToolChoice::Specific {
594 function_names: vec!["add".to_string()],
595 };
596
597 assert_eq!(
598 allowed_tool_names_for_choice(&executable, Some(&choice)).unwrap(),
599 tool_names(&["add"])
600 );
601 }
602
603 #[test]
604 fn allowed_tool_names_specific_rejects_missing_tools() {
605 let executable = tool_names(&["add"]);
606 let choice = ToolChoice::Specific {
607 function_names: vec!["missing".to_string()],
608 };
609
610 let err = allowed_tool_names_for_choice(&executable, Some(&choice))
611 .expect_err("missing specific tool should fail before provider request");
612
613 assert!(matches!(
614 err,
615 CompletionError::RequestError(err)
616 if err.to_string().contains("missing")
617 && err.to_string().contains("add")
618 ));
619 }
620
621 #[test]
622 fn allowed_tool_names_specific_rejects_empty_names() {
623 let executable = tool_names(&["add"]);
624 let choice = ToolChoice::Specific {
625 function_names: vec![],
626 };
627
628 let err = allowed_tool_names_for_choice(&executable, Some(&choice))
629 .expect_err("empty specific tool choice should fail before provider request");
630
631 assert!(matches!(
632 err,
633 CompletionError::RequestError(err)
634 if err.to_string().contains("requires at least one function name")
635 ));
636 }
637}