1use std::collections::HashMap;
6use std::path::PathBuf;
7use std::time::Duration;
8
9use rust_decimal::Decimal;
10
11use crate::client::messages::DEFAULT_MAX_TOKENS;
12use crate::output_style::OutputStyle;
13use crate::permissions::PermissionPolicy;
14use crate::tools::ToolAccess;
15
16#[derive(Debug, Clone)]
18pub struct AgentModelConfig {
19 pub primary: String,
21 pub small: String,
23 pub max_tokens: u32,
25 pub extended_context: bool,
27}
28
29impl Default for AgentModelConfig {
30 fn default() -> Self {
31 Self {
32 primary: crate::client::DEFAULT_MODEL.to_string(),
33 small: crate::client::DEFAULT_SMALL_MODEL.to_string(),
34 max_tokens: DEFAULT_MAX_TOKENS,
35 extended_context: false,
36 }
37 }
38}
39
40impl AgentModelConfig {
41 pub fn new(primary: impl Into<String>) -> Self {
42 Self {
43 primary: primary.into(),
44 ..Default::default()
45 }
46 }
47
48 pub fn small(mut self, small: impl Into<String>) -> Self {
49 self.small = small.into();
50 self
51 }
52
53 pub fn max_tokens(mut self, tokens: u32) -> Self {
54 self.max_tokens = tokens;
55 self
56 }
57
58 pub fn extended_context(mut self, enabled: bool) -> Self {
59 self.extended_context = enabled;
60 self
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct ExecutionConfig {
67 pub max_iterations: usize,
69 pub timeout: Option<Duration>,
71 pub chunk_timeout: Duration,
73 pub auto_compact: bool,
75 pub compact_threshold: f32,
77 pub compact_keep_messages: usize,
79}
80
81impl Default for ExecutionConfig {
82 fn default() -> Self {
83 Self {
84 max_iterations: 100,
85 timeout: Some(Duration::from_secs(300)),
86 chunk_timeout: Duration::from_secs(60),
87 auto_compact: true,
88 compact_threshold: crate::session::compact::DEFAULT_COMPACT_THRESHOLD,
89 compact_keep_messages: 4,
90 }
91 }
92}
93
94impl ExecutionConfig {
95 pub fn max_iterations(mut self, max: usize) -> Self {
96 self.max_iterations = max;
97 self
98 }
99
100 pub fn timeout(mut self, timeout: Duration) -> Self {
101 self.timeout = Some(timeout);
102 self
103 }
104
105 pub fn without_timeout(mut self) -> Self {
106 self.timeout = None;
107 self
108 }
109
110 pub fn chunk_timeout(mut self, timeout: Duration) -> Self {
111 self.chunk_timeout = timeout;
112 self
113 }
114
115 pub fn auto_compact(mut self, enabled: bool) -> Self {
116 self.auto_compact = enabled;
117 self
118 }
119
120 pub fn compact_threshold(mut self, threshold: f32) -> Self {
121 self.compact_threshold = threshold.clamp(0.0, 1.0);
122 self
123 }
124
125 pub fn compact_keep_messages(mut self, count: usize) -> Self {
126 self.compact_keep_messages = count;
127 self
128 }
129}
130
131#[derive(Debug, Clone, Default)]
133pub struct SecurityConfig {
134 pub permission_policy: PermissionPolicy,
136 pub tool_access: ToolAccess,
138 pub env: HashMap<String, String>,
140}
141
142impl SecurityConfig {
143 pub fn permissive() -> Self {
144 Self {
145 permission_policy: PermissionPolicy::permissive(),
146 tool_access: ToolAccess::All,
147 ..Default::default()
148 }
149 }
150
151 pub fn read_only() -> Self {
152 Self {
153 permission_policy: PermissionPolicy::read_only(),
154 tool_access: ToolAccess::only(["Read", "Glob", "Grep", "Task", "TaskOutput"]),
155 ..Default::default()
156 }
157 }
158
159 pub fn permission_policy(mut self, policy: PermissionPolicy) -> Self {
160 self.permission_policy = policy;
161 self
162 }
163
164 pub fn tool_access(mut self, access: ToolAccess) -> Self {
165 self.tool_access = access;
166 self
167 }
168
169 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
170 self.env.insert(key.into(), value.into());
171 self
172 }
173
174 pub fn envs(
175 mut self,
176 vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
177 ) -> Self {
178 for (k, v) in vars {
179 self.env.insert(k.into(), v.into());
180 }
181 self
182 }
183}
184
185#[derive(Debug, Clone, Default)]
187pub struct BudgetConfig {
188 pub max_cost_usd: Option<Decimal>,
190 pub tenant_id: Option<String>,
192 pub fallback_model: Option<String>,
194}
195
196impl BudgetConfig {
197 pub fn unlimited() -> Self {
198 Self::default()
199 }
200
201 pub fn max_cost(mut self, usd: Decimal) -> Self {
202 self.max_cost_usd = Some(usd);
203 self
204 }
205
206 pub fn tenant(mut self, tenant_id: impl Into<String>) -> Self {
207 self.tenant_id = Some(tenant_id.into());
208 self
209 }
210
211 pub fn fallback(mut self, model: impl Into<String>) -> Self {
212 self.fallback_model = Some(model.into());
213 self
214 }
215}
216
217#[derive(Debug, Clone, Default)]
219pub struct PromptConfig {
220 pub system_prompt: Option<String>,
222 pub system_prompt_mode: SystemPromptMode,
224 pub output_style: Option<OutputStyle>,
226 pub output_schema: Option<serde_json::Value>,
228}
229
230#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
231pub enum SystemPromptMode {
232 #[default]
234 Replace,
235 Append,
237}
238
239impl PromptConfig {
240 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
241 self.system_prompt = Some(prompt.into());
242 self
243 }
244
245 pub fn append_mode(mut self) -> Self {
246 self.system_prompt_mode = SystemPromptMode::Append;
247 self
248 }
249
250 pub fn output_style(mut self, style: OutputStyle) -> Self {
251 self.output_style = Some(style);
252 self
253 }
254
255 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
256 self.output_schema = Some(schema);
257 self
258 }
259
260 pub fn structured_output<T: schemars::JsonSchema>(mut self) -> Self {
261 let schema = schemars::schema_for!(T);
262 self.output_schema = serde_json::to_value(schema).ok();
263 self
264 }
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
272pub enum CacheStrategy {
273 Disabled,
275 SystemOnly,
277 MessagesOnly,
279 #[default]
281 Full,
282}
283
284impl CacheStrategy {
285 pub fn cache_system(&self) -> bool {
287 matches!(self, Self::SystemOnly | Self::Full)
288 }
289
290 pub fn cache_messages(&self) -> bool {
292 matches!(self, Self::MessagesOnly | Self::Full)
293 }
294
295 pub fn is_enabled(&self) -> bool {
297 !matches!(self, Self::Disabled)
298 }
299}
300
301#[derive(Debug, Clone)]
308pub struct CacheConfig {
309 pub strategy: CacheStrategy,
311 pub static_ttl: crate::types::CacheTtl,
313 pub message_ttl: crate::types::CacheTtl,
315}
316
317impl Default for CacheConfig {
318 fn default() -> Self {
319 Self {
320 strategy: CacheStrategy::Full,
321 static_ttl: crate::types::CacheTtl::OneHour,
322 message_ttl: crate::types::CacheTtl::FiveMinutes,
323 }
324 }
325}
326
327impl CacheConfig {
328 pub fn disabled() -> Self {
330 Self {
331 strategy: CacheStrategy::Disabled,
332 ..Default::default()
333 }
334 }
335
336 pub fn system_only() -> Self {
338 Self {
339 strategy: CacheStrategy::SystemOnly,
340 ..Default::default()
341 }
342 }
343
344 pub fn messages_only() -> Self {
346 Self {
347 strategy: CacheStrategy::MessagesOnly,
348 ..Default::default()
349 }
350 }
351
352 pub fn strategy(mut self, strategy: CacheStrategy) -> Self {
354 self.strategy = strategy;
355 self
356 }
357
358 pub fn static_ttl(mut self, ttl: crate::types::CacheTtl) -> Self {
360 self.static_ttl = ttl;
361 self
362 }
363
364 pub fn message_ttl(mut self, ttl: crate::types::CacheTtl) -> Self {
366 self.message_ttl = ttl;
367 self
368 }
369
370 pub fn message_ttl_option(&self) -> Option<crate::types::CacheTtl> {
375 if self.strategy.cache_messages() {
376 Some(self.message_ttl)
377 } else {
378 None
379 }
380 }
381}
382
383#[derive(Debug, Clone, Default)]
388pub struct ServerToolsConfig {
389 pub web_search: Option<crate::types::WebSearchTool>,
390 pub web_fetch: Option<crate::types::WebFetchTool>,
391}
392
393impl ServerToolsConfig {
394 pub fn all() -> Self {
395 Self {
396 web_search: Some(crate::types::WebSearchTool::default()),
397 web_fetch: Some(crate::types::WebFetchTool::default()),
398 }
399 }
400
401 pub fn web_search(mut self, config: crate::types::WebSearchTool) -> Self {
402 self.web_search = Some(config);
403 self
404 }
405
406 pub fn web_fetch(mut self, config: crate::types::WebFetchTool) -> Self {
407 self.web_fetch = Some(config);
408 self
409 }
410}
411
412#[derive(Debug, Clone, Default)]
414pub struct AgentConfig {
415 pub model: AgentModelConfig,
416 pub execution: ExecutionConfig,
417 pub security: SecurityConfig,
418 pub budget: BudgetConfig,
419 pub prompt: PromptConfig,
420 pub cache: CacheConfig,
421 pub working_dir: Option<PathBuf>,
422 pub server_tools: ServerToolsConfig,
423 pub coding_mode: bool,
424}
425
426impl AgentConfig {
427 pub fn new() -> Self {
428 Self::default()
429 }
430
431 pub fn model(mut self, config: AgentModelConfig) -> Self {
432 self.model = config;
433 self
434 }
435
436 pub fn execution(mut self, config: ExecutionConfig) -> Self {
437 self.execution = config;
438 self
439 }
440
441 pub fn security(mut self, config: SecurityConfig) -> Self {
442 self.security = config;
443 self
444 }
445
446 pub fn budget(mut self, config: BudgetConfig) -> Self {
447 self.budget = config;
448 self
449 }
450
451 pub fn prompt(mut self, config: PromptConfig) -> Self {
452 self.prompt = config;
453 self
454 }
455
456 pub fn cache(mut self, config: CacheConfig) -> Self {
457 self.cache = config;
458 self
459 }
460
461 pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
462 self.working_dir = Some(dir.into());
463 self
464 }
465
466 pub fn server_tools(mut self, config: ServerToolsConfig) -> Self {
467 self.server_tools = config;
468 self
469 }
470
471 pub fn coding_mode(mut self, enabled: bool) -> Self {
472 self.coding_mode = enabled;
473 self
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use rust_decimal_macros::dec;
480
481 use super::*;
482
483 #[test]
484 fn test_model_config() {
485 let config = AgentModelConfig::new("claude-opus-4-6")
486 .small("claude-haiku")
487 .max_tokens(4096);
488
489 assert_eq!(config.primary, "claude-opus-4-6");
490 assert_eq!(config.small, "claude-haiku");
491 assert_eq!(config.max_tokens, 4096);
492 }
493
494 #[test]
495 fn test_execution_config() {
496 let config = ExecutionConfig::default()
497 .max_iterations(50)
498 .timeout(Duration::from_secs(600))
499 .auto_compact(false);
500
501 assert_eq!(config.max_iterations, 50);
502 assert_eq!(config.timeout, Some(Duration::from_secs(600)));
503 assert!(!config.auto_compact);
504 }
505
506 #[test]
507 fn test_security_config() {
508 let config = SecurityConfig::permissive().env("API_KEY", "secret");
509
510 assert_eq!(config.env.get("API_KEY"), Some(&"secret".to_string()));
511 }
512
513 #[test]
514 fn test_budget_config() {
515 let config = BudgetConfig::unlimited()
516 .max_cost(dec!(10))
517 .tenant("org-123")
518 .fallback("claude-haiku");
519
520 assert_eq!(config.max_cost_usd, Some(dec!(10)));
521 assert_eq!(config.tenant_id, Some("org-123".to_string()));
522 assert_eq!(config.fallback_model, Some("claude-haiku".to_string()));
523 }
524
525 #[test]
526 fn test_agent_config() {
527 let config = AgentConfig::new()
528 .model(AgentModelConfig::new("claude-opus-4-6"))
529 .budget(BudgetConfig::unlimited().max_cost(dec!(5)))
530 .working_dir("/project");
531
532 assert_eq!(config.model.primary, "claude-opus-4-6");
533 assert_eq!(config.budget.max_cost_usd, Some(dec!(5)));
534 assert_eq!(config.working_dir, Some(PathBuf::from("/project")));
535 }
536
537 #[test]
538 fn test_cache_strategy_default_is_full() {
539 let config = CacheConfig::default();
540 assert_eq!(config.strategy, CacheStrategy::Full);
541 assert_eq!(config.static_ttl, crate::types::CacheTtl::OneHour);
542 assert_eq!(config.message_ttl, crate::types::CacheTtl::FiveMinutes);
543 }
544
545 #[test]
546 fn test_cache_strategy_disabled() {
547 let config = CacheConfig::disabled();
548 assert_eq!(config.strategy, CacheStrategy::Disabled);
549 assert!(!config.strategy.is_enabled());
550 assert!(!config.strategy.cache_system());
551 assert!(!config.strategy.cache_messages());
552 }
553
554 #[test]
555 fn test_cache_strategy_system_only() {
556 let config = CacheConfig::system_only();
557 assert_eq!(config.strategy, CacheStrategy::SystemOnly);
558 assert!(config.strategy.is_enabled());
559 assert!(config.strategy.cache_system());
560 assert!(!config.strategy.cache_messages());
561 }
562
563 #[test]
564 fn test_cache_strategy_messages_only() {
565 let config = CacheConfig::messages_only();
566 assert_eq!(config.strategy, CacheStrategy::MessagesOnly);
567 assert!(config.strategy.is_enabled());
568 assert!(!config.strategy.cache_system());
569 assert!(config.strategy.cache_messages());
570 }
571
572 #[test]
573 fn test_cache_strategy_full() {
574 let config = CacheConfig::default();
575 assert!(config.strategy.is_enabled());
576 assert!(config.strategy.cache_system());
577 assert!(config.strategy.cache_messages());
578 }
579
580 #[test]
581 fn test_cache_config_with_ttl() {
582 let config = CacheConfig::default()
583 .static_ttl(crate::types::CacheTtl::FiveMinutes)
584 .message_ttl(crate::types::CacheTtl::OneHour);
585
586 assert_eq!(config.static_ttl, crate::types::CacheTtl::FiveMinutes);
587 assert_eq!(config.message_ttl, crate::types::CacheTtl::OneHour);
588 }
589}