1use crate::middleware::{token_tracking::TokenTrackingConfig, AgentMiddleware, HitlPolicy};
7use agents_core::agent::PlannerHandle;
8use agents_core::persistence::Checkpointer;
9use agents_core::tools::ToolBox;
10use std::collections::{HashMap, HashSet};
11use std::num::NonZeroUsize;
12use std::sync::Arc;
13
14#[derive(Default)]
30pub struct CreateDeepAgentParams {
31 pub tools: Vec<ToolBox>,
32 pub instructions: String,
33 pub middleware: Vec<Arc<dyn AgentMiddleware>>,
34 pub model: Option<Arc<dyn agents_core::llm::LanguageModel>>,
35 pub subagents: Vec<SubAgentConfig>,
36 pub context_schema: Option<String>,
37 pub checkpointer: Option<Arc<dyn Checkpointer>>,
38 pub tool_configs: HashMap<String, HitlPolicy>,
39}
40
41pub struct DeepAgentConfig {
45 pub instructions: String,
46 pub custom_system_prompt: Option<String>,
48 pub planner: Arc<dyn PlannerHandle>,
49 pub tools: Vec<ToolBox>,
50 pub subagent_configs: Vec<SubAgentConfig>,
51 pub summarization: Option<SummarizationConfig>,
52 pub tool_interrupts: HashMap<String, HitlPolicy>,
53 pub builtin_tools: Option<HashSet<String>>,
54 pub auto_general_purpose: bool,
55 pub enable_prompt_caching: bool,
56 pub checkpointer: Option<Arc<dyn Checkpointer>>,
57 pub event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
58 pub enable_pii_sanitization: bool,
59 pub token_tracking_config: Option<TokenTrackingConfig>,
60 pub max_iterations: NonZeroUsize,
61}
62
63impl DeepAgentConfig {
64 pub fn new(instructions: impl Into<String>, planner: Arc<dyn PlannerHandle>) -> Self {
65 Self {
66 instructions: instructions.into(),
67 custom_system_prompt: None,
68 planner,
69 tools: Vec::new(),
70 subagent_configs: Vec::new(),
71 summarization: None,
72 tool_interrupts: HashMap::new(),
73 builtin_tools: None,
74 auto_general_purpose: true,
75 enable_prompt_caching: false,
76 checkpointer: None,
77 event_dispatcher: None,
78 enable_pii_sanitization: true, token_tracking_config: None,
80 max_iterations: NonZeroUsize::new(10).unwrap(),
81 }
82 }
83
84 pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
88 self.custom_system_prompt = Some(system_prompt.into());
89 self
90 }
91
92 pub fn with_tool(mut self, tool: ToolBox) -> Self {
93 self.tools.push(tool);
94 self
95 }
96
97 pub fn with_subagent_config(mut self, config: SubAgentConfig) -> Self {
99 self.subagent_configs.push(config);
100 self
101 }
102
103 pub fn with_subagent_configs<I>(mut self, configs: I) -> Self
105 where
106 I: IntoIterator<Item = SubAgentConfig>,
107 {
108 self.subagent_configs.extend(configs);
109 self
110 }
111
112 pub fn with_summarization(mut self, config: SummarizationConfig) -> Self {
113 self.summarization = Some(config);
114 self
115 }
116
117 pub fn with_tool_interrupt(mut self, tool_name: impl Into<String>, policy: HitlPolicy) -> Self {
118 self.tool_interrupts.insert(tool_name.into(), policy);
119 self
120 }
121
122 pub fn with_builtin_tools<I, S>(mut self, names: I) -> Self
126 where
127 I: IntoIterator<Item = S>,
128 S: Into<String>,
129 {
130 let set: HashSet<String> = names.into_iter().map(|s| s.into()).collect();
131 self.builtin_tools = Some(set);
132 self
133 }
134
135 pub fn with_auto_general_purpose(mut self, enabled: bool) -> Self {
138 self.auto_general_purpose = enabled;
139 self
140 }
141
142 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
145 self.enable_prompt_caching = enabled;
146 self
147 }
148
149 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
151 self.checkpointer = Some(checkpointer);
152 self
153 }
154
155 pub fn with_event_broadcaster(
157 mut self,
158 broadcaster: Arc<dyn agents_core::events::EventBroadcaster>,
159 ) -> Self {
160 if self.event_dispatcher.is_none() {
161 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
162 }
163 if let Some(dispatcher) = Arc::get_mut(self.event_dispatcher.as_mut().unwrap()) {
164 dispatcher.add_broadcaster(broadcaster);
165 }
166 self
167 }
168
169 pub fn with_event_dispatcher(
171 mut self,
172 dispatcher: Arc<agents_core::events::EventDispatcher>,
173 ) -> Self {
174 self.event_dispatcher = Some(dispatcher);
175 self
176 }
177
178 pub fn with_pii_sanitization(mut self, enabled: bool) -> Self {
187 self.enable_pii_sanitization = enabled;
188 self
189 }
190
191 pub fn with_token_tracking_config(mut self, config: TokenTrackingConfig) -> Self {
193 self.token_tracking_config = Some(config);
194 self
195 }
196
197 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
209 self.max_iterations =
210 NonZeroUsize::new(max_iterations).expect("max_iterations must be greater than 0");
211 self
212 }
213}
214
215pub struct SubAgentConfig {
234 pub name: String,
236 pub description: String,
237 pub instructions: String,
238
239 pub model: Option<Arc<dyn agents_core::llm::LanguageModel>>,
241 pub tools: Option<Vec<ToolBox>>,
242 pub builtin_tools: Option<HashSet<String>>,
243 pub enable_prompt_caching: bool,
244}
245
246impl SubAgentConfig {
247 pub fn new(
249 name: impl Into<String>,
250 description: impl Into<String>,
251 instructions: impl Into<String>,
252 ) -> Self {
253 Self {
254 name: name.into(),
255 description: description.into(),
256 instructions: instructions.into(),
257 model: None,
258 tools: None,
259 builtin_tools: None,
260 enable_prompt_caching: false,
261 }
262 }
263
264 pub fn with_model(mut self, model: Arc<dyn agents_core::llm::LanguageModel>) -> Self {
266 self.model = Some(model);
267 self
268 }
269
270 pub fn with_tools(mut self, tools: Vec<ToolBox>) -> Self {
272 self.tools = Some(tools);
273 self
274 }
275
276 pub fn with_builtin_tools(mut self, tools: HashSet<String>) -> Self {
278 self.builtin_tools = Some(tools);
279 self
280 }
281
282 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
284 self.enable_prompt_caching = enabled;
285 self
286 }
287}
288
289impl IntoIterator for SubAgentConfig {
290 type Item = SubAgentConfig;
291 type IntoIter = std::iter::Once<SubAgentConfig>;
292
293 fn into_iter(self) -> Self::IntoIter {
294 std::iter::once(self)
295 }
296}
297
298#[derive(Clone)]
300pub struct SummarizationConfig {
301 pub messages_to_keep: usize,
302 pub summary_note: String,
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::planner::LlmBackedPlanner;
309 use std::sync::Arc;
310
311 fn create_mock_planner() -> Arc<dyn PlannerHandle> {
313 use crate::providers::{OpenAiChatModel, OpenAiConfig};
316 use agents_core::llm::LanguageModel;
317
318 let config = OpenAiConfig {
320 api_key: "test-key".to_string(),
321 model: "gpt-4o-mini".to_string(),
322 api_url: None,
323 custom_headers: Vec::new(),
324 };
325
326 let model: Arc<dyn LanguageModel> =
327 Arc::new(OpenAiChatModel::new(config).expect("Failed to create test model"));
328 Arc::new(LlmBackedPlanner::new(model))
329 }
330
331 #[test]
332 fn test_config_default_max_iterations() {
333 let planner = create_mock_planner();
334 let config = DeepAgentConfig::new("test instructions", planner);
335 assert_eq!(config.max_iterations.get(), 10);
336 }
337
338 #[test]
339 fn test_config_custom_max_iterations() {
340 let planner = create_mock_planner();
341 let config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(25);
342 assert_eq!(config.max_iterations.get(), 25);
343 }
344
345 #[test]
346 fn test_config_chaining_with_max_iterations() {
347 let planner = create_mock_planner();
348 let config = DeepAgentConfig::new("test instructions", planner)
349 .with_max_iterations(30)
350 .with_auto_general_purpose(false)
351 .with_prompt_caching(true)
352 .with_pii_sanitization(false);
353
354 assert_eq!(config.max_iterations.get(), 30);
355 assert!(!config.auto_general_purpose);
356 assert!(config.enable_prompt_caching);
357 assert!(!config.enable_pii_sanitization);
358 }
359
360 #[test]
361 fn test_config_max_iterations_persists() {
362 let planner = create_mock_planner();
363 let config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(42);
364
365 assert_eq!(config.max_iterations.get(), 42);
367 }
368
369 #[test]
370 #[should_panic(expected = "max_iterations must be greater than 0")]
371 fn test_config_zero_max_iterations_panics() {
372 let planner = create_mock_planner();
373 let _config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(0);
374 }
375
376 #[test]
377 fn test_config_max_iterations_with_other_options() {
378 let planner = create_mock_planner();
379
380 let config =
382 DeepAgentConfig::new("test instructions", planner.clone()).with_max_iterations(5);
383 assert_eq!(config.max_iterations.get(), 5);
384
385 let config2 = DeepAgentConfig::new("test instructions", planner.clone())
386 .with_prompt_caching(true)
387 .with_max_iterations(15);
388 assert_eq!(config2.max_iterations.get(), 15);
389 assert!(config2.enable_prompt_caching);
390
391 let config3 = DeepAgentConfig::new("test instructions", planner)
392 .with_auto_general_purpose(false)
393 .with_max_iterations(100)
394 .with_pii_sanitization(true);
395 assert_eq!(config3.max_iterations.get(), 100);
396 assert!(!config3.auto_general_purpose);
397 assert!(config3.enable_pii_sanitization);
398 }
399}