use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use rust_decimal::Decimal;
use crate::client::messages::DEFAULT_MAX_TOKENS;
use crate::output_style::OutputStyle;
use crate::permissions::PermissionPolicy;
use crate::tools::ToolAccess;
#[derive(Debug, Clone)]
pub struct AgentModelConfig {
pub primary: String,
pub small: String,
pub max_tokens: u32,
pub extended_context: bool,
}
impl Default for AgentModelConfig {
fn default() -> Self {
Self {
primary: crate::client::DEFAULT_MODEL.to_string(),
small: crate::client::DEFAULT_SMALL_MODEL.to_string(),
max_tokens: DEFAULT_MAX_TOKENS,
extended_context: false,
}
}
}
impl AgentModelConfig {
pub fn new(primary: impl Into<String>) -> Self {
Self {
primary: primary.into(),
..Default::default()
}
}
pub fn small(mut self, small: impl Into<String>) -> Self {
self.small = small.into();
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = tokens;
self
}
pub fn extended_context(mut self, enabled: bool) -> Self {
self.extended_context = enabled;
self
}
}
#[derive(Debug, Clone)]
pub struct ExecutionConfig {
pub max_iterations: usize,
pub timeout: Option<Duration>,
pub chunk_timeout: Duration,
pub auto_compact: bool,
pub compact_threshold: f32,
pub compact_keep_messages: usize,
}
impl Default for ExecutionConfig {
fn default() -> Self {
Self {
max_iterations: 100,
timeout: Some(Duration::from_secs(300)),
chunk_timeout: Duration::from_secs(60),
auto_compact: true,
compact_threshold: crate::session::compact::DEFAULT_COMPACT_THRESHOLD,
compact_keep_messages: 4,
}
}
}
impl ExecutionConfig {
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn without_timeout(mut self) -> Self {
self.timeout = None;
self
}
pub fn chunk_timeout(mut self, timeout: Duration) -> Self {
self.chunk_timeout = timeout;
self
}
pub fn auto_compact(mut self, enabled: bool) -> Self {
self.auto_compact = enabled;
self
}
pub fn compact_threshold(mut self, threshold: f32) -> Self {
self.compact_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn compact_keep_messages(mut self, count: usize) -> Self {
self.compact_keep_messages = count;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct SecurityConfig {
pub permission_policy: PermissionPolicy,
pub tool_access: ToolAccess,
pub env: HashMap<String, String>,
}
impl SecurityConfig {
pub fn permissive() -> Self {
Self {
permission_policy: PermissionPolicy::permissive(),
tool_access: ToolAccess::All,
..Default::default()
}
}
pub fn read_only() -> Self {
Self {
permission_policy: PermissionPolicy::read_only(),
tool_access: ToolAccess::only(["Read", "Glob", "Grep", "Task", "TaskOutput"]),
..Default::default()
}
}
pub fn permission_policy(mut self, policy: PermissionPolicy) -> Self {
self.permission_policy = policy;
self
}
pub fn tool_access(mut self, access: ToolAccess) -> Self {
self.tool_access = access;
self
}
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn envs(
mut self,
vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
for (k, v) in vars {
self.env.insert(k.into(), v.into());
}
self
}
}
#[derive(Debug, Clone, Default)]
pub struct BudgetConfig {
pub max_cost_usd: Option<Decimal>,
pub tenant_id: Option<String>,
pub fallback_model: Option<String>,
}
impl BudgetConfig {
pub fn unlimited() -> Self {
Self::default()
}
pub fn max_cost(mut self, usd: Decimal) -> Self {
self.max_cost_usd = Some(usd);
self
}
pub fn tenant(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn fallback(mut self, model: impl Into<String>) -> Self {
self.fallback_model = Some(model.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct PromptConfig {
pub system_prompt: Option<String>,
pub system_prompt_mode: SystemPromptMode,
pub output_style: Option<OutputStyle>,
pub output_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum SystemPromptMode {
#[default]
Replace,
Append,
}
impl PromptConfig {
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn append_mode(mut self) -> Self {
self.system_prompt_mode = SystemPromptMode::Append;
self
}
pub fn output_style(mut self, style: OutputStyle) -> Self {
self.output_style = Some(style);
self
}
pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
self.output_schema = Some(schema);
self
}
pub fn structured_output<T: schemars::JsonSchema>(mut self) -> Self {
let schema = schemars::schema_for!(T);
self.output_schema = serde_json::to_value(schema).ok();
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CacheStrategy {
Disabled,
SystemOnly,
MessagesOnly,
#[default]
Full,
}
impl CacheStrategy {
pub fn cache_system(&self) -> bool {
matches!(self, Self::SystemOnly | Self::Full)
}
pub fn cache_messages(&self) -> bool {
matches!(self, Self::MessagesOnly | Self::Full)
}
pub fn is_enabled(&self) -> bool {
!matches!(self, Self::Disabled)
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub strategy: CacheStrategy,
pub static_ttl: crate::types::CacheTtl,
pub message_ttl: crate::types::CacheTtl,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
strategy: CacheStrategy::Full,
static_ttl: crate::types::CacheTtl::OneHour,
message_ttl: crate::types::CacheTtl::FiveMinutes,
}
}
}
impl CacheConfig {
pub fn disabled() -> Self {
Self {
strategy: CacheStrategy::Disabled,
..Default::default()
}
}
pub fn system_only() -> Self {
Self {
strategy: CacheStrategy::SystemOnly,
..Default::default()
}
}
pub fn messages_only() -> Self {
Self {
strategy: CacheStrategy::MessagesOnly,
..Default::default()
}
}
pub fn strategy(mut self, strategy: CacheStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn static_ttl(mut self, ttl: crate::types::CacheTtl) -> Self {
self.static_ttl = ttl;
self
}
pub fn message_ttl(mut self, ttl: crate::types::CacheTtl) -> Self {
self.message_ttl = ttl;
self
}
pub fn message_ttl_option(&self) -> Option<crate::types::CacheTtl> {
if self.strategy.cache_messages() {
Some(self.message_ttl)
} else {
None
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ServerToolsConfig {
pub web_search: Option<crate::types::WebSearchTool>,
pub web_fetch: Option<crate::types::WebFetchTool>,
}
impl ServerToolsConfig {
pub fn all() -> Self {
Self {
web_search: Some(crate::types::WebSearchTool::default()),
web_fetch: Some(crate::types::WebFetchTool::default()),
}
}
pub fn web_search(mut self, config: crate::types::WebSearchTool) -> Self {
self.web_search = Some(config);
self
}
pub fn web_fetch(mut self, config: crate::types::WebFetchTool) -> Self {
self.web_fetch = Some(config);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct AgentConfig {
pub model: AgentModelConfig,
pub execution: ExecutionConfig,
pub security: SecurityConfig,
pub budget: BudgetConfig,
pub prompt: PromptConfig,
pub cache: CacheConfig,
pub working_dir: Option<PathBuf>,
pub server_tools: ServerToolsConfig,
pub coding_mode: bool,
}
impl AgentConfig {
pub fn new() -> Self {
Self::default()
}
pub fn model(mut self, config: AgentModelConfig) -> Self {
self.model = config;
self
}
pub fn execution(mut self, config: ExecutionConfig) -> Self {
self.execution = config;
self
}
pub fn security(mut self, config: SecurityConfig) -> Self {
self.security = config;
self
}
pub fn budget(mut self, config: BudgetConfig) -> Self {
self.budget = config;
self
}
pub fn prompt(mut self, config: PromptConfig) -> Self {
self.prompt = config;
self
}
pub fn cache(mut self, config: CacheConfig) -> Self {
self.cache = config;
self
}
pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_dir = Some(dir.into());
self
}
pub fn server_tools(mut self, config: ServerToolsConfig) -> Self {
self.server_tools = config;
self
}
pub fn coding_mode(mut self, enabled: bool) -> Self {
self.coding_mode = enabled;
self
}
}
#[cfg(test)]
mod tests {
use rust_decimal_macros::dec;
use super::*;
#[test]
fn test_model_config() {
let config = AgentModelConfig::new("claude-opus-4-6")
.small("claude-haiku")
.max_tokens(4096);
assert_eq!(config.primary, "claude-opus-4-6");
assert_eq!(config.small, "claude-haiku");
assert_eq!(config.max_tokens, 4096);
}
#[test]
fn test_execution_config() {
let config = ExecutionConfig::default()
.max_iterations(50)
.timeout(Duration::from_secs(600))
.auto_compact(false);
assert_eq!(config.max_iterations, 50);
assert_eq!(config.timeout, Some(Duration::from_secs(600)));
assert!(!config.auto_compact);
}
#[test]
fn test_security_config() {
let config = SecurityConfig::permissive().env("API_KEY", "secret");
assert_eq!(config.env.get("API_KEY"), Some(&"secret".to_string()));
}
#[test]
fn test_budget_config() {
let config = BudgetConfig::unlimited()
.max_cost(dec!(10))
.tenant("org-123")
.fallback("claude-haiku");
assert_eq!(config.max_cost_usd, Some(dec!(10)));
assert_eq!(config.tenant_id, Some("org-123".to_string()));
assert_eq!(config.fallback_model, Some("claude-haiku".to_string()));
}
#[test]
fn test_agent_config() {
let config = AgentConfig::new()
.model(AgentModelConfig::new("claude-opus-4-6"))
.budget(BudgetConfig::unlimited().max_cost(dec!(5)))
.working_dir("/project");
assert_eq!(config.model.primary, "claude-opus-4-6");
assert_eq!(config.budget.max_cost_usd, Some(dec!(5)));
assert_eq!(config.working_dir, Some(PathBuf::from("/project")));
}
#[test]
fn test_cache_strategy_default_is_full() {
let config = CacheConfig::default();
assert_eq!(config.strategy, CacheStrategy::Full);
assert_eq!(config.static_ttl, crate::types::CacheTtl::OneHour);
assert_eq!(config.message_ttl, crate::types::CacheTtl::FiveMinutes);
}
#[test]
fn test_cache_strategy_disabled() {
let config = CacheConfig::disabled();
assert_eq!(config.strategy, CacheStrategy::Disabled);
assert!(!config.strategy.is_enabled());
assert!(!config.strategy.cache_system());
assert!(!config.strategy.cache_messages());
}
#[test]
fn test_cache_strategy_system_only() {
let config = CacheConfig::system_only();
assert_eq!(config.strategy, CacheStrategy::SystemOnly);
assert!(config.strategy.is_enabled());
assert!(config.strategy.cache_system());
assert!(!config.strategy.cache_messages());
}
#[test]
fn test_cache_strategy_messages_only() {
let config = CacheConfig::messages_only();
assert_eq!(config.strategy, CacheStrategy::MessagesOnly);
assert!(config.strategy.is_enabled());
assert!(!config.strategy.cache_system());
assert!(config.strategy.cache_messages());
}
#[test]
fn test_cache_strategy_full() {
let config = CacheConfig::default();
assert!(config.strategy.is_enabled());
assert!(config.strategy.cache_system());
assert!(config.strategy.cache_messages());
}
#[test]
fn test_cache_config_with_ttl() {
let config = CacheConfig::default()
.static_ttl(crate::types::CacheTtl::FiveMinutes)
.message_ttl(crate::types::CacheTtl::OneHour);
assert_eq!(config.static_ttl, crate::types::CacheTtl::FiveMinutes);
assert_eq!(config.message_ttl, crate::types::CacheTtl::OneHour);
}
}