1use super::api::{
7 create_async_deep_agent_from_config, create_deep_agent_from_config, get_default_model,
8};
9use super::config::{DeepAgentConfig, SubAgentConfig, SummarizationConfig};
10use super::runtime::DeepAgent;
11use crate::middleware::{
12 token_tracking::{TokenTrackingConfig, TokenTrackingMiddleware},
13 HitlPolicy,
14};
15use crate::planner::LlmBackedPlanner;
16use agents_core::agent::PlannerHandle;
17use agents_core::llm::LanguageModel;
18use agents_core::persistence::Checkpointer;
19use agents_core::tools::ToolBox;
20use std::collections::{HashMap, HashSet};
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23
24pub struct ConfigurableAgentBuilder {
27 instructions: String,
28 custom_system_prompt: Option<String>,
29 planner: Option<Arc<dyn PlannerHandle>>,
30 tools: Vec<ToolBox>,
31 subagents: Vec<SubAgentConfig>,
32 summarization: Option<SummarizationConfig>,
33 tool_interrupts: HashMap<String, HitlPolicy>,
34 builtin_tools: Option<HashSet<String>>,
35 auto_general_purpose: bool,
36 enable_prompt_caching: bool,
37 checkpointer: Option<Arc<dyn Checkpointer>>,
38 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
39 enable_pii_sanitization: bool,
40 token_tracking_config: Option<TokenTrackingConfig>,
41 max_iterations: NonZeroUsize,
42}
43
44impl ConfigurableAgentBuilder {
45 pub fn new(instructions: impl Into<String>) -> Self {
46 Self {
47 instructions: instructions.into(),
48 custom_system_prompt: None,
49 planner: None,
50 tools: Vec::new(),
51 subagents: Vec::new(),
52 summarization: None,
53 tool_interrupts: HashMap::new(),
54 builtin_tools: None,
55 auto_general_purpose: true,
56 enable_prompt_caching: false,
57 checkpointer: None,
58 event_dispatcher: None,
59 enable_pii_sanitization: true, token_tracking_config: None,
61 max_iterations: NonZeroUsize::new(10).unwrap(),
62 }
63 }
64
65 pub fn with_model(mut self, model: Arc<dyn LanguageModel>) -> Self {
67 let planner: Arc<dyn PlannerHandle> = Arc::new(LlmBackedPlanner::new(model));
68 self.planner = Some(planner);
69 self
70 }
71
72 pub fn with_planner(mut self, planner: Arc<dyn PlannerHandle>) -> Self {
74 self.planner = Some(planner);
75 self
76 }
77
78 pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
107 self.custom_system_prompt = Some(system_prompt.into());
108 self
109 }
110
111 pub fn with_tool(mut self, tool: ToolBox) -> Self {
113 self.tools.push(tool);
114 self
115 }
116
117 pub fn with_tools<I>(mut self, tools: I) -> Self
119 where
120 I: IntoIterator<Item = ToolBox>,
121 {
122 self.tools.extend(tools);
123 self
124 }
125
126 pub fn with_subagent_config<I>(mut self, cfgs: I) -> Self
127 where
128 I: IntoIterator<Item = SubAgentConfig>,
129 {
130 self.subagents.extend(cfgs);
131 self
132 }
133
134 pub fn with_subagent_tools<I>(mut self, tools: I) -> Self
137 where
138 I: IntoIterator<Item = ToolBox>,
139 {
140 for tool in tools {
141 let tool_name = tool.schema().name.clone();
142 let subagent_config = SubAgentConfig::new(
143 format!("{}-agent", tool_name),
144 format!("Specialized agent for {} operations", tool_name),
145 format!(
146 "You are a specialized agent. Use the {} tool to complete tasks efficiently.",
147 tool_name
148 ),
149 )
150 .with_tools(vec![tool]);
151 self.subagents.push(subagent_config);
152 }
153 self
154 }
155
156 pub fn with_summarization(mut self, config: SummarizationConfig) -> Self {
157 self.summarization = Some(config);
158 self
159 }
160
161 pub fn with_tool_interrupt(mut self, tool_name: impl Into<String>, policy: HitlPolicy) -> Self {
162 self.tool_interrupts.insert(tool_name.into(), policy);
163 self
164 }
165
166 pub fn with_builtin_tools<I, S>(mut self, names: I) -> Self
167 where
168 I: IntoIterator<Item = S>,
169 S: Into<String>,
170 {
171 self.builtin_tools = Some(names.into_iter().map(|s| s.into()).collect());
172 self
173 }
174
175 pub fn with_auto_general_purpose(mut self, enabled: bool) -> Self {
176 self.auto_general_purpose = enabled;
177 self
178 }
179
180 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
181 self.enable_prompt_caching = enabled;
182 self
183 }
184
185 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
186 self.checkpointer = Some(checkpointer);
187 self
188 }
189
190 pub fn with_event_broadcaster(
197 mut self,
198 broadcaster: Arc<dyn agents_core::events::EventBroadcaster>,
199 ) -> Self {
200 if self.event_dispatcher.is_none() {
202 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
203 }
204
205 if let Some(dispatcher) = &self.event_dispatcher {
207 dispatcher.add_broadcaster(broadcaster);
208 }
209
210 self
211 }
212
213 pub fn with_event_broadcasters(
224 mut self,
225 broadcasters: Vec<Arc<dyn agents_core::events::EventBroadcaster>>,
226 ) -> Self {
227 if self.event_dispatcher.is_none() {
229 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
230 }
231
232 if let Some(dispatcher) = &self.event_dispatcher {
234 for broadcaster in broadcasters {
235 dispatcher.add_broadcaster(broadcaster);
236 }
237 }
238
239 self
240 }
241
242 pub fn with_event_dispatcher(
244 mut self,
245 dispatcher: Arc<agents_core::events::EventDispatcher>,
246 ) -> Self {
247 self.event_dispatcher = Some(dispatcher);
248 self
249 }
250
251 pub fn with_pii_sanitization(mut self, enabled: bool) -> Self {
277 self.enable_pii_sanitization = enabled;
278 self
279 }
280
281 pub fn with_token_tracking(mut self, enabled: bool) -> Self {
308 self.token_tracking_config = Some(TokenTrackingConfig {
309 enabled,
310 emit_events: enabled,
311 log_usage: enabled,
312 custom_costs: None,
313 });
314 self
315 }
316
317 pub fn with_token_tracking_config(mut self, config: TokenTrackingConfig) -> Self {
337 self.token_tracking_config = Some(config);
338 self
339 }
340
341 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
366 self.max_iterations =
367 NonZeroUsize::new(max_iterations).expect("max_iterations must be greater than 0");
368 self
369 }
370
371 pub fn build(self) -> anyhow::Result<DeepAgent> {
372 self.finalize(create_deep_agent_from_config)
373 }
374
375 pub fn build_async(self) -> anyhow::Result<DeepAgent> {
378 self.finalize(create_async_deep_agent_from_config)
379 }
380
381 fn finalize(self, ctor: fn(DeepAgentConfig) -> DeepAgent) -> anyhow::Result<DeepAgent> {
382 let Self {
383 instructions,
384 custom_system_prompt,
385 planner,
386 tools,
387 subagents,
388 summarization,
389 tool_interrupts,
390 builtin_tools,
391 auto_general_purpose,
392 enable_prompt_caching,
393 checkpointer,
394 event_dispatcher,
395 enable_pii_sanitization,
396 token_tracking_config,
397 max_iterations,
398 } = self;
399
400 let planner = planner.unwrap_or_else(|| {
401 let default_model = get_default_model().expect("Failed to get default model");
403 Arc::new(LlmBackedPlanner::new(default_model)) as Arc<dyn PlannerHandle>
404 });
405
406 let final_planner = if let Some(token_config) = token_tracking_config {
408 if token_config.enabled {
409 let planner_any = planner.as_any();
411 if let Some(llm_planner) = planner_any.downcast_ref::<LlmBackedPlanner>() {
412 let model = llm_planner.model().clone();
413 let tracked_model = Arc::new(TokenTrackingMiddleware::new(
414 token_config,
415 model,
416 event_dispatcher.clone(),
417 ));
418 Arc::new(LlmBackedPlanner::new(tracked_model)) as Arc<dyn PlannerHandle>
419 } else {
420 planner
421 }
422 } else {
423 planner
424 }
425 } else {
426 planner
427 };
428
429 let mut cfg = DeepAgentConfig::new(instructions, final_planner)
430 .with_auto_general_purpose(auto_general_purpose)
431 .with_prompt_caching(enable_prompt_caching)
432 .with_pii_sanitization(enable_pii_sanitization)
433 .with_max_iterations(max_iterations.get());
434
435 if let Some(prompt) = custom_system_prompt {
437 cfg = cfg.with_system_prompt(prompt);
438 }
439
440 if let Some(ckpt) = checkpointer {
441 cfg = cfg.with_checkpointer(ckpt);
442 }
443 if let Some(dispatcher) = event_dispatcher {
444 cfg = cfg.with_event_dispatcher(dispatcher);
445 }
446 if let Some(sum) = summarization {
447 cfg = cfg.with_summarization(sum);
448 }
449 if let Some(selected) = builtin_tools {
450 cfg = cfg.with_builtin_tools(selected);
451 }
452 for (name, policy) in tool_interrupts {
453 cfg = cfg.with_tool_interrupt(name, policy);
454 }
455 for tool in tools {
456 cfg = cfg.with_tool(tool);
457 }
458 for sub_cfg in subagents {
459 cfg = cfg.with_subagent_config(sub_cfg);
460 }
461
462 Ok(ctor(cfg))
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_builder_default_max_iterations() {
472 let builder = ConfigurableAgentBuilder::new("test instructions");
473 assert_eq!(builder.max_iterations.get(), 10);
474 }
475
476 #[test]
477 fn test_builder_custom_max_iterations() {
478 let builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(20);
479 assert_eq!(builder.max_iterations.get(), 20);
480 }
481
482 #[test]
483 #[should_panic(expected = "max_iterations must be greater than 0")]
484 fn test_builder_zero_max_iterations_panics() {
485 let _builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(0);
486 }
487
488 #[test]
489 fn test_builder_large_max_iterations() {
490 let builder = ConfigurableAgentBuilder::new("test instructions").with_max_iterations(1000);
491 assert_eq!(builder.max_iterations.get(), 1000);
492 }
493
494 #[test]
495 fn test_builder_chaining_with_max_iterations() {
496 let builder = ConfigurableAgentBuilder::new("test instructions")
497 .with_max_iterations(15)
498 .with_auto_general_purpose(false)
499 .with_prompt_caching(true)
500 .with_pii_sanitization(false);
501
502 assert_eq!(builder.max_iterations.get(), 15);
503 assert!(!builder.auto_general_purpose);
504 assert!(builder.enable_prompt_caching);
505 assert!(!builder.enable_pii_sanitization);
506 }
507
508 #[test]
509 fn test_builder_default_no_custom_system_prompt() {
510 let builder = ConfigurableAgentBuilder::new("test instructions");
511 assert!(builder.custom_system_prompt.is_none());
512 }
513
514 #[test]
515 fn test_builder_with_system_prompt() {
516 let custom_prompt = "You are a custom assistant.";
517 let builder = ConfigurableAgentBuilder::new("ignored").with_system_prompt(custom_prompt);
518
519 assert!(builder.custom_system_prompt.is_some());
520 assert_eq!(builder.custom_system_prompt.unwrap(), custom_prompt);
521 }
522
523 #[test]
524 fn test_builder_system_prompt_chaining() {
525 let builder = ConfigurableAgentBuilder::new("ignored")
526 .with_system_prompt("Custom prompt")
527 .with_max_iterations(20)
528 .with_pii_sanitization(false);
529
530 assert!(builder.custom_system_prompt.is_some());
531 assert_eq!(builder.max_iterations.get(), 20);
532 assert!(!builder.enable_pii_sanitization);
533 }
534}