1use super::api::{
7 create_async_deep_agent_from_config, create_deep_agent_from_config, get_default_model,
8};
9use super::config::{DeepAgentConfig, SubAgentConfig, SummarizationConfig};
10use super::runtime::DeepAgent;
11use crate::middleware::{
12 token_tracking::{TokenTrackingConfig, TokenTrackingMiddleware},
13 HitlPolicy,
14};
15use crate::planner::LlmBackedPlanner;
16use crate::prompts::PromptFormat;
17use agents_core::agent::PlannerHandle;
18use agents_core::llm::LanguageModel;
19use agents_core::persistence::Checkpointer;
20use agents_core::tools::ToolBox;
21use std::collections::{HashMap, HashSet};
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24
25pub struct ConfigurableAgentBuilder {
28 instructions: String,
29 custom_system_prompt: Option<String>,
30 prompt_format: PromptFormat,
31 planner: Option<Arc<dyn PlannerHandle>>,
32 tools: Vec<ToolBox>,
33 subagents: Vec<SubAgentConfig>,
34 summarization: Option<SummarizationConfig>,
35 tool_interrupts: HashMap<String, HitlPolicy>,
36 builtin_tools: Option<HashSet<String>>,
37 auto_general_purpose: bool,
38 enable_prompt_caching: bool,
39 checkpointer: Option<Arc<dyn Checkpointer>>,
40 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
41 enable_pii_sanitization: bool,
42 token_tracking_config: Option<TokenTrackingConfig>,
43 max_iterations: NonZeroUsize,
44}
45
46impl ConfigurableAgentBuilder {
47 pub fn new(instructions: impl Into<String>) -> Self {
48 Self {
49 instructions: instructions.into(),
50 custom_system_prompt: None,
51 prompt_format: PromptFormat::default(),
52 planner: None,
53 tools: Vec::new(),
54 subagents: Vec::new(),
55 summarization: None,
56 tool_interrupts: HashMap::new(),
57 builtin_tools: None,
58 auto_general_purpose: true,
59 enable_prompt_caching: false,
60 checkpointer: None,
61 event_dispatcher: None,
62 enable_pii_sanitization: true, token_tracking_config: None,
64 max_iterations: NonZeroUsize::new(10).unwrap(),
65 }
66 }
67
68 pub fn with_model(mut self, model: Arc<dyn LanguageModel>) -> Self {
70 let planner: Arc<dyn PlannerHandle> = Arc::new(LlmBackedPlanner::new(model));
71 self.planner = Some(planner);
72 self
73 }
74
75 pub fn with_planner(mut self, planner: Arc<dyn PlannerHandle>) -> Self {
77 self.planner = Some(planner);
78 self
79 }
80
81 pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
110 self.custom_system_prompt = Some(system_prompt.into());
111 self
112 }
113
114 pub fn with_prompt_format(mut self, format: PromptFormat) -> Self {
139 self.prompt_format = format;
140 self
141 }
142
143 pub fn with_tool(mut self, tool: ToolBox) -> Self {
145 self.tools.push(tool);
146 self
147 }
148
149 pub fn with_tools<I>(mut self, tools: I) -> Self
151 where
152 I: IntoIterator<Item = ToolBox>,
153 {
154 self.tools.extend(tools);
155 self
156 }
157
158 pub fn with_subagent_config<I>(mut self, cfgs: I) -> Self
159 where
160 I: IntoIterator<Item = SubAgentConfig>,
161 {
162 self.subagents.extend(cfgs);
163 self
164 }
165
166 pub fn with_subagent_tools<I>(mut self, tools: I) -> Self
169 where
170 I: IntoIterator<Item = ToolBox>,
171 {
172 for tool in tools {
173 let tool_name = tool.schema().name.clone();
174 let subagent_config = SubAgentConfig::new(
175 format!("{}-agent", tool_name),
176 format!("Specialized agent for {} operations", tool_name),
177 format!(
178 "You are a specialized agent. Use the {} tool to complete tasks efficiently.",
179 tool_name
180 ),
181 )
182 .with_tools(vec![tool]);
183 self.subagents.push(subagent_config);
184 }
185 self
186 }
187
188 pub fn with_summarization(mut self, config: SummarizationConfig) -> Self {
189 self.summarization = Some(config);
190 self
191 }
192
193 pub fn with_tool_interrupt(mut self, tool_name: impl Into<String>, policy: HitlPolicy) -> Self {
194 self.tool_interrupts.insert(tool_name.into(), policy);
195 self
196 }
197
198 pub fn with_builtin_tools<I, S>(mut self, names: I) -> Self
199 where
200 I: IntoIterator<Item = S>,
201 S: Into<String>,
202 {
203 self.builtin_tools = Some(names.into_iter().map(|s| s.into()).collect());
204 self
205 }
206
207 pub fn with_auto_general_purpose(mut self, enabled: bool) -> Self {
208 self.auto_general_purpose = enabled;
209 self
210 }
211
212 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
213 self.enable_prompt_caching = enabled;
214 self
215 }
216
217 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
218 self.checkpointer = Some(checkpointer);
219 self
220 }
221
222 pub fn with_event_broadcaster(
229 mut self,
230 broadcaster: Arc<dyn agents_core::events::EventBroadcaster>,
231 ) -> Self {
232 if self.event_dispatcher.is_none() {
234 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
235 }
236
237 if let Some(dispatcher) = &self.event_dispatcher {
239 dispatcher.add_broadcaster(broadcaster);
240 }
241
242 self
243 }
244
245 pub fn with_event_broadcasters(
256 mut self,
257 broadcasters: Vec<Arc<dyn agents_core::events::EventBroadcaster>>,
258 ) -> Self {
259 if self.event_dispatcher.is_none() {
261 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
262 }
263
264 if let Some(dispatcher) = &self.event_dispatcher {
266 for broadcaster in broadcasters {
267 dispatcher.add_broadcaster(broadcaster);
268 }
269 }
270
271 self
272 }
273
274 pub fn with_event_dispatcher(
276 mut self,
277 dispatcher: Arc<agents_core::events::EventDispatcher>,
278 ) -> Self {
279 self.event_dispatcher = Some(dispatcher);
280 self
281 }
282
283 pub fn with_pii_sanitization(mut self, enabled: bool) -> Self {
309 self.enable_pii_sanitization = enabled;
310 self
311 }
312
313 pub fn with_token_tracking(mut self, enabled: bool) -> Self {
340 self.token_tracking_config = Some(TokenTrackingConfig {
341 enabled,
342 emit_events: enabled,
343 log_usage: enabled,
344 custom_costs: None,
345 });
346 self
347 }
348
349 pub fn with_token_tracking_config(mut self, config: TokenTrackingConfig) -> Self {
369 self.token_tracking_config = Some(config);
370 self
371 }
372
373 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
398 self.max_iterations =
399 NonZeroUsize::new(max_iterations).expect("max_iterations must be greater than 0");
400 self
401 }
402
403 pub fn build(self) -> anyhow::Result<DeepAgent> {
404 self.finalize(create_deep_agent_from_config)
405 }
406
407 pub fn build_async(self) -> anyhow::Result<DeepAgent> {
410 self.finalize(create_async_deep_agent_from_config)
411 }
412
413 fn finalize(self, ctor: fn(DeepAgentConfig) -> DeepAgent) -> anyhow::Result<DeepAgent> {
414 let Self {
415 instructions,
416 custom_system_prompt,
417 prompt_format,
418 planner,
419 tools,
420 subagents,
421 summarization,
422 tool_interrupts,
423 builtin_tools,
424 auto_general_purpose,
425 enable_prompt_caching,
426 checkpointer,
427 event_dispatcher,
428 enable_pii_sanitization,
429 token_tracking_config,
430 max_iterations,
431 } = self;
432
433 let planner = planner.unwrap_or_else(|| {
434 let default_model = get_default_model().expect("Failed to get default model");
436 Arc::new(LlmBackedPlanner::new(default_model)) as Arc<dyn PlannerHandle>
437 });
438
439 let final_planner = if let Some(token_config) = token_tracking_config {
441 if token_config.enabled {
442 let planner_any = planner.as_any();
444 if let Some(llm_planner) = planner_any.downcast_ref::<LlmBackedPlanner>() {
445 let model = llm_planner.model().clone();
446 let tracked_model = Arc::new(TokenTrackingMiddleware::new(
447 token_config,
448 model,
449 event_dispatcher.clone(),
450 ));
451 Arc::new(LlmBackedPlanner::new(tracked_model)) as Arc<dyn PlannerHandle>
452 } else {
453 planner
454 }
455 } else {
456 planner
457 }
458 } else {
459 planner
460 };
461
462 let mut cfg = DeepAgentConfig::new(instructions, final_planner)
463 .with_auto_general_purpose(auto_general_purpose)
464 .with_prompt_caching(enable_prompt_caching)
465 .with_pii_sanitization(enable_pii_sanitization)
466 .with_max_iterations(max_iterations.get())
467 .with_prompt_format(prompt_format);
468
469 if let Some(prompt) = custom_system_prompt {
471 cfg = cfg.with_system_prompt(prompt);
472 }
473
474 if let Some(ckpt) = checkpointer {
475 cfg = cfg.with_checkpointer(ckpt);
476 }
477 if let Some(dispatcher) = event_dispatcher {
478 cfg = cfg.with_event_dispatcher(dispatcher);
479 }
480 if let Some(sum) = summarization {
481 cfg = cfg.with_summarization(sum);
482 }
483 if let Some(selected) = builtin_tools {
484 cfg = cfg.with_builtin_tools(selected);
485 }
486 for (name, policy) in tool_interrupts {
487 cfg = cfg.with_tool_interrupt(name, policy);
488 }
489 for tool in tools {
490 cfg = cfg.with_tool(tool);
491 }
492 for sub_cfg in subagents {
493 cfg = cfg.with_subagent_config(sub_cfg);
494 }
495
496 Ok(ctor(cfg))
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_builder_default_max_iterations() {
506 let builder = ConfigurableAgentBuilder::new("test instructions");
507 assert_eq!(builder.max_iterations.get(), 10);
508 }
509
510 #[test]
511 fn test_builder_custom_max_iterations() {
512 let builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(20);
513 assert_eq!(builder.max_iterations.get(), 20);
514 }
515
516 #[test]
517 #[should_panic(expected = "max_iterations must be greater than 0")]
518 fn test_builder_zero_max_iterations_panics() {
519 let _builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(0);
520 }
521
522 #[test]
523 fn test_builder_large_max_iterations() {
524 let builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(1000);
525 assert_eq!(builder.max_iterations.get(), 1000);
526 }
527
528 #[test]
529 fn test_builder_chaining_with_max_iterations() {
530 let builder = ConfigurableAgentBuilder::new("test instructions")
531 .with_max_iterations(15)
532 .with_auto_general_purpose(false)
533 .with_prompt_caching(true)
534 .with_pii_sanitization(false);
535
536 assert_eq!(builder.max_iterations.get(), 15);
537 assert!(!builder.auto_general_purpose);
538 assert!(builder.enable_prompt_caching);
539 assert!(!builder.enable_pii_sanitization);
540 }
541
542 #[test]
543 fn test_builder_default_no_custom_system_prompt() {
544 let builder = ConfigurableAgentBuilder::new("test instructions");
545 assert!(builder.custom_system_prompt.is_none());
546 }
547
548 #[test]
549 fn test_builder_with_system_prompt() {
550 let custom_prompt = "You are a custom assistant.";
551 let builder = ConfigurableAgentBuilder::new("ignored").with_system_prompt(custom_prompt);
552
553 assert!(builder.custom_system_prompt.is_some());
554 assert_eq!(builder.custom_system_prompt.unwrap(), custom_prompt);
555 }
556
557 #[test]
558 fn test_builder_system_prompt_chaining() {
559 let builder = ConfigurableAgentBuilder::new("ignored")
560 .with_system_prompt("Custom prompt")
561 .with_max_iterations(20)
562 .with_pii_sanitization(false);
563
564 assert!(builder.custom_system_prompt.is_some());
565 assert_eq!(builder.max_iterations.get(), 20);
566 assert!(!builder.enable_pii_sanitization);
567 }
568
569 #[test]
570 fn test_builder_default_prompt_format_is_json() {
571 let builder = ConfigurableAgentBuilder::new("test instructions");
572 assert_eq!(builder.prompt_format, PromptFormat::Json);
573 }
574
575 #[test]
576 fn test_builder_with_toon_prompt_format() {
577 let builder = ConfigurableAgentBuilder::new("test instructions")
578 .with_prompt_format(PromptFormat::Toon);
579 assert_eq!(builder.prompt_format, PromptFormat::Toon);
580 }
581
582 #[test]
583 fn test_builder_prompt_format_chaining() {
584 let builder = ConfigurableAgentBuilder::new("test instructions")
585 .with_prompt_format(PromptFormat::Toon)
586 .with_max_iterations(15)
587 .with_pii_sanitization(false);
588
589 assert_eq!(builder.prompt_format, PromptFormat::Toon);
590 assert_eq!(builder.max_iterations.get(), 15);
591 assert!(!builder.enable_pii_sanitization);
592 }
593}