use crate::mcp_config::McpServerConfig;
use crate::{
budget::BudgetLimits,
hooks::{HookCapability, HookExecutionMode, HookFailurePolicy, HookId, HookPoint},
retry::RetryPolicy,
types::{OutputSchema, SecurityMode},
};
use meerkat_models::ModelTier;
use schemars::JsonSchema;
use serde::de::Deserializer;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
use std::collections::{BTreeMap, HashMap};
use std::path::PathBuf;
use std::sync::OnceLock;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct Config {
pub agent: AgentConfig,
pub provider: ProviderConfig,
pub storage: StorageConfig,
pub budget: BudgetConfig,
pub retry: RetryConfig,
pub tools: ToolsConfig,
pub models: ModelDefaults,
pub max_tokens: u32,
pub shell: ShellDefaults,
pub store: StoreConfig,
pub providers: ProviderSettings,
pub comms: CommsRuntimeConfig,
pub compaction: CompactionRuntimeConfig,
pub limits: LimitsConfig,
pub rest: RestServerConfig,
pub hooks: HooksConfig,
pub skills: crate::skills_config::SkillsConfig,
pub self_hosted: SelfHostedConfig,
pub provider_tools: ProviderToolsConfig,
}
impl Default for Config {
fn default() -> Self {
let defaults = template_defaults();
let agent = AgentConfig::default();
let max_tokens = defaults
.max_tokens
.filter(|value| *value > 0)
.unwrap_or(agent.max_tokens_per_turn);
Self {
agent,
provider: ProviderConfig::default(),
storage: StorageConfig::default(),
budget: BudgetConfig::default(),
retry: RetryConfig::default(),
tools: ToolsConfig::default(),
models: ModelDefaults::default(),
max_tokens,
shell: ShellDefaults::default(),
store: StoreConfig::default(),
providers: ProviderSettings::default(),
comms: CommsRuntimeConfig::default(),
compaction: CompactionRuntimeConfig::default(),
limits: LimitsConfig::default(),
rest: RestServerConfig::default(),
hooks: HooksConfig::default(),
skills: crate::skills_config::SkillsConfig::default(),
self_hosted: SelfHostedConfig::default(),
provider_tools: ProviderToolsConfig::default(),
}
}
}
impl Config {
pub fn model_registry(&self) -> Result<crate::ModelRegistry, ConfigError> {
crate::ModelRegistry::from_config(self)
}
pub fn template_toml() -> &'static str {
CONFIG_TEMPLATE_TOML
}
pub fn template() -> Result<Self, ConfigError> {
toml::from_str(CONFIG_TEMPLATE_TOML).map_err(ConfigError::Parse)
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Config {
pub async fn load() -> Result<Self, ConfigError> {
let cwd = std::env::current_dir()?;
let home = dirs::home_dir();
Self::load_from_with_env(&cwd, home.as_deref(), |key| std::env::var(key).ok()).await
}
#[doc(hidden)]
pub async fn load_from_with_env<F>(
start_dir: &std::path::Path,
home_dir: Option<&std::path::Path>,
env: F,
) -> Result<Self, ConfigError>
where
F: FnMut(&str) -> Option<String>,
{
let mut config = Self::default();
if let Some(path) = Self::find_project_config_from(start_dir).await {
config.merge_file(&path).await?;
} else if let Some(path) = home_dir.map(|home| home.join(".rkat/config.toml"))
&& tokio::fs::try_exists(&path).await.unwrap_or(false)
{
config.merge_file(&path).await?;
}
config.apply_env_overrides_from(env)?;
Ok(config)
}
#[doc(hidden)]
pub async fn load_from(
start_dir: &std::path::Path,
home_dir: Option<&std::path::Path>,
) -> Result<Self, ConfigError> {
Self::load_from_with_env(start_dir, home_dir, |key| std::env::var(key).ok()).await
}
pub async fn load_layered_hooks() -> Result<HooksConfig, ConfigError> {
let cwd = std::env::current_dir()?;
let home = dirs::home_dir();
Self::load_layered_hooks_from(&cwd, home.as_deref()).await
}
pub async fn load_layered_hooks_from(
start_dir: &std::path::Path,
home_dir: Option<&std::path::Path>,
) -> Result<HooksConfig, ConfigError> {
let mut hooks = HooksConfig::default();
if let Some(global_path) = home_dir.map(|home| home.join(".rkat/config.toml"))
&& tokio::fs::try_exists(&global_path).await.unwrap_or(false)
{
let content = tokio::fs::read_to_string(&global_path).await?;
let cfg: Config = toml::from_str(&content).map_err(ConfigError::Parse)?;
hooks.append_entries_from(&cfg.hooks);
}
if let Some(project_path) = Self::find_project_config_from(start_dir).await
&& tokio::fs::try_exists(&project_path).await.unwrap_or(false)
{
let content = tokio::fs::read_to_string(&project_path).await?;
let cfg: Config = toml::from_str(&content).map_err(ConfigError::Parse)?;
hooks.append_entries_from(&cfg.hooks);
}
Ok(hooks)
}
}
impl Config {
pub fn budget_limits(&self) -> BudgetLimits {
self.limits.to_budget_limits()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn global_config_path() -> Option<PathBuf> {
dirs::home_dir().map(|h| h.join(".rkat/config.toml"))
}
#[cfg(not(target_arch = "wasm32"))]
async fn find_project_config_from(start_dir: &std::path::Path) -> Option<PathBuf> {
let mut current = start_dir.to_path_buf();
loop {
let marker_dir = current.join(".rkat");
let config_path = marker_dir.join("config.toml");
let config_exists = tokio::fs::try_exists(&config_path).await.unwrap_or(false);
if config_exists {
return Some(config_path);
}
if !current.pop() {
return None;
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn merge_file(&mut self, path: &PathBuf) -> Result<(), ConfigError> {
let content = tokio::fs::read_to_string(path).await?;
self.merge_toml_str(&content)
}
pub fn merge_toml_str(&mut self, content: &str) -> Result<(), ConfigError> {
let file_config: Config = toml::from_str(content).map_err(ConfigError::Parse)?;
let tools_layer = file_config.tools.clone();
let retry_layer = file_config.retry.clone();
let self_hosted_layer = file_config.self_hosted.clone();
let provider_tools_layer = file_config.provider_tools.clone();
self.merge(file_config);
let parsed: toml::Value = toml::from_str(content).map_err(ConfigError::Parse)?;
self.merge_tools_from_toml_presence(&parsed, &tools_layer);
self.merge_retry_from_toml_presence(&parsed, &retry_layer);
self.merge_self_hosted_from_toml_presence(&parsed, &self_hosted_layer);
self.merge_provider_tools_from_toml_presence(&parsed, &provider_tools_layer);
Ok(())
}
fn merge(&mut self, other: Config) {
if other.agent.system_prompt.is_some() {
self.agent.system_prompt = other.agent.system_prompt;
}
if other.agent.tool_instructions.is_some() {
self.agent.tool_instructions = other.agent.tool_instructions;
}
if other.agent.model != AgentConfig::default().model {
self.agent.model = other.agent.model;
}
if other.agent.max_tokens_per_turn != AgentConfig::default().max_tokens_per_turn {
self.agent.max_tokens_per_turn = other.agent.max_tokens_per_turn;
}
if other.agent.extraction_prompt.is_some() {
self.agent.extraction_prompt = other.agent.extraction_prompt;
}
self.provider = other.provider;
if other.storage.directory.is_some() {
self.storage.directory = other.storage.directory;
}
if other.budget.max_tokens.is_some() {
self.budget.max_tokens = other.budget.max_tokens;
}
if other.budget.max_duration.is_some() {
self.budget.max_duration = other.budget.max_duration;
}
if other.budget.max_tool_calls.is_some() {
self.budget.max_tool_calls = other.budget.max_tool_calls;
}
self.merge_retry(&other.retry);
self.merge_tools(&other.tools);
if other.models != ModelDefaults::default() {
self.models = other.models;
}
if other.max_tokens != Config::default().max_tokens {
self.max_tokens = other.max_tokens;
}
if other.shell != ShellDefaults::default() {
self.shell = other.shell;
}
if other.store != StoreConfig::default() {
self.store = other.store;
}
if other.providers != ProviderSettings::default() {
self.providers = other.providers;
}
if other.comms != CommsRuntimeConfig::default() {
self.comms = other.comms;
}
if other.compaction != CompactionRuntimeConfig::default() {
self.compaction = other.compaction;
}
if other.limits != LimitsConfig::default() {
self.limits = other.limits;
}
if other.rest != RestServerConfig::default() {
self.rest = other.rest;
}
if other.hooks != HooksConfig::default() {
let default_hooks = HooksConfig::default();
if other.hooks.default_timeout_ms != default_hooks.default_timeout_ms {
self.hooks.default_timeout_ms = other.hooks.default_timeout_ms;
}
if other.hooks.payload_max_bytes != default_hooks.payload_max_bytes {
self.hooks.payload_max_bytes = other.hooks.payload_max_bytes;
}
if other.hooks.background_max_concurrency != default_hooks.background_max_concurrency {
self.hooks.background_max_concurrency = other.hooks.background_max_concurrency;
}
self.hooks.entries.extend(other.hooks.entries);
}
}
fn merge_retry(&mut self, other: &RetryConfig) {
let defaults = RetryConfig::default();
if other.max_retries != defaults.max_retries {
self.retry.max_retries = other.max_retries;
}
if other.initial_delay != defaults.initial_delay {
self.retry.initial_delay = other.initial_delay;
}
if other.max_delay != defaults.max_delay {
self.retry.max_delay = other.max_delay;
}
if other.multiplier != defaults.multiplier {
self.retry.multiplier = other.multiplier;
}
if other.call_timeout_override != defaults.call_timeout_override {
self.retry.call_timeout_override = other.call_timeout_override.clone();
}
}
fn merge_tools(&mut self, other: &ToolsConfig) {
let defaults = ToolsConfig::default();
if !other.mcp_servers.is_empty() {
self.tools.mcp_servers.clone_from(&other.mcp_servers);
}
if other.default_timeout != defaults.default_timeout {
self.tools.default_timeout = other.default_timeout;
}
if other.tool_timeouts != defaults.tool_timeouts {
self.tools.tool_timeouts.clone_from(&other.tool_timeouts);
}
if other.max_concurrent != defaults.max_concurrent {
self.tools.max_concurrent = other.max_concurrent;
}
if other.builtins_enabled != defaults.builtins_enabled {
self.tools.builtins_enabled = other.builtins_enabled;
}
if other.shell_enabled != defaults.shell_enabled {
self.tools.shell_enabled = other.shell_enabled;
}
if other.comms_enabled != defaults.comms_enabled {
self.tools.comms_enabled = other.comms_enabled;
}
if other.mob_enabled != defaults.mob_enabled {
self.tools.mob_enabled = other.mob_enabled;
}
}
fn merge_tools_from_toml_presence(&mut self, parsed: &toml::Value, layer: &ToolsConfig) {
let Some(tools) = parsed.get("tools").and_then(toml::Value::as_table) else {
return;
};
if tools.contains_key("mcp_servers") {
self.tools.mcp_servers.clone_from(&layer.mcp_servers);
}
if tools.contains_key("default_timeout") {
self.tools.default_timeout = layer.default_timeout;
}
if tools.contains_key("tool_timeouts") {
self.tools.tool_timeouts.clone_from(&layer.tool_timeouts);
}
if tools.contains_key("max_concurrent") {
self.tools.max_concurrent = layer.max_concurrent;
}
if tools.contains_key("builtins_enabled") {
self.tools.builtins_enabled = layer.builtins_enabled;
}
if tools.contains_key("shell_enabled") {
self.tools.shell_enabled = layer.shell_enabled;
}
if tools.contains_key("comms_enabled") {
self.tools.comms_enabled = layer.comms_enabled;
}
if tools.contains_key("mob_enabled") {
self.tools.mob_enabled = layer.mob_enabled;
}
}
fn merge_retry_from_toml_presence(&mut self, parsed: &toml::Value, layer: &RetryConfig) {
let Some(retry) = parsed.get("retry").and_then(toml::Value::as_table) else {
return;
};
if retry.contains_key("max_retries") {
self.retry.max_retries = layer.max_retries;
}
if retry.contains_key("initial_delay") {
self.retry.initial_delay = layer.initial_delay;
}
if retry.contains_key("max_delay") {
self.retry.max_delay = layer.max_delay;
}
if retry.contains_key("multiplier") {
self.retry.multiplier = layer.multiplier;
}
if retry.contains_key("call_timeout") {
self.retry.call_timeout_override = layer.call_timeout_override.clone();
}
}
fn merge_provider_tools_from_toml_presence(
&mut self,
parsed: &toml::Value,
layer: &ProviderToolsConfig,
) {
let Some(pt) = parsed.get("provider_tools").and_then(toml::Value::as_table) else {
return;
};
if let Some(anthropic) = pt.get("anthropic").and_then(toml::Value::as_table)
&& anthropic.contains_key("web_search")
{
self.provider_tools.anthropic.web_search = layer.anthropic.web_search;
}
if let Some(openai) = pt.get("openai").and_then(toml::Value::as_table)
&& openai.contains_key("web_search")
{
self.provider_tools.openai.web_search = layer.openai.web_search;
}
if let Some(gemini) = pt.get("gemini").and_then(toml::Value::as_table)
&& gemini.contains_key("google_search")
{
self.provider_tools.gemini.google_search = layer.gemini.google_search;
}
}
fn merge_self_hosted_from_toml_presence(
&mut self,
parsed: &toml::Value,
layer: &SelfHostedConfig,
) {
let Some(self_hosted) = parsed.get("self_hosted").and_then(toml::Value::as_table) else {
return;
};
if let Some(servers) = self_hosted.get("servers").and_then(toml::Value::as_table) {
if servers.is_empty() {
self.self_hosted.servers.clear();
self.self_hosted.models.clear();
} else {
let mut merged_servers = self.self_hosted.servers.clone();
for (server_id, server_value) in servers {
let Some(server_table) = server_value.as_table() else {
continue;
};
let mut merged = self
.self_hosted
.servers
.get(server_id)
.cloned()
.unwrap_or_default();
let Some(server_layer) = layer.servers.get(server_id) else {
continue;
};
if server_table.contains_key("transport") {
merged.transport = server_layer.transport;
}
if server_table.contains_key("base_url") {
merged.base_url = server_layer.base_url.clone();
}
if server_table.contains_key("api_style") {
merged.api_style = server_layer.api_style;
}
if server_table.contains_key("bearer_token") {
merged.bearer_token = server_layer.bearer_token.clone();
}
if server_table.contains_key("bearer_token_env") {
merged.bearer_token_env = server_layer.bearer_token_env.clone();
}
merged_servers.insert(server_id.clone(), merged);
}
self.self_hosted.servers = merged_servers;
}
}
if let Some(models) = self_hosted.get("models").and_then(toml::Value::as_table) {
if models.is_empty() {
self.self_hosted.models.clear();
} else {
let mut merged_models = self.self_hosted.models.clone();
for (model_id, model_value) in models {
let Some(model_table) = model_value.as_table() else {
continue;
};
let mut merged = self
.self_hosted
.models
.get(model_id)
.cloned()
.unwrap_or_default();
let Some(model_layer) = layer.models.get(model_id) else {
continue;
};
if model_table.contains_key("server") {
merged.server = model_layer.server.clone();
}
if model_table.contains_key("remote_model") {
merged.remote_model = model_layer.remote_model.clone();
}
if model_table.contains_key("display_name") {
merged.display_name = model_layer.display_name.clone();
}
if model_table.contains_key("family") {
merged.family = model_layer.family.clone();
}
if model_table.contains_key("tier") {
merged.tier = model_layer.tier;
}
if model_table.contains_key("context_window") {
merged.context_window = model_layer.context_window;
}
if model_table.contains_key("max_output_tokens") {
merged.max_output_tokens = model_layer.max_output_tokens;
}
if model_table.contains_key("vision") {
merged.vision = model_layer.vision;
}
if model_table.contains_key("image_tool_results") {
merged.image_tool_results = model_layer.image_tool_results;
}
if model_table.contains_key("inline_video") {
merged.inline_video = model_layer.inline_video;
}
if model_table.contains_key("supports_temperature") {
merged.supports_temperature = model_layer.supports_temperature;
}
if model_table.contains_key("supports_thinking") {
merged.supports_thinking = model_layer.supports_thinking;
}
if model_table.contains_key("supports_reasoning") {
merged.supports_reasoning = model_layer.supports_reasoning;
}
if model_table.contains_key("call_timeout_secs") {
merged.call_timeout_secs = model_layer.call_timeout_secs;
}
merged_models.insert(model_id.clone(), merged);
}
self.self_hosted.models = merged_models;
}
}
}
pub fn apply_env_overrides(&mut self) -> Result<(), ConfigError> {
self.apply_env_overrides_from(|key| std::env::var(key).ok())
}
#[doc(hidden)]
pub fn apply_env_overrides_from<F>(&mut self, mut env: F) -> Result<(), ConfigError>
where
F: FnMut(&str) -> Option<String>,
{
match &mut self.provider {
ProviderConfig::Anthropic { api_key, .. } => {
if api_key.is_none() {
let key = env("RKAT_ANTHROPIC_API_KEY").or_else(|| env("ANTHROPIC_API_KEY"));
if let Some(key) = key {
*api_key = Some(key);
}
}
}
ProviderConfig::OpenAI { api_key, .. } => {
if api_key.is_none() {
let key = env("RKAT_OPENAI_API_KEY").or_else(|| env("OPENAI_API_KEY"));
if let Some(key) = key {
*api_key = Some(key);
}
}
}
ProviderConfig::Gemini { api_key } => {
if api_key.is_none() {
let key = env("RKAT_GEMINI_API_KEY")
.or_else(|| env("GEMINI_API_KEY"))
.or_else(|| env("GOOGLE_API_KEY"));
if let Some(key) = key {
*api_key = Some(key);
}
}
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn apply_cli_overrides(&mut self, cli: CliOverrides) {
if let Some(model) = cli.model {
self.agent.model = model;
}
if let Some(tokens) = cli.max_tokens {
self.budget.max_tokens = Some(tokens);
}
if let Some(duration) = cli.max_duration {
self.budget.max_duration = Some(duration);
}
if let Some(calls) = cli.max_tool_calls {
self.budget.max_tool_calls = Some(calls);
}
if let Some(delta) = cli.override_config {
let mut value = serde_json::to_value(&self).unwrap_or_default();
crate::config_store::merge_patch(&mut value, delta.0);
if let Ok(updated) = serde_json::from_value(value) {
*self = updated;
}
}
}
}
impl Config {
pub fn validate(&self) -> Result<(), ConfigError> {
if self.max_tokens == 0 {
return Err(ConfigError::Validation(
"max_tokens must be greater than 0".to_string(),
));
}
if self.agent.max_tokens_per_turn == 0 {
return Err(ConfigError::Validation(
"agent.max_tokens_per_turn must be greater than 0".to_string(),
));
}
if self.budget.max_tokens == Some(0) {
return Err(ConfigError::Validation(
"budget.max_tokens must be greater than 0 when set".to_string(),
));
}
if self.limits.budget == Some(0) {
return Err(ConfigError::Validation(
"limits.budget must be greater than 0 when set".to_string(),
));
}
if self.compaction.auto_compact_threshold == 0 {
return Err(ConfigError::Validation(
"compaction.auto_compact_threshold must be greater than 0".to_string(),
));
}
if self.compaction.recent_turn_budget == 0 {
return Err(ConfigError::Validation(
"compaction.recent_turn_budget must be greater than 0".to_string(),
));
}
if self.compaction.max_summary_tokens == 0 {
return Err(ConfigError::Validation(
"compaction.max_summary_tokens must be greater than 0".to_string(),
));
}
if self.compaction.min_turns_between_compactions == 0 {
return Err(ConfigError::Validation(
"compaction.min_turns_between_compactions must be greater than 0".to_string(),
));
}
if let Some(base_urls) = &self.providers.base_urls {
let maybe_conflict = match &self.provider {
ProviderConfig::Anthropic {
base_url: Some(url),
..
} => base_urls.get("anthropic").filter(|mapped| *mapped != url),
ProviderConfig::OpenAI {
base_url: Some(url),
..
} => base_urls.get("openai").filter(|mapped| *mapped != url),
_ => None,
};
if maybe_conflict.is_some() {
return Err(ConfigError::Validation(
"provider base_url conflicts with providers.base_urls entry".to_string(),
));
}
}
if let Some(api_keys) = &self.providers.api_keys {
let maybe_conflict = match &self.provider {
ProviderConfig::Anthropic {
api_key: Some(key), ..
} => api_keys.get("anthropic").filter(|mapped| *mapped != key),
ProviderConfig::OpenAI {
api_key: Some(key), ..
} => api_keys.get("openai").filter(|mapped| *mapped != key),
ProviderConfig::Gemini { api_key: Some(key) } => {
api_keys.get("gemini").filter(|mapped| *mapped != key)
}
_ => None,
};
if maybe_conflict.is_some() {
return Err(ConfigError::Validation(
"provider api_key conflicts with providers.api_keys entry".to_string(),
));
}
}
crate::model_registry::ModelRegistry::from_config(self)?;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct CliOverrides {
pub model: Option<String>,
pub max_tokens: Option<u64>,
pub max_duration: Option<Duration>,
pub max_tool_calls: Option<usize>,
pub override_config: Option<ConfigDelta>,
}
fn default_structured_output_retries() -> u32 {
2
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AgentConfig {
pub system_prompt: Option<String>,
pub system_prompt_file: Option<PathBuf>,
pub tool_instructions: Option<String>,
pub model: String,
pub max_tokens_per_turn: u32,
pub temperature: Option<f32>,
pub budget_warning_threshold: f32,
pub max_turns: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_params: Option<serde_json::Value>,
#[serde(skip)]
pub provider_tool_defaults: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<OutputSchema>,
#[serde(default = "default_structured_output_retries")]
pub structured_output_retries: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extraction_prompt: Option<String>,
}
impl Default for AgentConfig {
fn default() -> Self {
let defaults = template_defaults();
let agent = defaults.agent.as_ref();
Self {
system_prompt: None,
system_prompt_file: None,
tool_instructions: None,
model: agent.and_then(|cfg| cfg.model.clone()).unwrap_or_default(),
max_tokens_per_turn: agent
.and_then(|cfg| cfg.max_tokens_per_turn)
.unwrap_or_default(),
temperature: None,
budget_warning_threshold: agent
.and_then(|cfg| cfg.budget_warning_threshold)
.unwrap_or_default(),
max_turns: None,
provider_params: None,
provider_tool_defaults: None,
output_schema: None,
structured_output_retries: default_structured_output_retries(),
extraction_prompt: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct ModelDefaults {
pub anthropic: String,
pub openai: String,
pub gemini: String,
}
impl Default for ModelDefaults {
fn default() -> Self {
Self {
anthropic: meerkat_models::default_model("anthropic")
.unwrap_or_default()
.to_string(),
openai: meerkat_models::default_model("openai")
.unwrap_or_default()
.to_string(),
gemini: meerkat_models::default_model("gemini")
.unwrap_or_default()
.to_string(),
}
}
}
pub const DEFAULT_SHELL_PROGRAM: &str = "nu";
pub const DEFAULT_SHELL_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_SHELL_SECURITY_MODE: SecurityMode = SecurityMode::Unrestricted;
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(default)]
pub struct ShellDefaults {
pub program: String,
pub timeout_secs: u64,
pub security_mode: SecurityMode,
pub security_patterns: Vec<String>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(default)]
struct ShellDefaultsSeed {
program: Option<String>,
timeout_secs: Option<u64>,
security_mode: Option<SecurityMode>,
security_patterns: Option<Vec<String>>,
#[serde(alias = "allowlist")]
allowlist: Option<Vec<String>>,
}
impl<'de> Deserialize<'de> for ShellDefaults {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let seed = ShellDefaultsSeed::deserialize(deserializer)?;
let mut defaults = ShellDefaults::default();
if let Some(program) = seed.program {
defaults.program = program;
}
if let Some(timeout_secs) = seed.timeout_secs {
defaults.timeout_secs = timeout_secs;
}
if let Some(security_mode) = seed.security_mode {
defaults.security_mode = security_mode;
}
if let Some(security_patterns) = seed.security_patterns.or(seed.allowlist.clone()) {
defaults.security_patterns = security_patterns;
}
if seed.security_mode.is_none() && seed.allowlist.is_some() {
defaults.security_mode = SecurityMode::AllowList;
}
Ok(defaults)
}
}
impl Default for ShellDefaults {
fn default() -> Self {
let defaults = template_defaults();
let shell = defaults.shell.as_ref();
Self {
program: shell
.and_then(|cfg| cfg.program.clone())
.unwrap_or_else(|| DEFAULT_SHELL_PROGRAM.to_string()),
timeout_secs: shell
.and_then(|cfg| cfg.timeout_secs)
.unwrap_or(DEFAULT_SHELL_TIMEOUT_SECS),
security_mode: shell
.and_then(|cfg| cfg.security_mode)
.unwrap_or(DEFAULT_SHELL_SECURITY_MODE),
security_patterns: shell
.and_then(|cfg| cfg.security_patterns.clone())
.unwrap_or_default(),
}
}
}
const CONFIG_TEMPLATE_TOML: &str = include_str!("config_template.toml");
#[derive(Debug, Deserialize)]
struct TemplateAgentDefaults {
model: Option<String>,
max_tokens_per_turn: Option<u32>,
budget_warning_threshold: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct TemplateShellDefaults {
program: Option<String>,
timeout_secs: Option<u64>,
security_mode: Option<SecurityMode>,
security_patterns: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct TemplateDefaults {
agent: Option<TemplateAgentDefaults>,
shell: Option<TemplateShellDefaults>,
max_tokens: Option<u32>,
}
impl TemplateDefaults {
fn empty() -> Self {
Self {
agent: None,
shell: None,
max_tokens: None,
}
}
}
fn template_defaults() -> &'static TemplateDefaults {
static DEFAULTS: OnceLock<TemplateDefaults> = OnceLock::new();
DEFAULTS.get_or_init(|| {
toml::from_str(CONFIG_TEMPLATE_TOML).unwrap_or_else(|e| {
tracing::error!("Invalid config template defaults: {}", e);
TemplateDefaults::empty()
})
})
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct StoreConfig {
pub sessions_path: Option<PathBuf>,
pub tasks_path: Option<PathBuf>,
pub database_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct ProviderSettings {
pub base_urls: Option<HashMap<String, String>>,
pub api_keys: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum SelfHostedTransport {
#[default]
#[serde(alias = "openai_compatible")]
OpenAiCompatible,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum SelfHostedApiStyle {
Responses,
#[default]
ChatCompletions,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(default)]
pub struct SelfHostedServerConfig {
pub transport: SelfHostedTransport,
pub base_url: String,
pub api_style: SelfHostedApiStyle,
#[serde(default, skip_serializing)]
pub bearer_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bearer_token_env: Option<String>,
}
impl Default for SelfHostedServerConfig {
fn default() -> Self {
Self {
transport: SelfHostedTransport::OpenAiCompatible,
base_url: String::new(),
api_style: SelfHostedApiStyle::ChatCompletions,
bearer_token: None,
bearer_token_env: None,
}
}
}
impl SelfHostedServerConfig {
pub fn resolve_bearer_token(&self) -> Option<String> {
self.bearer_token.clone().or_else(|| {
self.bearer_token_env
.as_deref()
.and_then(|env_key| std::env::var(env_key).ok())
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(default)]
pub struct SelfHostedModelConfig {
pub server: String,
pub remote_model: String,
pub display_name: String,
pub family: String,
pub tier: ModelTier,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context_window: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
pub vision: bool,
pub image_tool_results: bool,
pub inline_video: bool,
pub supports_temperature: bool,
pub supports_thinking: bool,
pub supports_reasoning: bool,
#[serde(default)]
pub supports_web_search: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub call_timeout_secs: Option<u64>,
}
impl Default for SelfHostedModelConfig {
fn default() -> Self {
Self {
server: String::new(),
remote_model: String::new(),
display_name: String::new(),
family: String::new(),
tier: ModelTier::Supported,
context_window: None,
max_output_tokens: None,
vision: false,
image_tool_results: false,
inline_video: false,
supports_temperature: true,
supports_thinking: false,
supports_reasoning: false,
supports_web_search: false,
call_timeout_secs: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, JsonSchema)]
#[serde(default)]
pub struct SelfHostedConfig {
pub servers: BTreeMap<String, SelfHostedServerConfig>,
pub models: BTreeMap<String, SelfHostedModelConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct ProviderToolsConfig {
pub anthropic: AnthropicProviderToolsConfig,
pub openai: OpenAiProviderToolsConfig,
pub gemini: GeminiProviderToolsConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct AnthropicProviderToolsConfig {
pub web_search: bool,
}
impl Default for AnthropicProviderToolsConfig {
fn default() -> Self {
Self { web_search: true }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct OpenAiProviderToolsConfig {
pub web_search: bool,
}
impl Default for OpenAiProviderToolsConfig {
fn default() -> Self {
Self { web_search: true }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct GeminiProviderToolsConfig {
pub google_search: bool,
}
impl Default for GeminiProviderToolsConfig {
fn default() -> Self {
Self {
google_search: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct LimitsConfig {
pub budget: Option<u64>,
#[serde(with = "optional_duration_serde")]
pub max_duration: Option<Duration>,
}
impl LimitsConfig {
pub fn to_budget_limits(&self) -> BudgetLimits {
BudgetLimits {
max_tokens: self.budget,
max_duration: self.max_duration,
max_tool_calls: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct RestServerConfig {
pub host: String,
pub port: u16,
}
impl Default for RestServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "snake_case")]
pub enum CommsAuthMode {
#[default]
#[serde(rename = "none")]
Open,
Ed25519,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PlainEventSource {
Tcp,
Uds,
Stdin,
Webhook,
Rpc,
}
impl std::fmt::Display for PlainEventSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tcp => write!(f, "tcp"),
Self::Uds => write!(f, "uds"),
Self::Stdin => write!(f, "stdin"),
Self::Webhook => write!(f, "webhook"),
Self::Rpc => write!(f, "rpc"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct CommsRuntimeConfig {
pub mode: CommsRuntimeMode,
pub address: Option<String>,
pub auth: CommsAuthMode,
pub require_peer_auth: bool,
pub event_address: Option<String>,
}
impl Default for CommsRuntimeConfig {
fn default() -> Self {
Self {
mode: CommsRuntimeMode::Inproc,
address: None,
auth: CommsAuthMode::default(),
require_peer_auth: true,
event_address: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct CompactionRuntimeConfig {
pub auto_compact_threshold: u64,
pub recent_turn_budget: usize,
pub max_summary_tokens: u32,
pub min_turns_between_compactions: u32,
}
impl Default for CompactionRuntimeConfig {
fn default() -> Self {
Self {
auto_compact_threshold: 100_000,
recent_turn_budget: 4,
max_summary_tokens: 4096,
min_turns_between_compactions: 3,
}
}
}
impl From<CompactionRuntimeConfig> for crate::CompactionConfig {
fn from(value: CompactionRuntimeConfig) -> Self {
Self {
auto_compact_threshold: value.auto_compact_threshold,
recent_turn_budget: value.recent_turn_budget,
max_summary_tokens: value.max_summary_tokens,
min_turns_between_compactions: value.min_turns_between_compactions,
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum CommsRuntimeMode {
#[default]
Inproc,
Tcp,
Uds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ProviderConfig {
Anthropic {
api_key: Option<String>,
base_url: Option<String>,
},
#[serde(rename = "openai")]
OpenAI {
api_key: Option<String>,
base_url: Option<String>,
},
Gemini {
api_key: Option<String>,
},
}
impl Default for ProviderConfig {
fn default() -> Self {
Self::Anthropic {
api_key: None,
base_url: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StorageConfig {
pub directory: Option<PathBuf>,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
directory: data_dir().map(|d| d.join("sessions")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct BudgetConfig {
pub max_tokens: Option<u64>,
#[serde(with = "optional_duration_serde")]
pub max_duration: Option<Duration>,
pub max_tool_calls: Option<usize>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum CallTimeoutOverride {
#[default]
Inherit,
Disabled,
Value(Duration),
}
impl Serialize for CallTimeoutOverride {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
Self::Inherit => serializer.serialize_none(),
Self::Disabled => serializer.serialize_str("disabled"),
Self::Value(d) => {
let s = humantime_serde::re::humantime::format_duration(*d).to_string();
serializer.serialize_str(&s)
}
}
}
}
impl<'de> Deserialize<'de> for CallTimeoutOverride {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
if s == "disabled" {
return Ok(Self::Disabled);
}
let d: Duration = s
.parse::<humantime_serde::re::humantime::Duration>()
.map(|ht| *ht)
.map_err(serde::de::Error::custom)?;
Ok(Self::Value(d))
}
}
impl CallTimeoutOverride {
pub fn is_inherit(&self) -> bool {
matches!(self, Self::Inherit)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RetryConfig {
pub max_retries: u32,
#[serde(with = "humantime_serde")]
pub initial_delay: Duration,
#[serde(with = "humantime_serde")]
pub max_delay: Duration,
pub multiplier: f64,
#[serde(
default,
rename = "call_timeout",
skip_serializing_if = "CallTimeoutOverride::is_inherit"
)]
pub call_timeout_override: CallTimeoutOverride,
}
impl Default for RetryConfig {
fn default() -> Self {
let policy = RetryPolicy::default();
Self {
max_retries: policy.max_retries,
initial_delay: policy.initial_delay,
max_delay: policy.max_delay,
multiplier: policy.multiplier,
call_timeout_override: CallTimeoutOverride::default(),
}
}
}
impl From<RetryConfig> for RetryPolicy {
fn from(config: RetryConfig) -> Self {
let call_timeout = match config.call_timeout_override {
CallTimeoutOverride::Inherit => None,
CallTimeoutOverride::Disabled => None,
CallTimeoutOverride::Value(d) => Some(d),
};
RetryPolicy {
max_retries: config.max_retries,
initial_delay: config.initial_delay,
max_delay: config.max_delay,
multiplier: config.multiplier,
call_timeout,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ToolsConfig {
#[serde(default)]
pub mcp_servers: Vec<McpServerConfig>,
#[serde(with = "humantime_serde")]
pub default_timeout: Duration,
#[serde(default)]
pub tool_timeouts: HashMap<String, Duration>,
pub max_concurrent: usize,
pub builtins_enabled: bool,
pub shell_enabled: bool,
pub comms_enabled: bool,
pub mob_enabled: bool,
}
impl Default for ToolsConfig {
fn default() -> Self {
Self {
mcp_servers: Vec::new(),
default_timeout: Duration::from_secs(600),
tool_timeouts: HashMap::new(),
max_concurrent: 10,
builtins_enabled: false,
shell_enabled: false,
comms_enabled: false,
mob_enabled: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(default)]
pub struct HooksConfig {
pub default_timeout_ms: u64,
pub payload_max_bytes: usize,
pub background_max_concurrency: usize,
#[serde(default)]
pub entries: Vec<HookEntryConfig>,
}
impl HooksConfig {
pub fn append_entries_from(&mut self, other: &HooksConfig) {
self.entries.extend(other.entries.clone());
}
}
impl Default for HooksConfig {
fn default() -> Self {
Self {
default_timeout_ms: 5_000,
payload_max_bytes: 128 * 1024,
background_max_concurrency: 32,
entries: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
#[serde(default)]
pub struct HookRunOverrides {
#[serde(default)]
pub entries: Vec<HookEntryConfig>,
#[serde(default)]
pub disable: Vec<HookId>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(default)]
pub struct HookEntryConfig {
pub id: HookId,
pub enabled: bool,
pub point: HookPoint,
pub mode: HookExecutionMode,
pub capability: HookCapability,
pub priority: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub failure_policy: Option<HookFailurePolicy>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
pub runtime: HookRuntimeConfig,
}
impl HookEntryConfig {
pub fn effective_failure_policy(&self) -> HookFailurePolicy {
self.failure_policy
.unwrap_or_else(|| crate::hooks::default_failure_policy(self.capability))
}
}
impl Default for HookEntryConfig {
fn default() -> Self {
Self {
id: HookId::new("hook"),
enabled: true,
point: HookPoint::TurnBoundary,
mode: HookExecutionMode::Foreground,
capability: HookCapability::Observe,
priority: 100,
failure_policy: None,
timeout_ms: None,
runtime: HookRuntimeConfig::new("in_process", Some(serde_json::json!({"name":"noop"})))
.unwrap_or_else(|_| HookRuntimeConfig {
kind: "in_process".to_string(),
config: None,
}),
}
}
}
#[derive(Debug, Clone)]
pub struct HookRuntimeConfig {
pub kind: String,
#[allow(clippy::box_collection)]
pub config: Option<Box<RawValue>>,
}
impl PartialEq for HookRuntimeConfig {
fn eq(&self, other: &Self) -> bool {
self.kind == other.kind
&& self.config.as_ref().map(|raw| raw.get())
== other.config.as_ref().map(|raw| raw.get())
}
}
impl HookRuntimeConfig {
pub fn new(kind: impl Into<String>, config: Option<Value>) -> Result<Self, serde_json::Error> {
let config = match config {
Some(value) => Some(raw_json_from_value(value)?),
None => None,
};
Ok(Self {
kind: kind.into(),
config,
})
}
pub fn config_value(&self) -> Result<Value, serde_json::Error> {
match &self.config {
Some(raw) => serde_json::from_str(raw.get()),
None => Ok(Value::Null),
}
}
}
impl Default for HookRuntimeConfig {
fn default() -> Self {
Self::new("in_process", Some(serde_json::json!({"name":"noop"}))).unwrap_or_else(|_| Self {
kind: "in_process".to_string(),
config: None,
})
}
}
impl Serialize for HookRuntimeConfig {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = Map::new();
map.insert("type".to_string(), Value::String(self.kind.clone()));
if let Some(raw) = &self.config {
let parsed: Value =
serde_json::from_str(raw.get()).map_err(serde::ser::Error::custom)?;
match parsed {
Value::Object(obj) => {
for (key, value) in obj {
map.insert(key, value);
}
}
other => {
map.insert("config".to_string(), other);
}
}
}
Value::Object(map).serialize(serializer)
}
}
impl<'de> Deserialize<'de> for HookRuntimeConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
let mut obj = value
.as_object()
.cloned()
.ok_or_else(|| serde::de::Error::custom("hook runtime must be an object"))?;
let kind = obj
.remove("type")
.and_then(|value| value.as_str().map(ToOwned::to_owned))
.ok_or_else(|| {
serde::de::Error::custom("hook runtime missing required field 'type'")
})?;
let config_value = if let Some(explicit) = obj.remove("config") {
if obj.is_empty() {
explicit
} else {
obj.insert("config".to_string(), explicit);
Value::Object(obj)
}
} else if obj.is_empty() {
Value::Null
} else {
Value::Object(obj)
};
let config = if config_value.is_null() {
None
} else {
Some(raw_json_from_value(config_value).map_err(serde::de::Error::custom)?)
};
Ok(Self { kind, config })
}
}
impl JsonSchema for HookRuntimeConfig {
fn schema_name() -> std::borrow::Cow<'static, str> {
"HookRuntimeConfig".into()
}
fn json_schema(_gen: &mut schemars::SchemaGenerator) -> schemars::Schema {
schemars::json_schema!({
"type": "object",
"required": ["type"],
"properties": {
"type": { "type": "string" },
"config": {}
},
"additionalProperties": true
})
}
}
fn raw_json_from_value(value: Value) -> Result<Box<RawValue>, serde_json::Error> {
RawValue::from_string(serde_json::to_string(&value)?)
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ConfigScope {
Global,
Project,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct ConfigDelta(pub serde_json::Value);
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Parse error: {0}")]
Parse(#[from] toml::de::Error),
#[error("TOML serialization error: {0}")]
TomlSerialize(#[from] toml::ser::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("UTF-8 error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[allow(dead_code)]
#[error("Invalid value for {0}")]
InvalidValue(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Missing API key: {0}")]
MissingApiKey(&'static str),
#[error("Internal error: {0}")]
InternalError(String),
#[error("Validation error: {0}")]
Validation(String),
}
mod optional_duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
#[allow(clippy::ref_option)]
pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match duration {
Some(d) => {
let s = humantime_serde::re::humantime::format_duration(*d).to_string();
s.serialize(serializer)
}
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let value: Option<serde_json::Value> = Option::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(serde_json::Value::String(s)) => {
humantime_serde::re::humantime::parse_duration(&s)
.map(Some)
.map_err(|e| D::Error::custom(e.to_string()))
}
Some(serde_json::Value::Number(n)) => {
let millis = n
.as_u64()
.ok_or_else(|| D::Error::custom("invalid number"))?;
Ok(Some(Duration::from_millis(millis)))
}
_ => Err(D::Error::custom("expected string or number for duration")),
}
}
}
pub fn find_project_root(start_dir: &std::path::Path) -> Option<PathBuf> {
let mut current = start_dir.to_path_buf();
loop {
if current.join(".rkat").is_dir() {
return Some(current);
}
if !current.pop() {
return None;
}
}
}
pub fn data_dir() -> Option<PathBuf> {
if let Ok(cwd) = std::env::current_dir()
&& let Some(root) = find_project_root(&cwd)
{
return Some(root.join(".rkat"));
}
dirs::home_dir().map(|h| h.join(".rkat"))
}
pub mod dirs {
use std::path::PathBuf;
pub fn home_dir() -> Option<PathBuf> {
std::env::var_os("HOME").map(PathBuf::from)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::Provider;
#[test]
fn test_config_default() {
let config = Config::default();
assert_eq!(config.agent.model, "claude-opus-4-6");
assert_eq!(config.agent.max_tokens_per_turn, 16384);
assert_eq!(config.retry.max_retries, 3);
}
#[test]
fn test_config_layering() {
let config = Config::default();
assert_eq!(config.agent.model, "claude-opus-4-6");
assert_eq!(config.budget.max_tokens, None);
{
let env = std::collections::HashMap::from([
("RKAT_MODEL".to_string(), "env-model".to_string()),
("ANTHROPIC_API_KEY".to_string(), "secret-key".to_string()),
]);
let mut config = Config::default();
config
.apply_env_overrides_from(|key| env.get(key).cloned())
.expect("apply env overrides");
match config.provider {
ProviderConfig::Anthropic { api_key, .. } => {
assert_eq!(api_key.as_deref(), Some("secret-key"));
}
_ => unreachable!("expected anthropic provider"),
}
}
let mut config = Config::default();
let file_config = Config {
agent: AgentConfig {
model: "file-model".to_string(),
..Default::default()
},
..Default::default()
};
config.merge(file_config);
assert_eq!(config.agent.model, "file-model");
let mut config = Config::default();
config.apply_cli_overrides(CliOverrides {
model: Some("cli-model".to_string()),
max_tokens: Some(50000),
..Default::default()
});
assert_eq!(config.agent.model, "cli-model");
assert_eq!(config.budget.max_tokens, Some(50000));
}
#[test]
fn test_merge_extraction_prompt_survives_layering() {
let mut base = Config::default();
assert!(base.agent.extraction_prompt.is_none());
let toml = r#"
[agent]
extraction_prompt = "Return JSON only."
"#;
base.merge_toml_str(toml).expect("merge toml");
assert_eq!(
base.agent.extraction_prompt.as_deref(),
Some("Return JSON only.")
);
let toml2 = r#"
[agent]
model = "custom-model"
"#;
base.merge_toml_str(toml2).expect("merge toml2");
assert_eq!(
base.agent.extraction_prompt.as_deref(),
Some("Return JSON only."),
"extraction_prompt must survive merge when absent in later layer"
);
assert_eq!(base.agent.model, "custom-model");
}
#[test]
fn test_merge_hooks_entries_append() {
let mut base = Config::default();
let base_entry = HookEntryConfig {
id: HookId::new("base"),
..HookEntryConfig::default()
};
base.hooks.entries.push(base_entry);
let mut other = Config::default();
let other_entry = HookEntryConfig {
id: HookId::new("other"),
..HookEntryConfig::default()
};
other.hooks.entries.push(other_entry);
base.merge(other);
let ids = base
.hooks
.entries
.iter()
.map(|entry| entry.id.0.as_str())
.collect::<Vec<_>>();
assert_eq!(ids, vec!["base", "other"]);
}
#[test]
fn test_merge_self_hosted_preserves_lower_layer_servers_and_models() {
let mut base = Config::default();
base.merge_toml_str(
r#"
[self_hosted.servers.local]
base_url = "http://127.0.0.1:11434"
"#,
)
.expect("base self-hosted server");
base.merge_toml_str(
r#"
[self_hosted.models.gemma-4-e2b]
server = "local"
remote_model = "gemma4:e2b"
display_name = "Gemma 4 E2B"
family = "gemma-4"
"#,
)
.expect("overlay self-hosted model");
assert!(base.self_hosted.servers.contains_key("local"));
assert!(base.self_hosted.models.contains_key("gemma-4-e2b"));
let registry = base.model_registry().expect("merged self-hosted registry");
assert_eq!(
registry
.entry("gemma-4-e2b")
.and_then(|entry| entry.self_hosted.as_ref())
.map(|server| server.server_id.as_str()),
Some("local")
);
}
#[test]
fn test_merge_self_hosted_partial_server_override_preserves_existing_fields() {
let mut config = Config::default();
config
.merge_toml_str(
r#"
[self_hosted.servers.local]
base_url = "http://127.0.0.1:11434"
api_style = "responses"
"#,
)
.expect("base server");
config
.merge_toml_str(
r#"
[self_hosted.servers.local]
bearer_token_env = "OLLAMA_TOKEN"
"#,
)
.expect("overlay server");
let server = config
.self_hosted
.servers
.get("local")
.expect("merged server");
assert_eq!(server.base_url, "http://127.0.0.1:11434");
assert_eq!(server.api_style, SelfHostedApiStyle::Responses);
assert_eq!(server.bearer_token_env.as_deref(), Some("OLLAMA_TOKEN"));
}
#[test]
fn test_merge_self_hosted_partial_override_preserves_unrelated_inherited_entries() {
let mut config = Config::default();
config
.merge_toml_str(
r#"
[self_hosted.servers.local]
base_url = "http://127.0.0.1:11434"
[self_hosted.servers.backup]
base_url = "http://127.0.0.1:11435"
[self_hosted.models.gemma-4-e2b]
server = "local"
remote_model = "gemma4:e2b"
display_name = "Gemma 4 E2B"
family = "gemma-4"
[self_hosted.models.gemma-4-e4b]
server = "backup"
remote_model = "gemma4:e4b"
display_name = "Gemma 4 E4B"
family = "gemma-4"
"#,
)
.expect("base self-hosted config");
config
.merge_toml_str(
r#"
[self_hosted.servers.local]
bearer_token_env = "OLLAMA_TOKEN"
"#,
)
.expect("overlay self-hosted config");
assert!(config.self_hosted.servers.contains_key("backup"));
assert!(config.self_hosted.models.contains_key("gemma-4-e4b"));
let registry = config
.model_registry()
.expect("registry should remain valid");
assert_eq!(
registry
.entry("gemma-4-e4b")
.and_then(|entry| entry.self_hosted.as_ref())
.map(|server| server.server_id.as_str()),
Some("backup")
);
}
#[test]
fn test_merge_self_hosted_empty_table_clears_inherited_entries() {
let mut config = Config::default();
config
.merge_toml_str(
r#"
[self_hosted.servers.local]
base_url = "http://127.0.0.1:11434"
[self_hosted.models.gemma-4-e2b]
server = "local"
remote_model = "gemma4:e2b"
display_name = "Gemma 4 E2B"
family = "gemma-4"
"#,
)
.expect("base self-hosted config");
config
.merge_toml_str(
r"
[self_hosted.servers]
[self_hosted.models]
",
)
.expect("clear self-hosted config");
assert!(config.self_hosted.servers.is_empty());
assert!(config.self_hosted.models.is_empty());
}
#[test]
fn test_self_hosted_bearer_token_is_not_serialized() {
let config: Config = toml::from_str(
r#"
[self_hosted.servers.local]
base_url = "http://127.0.0.1:11434"
bearer_token = "secret-token"
"#,
)
.expect("config");
let value = serde_json::to_value(&config).expect("serialize config");
let server = &value["self_hosted"]["servers"]["local"];
assert!(
server.get("bearer_token").is_none(),
"literal bearer tokens must be redacted from serialized config"
);
}
#[test]
fn test_merge_providers_section_replaces_non_default() {
let mut base = Config::default();
base.providers.base_urls = Some(HashMap::from([
("anthropic".to_string(), "https://a.example".to_string()),
("openai".to_string(), "https://o.example".to_string()),
]));
let mut other = Config::default();
other.providers.base_urls = Some(HashMap::from([(
"openai".to_string(),
"https://override.example".to_string(),
)]));
base.merge(other);
let urls = base
.providers
.base_urls
.expect("providers.base_urls missing");
assert_eq!(urls.len(), 1);
assert_eq!(
urls.get("openai").map(String::as_str),
Some("https://override.example")
);
assert!(!urls.contains_key("anthropic"));
}
#[test]
fn test_merge_toml_tools_omitted_fields_preserve_lower_layer() {
let mut config = Config::default();
config.tools.mob_enabled = true;
config.tools.shell_enabled = true;
config
.merge_toml_str(
r"
[tools]
shell_enabled = false
",
)
.expect("merge should succeed");
assert!(config.tools.mob_enabled);
assert!(!config.tools.shell_enabled);
}
#[test]
fn test_merge_toml_tools_explicit_default_overrides_lower_layer() {
let mut config = Config::default();
config.tools.mob_enabled = true;
config
.merge_toml_str(
r"
[tools]
mob_enabled = false
",
)
.expect("merge should succeed");
assert!(!config.tools.mob_enabled);
}
#[test]
fn test_merge_toml_retry_omitted_fields_preserve_lower_layer() {
let mut config = Config::default();
config.retry.max_retries = 9;
config
.merge_toml_str(
r#"
[retry]
initial_delay = "750ms"
"#,
)
.expect("merge should succeed");
assert_eq!(config.retry.max_retries, 9);
assert_eq!(config.retry.initial_delay, Duration::from_millis(750));
}
#[test]
fn test_validate_rejects_zero_min_turns_between_compactions() {
let config = Config {
compaction: CompactionRuntimeConfig {
min_turns_between_compactions: 0,
..CompactionRuntimeConfig::default()
},
..Config::default()
};
let err = config
.validate()
.expect_err("min_turns_between_compactions=0 should be invalid");
assert!(
err.to_string()
.contains("compaction.min_turns_between_compactions")
);
}
#[test]
fn test_provider_config_serialization() {
let anthropic = ProviderConfig::Anthropic {
api_key: Some("sk-test".to_string()),
base_url: None,
};
let json = serde_json::to_value(&anthropic).unwrap();
assert_eq!(json["type"], "anthropic");
assert_eq!(json["api_key"], "sk-test");
let openai = ProviderConfig::OpenAI {
api_key: Some("sk-openai".to_string()),
base_url: Some("https://custom.openai.com".to_string()),
};
let json = serde_json::to_value(&openai).unwrap();
assert_eq!(json["type"], "openai");
let gemini = ProviderConfig::Gemini {
api_key: Some("gemini-key".to_string()),
};
let json = serde_json::to_value(&gemini).unwrap();
assert_eq!(json["type"], "gemini");
}
#[test]
fn test_budget_config_serialization() {
let budget = BudgetConfig {
max_tokens: Some(100_000),
max_duration: Some(Duration::from_secs(300)),
max_tool_calls: Some(50),
};
let json = serde_json::to_string(&budget).unwrap();
let parsed: BudgetConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.max_tokens, Some(100_000));
assert_eq!(parsed.max_duration, Some(Duration::from_secs(300)));
assert_eq!(parsed.max_tool_calls, Some(50));
}
#[test]
fn test_self_hosted_transport_accepts_openai_compatible_alias() {
let mut config = Config::default();
config
.merge_toml_str(
r#"
[self_hosted.servers.ollama]
transport = "openai_compatible"
base_url = "http://127.0.0.1:11434"
api_style = "chat_completions"
"#,
)
.expect("alias should parse");
assert_eq!(
config
.self_hosted
.servers
.get("ollama")
.expect("server should exist")
.transport,
SelfHostedTransport::OpenAiCompatible
);
}
#[test]
fn test_self_hosted_server_config_defaults_to_chat_completions() {
assert_eq!(
SelfHostedServerConfig::default().api_style,
SelfHostedApiStyle::ChatCompletions
);
}
#[test]
fn test_retry_config_to_policy() {
let config = RetryConfig::default();
let policy: RetryPolicy = config.into();
assert_eq!(policy.max_retries, 3);
assert_eq!(policy.initial_delay, Duration::from_millis(500));
}
#[tokio::test]
async fn test_regression_load_succeeds_without_config_toml() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let rkat_dir = temp_dir.path().join(".rkat");
std::fs::create_dir(&rkat_dir).unwrap();
assert!(rkat_dir.exists());
assert!(!rkat_dir.join("config.toml").exists());
let result =
Config::load_from_with_env(temp_dir.path(), Some(temp_dir.path()), |_| None).await;
assert!(
result.is_ok(),
"Config::load() should succeed when .rkat/ exists without config.toml: {:?}",
result.err()
);
}
#[test]
fn test_validate_rejects_zero_max_tokens() {
let config = Config {
max_tokens: 0,
..Config::default()
};
let err = config
.validate()
.expect_err("max_tokens=0 should be invalid");
assert!(
err.to_string()
.contains("max_tokens must be greater than 0")
);
}
#[test]
fn test_validate_rejects_zero_agent_max_tokens_per_turn() {
let mut config = Config::default();
config.agent.max_tokens_per_turn = 0;
let err = config
.validate()
.expect_err("agent.max_tokens_per_turn=0 should be invalid");
assert!(err.to_string().contains("agent.max_tokens_per_turn"));
}
#[test]
fn test_validate_rejects_provider_base_url_conflict() {
let mut config = Config {
provider: ProviderConfig::OpenAI {
api_key: None,
base_url: Some("https://one.example".to_string()),
},
..Config::default()
};
config.providers.base_urls = Some(std::collections::HashMap::from([(
"openai".to_string(),
"https://two.example".to_string(),
)]));
let err = config
.validate()
.expect_err("conflicting base_url settings should be invalid");
assert!(err.to_string().contains("provider base_url conflicts"));
}
#[test]
fn test_validate_rejects_provider_api_key_conflict() {
let mut config = Config {
provider: ProviderConfig::Gemini {
api_key: Some("one".to_string()),
},
..Config::default()
};
config.providers.api_keys = Some(std::collections::HashMap::from([(
"gemini".to_string(),
"two".to_string(),
)]));
let err = config
.validate()
.expect_err("conflicting api_key settings should be invalid");
assert!(err.to_string().contains("provider api_key conflicts"));
}
#[test]
fn test_provider_parse_strict() {
assert_eq!(
Provider::parse_strict("anthropic"),
Some(Provider::Anthropic)
);
assert_eq!(Provider::parse_strict("openai"), Some(Provider::OpenAI));
assert_eq!(Provider::parse_strict("gemini"), Some(Provider::Gemini));
assert_eq!(Provider::parse_strict("other"), None);
assert_eq!(Provider::parse_strict("claude"), None);
assert_eq!(Provider::parse_strict(""), None);
}
#[test]
fn test_provider_infer_from_model() {
assert_eq!(
Provider::infer_from_model("claude-opus-4-6"),
Some(Provider::Anthropic)
);
assert_eq!(
Provider::infer_from_model("gpt-5.2"),
Some(Provider::OpenAI)
);
assert_eq!(
Provider::infer_from_model("gemini-3-flash-preview"),
Some(Provider::Gemini)
);
assert_eq!(Provider::infer_from_model("llama-3"), None);
assert_eq!(Provider::infer_from_model(""), None);
}
#[test]
fn test_comms_auth_mode_default_is_open() {
assert_eq!(CommsAuthMode::default(), CommsAuthMode::Open);
}
#[test]
fn test_comms_auth_mode_serde_roundtrip() {
let json = serde_json::to_string(&CommsAuthMode::Open).unwrap();
assert_eq!(json, r#""none""#);
let parsed: CommsAuthMode = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, CommsAuthMode::Open);
let json = serde_json::to_string(&CommsAuthMode::Ed25519).unwrap();
assert_eq!(json, r#""ed25519""#);
let parsed: CommsAuthMode = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, CommsAuthMode::Ed25519);
}
#[test]
fn test_comms_auth_mode_toml_roundtrip() {
let config = CommsRuntimeConfig::default();
let toml_str = toml::to_string(&config).unwrap();
let parsed: CommsRuntimeConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.auth, CommsAuthMode::Open);
assert!(parsed.require_peer_auth);
let toml_str = r#"
mode = "inproc"
auth = "ed25519"
"#;
let parsed: CommsRuntimeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(parsed.auth, CommsAuthMode::Ed25519);
assert!(parsed.require_peer_auth);
}
#[test]
fn test_comms_runtime_config_default_has_open_auth() {
let config = CommsRuntimeConfig::default();
assert_eq!(config.auth, CommsAuthMode::Open);
assert!(config.require_peer_auth);
}
#[test]
fn test_plain_event_source_serde_roundtrip() {
let cases = [
(PlainEventSource::Tcp, r#""tcp""#),
(PlainEventSource::Uds, r#""uds""#),
(PlainEventSource::Stdin, r#""stdin""#),
(PlainEventSource::Webhook, r#""webhook""#),
(PlainEventSource::Rpc, r#""rpc""#),
];
for (variant, expected_json) in cases {
let json = serde_json::to_string(&variant).unwrap();
assert_eq!(json, expected_json, "serialize {variant:?}");
let parsed: PlainEventSource = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, variant, "deserialize {variant:?}");
}
}
#[test]
fn test_plain_event_source_display() {
assert_eq!(PlainEventSource::Tcp.to_string(), "tcp");
assert_eq!(PlainEventSource::Uds.to_string(), "uds");
assert_eq!(PlainEventSource::Stdin.to_string(), "stdin");
assert_eq!(PlainEventSource::Webhook.to_string(), "webhook");
assert_eq!(PlainEventSource::Rpc.to_string(), "rpc");
}
#[test]
fn test_comms_config_event_address_toml_roundtrip() {
let toml_str = r#"
mode = "tcp"
address = "127.0.0.1:4200"
auth = "none"
require_peer_auth = false
event_address = "127.0.0.1:4201"
"#;
let parsed: CommsRuntimeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(parsed.event_address.as_deref(), Some("127.0.0.1:4201"));
assert_eq!(parsed.auth, CommsAuthMode::Open);
assert!(!parsed.require_peer_auth);
}
#[test]
fn test_comms_config_event_address_defaults_none() {
let config = CommsRuntimeConfig::default();
assert!(config.event_address.is_none());
}
#[test]
fn call_timeout_override_default_is_inherit() {
assert_eq!(CallTimeoutOverride::default(), CallTimeoutOverride::Inherit);
assert!(CallTimeoutOverride::default().is_inherit());
}
#[test]
fn call_timeout_override_disabled_is_not_inherit() {
assert!(!CallTimeoutOverride::Disabled.is_inherit());
}
#[test]
fn call_timeout_override_value_is_not_inherit() {
assert!(!CallTimeoutOverride::Value(Duration::from_secs(45)).is_inherit());
}
#[test]
fn call_timeout_override_toml_deserialize_disabled() {
let toml_str = r#"call_timeout = "disabled""#;
#[derive(Deserialize)]
struct Wrapper {
call_timeout: CallTimeoutOverride,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(w.call_timeout, CallTimeoutOverride::Disabled);
}
#[test]
fn call_timeout_override_toml_deserialize_duration() {
let toml_str = r#"call_timeout = "45s""#;
#[derive(Deserialize)]
struct Wrapper {
call_timeout: CallTimeoutOverride,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(
w.call_timeout,
CallTimeoutOverride::Value(Duration::from_secs(45))
);
}
#[test]
fn call_timeout_override_toml_deserialize_complex_duration() {
let toml_str = r#"call_timeout = "2m 30s""#;
#[derive(Deserialize)]
struct Wrapper {
call_timeout: CallTimeoutOverride,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(
w.call_timeout,
CallTimeoutOverride::Value(Duration::from_secs(150))
);
}
#[test]
fn retry_config_default_has_inherit_call_timeout() {
let config = RetryConfig::default();
assert_eq!(config.call_timeout_override, CallTimeoutOverride::Inherit);
}
#[test]
fn retry_config_from_toml_with_call_timeout_value() {
let toml_str = r#"
[retry]
max_retries = 5
call_timeout = "60s"
"#;
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(config.retry.max_retries, 5);
assert_eq!(
config.retry.call_timeout_override,
CallTimeoutOverride::Value(Duration::from_secs(60))
);
}
#[test]
fn retry_config_from_toml_with_call_timeout_disabled() {
let toml_str = r#"
[retry]
call_timeout = "disabled"
"#;
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(
config.retry.call_timeout_override,
CallTimeoutOverride::Disabled
);
}
#[test]
fn retry_config_from_toml_omitted_is_inherit() {
let toml_str = r"
[retry]
max_retries = 2
";
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(
config.retry.call_timeout_override,
CallTimeoutOverride::Inherit
);
}
#[test]
fn retry_policy_from_config_with_value_override() {
let config = RetryConfig {
call_timeout_override: CallTimeoutOverride::Value(Duration::from_secs(90)),
..RetryConfig::default()
};
let policy: crate::retry::RetryPolicy = config.into();
assert_eq!(policy.call_timeout, Some(Duration::from_secs(90)));
}
#[test]
fn retry_policy_from_config_with_disabled_override() {
let config = RetryConfig {
call_timeout_override: CallTimeoutOverride::Disabled,
..RetryConfig::default()
};
let policy: crate::retry::RetryPolicy = config.into();
assert_eq!(policy.call_timeout, None);
}
#[test]
fn retry_policy_from_config_with_inherit_override() {
let config = RetryConfig {
call_timeout_override: CallTimeoutOverride::Inherit,
..RetryConfig::default()
};
let policy: crate::retry::RetryPolicy = config.into();
assert_eq!(policy.call_timeout, None);
}
#[test]
fn config_merge_preserves_call_timeout_override() {
let toml_base = r"
[retry]
max_retries = 2
";
let toml_overlay = r#"
[retry]
call_timeout = "30s"
"#;
let mut config: Config = toml::from_str(toml_base).unwrap();
let overlay: Config = toml::from_str(toml_overlay).unwrap();
let overlay_parsed: toml::Value = toml::from_str(toml_overlay).unwrap();
config.merge_retry_from_toml_presence(&overlay_parsed, &overlay.retry);
assert_eq!(config.retry.max_retries, 2); assert_eq!(
config.retry.call_timeout_override,
CallTimeoutOverride::Value(Duration::from_secs(30))
);
}
#[test]
fn test_provider_tools_defaults_all_enabled() {
let config = Config::default();
assert!(config.provider_tools.anthropic.web_search);
assert!(config.provider_tools.openai.web_search);
assert!(config.provider_tools.gemini.google_search);
}
#[test]
fn test_provider_tools_roundtrip_toml() {
let config = Config::default();
let toml_str = toml::to_string(&config.provider_tools).unwrap();
let parsed: ProviderToolsConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed, config.provider_tools);
}
#[test]
fn test_provider_tools_merge_preserves_when_absent() {
let mut config = Config::default();
config
.merge_toml_str(
r#"[agent]
model = "custom-model"
"#,
)
.unwrap();
assert!(config.provider_tools.anthropic.web_search);
assert!(config.provider_tools.openai.web_search);
assert!(config.provider_tools.gemini.google_search);
}
#[test]
fn test_provider_tools_merge_overrides_single_provider() {
let mut config = Config::default();
config
.merge_toml_str("[provider_tools.anthropic]\nweb_search = false\n")
.unwrap();
assert!(!config.provider_tools.anthropic.web_search);
assert!(config.provider_tools.openai.web_search);
assert!(config.provider_tools.gemini.google_search);
}
#[test]
fn test_provider_tool_defaults_not_serialized() {
let agent_config = AgentConfig {
provider_tool_defaults: Some(
serde_json::json!({"web_search": {"type": "web_search_20250305"}}),
),
..Default::default()
};
let json = serde_json::to_value(&agent_config).unwrap();
assert!(
json.get("provider_tool_defaults").is_none(),
"provider_tool_defaults must not be serialized: {json}"
);
}
}