1use super::api::{
7 create_async_deep_agent_from_config, create_deep_agent_from_config, get_default_model,
8};
9use super::config::{DeepAgentConfig, SubAgentConfig, SummarizationConfig};
10use super::runtime::DeepAgent;
11use crate::middleware::{
12 token_tracking::{TokenTrackingConfig, TokenTrackingMiddleware},
13 HitlPolicy,
14};
15use crate::planner::LlmBackedPlanner;
16use agents_core::agent::PlannerHandle;
17use agents_core::llm::LanguageModel;
18use agents_core::persistence::Checkpointer;
19use agents_core::tools::ToolBox;
20use std::collections::{HashMap, HashSet};
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23
24pub struct ConfigurableAgentBuilder {
27 instructions: String,
28 planner: Option<Arc<dyn PlannerHandle>>,
29 tools: Vec<ToolBox>,
30 subagents: Vec<SubAgentConfig>,
31 summarization: Option<SummarizationConfig>,
32 tool_interrupts: HashMap<String, HitlPolicy>,
33 builtin_tools: Option<HashSet<String>>,
34 auto_general_purpose: bool,
35 enable_prompt_caching: bool,
36 checkpointer: Option<Arc<dyn Checkpointer>>,
37 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
38 enable_pii_sanitization: bool,
39 token_tracking_config: Option<TokenTrackingConfig>,
40 max_iterations: NonZeroUsize,
41}
42
43impl ConfigurableAgentBuilder {
44 pub fn new(instructions: impl Into<String>) -> Self {
45 Self {
46 instructions: instructions.into(),
47 planner: None,
48 tools: Vec::new(),
49 subagents: Vec::new(),
50 summarization: None,
51 tool_interrupts: HashMap::new(),
52 builtin_tools: None,
53 auto_general_purpose: true,
54 enable_prompt_caching: false,
55 checkpointer: None,
56 event_dispatcher: None,
57 enable_pii_sanitization: true, token_tracking_config: None,
59 max_iterations: NonZeroUsize::new(10).unwrap(),
60 }
61 }
62
63 pub fn with_model(mut self, model: Arc<dyn LanguageModel>) -> Self {
65 let planner: Arc<dyn PlannerHandle> = Arc::new(LlmBackedPlanner::new(model));
66 self.planner = Some(planner);
67 self
68 }
69
70 pub fn with_planner(mut self, planner: Arc<dyn PlannerHandle>) -> Self {
72 self.planner = Some(planner);
73 self
74 }
75
76 pub fn with_tool(mut self, tool: ToolBox) -> Self {
78 self.tools.push(tool);
79 self
80 }
81
82 pub fn with_tools<I>(mut self, tools: I) -> Self
84 where
85 I: IntoIterator<Item = ToolBox>,
86 {
87 self.tools.extend(tools);
88 self
89 }
90
91 pub fn with_subagent_config<I>(mut self, cfgs: I) -> Self
92 where
93 I: IntoIterator<Item = SubAgentConfig>,
94 {
95 self.subagents.extend(cfgs);
96 self
97 }
98
99 pub fn with_subagent_tools<I>(mut self, tools: I) -> Self
102 where
103 I: IntoIterator<Item = ToolBox>,
104 {
105 for tool in tools {
106 let tool_name = tool.schema().name.clone();
107 let subagent_config = SubAgentConfig::new(
108 format!("{}-agent", tool_name),
109 format!("Specialized agent for {} operations", tool_name),
110 format!(
111 "You are a specialized agent. Use the {} tool to complete tasks efficiently.",
112 tool_name
113 ),
114 )
115 .with_tools(vec![tool]);
116 self.subagents.push(subagent_config);
117 }
118 self
119 }
120
121 pub fn with_summarization(mut self, config: SummarizationConfig) -> Self {
122 self.summarization = Some(config);
123 self
124 }
125
126 pub fn with_tool_interrupt(mut self, tool_name: impl Into<String>, policy: HitlPolicy) -> Self {
127 self.tool_interrupts.insert(tool_name.into(), policy);
128 self
129 }
130
131 pub fn with_builtin_tools<I, S>(mut self, names: I) -> Self
132 where
133 I: IntoIterator<Item = S>,
134 S: Into<String>,
135 {
136 self.builtin_tools = Some(names.into_iter().map(|s| s.into()).collect());
137 self
138 }
139
140 pub fn with_auto_general_purpose(mut self, enabled: bool) -> Self {
141 self.auto_general_purpose = enabled;
142 self
143 }
144
145 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
146 self.enable_prompt_caching = enabled;
147 self
148 }
149
150 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
151 self.checkpointer = Some(checkpointer);
152 self
153 }
154
155 pub fn with_event_broadcaster(
162 mut self,
163 broadcaster: Arc<dyn agents_core::events::EventBroadcaster>,
164 ) -> Self {
165 if self.event_dispatcher.is_none() {
167 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
168 }
169
170 if let Some(dispatcher) = &self.event_dispatcher {
172 dispatcher.add_broadcaster(broadcaster);
173 }
174
175 self
176 }
177
178 pub fn with_event_broadcasters(
189 mut self,
190 broadcasters: Vec<Arc<dyn agents_core::events::EventBroadcaster>>,
191 ) -> Self {
192 if self.event_dispatcher.is_none() {
194 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
195 }
196
197 if let Some(dispatcher) = &self.event_dispatcher {
199 for broadcaster in broadcasters {
200 dispatcher.add_broadcaster(broadcaster);
201 }
202 }
203
204 self
205 }
206
207 pub fn with_event_dispatcher(
209 mut self,
210 dispatcher: Arc<agents_core::events::EventDispatcher>,
211 ) -> Self {
212 self.event_dispatcher = Some(dispatcher);
213 self
214 }
215
216 pub fn with_pii_sanitization(mut self, enabled: bool) -> Self {
242 self.enable_pii_sanitization = enabled;
243 self
244 }
245
246 pub fn with_token_tracking(mut self, enabled: bool) -> Self {
273 self.token_tracking_config = Some(TokenTrackingConfig {
274 enabled,
275 emit_events: enabled,
276 log_usage: enabled,
277 custom_costs: None,
278 });
279 self
280 }
281
282 pub fn with_token_tracking_config(mut self, config: TokenTrackingConfig) -> Self {
302 self.token_tracking_config = Some(config);
303 self
304 }
305
306 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
331 self.max_iterations =
332 NonZeroUsize::new(max_iterations).expect("max_iterations must be greater than 0");
333 self
334 }
335
336 pub fn build(self) -> anyhow::Result<DeepAgent> {
337 self.finalize(create_deep_agent_from_config)
338 }
339
340 pub fn build_async(self) -> anyhow::Result<DeepAgent> {
343 self.finalize(create_async_deep_agent_from_config)
344 }
345
346 fn finalize(self, ctor: fn(DeepAgentConfig) -> DeepAgent) -> anyhow::Result<DeepAgent> {
347 let Self {
348 instructions,
349 planner,
350 tools,
351 subagents,
352 summarization,
353 tool_interrupts,
354 builtin_tools,
355 auto_general_purpose,
356 enable_prompt_caching,
357 checkpointer,
358 event_dispatcher,
359 enable_pii_sanitization,
360 token_tracking_config,
361 max_iterations,
362 } = self;
363
364 let planner = planner.unwrap_or_else(|| {
365 let default_model = get_default_model().expect("Failed to get default model");
367 Arc::new(LlmBackedPlanner::new(default_model)) as Arc<dyn PlannerHandle>
368 });
369
370 let final_planner = if let Some(token_config) = token_tracking_config {
372 if token_config.enabled {
373 let planner_any = planner.as_any();
375 if let Some(llm_planner) = planner_any.downcast_ref::<LlmBackedPlanner>() {
376 let model = llm_planner.model().clone();
377 let tracked_model = Arc::new(TokenTrackingMiddleware::new(
378 token_config,
379 model,
380 event_dispatcher.clone(),
381 ));
382 Arc::new(LlmBackedPlanner::new(tracked_model)) as Arc<dyn PlannerHandle>
383 } else {
384 planner
385 }
386 } else {
387 planner
388 }
389 } else {
390 planner
391 };
392
393 let mut cfg = DeepAgentConfig::new(instructions, final_planner)
394 .with_auto_general_purpose(auto_general_purpose)
395 .with_prompt_caching(enable_prompt_caching)
396 .with_pii_sanitization(enable_pii_sanitization)
397 .with_max_iterations(max_iterations.get());
398
399 if let Some(ckpt) = checkpointer {
400 cfg = cfg.with_checkpointer(ckpt);
401 }
402 if let Some(dispatcher) = event_dispatcher {
403 cfg = cfg.with_event_dispatcher(dispatcher);
404 }
405 if let Some(sum) = summarization {
406 cfg = cfg.with_summarization(sum);
407 }
408 if let Some(selected) = builtin_tools {
409 cfg = cfg.with_builtin_tools(selected);
410 }
411 for (name, policy) in tool_interrupts {
412 cfg = cfg.with_tool_interrupt(name, policy);
413 }
414 for tool in tools {
415 cfg = cfg.with_tool(tool);
416 }
417 for sub_cfg in subagents {
418 cfg = cfg.with_subagent_config(sub_cfg);
419 }
420
421 Ok(ctor(cfg))
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_builder_default_max_iterations() {
431 let builder = ConfigurableAgentBuilder::new("test instructions");
432 assert_eq!(builder.max_iterations.get(), 10);
433 }
434
435 #[test]
436 fn test_builder_custom_max_iterations() {
437 let builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(20);
438 assert_eq!(builder.max_iterations.get(), 20);
439 }
440
441 #[test]
442 #[should_panic(expected = "max_iterations must be greater than 0")]
443 fn test_builder_zero_max_iterations_panics() {
444 let _builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(0);
445 }
446
447 #[test]
448 fn test_builder_large_max_iterations() {
449 let builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(1000);
450 assert_eq!(builder.max_iterations.get(), 1000);
451 }
452
453 #[test]
454 fn test_builder_chaining_with_max_iterations() {
455 let builder = ConfigurableAgentBuilder::new("test instructions")
456 .with_max_iterations(15)
457 .with_auto_general_purpose(false)
458 .with_prompt_caching(true)
459 .with_pii_sanitization(false);
460
461 assert_eq!(builder.max_iterations.get(), 15);
462 assert!(!builder.auto_general_purpose);
463 assert!(builder.enable_prompt_caching);
464 assert!(!builder.enable_pii_sanitization);
465 }
466}