use crate::spec_ai_config::config::agent::AgentProfile;
use anyhow::{Context, Result};
use directories::BaseDirs;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
const DEFAULT_CONFIG: &str =
include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/spec-ai.config.toml"));
const CONFIG_FILE_NAME: &str = "spec-ai.config.toml";
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AppConfig {
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub model: ModelConfig,
#[serde(default)]
pub ui: UiConfig,
#[serde(default)]
pub logging: LoggingConfig,
#[serde(default)]
pub audio: AudioConfig,
#[serde(default)]
pub mesh: MeshConfig,
#[serde(default)]
pub plugins: PluginConfig,
#[serde(default)]
pub sync: SyncConfig,
#[serde(default)]
pub skills: SkillsConfig,
#[serde(default)]
pub mcp: McpConfig,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub safety: SafetyConfig,
#[serde(default)]
pub approval: ApprovalConfig,
#[serde(default)]
pub agents: HashMap<String, AgentProfile>,
#[serde(default)]
pub default_agent: Option<String>,
}
impl AppConfig {
pub fn load() -> Result<Self> {
if let Ok(content) = std::fs::read_to_string(CONFIG_FILE_NAME) {
return toml::from_str(&content)
.map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", CONFIG_FILE_NAME, e));
}
if let Ok(base_dirs) =
BaseDirs::new().ok_or(anyhow::anyhow!("Could not determine home directory"))
{
let home_config = base_dirs.home_dir().join(".spec-ai").join(CONFIG_FILE_NAME);
if let Ok(content) = std::fs::read_to_string(&home_config) {
return toml::from_str(&content).map_err(|e| {
anyhow::anyhow!("Failed to parse {}: {}", home_config.display(), e)
});
}
}
if let Ok(config_path) = std::env::var("CONFIG_PATH") {
if let Ok(content) = std::fs::read_to_string(&config_path) {
return toml::from_str(&content)
.map_err(|e| anyhow::anyhow!("Failed to parse config: {}", e));
}
}
eprintln!(
"No configuration file found. Creating {} with default settings...",
CONFIG_FILE_NAME
);
if let Err(e) = std::fs::write(CONFIG_FILE_NAME, DEFAULT_CONFIG) {
eprintln!("Warning: Could not create {}: {}", CONFIG_FILE_NAME, e);
eprintln!("Continuing with default configuration in memory.");
} else {
eprintln!(
"Created {}. You can edit this file to customize your settings.",
CONFIG_FILE_NAME
);
}
toml::from_str(DEFAULT_CONFIG)
.map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
}
pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
match std::fs::read_to_string(path) {
Ok(content) => toml::from_str(&content).map_err(|e| {
anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e)
}),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
eprintln!(
"Configuration file not found at {}. Creating with default settings...",
path.display()
);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.context(format!("Failed to create directory {}", parent.display()))?;
}
std::fs::write(path, DEFAULT_CONFIG).context(format!(
"Failed to create config file at {}",
path.display()
))?;
eprintln!(
"Created {}. You can edit this file to customize your settings.",
path.display()
);
toml::from_str(DEFAULT_CONFIG)
.map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
}
Err(e) => Err(anyhow::anyhow!(
"Failed to read config file {}: {}",
path.display(),
e
)),
}
}
pub fn validate(&self) -> Result<()> {
if self.model.provider.is_empty() {
return Err(anyhow::anyhow!("Model provider cannot be empty"));
}
{
let p = self.model.provider.to_lowercase();
let known = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
if !known.contains(&p.as_str()) {
return Err(anyhow::anyhow!(
"Invalid model provider: {}",
self.model.provider
));
}
}
if self.model.temperature < 0.0 || self.model.temperature > 2.0 {
return Err(anyhow::anyhow!(
"Temperature must be between 0.0 and 2.0, got {}",
self.model.temperature
));
}
self.safety.validate()?;
self.approval.validate()?;
match self.logging.level.as_str() {
"trace" | "debug" | "info" | "warn" | "error" => {}
_ => return Err(anyhow::anyhow!("Invalid log level: {}", self.logging.level)),
}
if let Some(default_agent) = &self.default_agent {
if !self.agents.contains_key(default_agent) {
return Err(anyhow::anyhow!(
"Default agent '{}' not found in agents map",
default_agent
));
}
}
Ok(())
}
pub fn apply_env_overrides(&mut self) {
fn first(a: &str, b: &str) -> Option<String> {
std::env::var(a).ok().or_else(|| std::env::var(b).ok())
}
if let Some(provider) = first("AGENT_MODEL_PROVIDER", "SPEC_AI_PROVIDER") {
self.model.provider = provider;
}
if let Some(model_name) = first("AGENT_MODEL_NAME", "SPEC_AI_MODEL") {
self.model.model_name = Some(model_name);
}
if let Some(code_model) = first("AGENT_CODE_MODEL", "SPEC_AI_CODE_MODEL") {
self.model.code_model = Some(code_model);
}
if let Some(api_key_source) = first("AGENT_API_KEY_SOURCE", "SPEC_AI_API_KEY_SOURCE") {
self.model.api_key_source = Some(api_key_source);
}
if let Some(temp_str) = first("AGENT_MODEL_TEMPERATURE", "SPEC_AI_TEMPERATURE") {
if let Ok(temp) = temp_str.parse::<f32>() {
self.model.temperature = temp;
}
}
if let Some(level) = first("AGENT_LOG_LEVEL", "SPEC_AI_LOG_LEVEL") {
self.logging.level = level;
}
if let Some(db_path) = first("AGENT_DB_PATH", "SPEC_AI_DB_PATH") {
self.database.path = PathBuf::from(db_path);
}
if let Some(theme) = first("AGENT_UI_THEME", "SPEC_AI_UI_THEME") {
self.ui.theme = theme;
}
if let Some(default_agent) = first("AGENT_DEFAULT_AGENT", "SPEC_AI_DEFAULT_AGENT") {
self.default_agent = Some(default_agent);
}
if let Some(value) = first(
"AGENT_MAX_MODEL_CALLS_PER_RUN",
"SPEC_AI_MAX_MODEL_CALLS_PER_RUN",
) {
if let Ok(parsed) = value.parse::<usize>() {
self.safety.max_model_calls_per_run = parsed;
}
}
if let Some(value) = first(
"AGENT_MAX_TOOL_CALLS_PER_RUN",
"SPEC_AI_MAX_TOOL_CALLS_PER_RUN",
) {
if let Ok(parsed) = value.parse::<usize>() {
self.safety.max_tool_calls_per_run = parsed;
}
}
if let Some(value) = first(
"AGENT_MAX_TOOL_LOOP_ITERATIONS",
"SPEC_AI_MAX_TOOL_LOOP_ITERATIONS",
) {
if let Ok(parsed) = value.parse::<usize>() {
self.safety.max_tool_loop_iterations = parsed;
}
}
if let Some(value) = first(
"AGENT_MAX_TOTAL_TOKENS_PER_RUN",
"SPEC_AI_MAX_TOTAL_TOKENS_PER_RUN",
) {
if let Ok(parsed) = value.parse::<u64>() {
self.safety.max_total_tokens_per_run = parsed;
}
}
if let Some(value) = first(
"AGENT_MAX_OUTPUT_TOKENS_PER_CALL",
"SPEC_AI_MAX_OUTPUT_TOKENS_PER_CALL",
) {
if let Ok(parsed) = value.parse::<u32>() {
self.safety.max_output_tokens_per_call = parsed;
}
}
}
pub fn summary(&self) -> String {
let mut summary = String::new();
summary.push_str("Configuration loaded:\n");
summary.push_str(&format!("Database: {}\n", self.database.path.display()));
summary.push_str(&format!("Model Provider: {}\n", self.model.provider));
if let Some(model) = &self.model.model_name {
summary.push_str(&format!("Model Name: {}\n", model));
}
if let Some(code_model) = &self.model.code_model {
summary.push_str(&format!("Code Model: {}\n", code_model));
}
summary.push_str(&format!("Temperature: {}\n", self.model.temperature));
summary.push_str(&format!("Logging Level: {}\n", self.logging.level));
summary.push_str(&format!("Approval Mode: {}\n", self.approval.mode.as_str()));
summary.push_str(&format!("UI Theme: {}\n", self.ui.theme));
summary.push_str(&format!("Available Agents: {}\n", self.agents.len()));
if let Some(default) = &self.default_agent {
summary.push_str(&format!("Default Agent: {}\n", default));
}
summary
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalConfig {
#[serde(default = "default_approval_mode")]
pub mode: ApprovalMode,
#[serde(default)]
pub tools: HashMap<String, ApprovalMode>,
}
impl ApprovalConfig {
pub fn validate(&self) -> Result<()> {
for tool_name in self.tools.keys() {
if tool_name.trim().is_empty() {
return Err(anyhow::anyhow!(
"approval.tools contains an empty tool name"
));
}
}
Ok(())
}
pub fn mode_for_tool(&self, tool_name: &str) -> ApprovalMode {
self.tools.get(tool_name).copied().unwrap_or(self.mode)
}
}
impl Default for ApprovalConfig {
fn default() -> Self {
Self {
mode: default_approval_mode(),
tools: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApprovalMode {
Ask,
Auto,
Allow,
AllowAll,
Deny,
}
impl ApprovalMode {
pub fn as_str(&self) -> &'static str {
match self {
ApprovalMode::Ask => "ask",
ApprovalMode::Auto => "auto",
ApprovalMode::Allow => "allow",
ApprovalMode::AllowAll => "allow_all",
ApprovalMode::Deny => "deny",
}
}
pub fn is_allowing(&self) -> bool {
matches!(self, ApprovalMode::Allow | ApprovalMode::AllowAll)
}
}
fn default_approval_mode() -> ApprovalMode {
ApprovalMode::Ask
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub path: PathBuf,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
path: PathBuf::from("spec-ai.duckdb"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub provider: String,
#[serde(default)]
pub model_name: Option<String>,
#[serde(default)]
pub code_model: Option<String>,
#[serde(default)]
pub embeddings_model: Option<String>,
#[serde(default)]
pub api_key_source: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyConfig {
#[serde(default = "default_safety_enabled")]
pub enabled: bool,
#[serde(default = "default_max_model_calls_per_run")]
pub max_model_calls_per_run: usize,
#[serde(default = "default_max_tool_calls_per_run")]
pub max_tool_calls_per_run: usize,
#[serde(default = "default_max_tool_loop_iterations")]
pub max_tool_loop_iterations: usize,
#[serde(default = "default_max_repeated_tool_calls")]
pub max_repeated_tool_calls: usize,
#[serde(default = "default_max_total_tokens_per_run")]
pub max_total_tokens_per_run: u64,
#[serde(default = "default_max_output_tokens_per_call")]
pub max_output_tokens_per_call: u32,
#[serde(default = "default_max_prompt_bytes_per_run")]
pub max_prompt_bytes_per_run: usize,
#[serde(default = "default_max_tool_output_bytes")]
pub max_tool_output_bytes: usize,
#[serde(default = "default_max_delegation_depth")]
pub max_delegation_depth: usize,
}
fn default_safety_enabled() -> bool {
true
}
fn default_max_model_calls_per_run() -> usize {
6
}
fn default_max_tool_calls_per_run() -> usize {
12
}
fn default_max_tool_loop_iterations() -> usize {
5
}
fn default_max_repeated_tool_calls() -> usize {
3
}
fn default_max_total_tokens_per_run() -> u64 {
50_000
}
fn default_max_output_tokens_per_call() -> u32 {
4_096
}
fn default_max_prompt_bytes_per_run() -> usize {
200_000
}
fn default_max_tool_output_bytes() -> usize {
64_000
}
fn default_max_delegation_depth() -> usize {
3
}
impl SafetyConfig {
pub fn validate(&self) -> Result<()> {
if !self.enabled {
return Ok(());
}
if self.max_model_calls_per_run == 0 {
return Err(anyhow::anyhow!(
"safety.max_model_calls_per_run must be greater than 0"
));
}
if self.max_tool_calls_per_run == 0 {
return Err(anyhow::anyhow!(
"safety.max_tool_calls_per_run must be greater than 0"
));
}
if self.max_tool_loop_iterations == 0 {
return Err(anyhow::anyhow!(
"safety.max_tool_loop_iterations must be greater than 0"
));
}
if self.max_repeated_tool_calls == 0 {
return Err(anyhow::anyhow!(
"safety.max_repeated_tool_calls must be greater than 0"
));
}
if self.max_total_tokens_per_run == 0 {
return Err(anyhow::anyhow!(
"safety.max_total_tokens_per_run must be greater than 0"
));
}
if self.max_output_tokens_per_call == 0 {
return Err(anyhow::anyhow!(
"safety.max_output_tokens_per_call must be greater than 0"
));
}
if self.max_prompt_bytes_per_run == 0 {
return Err(anyhow::anyhow!(
"safety.max_prompt_bytes_per_run must be greater than 0"
));
}
if self.max_tool_output_bytes == 0 {
return Err(anyhow::anyhow!(
"safety.max_tool_output_bytes must be greater than 0"
));
}
Ok(())
}
}
impl Default for SafetyConfig {
fn default() -> Self {
Self {
enabled: default_safety_enabled(),
max_model_calls_per_run: default_max_model_calls_per_run(),
max_tool_calls_per_run: default_max_tool_calls_per_run(),
max_tool_loop_iterations: default_max_tool_loop_iterations(),
max_repeated_tool_calls: default_max_repeated_tool_calls(),
max_total_tokens_per_run: default_max_total_tokens_per_run(),
max_output_tokens_per_call: default_max_output_tokens_per_call(),
max_prompt_bytes_per_run: default_max_prompt_bytes_per_run(),
max_tool_output_bytes: default_max_tool_output_bytes(),
max_delegation_depth: default_max_delegation_depth(),
}
}
}
fn default_temperature() -> f32 {
0.7
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
provider: "mock".to_string(),
model_name: None,
code_model: None,
embeddings_model: None,
api_key_source: None,
temperature: default_temperature(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UiConfig {
pub prompt: String,
pub theme: String,
}
impl Default for UiConfig {
fn default() -> Self {
Self {
prompt: "> ".to_string(),
theme: "default".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: String,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeshConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_registry_port")]
pub registry_port: u16,
#[serde(default = "default_heartbeat_interval")]
pub heartbeat_interval_secs: u64,
#[serde(default = "default_leader_timeout")]
pub leader_timeout_secs: u64,
#[serde(default = "default_replication_factor")]
pub replication_factor: usize,
#[serde(default)]
pub auto_join: bool,
}
fn default_registry_port() -> u16 {
3000
}
fn default_heartbeat_interval() -> u64 {
5
}
fn default_leader_timeout() -> u64 {
15
}
fn default_replication_factor() -> usize {
2
}
impl Default for MeshConfig {
fn default() -> Self {
Self {
enabled: false,
registry_port: default_registry_port(),
heartbeat_interval_secs: default_heartbeat_interval(),
leader_timeout_secs: default_leader_timeout(),
replication_factor: default_replication_factor(),
auto_join: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_transcription_provider")]
pub provider: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub api_key_source: Option<String>,
#[serde(default)]
pub on_device: bool,
#[serde(default)]
pub endpoint: Option<String>,
#[serde(default = "default_chunk_duration")]
pub chunk_duration_secs: f64,
#[serde(default = "default_duration")]
pub default_duration_secs: u64,
#[serde(default = "default_duration")]
pub default_duration: u64,
#[serde(default)]
pub out_file: Option<String>,
#[serde(default)]
pub language: Option<String>,
#[serde(default)]
pub auto_respond: bool,
#[serde(default = "default_mock_scenario")]
pub mock_scenario: String,
#[serde(default = "default_event_delay_ms")]
pub event_delay_ms: u64,
#[serde(default)]
pub speak_responses: bool,
}
fn default_transcription_provider() -> String {
"vttrs".to_string()
}
fn default_chunk_duration() -> f64 {
5.0
}
fn default_duration() -> u64 {
30
}
fn default_mock_scenario() -> String {
"simple_conversation".to_string()
}
fn default_event_delay_ms() -> u64 {
500
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
provider: default_transcription_provider(),
model: Some("whisper-1".to_string()),
api_key_source: None,
on_device: false,
endpoint: None,
chunk_duration_secs: default_chunk_duration(),
default_duration_secs: default_duration(),
default_duration: default_duration(),
out_file: None,
language: None,
auto_respond: false,
mock_scenario: default_mock_scenario(),
event_delay_ms: default_event_delay_ms(),
speak_responses: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_plugins_dir")]
pub custom_tools_dir: PathBuf,
#[serde(default = "default_continue_on_error")]
pub continue_on_error: bool,
#[serde(default)]
pub allow_override_builtin: bool,
}
fn default_plugins_dir() -> PathBuf {
PathBuf::from("~/.spec-ai/tools")
}
fn default_continue_on_error() -> bool {
true
}
impl Default for PluginConfig {
fn default() -> Self {
Self {
enabled: false,
custom_tools_dir: default_plugins_dir(),
continue_on_error: true,
allow_override_builtin: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub credentials_file: Option<PathBuf>,
#[serde(default = "default_token_expiry")]
pub token_expiry_secs: u64,
#[serde(default)]
pub token_secret: Option<String>,
}
fn default_token_expiry() -> u64 {
86400 }
impl Default for AuthConfig {
fn default() -> Self {
Self {
enabled: false,
credentials_file: None,
token_expiry_secs: default_token_expiry(),
token_secret: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safety_config_defaults_are_strict() {
let safety = SafetyConfig::default();
assert!(safety.enabled);
assert_eq!(safety.max_model_calls_per_run, 6);
assert_eq!(safety.max_tool_calls_per_run, 12);
assert_eq!(safety.max_tool_loop_iterations, 5);
assert_eq!(safety.max_repeated_tool_calls, 3);
assert_eq!(safety.max_total_tokens_per_run, 50_000);
assert_eq!(safety.max_output_tokens_per_call, 4_096);
assert!(safety.validate().is_ok());
}
#[test]
fn safety_config_rejects_zero_limits_when_enabled() {
let safety = SafetyConfig {
max_model_calls_per_run: 0,
..SafetyConfig::default()
};
assert!(safety.validate().is_err());
}
#[test]
fn app_config_parses_safety_section() {
let toml = r#"
[model]
provider = "mock"
[safety]
max_model_calls_per_run = 9
max_tool_calls_per_run = 20
"#;
let config: AppConfig = toml::from_str(toml).unwrap();
assert_eq!(config.safety.max_model_calls_per_run, 9);
assert_eq!(config.safety.max_tool_calls_per_run, 20);
assert!(config.safety.enabled);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_sync_interval")]
pub interval_secs: u64,
#[serde(default = "default_max_concurrent_syncs")]
pub max_concurrent_syncs: usize,
#[serde(default = "default_retry_interval")]
pub retry_interval_secs: u64,
#[serde(default = "default_max_retries")]
pub max_retries: usize,
#[serde(default)]
pub namespaces: Vec<SyncNamespace>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncNamespace {
pub session_id: String,
#[serde(default = "default_graph_name")]
pub graph_name: String,
}
fn default_sync_interval() -> u64 {
60
}
fn default_max_concurrent_syncs() -> usize {
3
}
fn default_retry_interval() -> u64 {
300
}
fn default_max_retries() -> usize {
3
}
fn default_graph_name() -> String {
"default".to_string()
}
impl Default for SyncConfig {
fn default() -> Self {
Self {
enabled: false,
interval_secs: default_sync_interval(),
max_concurrent_syncs: default_max_concurrent_syncs(),
retry_interval_secs: default_retry_interval(),
max_retries: default_max_retries(),
namespaces: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillsConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_skills_dirs")]
pub skills_dirs: Vec<PathBuf>,
}
fn default_skills_dirs() -> Vec<PathBuf> {
vec![PathBuf::from("~/.agents/skills")]
}
impl Default for SkillsConfig {
fn default() -> Self {
Self {
enabled: true,
skills_dirs: default_skills_dirs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub servers: HashMap<String, McpServerConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
}
fn default_true() -> bool {
true
}