1use crate::middleware::{token_tracking::TokenTrackingConfig, AgentMiddleware, HitlPolicy};
7use crate::prompts::PromptFormat;
8use agents_core::agent::PlannerHandle;
9use agents_core::persistence::Checkpointer;
10use agents_core::tools::ToolBox;
11use std::collections::{HashMap, HashSet};
12use std::num::NonZeroUsize;
13use std::sync::Arc;
14
15#[derive(Default)]
31pub struct CreateDeepAgentParams {
32 pub tools: Vec<ToolBox>,
33 pub instructions: String,
34 pub middleware: Vec<Arc<dyn AgentMiddleware>>,
35 pub model: Option<Arc<dyn agents_core::llm::LanguageModel>>,
36 pub subagents: Vec<SubAgentConfig>,
37 pub context_schema: Option<String>,
38 pub checkpointer: Option<Arc<dyn Checkpointer>>,
39 pub tool_configs: HashMap<String, HitlPolicy>,
40}
41
42pub struct DeepAgentConfig {
46 pub instructions: String,
47 pub custom_system_prompt: Option<String>,
49 pub prompt_format: PromptFormat,
51 pub planner: Arc<dyn PlannerHandle>,
52 pub tools: Vec<ToolBox>,
53 pub subagent_configs: Vec<SubAgentConfig>,
54 pub summarization: Option<SummarizationConfig>,
55 pub tool_interrupts: HashMap<String, HitlPolicy>,
56 pub builtin_tools: Option<HashSet<String>>,
57 pub auto_general_purpose: bool,
58 pub enable_prompt_caching: bool,
59 pub checkpointer: Option<Arc<dyn Checkpointer>>,
60 pub event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
61 pub enable_pii_sanitization: bool,
62 pub token_tracking_config: Option<TokenTrackingConfig>,
63 pub max_iterations: NonZeroUsize,
64}
65
66impl DeepAgentConfig {
67 pub fn new(instructions: impl Into<String>, planner: Arc<dyn PlannerHandle>) -> Self {
68 Self {
69 instructions: instructions.into(),
70 custom_system_prompt: None,
71 prompt_format: PromptFormat::default(),
72 planner,
73 tools: Vec::new(),
74 subagent_configs: Vec::new(),
75 summarization: None,
76 tool_interrupts: HashMap::new(),
77 builtin_tools: None,
78 auto_general_purpose: true,
79 enable_prompt_caching: false,
80 checkpointer: None,
81 event_dispatcher: None,
82 enable_pii_sanitization: true, token_tracking_config: None,
84 max_iterations: NonZeroUsize::new(10).unwrap(),
85 }
86 }
87
88 pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
92 self.custom_system_prompt = Some(system_prompt.into());
93 self
94 }
95
96 pub fn with_prompt_format(mut self, format: PromptFormat) -> Self {
103 self.prompt_format = format;
104 self
105 }
106
107 pub fn with_tool(mut self, tool: ToolBox) -> Self {
108 self.tools.push(tool);
109 self
110 }
111
112 pub fn with_subagent_config(mut self, config: SubAgentConfig) -> Self {
114 self.subagent_configs.push(config);
115 self
116 }
117
118 pub fn with_subagent_configs<I>(mut self, configs: I) -> Self
120 where
121 I: IntoIterator<Item = SubAgentConfig>,
122 {
123 self.subagent_configs.extend(configs);
124 self
125 }
126
127 pub fn with_summarization(mut self, config: SummarizationConfig) -> Self {
128 self.summarization = Some(config);
129 self
130 }
131
132 pub fn with_tool_interrupt(mut self, tool_name: impl Into<String>, policy: HitlPolicy) -> Self {
133 self.tool_interrupts.insert(tool_name.into(), policy);
134 self
135 }
136
137 pub fn with_builtin_tools<I, S>(mut self, names: I) -> Self
141 where
142 I: IntoIterator<Item = S>,
143 S: Into<String>,
144 {
145 let set: HashSet<String> = names.into_iter().map(|s| s.into()).collect();
146 self.builtin_tools = Some(set);
147 self
148 }
149
150 pub fn with_auto_general_purpose(mut self, enabled: bool) -> Self {
153 self.auto_general_purpose = enabled;
154 self
155 }
156
157 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
160 self.enable_prompt_caching = enabled;
161 self
162 }
163
164 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
166 self.checkpointer = Some(checkpointer);
167 self
168 }
169
170 pub fn with_event_broadcaster(
172 mut self,
173 broadcaster: Arc<dyn agents_core::events::EventBroadcaster>,
174 ) -> Self {
175 if self.event_dispatcher.is_none() {
176 self.event_dispatcher = Some(Arc::new(agents_core::events::EventDispatcher::new()));
177 }
178 if let Some(dispatcher) = Arc::get_mut(self.event_dispatcher.as_mut().unwrap()) {
179 dispatcher.add_broadcaster(broadcaster);
180 }
181 self
182 }
183
184 pub fn with_event_dispatcher(
186 mut self,
187 dispatcher: Arc<agents_core::events::EventDispatcher>,
188 ) -> Self {
189 self.event_dispatcher = Some(dispatcher);
190 self
191 }
192
193 pub fn with_pii_sanitization(mut self, enabled: bool) -> Self {
202 self.enable_pii_sanitization = enabled;
203 self
204 }
205
206 pub fn with_token_tracking_config(mut self, config: TokenTrackingConfig) -> Self {
208 self.token_tracking_config = Some(config);
209 self
210 }
211
212 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
224 self.max_iterations =
225 NonZeroUsize::new(max_iterations).expect("max_iterations must be greater than 0");
226 self
227 }
228}
229
230pub struct SubAgentConfig {
249 pub name: String,
251 pub description: String,
252 pub instructions: String,
253
254 pub model: Option<Arc<dyn agents_core::llm::LanguageModel>>,
256 pub tools: Option<Vec<ToolBox>>,
257 pub builtin_tools: Option<HashSet<String>>,
258 pub enable_prompt_caching: bool,
259}
260
261impl SubAgentConfig {
262 pub fn new(
264 name: impl Into<String>,
265 description: impl Into<String>,
266 instructions: impl Into<String>,
267 ) -> Self {
268 Self {
269 name: name.into(),
270 description: description.into(),
271 instructions: instructions.into(),
272 model: None,
273 tools: None,
274 builtin_tools: None,
275 enable_prompt_caching: false,
276 }
277 }
278
279 pub fn with_model(mut self, model: Arc<dyn agents_core::llm::LanguageModel>) -> Self {
281 self.model = Some(model);
282 self
283 }
284
285 pub fn with_tools(mut self, tools: Vec<ToolBox>) -> Self {
287 self.tools = Some(tools);
288 self
289 }
290
291 pub fn with_builtin_tools(mut self, tools: HashSet<String>) -> Self {
293 self.builtin_tools = Some(tools);
294 self
295 }
296
297 pub fn with_prompt_caching(mut self, enabled: bool) -> Self {
299 self.enable_prompt_caching = enabled;
300 self
301 }
302}
303
304impl IntoIterator for SubAgentConfig {
305 type Item = SubAgentConfig;
306 type IntoIter = std::iter::Once<SubAgentConfig>;
307
308 fn into_iter(self) -> Self::IntoIter {
309 std::iter::once(self)
310 }
311}
312
313#[derive(Clone)]
315pub struct SummarizationConfig {
316 pub messages_to_keep: usize,
317 pub summary_note: String,
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::planner::LlmBackedPlanner;
324 use std::sync::Arc;
325
326 fn create_mock_planner() -> Arc<dyn PlannerHandle> {
328 use crate::providers::{OpenAiChatModel, OpenAiConfig};
331 use agents_core::llm::LanguageModel;
332
333 let config = OpenAiConfig {
335 api_key: "test-key".to_string(),
336 model: "gpt-4o-mini".to_string(),
337 api_url: None,
338 custom_headers: Vec::new(),
339 };
340
341 let model: Arc<dyn LanguageModel> =
342 Arc::new(OpenAiChatModel::new(config).expect("Failed to create test model"));
343 Arc::new(LlmBackedPlanner::new(model))
344 }
345
346 #[test]
347 fn test_config_default_max_iterations() {
348 let planner = create_mock_planner();
349 let config = DeepAgentConfig::new("test instructions", planner);
350 assert_eq!(config.max_iterations.get(), 10);
351 }
352
353 #[test]
354 fn test_config_custom_max_iterations() {
355 let planner = create_mock_planner();
356 let config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(25);
357 assert_eq!(config.max_iterations.get(), 25);
358 }
359
360 #[test]
361 fn test_config_chaining_with_max_iterations() {
362 let planner = create_mock_planner();
363 let config = DeepAgentConfig::new("test instructions", planner)
364 .with_max_iterations(30)
365 .with_auto_general_purpose(false)
366 .with_prompt_caching(true)
367 .with_pii_sanitization(false);
368
369 assert_eq!(config.max_iterations.get(), 30);
370 assert!(!config.auto_general_purpose);
371 assert!(config.enable_prompt_caching);
372 assert!(!config.enable_pii_sanitization);
373 }
374
375 #[test]
376 fn test_config_max_iterations_persists() {
377 let planner = create_mock_planner();
378 let config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(42);
379
380 assert_eq!(config.max_iterations.get(), 42);
382 }
383
384 #[test]
385 #[should_panic(expected = "max_iterations must be greater than 0")]
386 fn test_config_zero_max_iterations_panics() {
387 let planner = create_mock_planner();
388 let _config = DeepAgentConfig::new("test instructions", planner).with_max_iterations(0);
389 }
390
391 #[test]
392 fn test_config_max_iterations_with_other_options() {
393 let planner = create_mock_planner();
394
395 let config =
397 DeepAgentConfig::new("test instructions", planner.clone()).with_max_iterations(5);
398 assert_eq!(config.max_iterations.get(), 5);
399
400 let config2 = DeepAgentConfig::new("test instructions", planner.clone())
401 .with_prompt_caching(true)
402 .with_max_iterations(15);
403 assert_eq!(config2.max_iterations.get(), 15);
404 assert!(config2.enable_prompt_caching);
405
406 let config3 = DeepAgentConfig::new("test instructions", planner)
407 .with_auto_general_purpose(false)
408 .with_max_iterations(100)
409 .with_pii_sanitization(true);
410 assert_eq!(config3.max_iterations.get(), 100);
411 assert!(!config3.auto_general_purpose);
412 assert!(config3.enable_pii_sanitization);
413 }
414}