use crate::mcp_config::McpServerConfig;
#[cfg(target_arch = "wasm32")]
use crate::tokio;
use crate::{
budget::BudgetLimits,
hooks::{HookCapability, HookExecutionMode, HookFailurePolicy, HookId, HookPoint},
provider::Provider,
retry::RetryPolicy,
types::{OutputSchema, SecurityMode},
};
use schemars::JsonSchema;
use serde::de::Deserializer;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
use std::collections::{BTreeMap, HashMap};
use std::path::PathBuf;
use std::sync::OnceLock;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct Config {
pub agent: AgentConfig,
pub provider: ProviderConfig,
pub storage: StorageConfig,
pub budget: BudgetConfig,
pub retry: RetryConfig,
pub tools: ToolsConfig,
pub models: ModelDefaults,
pub max_tokens: u32,
pub shell: ShellDefaults,
pub store: StoreConfig,
pub providers: ProviderSettings,
pub comms: CommsRuntimeConfig,
pub compaction: CompactionRuntimeConfig,
pub limits: LimitsConfig,
pub rest: RestServerConfig,
pub sub_agents: SubAgentsConfig,
pub hooks: HooksConfig,
pub skills: crate::skills_config::SkillsConfig,
}
impl Default for Config {
fn default() -> Self {
let defaults = template_defaults();
let agent = AgentConfig::default();
let max_tokens = defaults
.max_tokens
.filter(|value| *value > 0)
.unwrap_or(agent.max_tokens_per_turn);
Self {
agent,
provider: ProviderConfig::default(),
storage: StorageConfig::default(),
budget: BudgetConfig::default(),
retry: RetryConfig::default(),
tools: ToolsConfig::default(),
models: ModelDefaults::default(),
max_tokens,
shell: ShellDefaults::default(),
store: StoreConfig::default(),
providers: ProviderSettings::default(),
comms: CommsRuntimeConfig::default(),
compaction: CompactionRuntimeConfig::default(),
limits: LimitsConfig::default(),
rest: RestServerConfig::default(),
sub_agents: SubAgentsConfig::default(),
hooks: HooksConfig::default(),
skills: crate::skills_config::SkillsConfig::default(),
}
}
}
impl Config {
pub fn template_toml() -> &'static str {
CONFIG_TEMPLATE_TOML
}
pub fn template() -> Result<Self, ConfigError> {
toml::from_str(CONFIG_TEMPLATE_TOML).map_err(ConfigError::Parse)
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Config {
pub async fn load() -> Result<Self, ConfigError> {
let cwd = std::env::current_dir()?;
let home = dirs::home_dir();
Self::load_from_with_env(&cwd, home.as_deref(), |key| std::env::var(key).ok()).await
}
#[doc(hidden)]
pub async fn load_from_with_env<F>(
start_dir: &std::path::Path,
home_dir: Option<&std::path::Path>,
env: F,
) -> Result<Self, ConfigError>
where
F: FnMut(&str) -> Option<String>,
{
let mut config = Self::default();
if let Some(path) = Self::find_project_config_from(start_dir).await {
config.merge_file(&path).await?;
} else if let Some(path) = home_dir.map(|home| home.join(".rkat/config.toml"))
&& tokio::fs::try_exists(&path).await.unwrap_or(false)
{
config.merge_file(&path).await?;
}
config.apply_env_overrides_from(env)?;
Ok(config)
}
#[doc(hidden)]
pub async fn load_from(
start_dir: &std::path::Path,
home_dir: Option<&std::path::Path>,
) -> Result<Self, ConfigError> {
Self::load_from_with_env(start_dir, home_dir, |key| std::env::var(key).ok()).await
}
pub async fn load_layered_hooks() -> Result<HooksConfig, ConfigError> {
let cwd = std::env::current_dir()?;
let home = dirs::home_dir();
Self::load_layered_hooks_from(&cwd, home.as_deref()).await
}
pub async fn load_layered_hooks_from(
start_dir: &std::path::Path,
home_dir: Option<&std::path::Path>,
) -> Result<HooksConfig, ConfigError> {
let mut hooks = HooksConfig::default();
if let Some(global_path) = home_dir.map(|home| home.join(".rkat/config.toml"))
&& tokio::fs::try_exists(&global_path).await.unwrap_or(false)
{
let content = tokio::fs::read_to_string(&global_path).await?;
let cfg: Config = toml::from_str(&content).map_err(ConfigError::Parse)?;
hooks.append_entries_from(&cfg.hooks);
}
if let Some(project_path) = Self::find_project_config_from(start_dir).await
&& tokio::fs::try_exists(&project_path).await.unwrap_or(false)
{
let content = tokio::fs::read_to_string(&project_path).await?;
let cfg: Config = toml::from_str(&content).map_err(ConfigError::Parse)?;
hooks.append_entries_from(&cfg.hooks);
}
Ok(hooks)
}
}
impl Config {
pub fn budget_limits(&self) -> BudgetLimits {
self.limits.to_budget_limits()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn global_config_path() -> Option<PathBuf> {
dirs::home_dir().map(|h| h.join(".rkat/config.toml"))
}
#[cfg(not(target_arch = "wasm32"))]
async fn find_project_config_from(start_dir: &std::path::Path) -> Option<PathBuf> {
let mut current = start_dir.to_path_buf();
loop {
let marker_dir = current.join(".rkat");
let config_path = marker_dir.join("config.toml");
let config_exists = tokio::fs::try_exists(&config_path).await.unwrap_or(false);
if config_exists {
return Some(config_path);
}
if !current.pop() {
return None;
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn merge_file(&mut self, path: &PathBuf) -> Result<(), ConfigError> {
let content = tokio::fs::read_to_string(path).await?;
self.merge_toml_str(&content)
}
pub fn merge_toml_str(&mut self, content: &str) -> Result<(), ConfigError> {
let file_config: Config = toml::from_str(content).map_err(ConfigError::Parse)?;
let tools_layer = file_config.tools.clone();
let retry_layer = file_config.retry.clone();
self.merge(file_config);
let parsed: toml::Value = toml::from_str(content).map_err(ConfigError::Parse)?;
self.merge_tools_from_toml_presence(&parsed, &tools_layer);
self.merge_retry_from_toml_presence(&parsed, &retry_layer);
Ok(())
}
fn merge(&mut self, other: Config) {
if other.agent.system_prompt.is_some() {
self.agent.system_prompt = other.agent.system_prompt;
}
if other.agent.tool_instructions.is_some() {
self.agent.tool_instructions = other.agent.tool_instructions;
}
if other.agent.model != AgentConfig::default().model {
self.agent.model = other.agent.model;
}
if other.agent.max_tokens_per_turn != AgentConfig::default().max_tokens_per_turn {
self.agent.max_tokens_per_turn = other.agent.max_tokens_per_turn;
}
if other.agent.extraction_prompt.is_some() {
self.agent.extraction_prompt = other.agent.extraction_prompt;
}
self.provider = other.provider;
if other.storage.directory.is_some() {
self.storage.directory = other.storage.directory;
}
if other.budget.max_tokens.is_some() {
self.budget.max_tokens = other.budget.max_tokens;
}
if other.budget.max_duration.is_some() {
self.budget.max_duration = other.budget.max_duration;
}
if other.budget.max_tool_calls.is_some() {
self.budget.max_tool_calls = other.budget.max_tool_calls;
}
self.merge_retry(&other.retry);
self.merge_tools(&other.tools);
if other.models != ModelDefaults::default() {
self.models = other.models;
}
if other.max_tokens != Config::default().max_tokens {
self.max_tokens = other.max_tokens;
}
if other.shell != ShellDefaults::default() {
self.shell = other.shell;
}
if other.store != StoreConfig::default() {
self.store = other.store;
}
if other.providers != ProviderSettings::default() {
self.providers = other.providers;
}
if other.comms != CommsRuntimeConfig::default() {
self.comms = other.comms;
}
if other.compaction != CompactionRuntimeConfig::default() {
self.compaction = other.compaction;
}
if other.limits != LimitsConfig::default() {
self.limits = other.limits;
}
if other.rest != RestServerConfig::default() {
self.rest = other.rest;
}
if other.sub_agents != SubAgentsConfig::default() {
self.sub_agents = other.sub_agents;
}
if other.hooks != HooksConfig::default() {
let default_hooks = HooksConfig::default();
if other.hooks.default_timeout_ms != default_hooks.default_timeout_ms {
self.hooks.default_timeout_ms = other.hooks.default_timeout_ms;
}
if other.hooks.payload_max_bytes != default_hooks.payload_max_bytes {
self.hooks.payload_max_bytes = other.hooks.payload_max_bytes;
}
if other.hooks.background_max_concurrency != default_hooks.background_max_concurrency {
self.hooks.background_max_concurrency = other.hooks.background_max_concurrency;
}
self.hooks.entries.extend(other.hooks.entries);
}
}
fn merge_retry(&mut self, other: &RetryConfig) {
let defaults = RetryConfig::default();
if other.max_retries != defaults.max_retries {
self.retry.max_retries = other.max_retries;
}
if other.initial_delay != defaults.initial_delay {
self.retry.initial_delay = other.initial_delay;
}
if other.max_delay != defaults.max_delay {
self.retry.max_delay = other.max_delay;
}
if other.multiplier != defaults.multiplier {
self.retry.multiplier = other.multiplier;
}
}
fn merge_tools(&mut self, other: &ToolsConfig) {
let defaults = ToolsConfig::default();
if !other.mcp_servers.is_empty() {
self.tools.mcp_servers.clone_from(&other.mcp_servers);
}
if other.default_timeout != defaults.default_timeout {
self.tools.default_timeout = other.default_timeout;
}
if other.tool_timeouts != defaults.tool_timeouts {
self.tools.tool_timeouts.clone_from(&other.tool_timeouts);
}
if other.max_concurrent != defaults.max_concurrent {
self.tools.max_concurrent = other.max_concurrent;
}
if other.builtins_enabled != defaults.builtins_enabled {
self.tools.builtins_enabled = other.builtins_enabled;
}
if other.shell_enabled != defaults.shell_enabled {
self.tools.shell_enabled = other.shell_enabled;
}
if other.comms_enabled != defaults.comms_enabled {
self.tools.comms_enabled = other.comms_enabled;
}
if other.subagents_enabled != defaults.subagents_enabled {
self.tools.subagents_enabled = other.subagents_enabled;
}
if other.mob_enabled != defaults.mob_enabled {
self.tools.mob_enabled = other.mob_enabled;
}
}
fn merge_tools_from_toml_presence(&mut self, parsed: &toml::Value, layer: &ToolsConfig) {
let Some(tools) = parsed.get("tools").and_then(toml::Value::as_table) else {
return;
};
if tools.contains_key("mcp_servers") {
self.tools.mcp_servers.clone_from(&layer.mcp_servers);
}
if tools.contains_key("default_timeout") {
self.tools.default_timeout = layer.default_timeout;
}
if tools.contains_key("tool_timeouts") {
self.tools.tool_timeouts.clone_from(&layer.tool_timeouts);
}
if tools.contains_key("max_concurrent") {
self.tools.max_concurrent = layer.max_concurrent;
}
if tools.contains_key("builtins_enabled") {
self.tools.builtins_enabled = layer.builtins_enabled;
}
if tools.contains_key("shell_enabled") {
self.tools.shell_enabled = layer.shell_enabled;
}
if tools.contains_key("comms_enabled") {
self.tools.comms_enabled = layer.comms_enabled;
}
if tools.contains_key("subagents_enabled") {
self.tools.subagents_enabled = layer.subagents_enabled;
}
if tools.contains_key("mob_enabled") {
self.tools.mob_enabled = layer.mob_enabled;
}
}
fn merge_retry_from_toml_presence(&mut self, parsed: &toml::Value, layer: &RetryConfig) {
let Some(retry) = parsed.get("retry").and_then(toml::Value::as_table) else {
return;
};
if retry.contains_key("max_retries") {
self.retry.max_retries = layer.max_retries;
}
if retry.contains_key("initial_delay") {
self.retry.initial_delay = layer.initial_delay;
}
if retry.contains_key("max_delay") {
self.retry.max_delay = layer.max_delay;
}
if retry.contains_key("multiplier") {
self.retry.multiplier = layer.multiplier;
}
}
pub fn apply_env_overrides(&mut self) -> Result<(), ConfigError> {
self.apply_env_overrides_from(|key| std::env::var(key).ok())
}
#[doc(hidden)]
pub fn apply_env_overrides_from<F>(&mut self, mut env: F) -> Result<(), ConfigError>
where
F: FnMut(&str) -> Option<String>,
{
match &mut self.provider {
ProviderConfig::Anthropic { api_key, .. } => {
if api_key.is_none() {
let key = env("RKAT_ANTHROPIC_API_KEY").or_else(|| env("ANTHROPIC_API_KEY"));
if let Some(key) = key {
*api_key = Some(key);
}
}
}
ProviderConfig::OpenAI { api_key, .. } => {
if api_key.is_none() {
let key = env("RKAT_OPENAI_API_KEY").or_else(|| env("OPENAI_API_KEY"));
if let Some(key) = key {
*api_key = Some(key);
}
}
}
ProviderConfig::Gemini { api_key } => {
if api_key.is_none() {
let key = env("RKAT_GEMINI_API_KEY")
.or_else(|| env("GEMINI_API_KEY"))
.or_else(|| env("GOOGLE_API_KEY"));
if let Some(key) = key {
*api_key = Some(key);
}
}
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn apply_cli_overrides(&mut self, cli: CliOverrides) {
if let Some(model) = cli.model {
self.agent.model = model;
}
if let Some(tokens) = cli.max_tokens {
self.budget.max_tokens = Some(tokens);
}
if let Some(duration) = cli.max_duration {
self.budget.max_duration = Some(duration);
}
if let Some(calls) = cli.max_tool_calls {
self.budget.max_tool_calls = Some(calls);
}
if let Some(delta) = cli.override_config {
let mut value = serde_json::to_value(&self).unwrap_or_default();
crate::config_store::merge_patch(&mut value, delta.0);
if let Ok(updated) = serde_json::from_value(value) {
*self = updated;
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct SubAgentsConfig {
pub default_provider: String,
pub default_model: String,
pub allowed_models: BTreeMap<String, Vec<String>>,
}
impl Default for SubAgentsConfig {
fn default() -> Self {
Self {
default_provider: "inherit".to_string(),
default_model: "inherit".to_string(),
allowed_models: default_allowed_models(),
}
}
}
fn default_allowed_models() -> BTreeMap<String, Vec<String>> {
let mut map = BTreeMap::new();
map.insert(
"anthropic".to_string(),
vec![
"claude-opus-4-6".to_string(),
"claude-sonnet-4-5".to_string(),
"claude-opus-4-5".to_string(),
],
);
map.insert(
"openai".to_string(),
vec!["gpt-5.2".to_string(), "gpt-5.2-pro".to_string()],
);
map.insert(
"gemini".to_string(),
vec![
"gemini-3-flash-preview".to_string(),
"gemini-3-pro-preview".to_string(),
],
);
map
}
#[derive(Debug, Clone, PartialEq)]
pub struct ResolvedSubAgentConfig {
pub default_provider: Provider,
pub default_model: String,
pub allowed_models: BTreeMap<String, Vec<String>>,
}
impl ResolvedSubAgentConfig {
pub fn is_model_allowed(&self, provider: Provider, model: &str) -> bool {
self.allowed_models
.get(provider.as_str())
.is_some_and(|list| list.iter().any(|m| m == model))
}
pub fn allowed_models_description(&self) -> String {
let parts: Vec<String> = self
.allowed_models
.iter()
.map(|(provider, models)| {
let title = match provider.as_str() {
"anthropic" => "Anthropic",
"openai" => "OpenAI",
"gemini" => "Gemini",
other => other,
};
format!("{}: {}", title, models.join(", "))
})
.collect();
format!("Allowed models - {}", parts.join("; "))
}
}
impl Config {
pub fn validate(&self) -> Result<(), ConfigError> {
if self.max_tokens == 0 {
return Err(ConfigError::Validation(
"max_tokens must be greater than 0".to_string(),
));
}
if self.agent.max_tokens_per_turn == 0 {
return Err(ConfigError::Validation(
"agent.max_tokens_per_turn must be greater than 0".to_string(),
));
}
if self.budget.max_tokens == Some(0) {
return Err(ConfigError::Validation(
"budget.max_tokens must be greater than 0 when set".to_string(),
));
}
if self.limits.budget == Some(0) {
return Err(ConfigError::Validation(
"limits.budget must be greater than 0 when set".to_string(),
));
}
if self.compaction.auto_compact_threshold == 0 {
return Err(ConfigError::Validation(
"compaction.auto_compact_threshold must be greater than 0".to_string(),
));
}
if self.compaction.recent_turn_budget == 0 {
return Err(ConfigError::Validation(
"compaction.recent_turn_budget must be greater than 0".to_string(),
));
}
if self.compaction.max_summary_tokens == 0 {
return Err(ConfigError::Validation(
"compaction.max_summary_tokens must be greater than 0".to_string(),
));
}
if self.compaction.min_turns_between_compactions == 0 {
return Err(ConfigError::Validation(
"compaction.min_turns_between_compactions must be greater than 0".to_string(),
));
}
if let Some(base_urls) = &self.providers.base_urls {
let maybe_conflict = match &self.provider {
ProviderConfig::Anthropic {
base_url: Some(url),
..
} => base_urls.get("anthropic").filter(|mapped| *mapped != url),
ProviderConfig::OpenAI {
base_url: Some(url),
..
} => base_urls.get("openai").filter(|mapped| *mapped != url),
_ => None,
};
if maybe_conflict.is_some() {
return Err(ConfigError::Validation(
"provider base_url conflicts with providers.base_urls entry".to_string(),
));
}
}
if let Some(api_keys) = &self.providers.api_keys {
let maybe_conflict = match &self.provider {
ProviderConfig::Anthropic {
api_key: Some(key), ..
} => api_keys.get("anthropic").filter(|mapped| *mapped != key),
ProviderConfig::OpenAI {
api_key: Some(key), ..
} => api_keys.get("openai").filter(|mapped| *mapped != key),
ProviderConfig::Gemini { api_key: Some(key) } => {
api_keys.get("gemini").filter(|mapped| *mapped != key)
}
_ => None,
};
if maybe_conflict.is_some() {
return Err(ConfigError::Validation(
"provider api_key conflicts with providers.api_keys entry".to_string(),
));
}
}
let sa = &self.sub_agents;
for provider in Provider::ALL_CONCRETE {
let key = provider.as_str();
let models = sa.allowed_models.get(key).ok_or_else(|| {
ConfigError::Validation(format!(
"sub_agents.allowed_models missing provider key '{key}'"
))
})?;
if models.is_empty() {
return Err(ConfigError::Validation(format!(
"sub_agents.allowed_models['{key}'] must not be empty"
)));
}
for model in models {
if model == "*" {
return Err(ConfigError::Validation(format!(
"sub_agents.allowed_models['{key}']: wildcards ('*') are not allowed"
)));
}
}
}
if sa.default_provider != "inherit"
&& Provider::parse_strict(&sa.default_provider).is_none()
{
return Err(ConfigError::Validation(format!(
"sub_agents.default_provider '{}' is not a valid provider name",
sa.default_provider
)));
}
if sa.default_model != "inherit"
&& sa.default_provider != "inherit"
&& let Some(provider) = Provider::parse_strict(&sa.default_provider)
{
let models = sa
.allowed_models
.get(provider.as_str())
.cloned()
.unwrap_or_default();
if !models.iter().any(|m| m == &sa.default_model) {
return Err(ConfigError::Validation(format!(
"sub_agents.default_model '{}' is not in allowed_models for provider '{}'",
sa.default_model,
provider.as_str()
)));
}
}
Ok(())
}
pub fn resolve_sub_agent_config(
&self,
parent_provider: Option<Provider>,
parent_model: &str,
) -> Result<ResolvedSubAgentConfig, ConfigError> {
let sa = &self.sub_agents;
let provider = if sa.default_provider == "inherit" {
if let Some(p) = parent_provider {
if p == Provider::Other {
return Err(ConfigError::Validation(
"Cannot inherit sub-agent provider: parent provider is 'other'".to_string(),
));
}
p
} else {
Provider::infer_from_model(parent_model).ok_or_else(|| {
ConfigError::Validation(format!(
"Cannot resolve sub-agent provider: parent provider unknown and model '{parent_model}' is ambiguous"
))
})?
}
} else {
Provider::parse_strict(&sa.default_provider).ok_or_else(|| {
ConfigError::Validation(format!(
"sub_agents.default_provider '{}' is not a valid provider name",
sa.default_provider
))
})?
};
let model = if sa.default_model == "inherit" {
parent_model.to_string()
} else {
sa.default_model.clone()
};
let resolved = ResolvedSubAgentConfig {
default_provider: provider,
default_model: model,
allowed_models: sa.allowed_models.clone(),
};
if !resolved.is_model_allowed(resolved.default_provider, &resolved.default_model) {
return Err(ConfigError::Validation(format!(
"Resolved sub-agent default model '{}' is not in allowed_models for provider '{}'",
resolved.default_model,
resolved.default_provider.as_str()
)));
}
Ok(resolved)
}
}
#[derive(Debug, Clone, Default)]
pub struct CliOverrides {
pub model: Option<String>,
pub max_tokens: Option<u64>,
pub max_duration: Option<Duration>,
pub max_tool_calls: Option<usize>,
pub override_config: Option<ConfigDelta>,
}
fn default_structured_output_retries() -> u32 {
2
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AgentConfig {
pub system_prompt: Option<String>,
pub system_prompt_file: Option<PathBuf>,
pub tool_instructions: Option<String>,
pub model: String,
pub max_tokens_per_turn: u32,
pub temperature: Option<f32>,
pub budget_warning_threshold: f32,
pub max_turns: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_params: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<OutputSchema>,
#[serde(default = "default_structured_output_retries")]
pub structured_output_retries: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extraction_prompt: Option<String>,
}
impl Default for AgentConfig {
fn default() -> Self {
let defaults = template_defaults();
let agent = defaults.agent.as_ref();
Self {
system_prompt: None,
system_prompt_file: None,
tool_instructions: None,
model: agent.and_then(|cfg| cfg.model.clone()).unwrap_or_default(),
max_tokens_per_turn: agent
.and_then(|cfg| cfg.max_tokens_per_turn)
.unwrap_or_default(),
temperature: None,
budget_warning_threshold: agent
.and_then(|cfg| cfg.budget_warning_threshold)
.unwrap_or_default(),
max_turns: None,
provider_params: None,
output_schema: None,
structured_output_retries: default_structured_output_retries(),
extraction_prompt: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct ModelDefaults {
pub anthropic: String,
pub openai: String,
pub gemini: String,
}
impl Default for ModelDefaults {
fn default() -> Self {
let defaults = template_defaults();
let models = defaults.models.as_ref();
Self {
anthropic: models
.and_then(|cfg| cfg.anthropic.clone())
.unwrap_or_default(),
openai: models
.and_then(|cfg| cfg.openai.clone())
.unwrap_or_default(),
gemini: models
.and_then(|cfg| cfg.gemini.clone())
.unwrap_or_default(),
}
}
}
pub const DEFAULT_SHELL_PROGRAM: &str = "nu";
pub const DEFAULT_SHELL_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_SHELL_SECURITY_MODE: SecurityMode = SecurityMode::Unrestricted;
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(default)]
pub struct ShellDefaults {
pub program: String,
pub timeout_secs: u64,
pub security_mode: SecurityMode,
pub security_patterns: Vec<String>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(default)]
struct ShellDefaultsSeed {
program: Option<String>,
timeout_secs: Option<u64>,
security_mode: Option<SecurityMode>,
security_patterns: Option<Vec<String>>,
#[serde(alias = "allowlist")]
allowlist: Option<Vec<String>>,
}
impl<'de> Deserialize<'de> for ShellDefaults {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let seed = ShellDefaultsSeed::deserialize(deserializer)?;
let mut defaults = ShellDefaults::default();
if let Some(program) = seed.program {
defaults.program = program;
}
if let Some(timeout_secs) = seed.timeout_secs {
defaults.timeout_secs = timeout_secs;
}
if let Some(security_mode) = seed.security_mode {
defaults.security_mode = security_mode;
}
if let Some(security_patterns) = seed.security_patterns.or(seed.allowlist.clone()) {
defaults.security_patterns = security_patterns;
}
if seed.security_mode.is_none() && seed.allowlist.is_some() {
defaults.security_mode = SecurityMode::AllowList;
}
Ok(defaults)
}
}
impl Default for ShellDefaults {
fn default() -> Self {
let defaults = template_defaults();
let shell = defaults.shell.as_ref();
Self {
program: shell
.and_then(|cfg| cfg.program.clone())
.unwrap_or_else(|| DEFAULT_SHELL_PROGRAM.to_string()),
timeout_secs: shell
.and_then(|cfg| cfg.timeout_secs)
.unwrap_or(DEFAULT_SHELL_TIMEOUT_SECS),
security_mode: shell
.and_then(|cfg| cfg.security_mode)
.unwrap_or(DEFAULT_SHELL_SECURITY_MODE),
security_patterns: shell
.and_then(|cfg| cfg.security_patterns.clone())
.unwrap_or_default(),
}
}
}
const CONFIG_TEMPLATE_TOML: &str = include_str!("config_template.toml");
#[derive(Debug, Deserialize)]
struct TemplateAgentDefaults {
model: Option<String>,
max_tokens_per_turn: Option<u32>,
budget_warning_threshold: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct TemplateModelDefaults {
anthropic: Option<String>,
openai: Option<String>,
gemini: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TemplateShellDefaults {
program: Option<String>,
timeout_secs: Option<u64>,
security_mode: Option<SecurityMode>,
security_patterns: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct TemplateDefaults {
agent: Option<TemplateAgentDefaults>,
models: Option<TemplateModelDefaults>,
shell: Option<TemplateShellDefaults>,
max_tokens: Option<u32>,
}
impl TemplateDefaults {
fn empty() -> Self {
Self {
agent: None,
models: None,
shell: None,
max_tokens: None,
}
}
}
fn template_defaults() -> &'static TemplateDefaults {
static DEFAULTS: OnceLock<TemplateDefaults> = OnceLock::new();
DEFAULTS.get_or_init(|| {
toml::from_str(CONFIG_TEMPLATE_TOML).unwrap_or_else(|e| {
tracing::error!("Invalid config template defaults: {}", e);
TemplateDefaults::empty()
})
})
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct StoreConfig {
pub sessions_path: Option<PathBuf>,
pub tasks_path: Option<PathBuf>,
pub database_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct ProviderSettings {
pub base_urls: Option<HashMap<String, String>>,
pub api_keys: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct LimitsConfig {
pub budget: Option<u64>,
#[serde(with = "optional_duration_serde")]
pub max_duration: Option<Duration>,
}
impl LimitsConfig {
pub fn to_budget_limits(&self) -> BudgetLimits {
BudgetLimits {
max_tokens: self.budget,
max_duration: self.max_duration,
max_tool_calls: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct RestServerConfig {
pub host: String,
pub port: u16,
}
impl Default for RestServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "snake_case")]
pub enum CommsAuthMode {
#[default]
#[serde(rename = "none")]
Open,
Ed25519,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PlainEventSource {
Tcp,
Uds,
Stdin,
Webhook,
Rpc,
}
impl std::fmt::Display for PlainEventSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tcp => write!(f, "tcp"),
Self::Uds => write!(f, "uds"),
Self::Stdin => write!(f, "stdin"),
Self::Webhook => write!(f, "webhook"),
Self::Rpc => write!(f, "rpc"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct CommsRuntimeConfig {
pub mode: CommsRuntimeMode,
pub address: Option<String>,
pub auth: CommsAuthMode,
pub require_peer_auth: bool,
pub event_address: Option<String>,
pub auto_enable_for_subagents: bool,
}
impl Default for CommsRuntimeConfig {
fn default() -> Self {
Self {
mode: CommsRuntimeMode::Inproc,
address: None,
auth: CommsAuthMode::default(),
require_peer_auth: true,
event_address: None,
auto_enable_for_subagents: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct CompactionRuntimeConfig {
pub auto_compact_threshold: u64,
pub recent_turn_budget: usize,
pub max_summary_tokens: u32,
pub min_turns_between_compactions: u32,
}
impl Default for CompactionRuntimeConfig {
fn default() -> Self {
Self {
auto_compact_threshold: 100_000,
recent_turn_budget: 4,
max_summary_tokens: 4096,
min_turns_between_compactions: 3,
}
}
}
impl From<CompactionRuntimeConfig> for crate::CompactionConfig {
fn from(value: CompactionRuntimeConfig) -> Self {
Self {
auto_compact_threshold: value.auto_compact_threshold,
recent_turn_budget: value.recent_turn_budget,
max_summary_tokens: value.max_summary_tokens,
min_turns_between_compactions: value.min_turns_between_compactions,
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum CommsRuntimeMode {
#[default]
Inproc,
Tcp,
Uds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ProviderConfig {
Anthropic {
api_key: Option<String>,
base_url: Option<String>,
},
#[serde(rename = "openai")]
OpenAI {
api_key: Option<String>,
base_url: Option<String>,
},
Gemini {
api_key: Option<String>,
},
}
impl Default for ProviderConfig {
fn default() -> Self {
Self::Anthropic {
api_key: None,
base_url: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StorageConfig {
pub directory: Option<PathBuf>,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
directory: data_dir().map(|d| d.join("sessions")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct BudgetConfig {
pub max_tokens: Option<u64>,
#[serde(with = "optional_duration_serde")]
pub max_duration: Option<Duration>,
pub max_tool_calls: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RetryConfig {
pub max_retries: u32,
#[serde(with = "humantime_serde")]
pub initial_delay: Duration,
#[serde(with = "humantime_serde")]
pub max_delay: Duration,
pub multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
let policy = RetryPolicy::default();
Self {
max_retries: policy.max_retries,
initial_delay: policy.initial_delay,
max_delay: policy.max_delay,
multiplier: policy.multiplier,
}
}
}
impl From<RetryConfig> for RetryPolicy {
fn from(config: RetryConfig) -> Self {
RetryPolicy {
max_retries: config.max_retries,
initial_delay: config.initial_delay,
max_delay: config.max_delay,
multiplier: config.multiplier,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ToolsConfig {
#[serde(default)]
pub mcp_servers: Vec<McpServerConfig>,
#[serde(with = "humantime_serde")]
pub default_timeout: Duration,
#[serde(default)]
pub tool_timeouts: HashMap<String, Duration>,
pub max_concurrent: usize,
pub builtins_enabled: bool,
pub shell_enabled: bool,
pub comms_enabled: bool,
pub subagents_enabled: bool,
pub mob_enabled: bool,
}
impl Default for ToolsConfig {
fn default() -> Self {
Self {
mcp_servers: Vec::new(),
default_timeout: Duration::from_secs(600),
tool_timeouts: HashMap::new(),
max_concurrent: 10,
builtins_enabled: false,
shell_enabled: false,
comms_enabled: false,
subagents_enabled: false,
mob_enabled: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(default)]
pub struct HooksConfig {
pub default_timeout_ms: u64,
pub payload_max_bytes: usize,
pub background_max_concurrency: usize,
#[serde(default)]
pub entries: Vec<HookEntryConfig>,
}
impl HooksConfig {
pub fn append_entries_from(&mut self, other: &HooksConfig) {
self.entries.extend(other.entries.clone());
}
}
impl Default for HooksConfig {
fn default() -> Self {
Self {
default_timeout_ms: 5_000,
payload_max_bytes: 128 * 1024,
background_max_concurrency: 32,
entries: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
#[serde(default)]
pub struct HookRunOverrides {
#[serde(default)]
pub entries: Vec<HookEntryConfig>,
#[serde(default)]
pub disable: Vec<HookId>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(default)]
pub struct HookEntryConfig {
pub id: HookId,
pub enabled: bool,
pub point: HookPoint,
pub mode: HookExecutionMode,
pub capability: HookCapability,
pub priority: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub failure_policy: Option<HookFailurePolicy>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
pub runtime: HookRuntimeConfig,
}
impl HookEntryConfig {
pub fn effective_failure_policy(&self) -> HookFailurePolicy {
self.failure_policy
.unwrap_or_else(|| crate::hooks::default_failure_policy(self.capability))
}
}
impl Default for HookEntryConfig {
fn default() -> Self {
Self {
id: HookId::new("hook"),
enabled: true,
point: HookPoint::TurnBoundary,
mode: HookExecutionMode::Foreground,
capability: HookCapability::Observe,
priority: 100,
failure_policy: None,
timeout_ms: None,
runtime: HookRuntimeConfig::new("in_process", Some(serde_json::json!({"name":"noop"})))
.unwrap_or_else(|_| HookRuntimeConfig {
kind: "in_process".to_string(),
config: None,
}),
}
}
}
#[derive(Debug, Clone)]
pub struct HookRuntimeConfig {
pub kind: String,
#[allow(clippy::box_collection)]
pub config: Option<Box<RawValue>>,
}
impl PartialEq for HookRuntimeConfig {
fn eq(&self, other: &Self) -> bool {
self.kind == other.kind
&& self.config.as_ref().map(|raw| raw.get())
== other.config.as_ref().map(|raw| raw.get())
}
}
impl HookRuntimeConfig {
pub fn new(kind: impl Into<String>, config: Option<Value>) -> Result<Self, serde_json::Error> {
let config = match config {
Some(value) => Some(raw_json_from_value(value)?),
None => None,
};
Ok(Self {
kind: kind.into(),
config,
})
}
pub fn config_value(&self) -> Result<Value, serde_json::Error> {
match &self.config {
Some(raw) => serde_json::from_str(raw.get()),
None => Ok(Value::Null),
}
}
}
impl Default for HookRuntimeConfig {
fn default() -> Self {
Self::new("in_process", Some(serde_json::json!({"name":"noop"}))).unwrap_or_else(|_| Self {
kind: "in_process".to_string(),
config: None,
})
}
}
impl Serialize for HookRuntimeConfig {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = Map::new();
map.insert("type".to_string(), Value::String(self.kind.clone()));
if let Some(raw) = &self.config {
let parsed: Value =
serde_json::from_str(raw.get()).map_err(serde::ser::Error::custom)?;
match parsed {
Value::Object(obj) => {
for (key, value) in obj {
map.insert(key, value);
}
}
other => {
map.insert("config".to_string(), other);
}
}
}
Value::Object(map).serialize(serializer)
}
}
impl<'de> Deserialize<'de> for HookRuntimeConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
let mut obj = value
.as_object()
.cloned()
.ok_or_else(|| serde::de::Error::custom("hook runtime must be an object"))?;
let kind = obj
.remove("type")
.and_then(|value| value.as_str().map(ToOwned::to_owned))
.ok_or_else(|| {
serde::de::Error::custom("hook runtime missing required field 'type'")
})?;
let config_value = if let Some(explicit) = obj.remove("config") {
if obj.is_empty() {
explicit
} else {
obj.insert("config".to_string(), explicit);
Value::Object(obj)
}
} else if obj.is_empty() {
Value::Null
} else {
Value::Object(obj)
};
let config = if config_value.is_null() {
None
} else {
Some(raw_json_from_value(config_value).map_err(serde::de::Error::custom)?)
};
Ok(Self { kind, config })
}
}
impl JsonSchema for HookRuntimeConfig {
fn schema_name() -> std::borrow::Cow<'static, str> {
"HookRuntimeConfig".into()
}
fn json_schema(_gen: &mut schemars::SchemaGenerator) -> schemars::Schema {
schemars::json_schema!({
"type": "object",
"required": ["type"],
"properties": {
"type": { "type": "string" },
"config": {}
},
"additionalProperties": true
})
}
}
fn raw_json_from_value(value: Value) -> Result<Box<RawValue>, serde_json::Error> {
RawValue::from_string(serde_json::to_string(&value)?)
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ConfigScope {
Global,
Project,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct ConfigDelta(pub serde_json::Value);
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Parse error: {0}")]
Parse(#[from] toml::de::Error),
#[error("TOML serialization error: {0}")]
TomlSerialize(#[from] toml::ser::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("UTF-8 error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[allow(dead_code)]
#[error("Invalid value for {0}")]
InvalidValue(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Missing API key: {0}")]
MissingApiKey(&'static str),
#[error("Internal error: {0}")]
InternalError(String),
#[error("Validation error: {0}")]
Validation(String),
}
mod optional_duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
#[allow(clippy::ref_option)]
pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match duration {
Some(d) => {
let s = humantime_serde::re::humantime::format_duration(*d).to_string();
s.serialize(serializer)
}
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let value: Option<serde_json::Value> = Option::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(serde_json::Value::String(s)) => {
humantime_serde::re::humantime::parse_duration(&s)
.map(Some)
.map_err(|e| D::Error::custom(e.to_string()))
}
Some(serde_json::Value::Number(n)) => {
let millis = n
.as_u64()
.ok_or_else(|| D::Error::custom("invalid number"))?;
Ok(Some(Duration::from_millis(millis)))
}
_ => Err(D::Error::custom("expected string or number for duration")),
}
}
}
pub fn find_project_root(start_dir: &std::path::Path) -> Option<PathBuf> {
let mut current = start_dir.to_path_buf();
loop {
if current.join(".rkat").is_dir() {
return Some(current);
}
if !current.pop() {
return None;
}
}
}
pub fn data_dir() -> Option<PathBuf> {
if let Ok(cwd) = std::env::current_dir()
&& let Some(root) = find_project_root(&cwd)
{
return Some(root.join(".rkat"));
}
dirs::home_dir().map(|h| h.join(".rkat"))
}
pub mod dirs {
use std::path::PathBuf;
pub fn home_dir() -> Option<PathBuf> {
std::env::var_os("HOME").map(PathBuf::from)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = Config::default();
assert_eq!(config.agent.model, "claude-opus-4-6");
assert_eq!(config.agent.max_tokens_per_turn, 16384);
assert_eq!(config.retry.max_retries, 3);
}
#[test]
fn test_config_layering() {
let config = Config::default();
assert_eq!(config.agent.model, "claude-opus-4-6");
assert_eq!(config.budget.max_tokens, None);
{
let env = std::collections::HashMap::from([
("RKAT_MODEL".to_string(), "env-model".to_string()),
("ANTHROPIC_API_KEY".to_string(), "secret-key".to_string()),
]);
let mut config = Config::default();
config
.apply_env_overrides_from(|key| env.get(key).cloned())
.expect("apply env overrides");
match config.provider {
ProviderConfig::Anthropic { api_key, .. } => {
assert_eq!(api_key.as_deref(), Some("secret-key"));
}
_ => unreachable!("expected anthropic provider"),
}
}
let mut config = Config::default();
let file_config = Config {
agent: AgentConfig {
model: "file-model".to_string(),
..Default::default()
},
..Default::default()
};
config.merge(file_config);
assert_eq!(config.agent.model, "file-model");
let mut config = Config::default();
config.apply_cli_overrides(CliOverrides {
model: Some("cli-model".to_string()),
max_tokens: Some(50000),
..Default::default()
});
assert_eq!(config.agent.model, "cli-model");
assert_eq!(config.budget.max_tokens, Some(50000));
}
#[test]
fn test_merge_extraction_prompt_survives_layering() {
let mut base = Config::default();
assert!(base.agent.extraction_prompt.is_none());
let toml = r#"
[agent]
extraction_prompt = "Return JSON only."
"#;
base.merge_toml_str(toml).expect("merge toml");
assert_eq!(
base.agent.extraction_prompt.as_deref(),
Some("Return JSON only.")
);
let toml2 = r#"
[agent]
model = "custom-model"
"#;
base.merge_toml_str(toml2).expect("merge toml2");
assert_eq!(
base.agent.extraction_prompt.as_deref(),
Some("Return JSON only."),
"extraction_prompt must survive merge when absent in later layer"
);
assert_eq!(base.agent.model, "custom-model");
}
#[test]
fn test_merge_hooks_entries_append() {
let mut base = Config::default();
let base_entry = HookEntryConfig {
id: HookId::new("base"),
..HookEntryConfig::default()
};
base.hooks.entries.push(base_entry);
let mut other = Config::default();
let other_entry = HookEntryConfig {
id: HookId::new("other"),
..HookEntryConfig::default()
};
other.hooks.entries.push(other_entry);
base.merge(other);
let ids = base
.hooks
.entries
.iter()
.map(|entry| entry.id.0.as_str())
.collect::<Vec<_>>();
assert_eq!(ids, vec!["base", "other"]);
}
#[test]
fn test_merge_providers_section_replaces_non_default() {
let mut base = Config::default();
base.providers.base_urls = Some(HashMap::from([
("anthropic".to_string(), "https://a.example".to_string()),
("openai".to_string(), "https://o.example".to_string()),
]));
let mut other = Config::default();
other.providers.base_urls = Some(HashMap::from([(
"openai".to_string(),
"https://override.example".to_string(),
)]));
base.merge(other);
let urls = base
.providers
.base_urls
.expect("providers.base_urls missing");
assert_eq!(urls.len(), 1);
assert_eq!(
urls.get("openai").map(String::as_str),
Some("https://override.example")
);
assert!(!urls.contains_key("anthropic"));
}
#[test]
fn test_merge_toml_tools_omitted_fields_preserve_lower_layer() {
let mut config = Config::default();
config.tools.mob_enabled = true;
config.tools.shell_enabled = true;
config
.merge_toml_str(
r"
[tools]
shell_enabled = false
",
)
.expect("merge should succeed");
assert!(config.tools.mob_enabled);
assert!(!config.tools.shell_enabled);
}
#[test]
fn test_merge_toml_tools_explicit_default_overrides_lower_layer() {
let mut config = Config::default();
config.tools.mob_enabled = true;
config
.merge_toml_str(
r"
[tools]
mob_enabled = false
",
)
.expect("merge should succeed");
assert!(!config.tools.mob_enabled);
}
#[test]
fn test_merge_toml_retry_omitted_fields_preserve_lower_layer() {
let mut config = Config::default();
config.retry.max_retries = 9;
config
.merge_toml_str(
r#"
[retry]
initial_delay = "750ms"
"#,
)
.expect("merge should succeed");
assert_eq!(config.retry.max_retries, 9);
assert_eq!(config.retry.initial_delay, Duration::from_millis(750));
}
#[test]
fn test_validate_rejects_zero_min_turns_between_compactions() {
let config = Config {
compaction: CompactionRuntimeConfig {
min_turns_between_compactions: 0,
..CompactionRuntimeConfig::default()
},
..Config::default()
};
let err = config
.validate()
.expect_err("min_turns_between_compactions=0 should be invalid");
assert!(
err.to_string()
.contains("compaction.min_turns_between_compactions")
);
}
#[test]
fn test_provider_config_serialization() {
let anthropic = ProviderConfig::Anthropic {
api_key: Some("sk-test".to_string()),
base_url: None,
};
let json = serde_json::to_value(&anthropic).unwrap();
assert_eq!(json["type"], "anthropic");
assert_eq!(json["api_key"], "sk-test");
let openai = ProviderConfig::OpenAI {
api_key: Some("sk-openai".to_string()),
base_url: Some("https://custom.openai.com".to_string()),
};
let json = serde_json::to_value(&openai).unwrap();
assert_eq!(json["type"], "openai");
let gemini = ProviderConfig::Gemini {
api_key: Some("gemini-key".to_string()),
};
let json = serde_json::to_value(&gemini).unwrap();
assert_eq!(json["type"], "gemini");
}
#[test]
fn test_budget_config_serialization() {
let budget = BudgetConfig {
max_tokens: Some(100_000),
max_duration: Some(Duration::from_secs(300)),
max_tool_calls: Some(50),
};
let json = serde_json::to_string(&budget).unwrap();
let parsed: BudgetConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.max_tokens, Some(100_000));
assert_eq!(parsed.max_duration, Some(Duration::from_secs(300)));
assert_eq!(parsed.max_tool_calls, Some(50));
}
#[test]
fn test_retry_config_to_policy() {
let config = RetryConfig::default();
let policy: RetryPolicy = config.into();
assert_eq!(policy.max_retries, 3);
assert_eq!(policy.initial_delay, Duration::from_millis(500));
}
#[tokio::test]
async fn test_regression_load_succeeds_without_config_toml() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let rkat_dir = temp_dir.path().join(".rkat");
std::fs::create_dir(&rkat_dir).unwrap();
assert!(rkat_dir.exists());
assert!(!rkat_dir.join("config.toml").exists());
let result =
Config::load_from_with_env(temp_dir.path(), Some(temp_dir.path()), |_| None).await;
assert!(
result.is_ok(),
"Config::load() should succeed when .rkat/ exists without config.toml: {:?}",
result.err()
);
}
#[test]
fn test_sub_agents_config_default_validates() {
let config = Config::default();
config.validate().expect("Default config should validate");
}
#[test]
fn test_validate_rejects_zero_max_tokens() {
let config = Config {
max_tokens: 0,
..Config::default()
};
let err = config
.validate()
.expect_err("max_tokens=0 should be invalid");
assert!(
err.to_string()
.contains("max_tokens must be greater than 0")
);
}
#[test]
fn test_validate_rejects_zero_agent_max_tokens_per_turn() {
let mut config = Config::default();
config.agent.max_tokens_per_turn = 0;
let err = config
.validate()
.expect_err("agent.max_tokens_per_turn=0 should be invalid");
assert!(err.to_string().contains("agent.max_tokens_per_turn"));
}
#[test]
fn test_validate_rejects_provider_base_url_conflict() {
let mut config = Config {
provider: ProviderConfig::OpenAI {
api_key: None,
base_url: Some("https://one.example".to_string()),
},
..Config::default()
};
config.providers.base_urls = Some(std::collections::HashMap::from([(
"openai".to_string(),
"https://two.example".to_string(),
)]));
let err = config
.validate()
.expect_err("conflicting base_url settings should be invalid");
assert!(err.to_string().contains("provider base_url conflicts"));
}
#[test]
fn test_validate_rejects_provider_api_key_conflict() {
let mut config = Config {
provider: ProviderConfig::Gemini {
api_key: Some("one".to_string()),
},
..Config::default()
};
config.providers.api_keys = Some(std::collections::HashMap::from([(
"gemini".to_string(),
"two".to_string(),
)]));
let err = config
.validate()
.expect_err("conflicting api_key settings should be invalid");
assert!(err.to_string().contains("provider api_key conflicts"));
}
#[test]
fn test_sub_agents_config_default_has_all_providers() {
let sa = SubAgentsConfig::default();
assert!(sa.allowed_models.contains_key("anthropic"));
assert!(sa.allowed_models.contains_key("openai"));
assert!(sa.allowed_models.contains_key("gemini"));
}
#[test]
fn test_sub_agents_config_toml_roundtrip() {
let toml_str = r#"
[sub_agents]
default_provider = "openai"
default_model = "gpt-5.2"
[sub_agents.allowed_models]
anthropic = ["claude-opus-4-6"]
openai = ["gpt-5.2", "gpt-5.2-pro"]
gemini = ["gemini-3-flash-preview"]
"#;
let config: Config = toml::from_str(toml_str).expect("should parse");
assert_eq!(config.sub_agents.default_provider, "openai");
assert_eq!(config.sub_agents.default_model, "gpt-5.2");
assert_eq!(
config.sub_agents.allowed_models.get("openai").unwrap(),
&vec!["gpt-5.2".to_string(), "gpt-5.2-pro".to_string()]
);
}
#[test]
fn test_sub_agents_inherit_resolves_from_parent() {
let config = Config::default();
let resolved = config
.resolve_sub_agent_config(Some(Provider::OpenAI), "gpt-5.2")
.expect("should resolve");
assert_eq!(resolved.default_provider, Provider::OpenAI);
assert_eq!(resolved.default_model, "gpt-5.2");
}
#[test]
fn test_sub_agents_inherit_resolves_provider_from_model() {
let config = Config::default();
let resolved = config
.resolve_sub_agent_config(None, "claude-opus-4-6")
.expect("should resolve");
assert_eq!(resolved.default_provider, Provider::Anthropic);
assert_eq!(resolved.default_model, "claude-opus-4-6");
}
#[test]
fn test_sub_agents_wildcard_rejected() {
let mut config = Config::default();
config
.sub_agents
.allowed_models
.get_mut("openai")
.unwrap()
.push("*".to_string());
let err = config.validate().unwrap_err();
assert!(err.to_string().contains("wildcards"));
}
#[test]
fn test_sub_agents_missing_provider_rejected() {
let mut config = Config::default();
config.sub_agents.allowed_models.remove("gemini");
let err = config.validate().unwrap_err();
assert!(err.to_string().contains("gemini"));
}
#[test]
fn test_sub_agents_empty_list_rejected() {
let mut config = Config::default();
config
.sub_agents
.allowed_models
.insert("openai".to_string(), vec![]);
let err = config.validate().unwrap_err();
assert!(err.to_string().contains("must not be empty"));
}
#[test]
fn test_sub_agents_default_not_in_allowlist_rejected() {
let mut config = Config::default();
config.sub_agents.default_provider = "openai".to_string();
config.sub_agents.default_model = "nonexistent-model".to_string();
let err = config.validate().unwrap_err();
assert!(err.to_string().contains("not in allowed_models"));
}
#[test]
fn test_sub_agents_resolve_ambiguous_fails() {
let config = Config::default();
let err = config
.resolve_sub_agent_config(None, "custom-model")
.unwrap_err();
assert!(err.to_string().contains("ambiguous"));
}
#[test]
fn test_sub_agents_resolve_explicit_config() {
let mut config = Config::default();
config.sub_agents.default_provider = "gemini".to_string();
config.sub_agents.default_model = "gemini-3-flash-preview".to_string();
config.validate().expect("should validate");
let resolved = config
.resolve_sub_agent_config(Some(Provider::Anthropic), "claude-opus-4-6")
.expect("should resolve");
assert_eq!(resolved.default_provider, Provider::Gemini);
assert_eq!(resolved.default_model, "gemini-3-flash-preview");
}
#[test]
fn test_sub_agents_resolved_default_not_in_allowlist() {
let config = Config::default();
let err = config
.resolve_sub_agent_config(Some(Provider::Anthropic), "claude-old-model-42")
.unwrap_err();
assert!(err.to_string().contains("not in allowed_models"));
}
#[test]
fn test_provider_parse_strict() {
assert_eq!(
Provider::parse_strict("anthropic"),
Some(Provider::Anthropic)
);
assert_eq!(Provider::parse_strict("openai"), Some(Provider::OpenAI));
assert_eq!(Provider::parse_strict("gemini"), Some(Provider::Gemini));
assert_eq!(Provider::parse_strict("other"), None);
assert_eq!(Provider::parse_strict("claude"), None);
assert_eq!(Provider::parse_strict(""), None);
}
#[test]
fn test_provider_infer_from_model() {
assert_eq!(
Provider::infer_from_model("claude-opus-4-6"),
Some(Provider::Anthropic)
);
assert_eq!(
Provider::infer_from_model("gpt-5.2"),
Some(Provider::OpenAI)
);
assert_eq!(
Provider::infer_from_model("gemini-3-flash-preview"),
Some(Provider::Gemini)
);
assert_eq!(Provider::infer_from_model("llama-3"), None);
assert_eq!(Provider::infer_from_model(""), None);
}
#[test]
fn test_sub_agents_config_merge() {
let mut base = Config::default();
let mut other = Config::default();
other.sub_agents.default_provider = "openai".to_string();
other.sub_agents.default_model = "gpt-5.2".to_string();
base.merge(other);
assert_eq!(base.sub_agents.default_provider, "openai");
assert_eq!(base.sub_agents.default_model, "gpt-5.2");
}
#[test]
fn test_resolved_sub_agent_config_is_model_allowed() {
let config = Config::default();
let resolved = config
.resolve_sub_agent_config(Some(Provider::Anthropic), "claude-opus-4-6")
.unwrap();
assert!(resolved.is_model_allowed(Provider::Anthropic, "claude-opus-4-6"));
assert!(resolved.is_model_allowed(Provider::OpenAI, "gpt-5.2"));
assert!(!resolved.is_model_allowed(Provider::OpenAI, "gpt-4o"));
}
#[test]
fn test_resolved_sub_agent_config_description() {
let config = Config::default();
let resolved = config
.resolve_sub_agent_config(Some(Provider::Anthropic), "claude-opus-4-6")
.unwrap();
let desc = resolved.allowed_models_description();
assert!(desc.contains("Anthropic"));
assert!(desc.contains("OpenAI"));
assert!(desc.contains("Gemini"));
assert!(desc.contains("gpt-5.2"));
assert!(desc.contains("claude-opus-4-6"));
}
#[test]
fn test_comms_auth_mode_default_is_open() {
assert_eq!(CommsAuthMode::default(), CommsAuthMode::Open);
}
#[test]
fn test_comms_auth_mode_serde_roundtrip() {
let json = serde_json::to_string(&CommsAuthMode::Open).unwrap();
assert_eq!(json, r#""none""#);
let parsed: CommsAuthMode = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, CommsAuthMode::Open);
let json = serde_json::to_string(&CommsAuthMode::Ed25519).unwrap();
assert_eq!(json, r#""ed25519""#);
let parsed: CommsAuthMode = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, CommsAuthMode::Ed25519);
}
#[test]
fn test_comms_auth_mode_toml_roundtrip() {
let config = CommsRuntimeConfig::default();
let toml_str = toml::to_string(&config).unwrap();
let parsed: CommsRuntimeConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.auth, CommsAuthMode::Open);
assert!(parsed.require_peer_auth);
let toml_str = r#"
mode = "inproc"
auth = "ed25519"
auto_enable_for_subagents = false
"#;
let parsed: CommsRuntimeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(parsed.auth, CommsAuthMode::Ed25519);
assert!(parsed.require_peer_auth);
}
#[test]
fn test_comms_runtime_config_default_has_open_auth() {
let config = CommsRuntimeConfig::default();
assert_eq!(config.auth, CommsAuthMode::Open);
assert!(config.require_peer_auth);
}
#[test]
fn test_plain_event_source_serde_roundtrip() {
let cases = [
(PlainEventSource::Tcp, r#""tcp""#),
(PlainEventSource::Uds, r#""uds""#),
(PlainEventSource::Stdin, r#""stdin""#),
(PlainEventSource::Webhook, r#""webhook""#),
(PlainEventSource::Rpc, r#""rpc""#),
];
for (variant, expected_json) in cases {
let json = serde_json::to_string(&variant).unwrap();
assert_eq!(json, expected_json, "serialize {variant:?}");
let parsed: PlainEventSource = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, variant, "deserialize {variant:?}");
}
}
#[test]
fn test_plain_event_source_display() {
assert_eq!(PlainEventSource::Tcp.to_string(), "tcp");
assert_eq!(PlainEventSource::Uds.to_string(), "uds");
assert_eq!(PlainEventSource::Stdin.to_string(), "stdin");
assert_eq!(PlainEventSource::Webhook.to_string(), "webhook");
assert_eq!(PlainEventSource::Rpc.to_string(), "rpc");
}
#[test]
fn test_comms_config_event_address_toml_roundtrip() {
let toml_str = r#"
mode = "tcp"
address = "127.0.0.1:4200"
auth = "none"
require_peer_auth = false
event_address = "127.0.0.1:4201"
auto_enable_for_subagents = false
"#;
let parsed: CommsRuntimeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(parsed.event_address.as_deref(), Some("127.0.0.1:4201"));
assert_eq!(parsed.auth, CommsAuthMode::Open);
assert!(!parsed.require_peer_auth);
}
#[test]
fn test_comms_config_event_address_defaults_none() {
let config = CommsRuntimeConfig::default();
assert!(config.event_address.is_none());
}
}