1use crate::middleware::{token_tracking::TokenTrackingConfig, AgentMiddleware, HitlPolicy};
7use agents_core::agent::PlannerHandle;
8use agents_core::persistence::Checkpointer;
9use agents_core::tools::ToolBox;
10use std::collections::{HashMap, HashSet};
11use std::num::NonZeroUsize;
12use std::sync::Arc;
13
14#[derive(Default)]
30pub struct CreateDeepAgentParams {
31 pub tools: Vec<ToolBox>,
32 pub instructions: String,
33 pub middleware: Vec<Arc<dyn AgentMiddleware>>,
34 pub model: Option<Arc<dyn agents_core::llm::LanguageModel>>,
35 pub subagents: Vec<SubAgentConfig>,
36 pub context_schema: Option<String>,
37 pub checkpointer: Option<Arc<dyn Checkpointer>>,
38 pub tool_configs: HashMap<String, HitlPolicy>,
39}
40
41pub struct DeepAgentConfig {
45 pub instructions: String,
46 pub planner: Arc<dyn PlannerHandle>,
47 pub tools: Vec<ToolBox>,
48 pub subagent_configs: Vec<SubAgentConfig>,
49 pub summarization: Option<SummarizationConfig>,
50 pub tool_interrupts: HashMap<String, HitlPolicy>,
51 pub builtin_tools: Option<HashSet<String>>,
52 pub auto_general_purpose: bool,
53 pub enable_prompt_caching: bool,
54 pub checkpointer: Option<Arc<dyn Checkpointer>>,
55 pub event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
56 pub enable_pii_sanitization: bool,
57 pub token_tracking_config: Option<TokenTrackingConfig>,
58 pub max_iterations: NonZeroUsize,
59}
60
61impl DeepAgentConfig {
62 pub fn new(instructions: impl Into<String>, planner: Arc<dyn PlannerHandle>) -> Self {
63 Self {
64 instructions: instructions.into(),
65 planner,
66 tools: Vec::new(),
67 subagent_configs: Vec::new(),
68 summarization: None,
69 tool_interrupts: HashMap::new(),
70 builtin_tools: None,
71 auto_general_purpose: true,
72 enable_prompt_caching: false,
73 checkpointer: None,
74 event_dispatcher: None,
75 enable_pii_sanitization: true, token_tracking_config: None,
77 max_iterations: NonZeroUsize::new(10).unwrap(),
78 }
79 }
80
81 pub fn with_tool(mut self, tool: ToolBox) -> Self {
82 self.tools.push(tool);
83 self
84 }
85
86 pub fn with_subagent_config(mut self, config: SubAgentConfig) -> Self {
88 self.subagent_configs.push(config);
89 self
90 }
91
92 pub fn with_subagent_configs<I>(mut self, configs: I) -> Self
94 where
95 I: IntoIterator<Item = SubAgentConfig>,
96 {
97 self.subagent_configs.extend(configs);
98 self
99 }
100
101 pub fn with_summarization(mut self, config: SummarizationConfig) -> Self {
102 self.summarization = Some(config);
103 self
104 }
105
106 pub fn with_tool_interrupt(mut self, tool_name: impl Into<String>, policy: HitlPolicy) -> Self {
107 self.tool_interrupts.insert(tool_name.into(), policy);
108 self
109 }
110
111 pub fn with_builtin_tools<I, S>(mut self, names: I) -> Self
115 where
116 I: IntoIterator<Item = S>,
117 S: Into<String>,
118 {
119 let set: HashSet<String> = names.into_iter().map(|s| s.into()).collect();
120 self.builtin_tools = Some(set);
121 self
122 }
123
124 pub fn with_auto_general_purpose(mut self, enabled: bool) -> Self {
127 self.auto_general_purpose = enabled;
128 self
129 }
130
131 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
134 self.enable_prompt_caching = enabled;
135 self
136 }
137
138 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
140 self.checkpointer = Some(checkpointer);
141 self
142 }
143
144 pub fn with_event_broadcaster(
146 mut self,
147 broadcaster: Arc<dyn agents_core::events::EventBroadcaster>,
148 ) -> Self {
149 if self.event_dispatcher.is_none() {
150 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
151 }
152 if let Some(dispatcher) = Arc::get_mut(self.event_dispatcher.as_mut().unwrap()) {
153 dispatcher.add_broadcaster(broadcaster);
154 }
155 self
156 }
157
158 pub fn with_event_dispatcher(
160 mut self,
161 dispatcher: Arc<agents_core::events::EventDispatcher>,
162 ) -> Self {
163 self.event_dispatcher = Some(dispatcher);
164 self
165 }
166
167 pub fn with_pii_sanitization(mut self, enabled: bool) -> Self {
176 self.enable_pii_sanitization = enabled;
177 self
178 }
179
180 pub fn with_token_tracking_config(mut self, config: TokenTrackingConfig) -> Self {
182 self.token_tracking_config = Some(config);
183 self
184 }
185
186 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
198 self.max_iterations =
199 NonZeroUsize::new(max_iterations).expect("max_iterations must be greater than 0");
200 self
201 }
202}
203
204pub struct SubAgentConfig {
223 pub name: String,
225 pub description: String,
226 pub instructions: String,
227
228 pub model: Option<Arc<dyn agents_core::llm::LanguageModel>>,
230 pub tools: Option<Vec<ToolBox>>,
231 pub builtin_tools: Option<HashSet<String>>,
232 pub enable_prompt_caching: bool,
233}
234
235impl SubAgentConfig {
236 pub fn new(
238 name: impl Into<String>,
239 description: impl Into<String>,
240 instructions: impl Into<String>,
241 ) -> Self {
242 Self {
243 name: name.into(),
244 description: description.into(),
245 instructions: instructions.into(),
246 model: None,
247 tools: None,
248 builtin_tools: None,
249 enable_prompt_caching: false,
250 }
251 }
252
253 pub fn with_model(mut self, model: Arc<dyn agents_core::llm::LanguageModel>) -> Self {
255 self.model = Some(model);
256 self
257 }
258
259 pub fn with_tools(mut self, tools: Vec<ToolBox>) -> Self {
261 self.tools = Some(tools);
262 self
263 }
264
265 pub fn with_builtin_tools(mut self, tools: HashSet<String>) -> Self {
267 self.builtin_tools = Some(tools);
268 self
269 }
270
271 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
273 self.enable_prompt_caching = enabled;
274 self
275 }
276}
277
278impl IntoIterator for SubAgentConfig {
279 type Item = SubAgentConfig;
280 type IntoIter = std::iter::Once<SubAgentConfig>;
281
282 fn into_iter(self) -> Self::IntoIter {
283 std::iter::once(self)
284 }
285}
286
287#[derive(Clone)]
289pub struct SummarizationConfig {
290 pub messages_to_keep: usize,
291 pub summary_note: String,
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::planner::LlmBackedPlanner;
298 use std::sync::Arc;
299
300 fn create_mock_planner() -> Arc<dyn PlannerHandle> {
302 use crate::providers::{OpenAiChatModel, OpenAiConfig};
305 use agents_core::llm::LanguageModel;
306
307 let config = OpenAiConfig {
309 api_key: "test-key".to_string(),
310 model: "gpt-4o-mini".to_string(),
311 api_url: None,
312 custom_headers: Vec::new(),
313 };
314
315 let model: Arc<dyn LanguageModel> =
316 Arc::new(OpenAiChatModel::new(config).expect("Failed to create test model"));
317 Arc::new(LlmBackedPlanner::new(model))
318 }
319
320 #[test]
321 fn test_config_default_max_iterations() {
322 let planner = create_mock_planner();
323 let config = DeepAgentConfig::new("test instructions", planner);
324 assert_eq!(config.max_iterations.get(), 10);
325 }
326
327 #[test]
328 fn test_config_custom_max_iterations() {
329 let planner = create_mock_planner();
330 let config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(25);
331 assert_eq!(config.max_iterations.get(), 25);
332 }
333
334 #[test]
335 fn test_config_chaining_with_max_iterations() {
336 let planner = create_mock_planner();
337 let config = DeepAgentConfig::new("test instructions", planner)
338 .with_max_iterations(30)
339 .with_auto_general_purpose(false)
340 .with_prompt_caching(true)
341 .with_pii_sanitization(false);
342
343 assert_eq!(config.max_iterations.get(), 30);
344 assert!(!config.auto_general_purpose);
345 assert!(config.enable_prompt_caching);
346 assert!(!config.enable_pii_sanitization);
347 }
348
349 #[test]
350 fn test_config_max_iterations_persists() {
351 let planner = create_mock_planner();
352 let config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(42);
353
354 assert_eq!(config.max_iterations.get(), 42);
356 }
357
358 #[test]
359 #[should_panic(expected = "max_iterations must be greater than 0")]
360 fn test_config_zero_max_iterations_panics() {
361 let planner = create_mock_planner();
362 let _config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(0);
363 }
364
365 #[test]
366 fn test_config_max_iterations_with_other_options() {
367 let planner = create_mock_planner();
368
369 let config =
371 DeepAgentConfig::new("test instructions", planner.clone()).with_max_iterations(5);
372 assert_eq!(config.max_iterations.get(), 5);
373
374 let config2 = DeepAgentConfig::new("test instructions", planner.clone())
375 .with_prompt_caching(true)
376 .with_max_iterations(15);
377 assert_eq!(config2.max_iterations.get(), 15);
378 assert!(config2.enable_prompt_caching);
379
380 let config3 = DeepAgentConfig::new("test instructions", planner)
381 .with_auto_general_purpose(false)
382 .with_max_iterations(100)
383 .with_pii_sanitization(true);
384 assert_eq!(config3.max_iterations.get(), 100);
385 assert!(!config3.auto_general_purpose);
386 assert!(config3.enable_pii_sanitization);
387 }
388}