use crate::spec_ai_config::config::agent::AgentProfile;
use anyhow::{Context, Result};
use directories::BaseDirs;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
const DEFAULT_CONFIG: &str =
include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/spec-ai.config.toml"));
const CONFIG_FILE_NAME: &str = "spec-ai.config.toml";
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AppConfig {
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub model: ModelConfig,
#[serde(default)]
pub ui: UiConfig,
#[serde(default)]
pub logging: LoggingConfig,
#[serde(default)]
pub audio: AudioConfig,
#[serde(default)]
pub mesh: MeshConfig,
#[serde(default)]
pub plugins: PluginConfig,
#[serde(default)]
pub sync: SyncConfig,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub agents: HashMap<String, AgentProfile>,
#[serde(default)]
pub default_agent: Option<String>,
}
impl AppConfig {
pub fn load() -> Result<Self> {
if let Ok(content) = std::fs::read_to_string(CONFIG_FILE_NAME) {
return toml::from_str(&content)
.map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", CONFIG_FILE_NAME, e));
}
if let Ok(base_dirs) =
BaseDirs::new().ok_or(anyhow::anyhow!("Could not determine home directory"))
{
let home_config = base_dirs.home_dir().join(".spec-ai").join(CONFIG_FILE_NAME);
if let Ok(content) = std::fs::read_to_string(&home_config) {
return toml::from_str(&content).map_err(|e| {
anyhow::anyhow!("Failed to parse {}: {}", home_config.display(), e)
});
}
}
if let Ok(config_path) = std::env::var("CONFIG_PATH") {
if let Ok(content) = std::fs::read_to_string(&config_path) {
return toml::from_str(&content)
.map_err(|e| anyhow::anyhow!("Failed to parse config: {}", e));
}
}
eprintln!(
"No configuration file found. Creating {} with default settings...",
CONFIG_FILE_NAME
);
if let Err(e) = std::fs::write(CONFIG_FILE_NAME, DEFAULT_CONFIG) {
eprintln!("Warning: Could not create {}: {}", CONFIG_FILE_NAME, e);
eprintln!("Continuing with default configuration in memory.");
} else {
eprintln!(
"Created {}. You can edit this file to customize your settings.",
CONFIG_FILE_NAME
);
}
toml::from_str(DEFAULT_CONFIG)
.map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
}
pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
match std::fs::read_to_string(path) {
Ok(content) => toml::from_str(&content).map_err(|e| {
anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e)
}),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
eprintln!(
"Configuration file not found at {}. Creating with default settings...",
path.display()
);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.context(format!("Failed to create directory {}", parent.display()))?;
}
std::fs::write(path, DEFAULT_CONFIG).context(format!(
"Failed to create config file at {}",
path.display()
))?;
eprintln!(
"Created {}. You can edit this file to customize your settings.",
path.display()
);
toml::from_str(DEFAULT_CONFIG)
.map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
}
Err(e) => Err(anyhow::anyhow!(
"Failed to read config file {}: {}",
path.display(),
e
)),
}
}
pub fn validate(&self) -> Result<()> {
if self.model.provider.is_empty() {
return Err(anyhow::anyhow!("Model provider cannot be empty"));
}
{
let p = self.model.provider.to_lowercase();
let known = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
if !known.contains(&p.as_str()) {
return Err(anyhow::anyhow!(
"Invalid model provider: {}",
self.model.provider
));
}
}
if self.model.temperature < 0.0 || self.model.temperature > 2.0 {
return Err(anyhow::anyhow!(
"Temperature must be between 0.0 and 2.0, got {}",
self.model.temperature
));
}
match self.logging.level.as_str() {
"trace" | "debug" | "info" | "warn" | "error" => {}
_ => return Err(anyhow::anyhow!("Invalid log level: {}", self.logging.level)),
}
if let Some(default_agent) = &self.default_agent {
if !self.agents.contains_key(default_agent) {
return Err(anyhow::anyhow!(
"Default agent '{}' not found in agents map",
default_agent
));
}
}
Ok(())
}
pub fn apply_env_overrides(&mut self) {
fn first(a: &str, b: &str) -> Option<String> {
std::env::var(a).ok().or_else(|| std::env::var(b).ok())
}
if let Some(provider) = first("AGENT_MODEL_PROVIDER", "SPEC_AI_PROVIDER") {
self.model.provider = provider;
}
if let Some(model_name) = first("AGENT_MODEL_NAME", "SPEC_AI_MODEL") {
self.model.model_name = Some(model_name);
}
if let Some(code_model) = first("AGENT_CODE_MODEL", "SPEC_AI_CODE_MODEL") {
self.model.code_model = Some(code_model);
}
if let Some(api_key_source) = first("AGENT_API_KEY_SOURCE", "SPEC_AI_API_KEY_SOURCE") {
self.model.api_key_source = Some(api_key_source);
}
if let Some(temp_str) = first("AGENT_MODEL_TEMPERATURE", "SPEC_AI_TEMPERATURE") {
if let Ok(temp) = temp_str.parse::<f32>() {
self.model.temperature = temp;
}
}
if let Some(level) = first("AGENT_LOG_LEVEL", "SPEC_AI_LOG_LEVEL") {
self.logging.level = level;
}
if let Some(db_path) = first("AGENT_DB_PATH", "SPEC_AI_DB_PATH") {
self.database.path = PathBuf::from(db_path);
}
if let Some(theme) = first("AGENT_UI_THEME", "SPEC_AI_UI_THEME") {
self.ui.theme = theme;
}
if let Some(default_agent) = first("AGENT_DEFAULT_AGENT", "SPEC_AI_DEFAULT_AGENT") {
self.default_agent = Some(default_agent);
}
}
pub fn summary(&self) -> String {
let mut summary = String::new();
summary.push_str("Configuration loaded:\n");
summary.push_str(&format!("Database: {}\n", self.database.path.display()));
summary.push_str(&format!("Model Provider: {}\n", self.model.provider));
if let Some(model) = &self.model.model_name {
summary.push_str(&format!("Model Name: {}\n", model));
}
if let Some(code_model) = &self.model.code_model {
summary.push_str(&format!("Code Model: {}\n", code_model));
}
summary.push_str(&format!("Temperature: {}\n", self.model.temperature));
summary.push_str(&format!("Logging Level: {}\n", self.logging.level));
summary.push_str(&format!("UI Theme: {}\n", self.ui.theme));
summary.push_str(&format!("Available Agents: {}\n", self.agents.len()));
if let Some(default) = &self.default_agent {
summary.push_str(&format!("Default Agent: {}\n", default));
}
summary
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub path: PathBuf,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
path: PathBuf::from("spec-ai.duckdb"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub provider: String,
#[serde(default)]
pub model_name: Option<String>,
#[serde(default)]
pub code_model: Option<String>,
#[serde(default)]
pub embeddings_model: Option<String>,
#[serde(default)]
pub api_key_source: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
}
fn default_temperature() -> f32 {
0.7
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
provider: "mock".to_string(),
model_name: None,
code_model: None,
embeddings_model: None,
api_key_source: None,
temperature: default_temperature(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UiConfig {
pub prompt: String,
pub theme: String,
}
impl Default for UiConfig {
fn default() -> Self {
Self {
prompt: "> ".to_string(),
theme: "default".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: String,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeshConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_registry_port")]
pub registry_port: u16,
#[serde(default = "default_heartbeat_interval")]
pub heartbeat_interval_secs: u64,
#[serde(default = "default_leader_timeout")]
pub leader_timeout_secs: u64,
#[serde(default = "default_replication_factor")]
pub replication_factor: usize,
#[serde(default)]
pub auto_join: bool,
}
fn default_registry_port() -> u16 {
3000
}
fn default_heartbeat_interval() -> u64 {
5
}
fn default_leader_timeout() -> u64 {
15
}
fn default_replication_factor() -> usize {
2
}
impl Default for MeshConfig {
fn default() -> Self {
Self {
enabled: false,
registry_port: default_registry_port(),
heartbeat_interval_secs: default_heartbeat_interval(),
leader_timeout_secs: default_leader_timeout(),
replication_factor: default_replication_factor(),
auto_join: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_transcription_provider")]
pub provider: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub api_key_source: Option<String>,
#[serde(default)]
pub on_device: bool,
#[serde(default)]
pub endpoint: Option<String>,
#[serde(default = "default_chunk_duration")]
pub chunk_duration_secs: f64,
#[serde(default = "default_duration")]
pub default_duration_secs: u64,
#[serde(default = "default_duration")]
pub default_duration: u64,
#[serde(default)]
pub out_file: Option<String>,
#[serde(default)]
pub language: Option<String>,
#[serde(default)]
pub auto_respond: bool,
#[serde(default = "default_mock_scenario")]
pub mock_scenario: String,
#[serde(default = "default_event_delay_ms")]
pub event_delay_ms: u64,
#[serde(default)]
pub speak_responses: bool,
}
fn default_transcription_provider() -> String {
"vttrs".to_string()
}
fn default_chunk_duration() -> f64 {
5.0
}
fn default_duration() -> u64 {
30
}
fn default_mock_scenario() -> String {
"simple_conversation".to_string()
}
fn default_event_delay_ms() -> u64 {
500
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
provider: default_transcription_provider(),
model: Some("whisper-1".to_string()),
api_key_source: None,
on_device: false,
endpoint: None,
chunk_duration_secs: default_chunk_duration(),
default_duration_secs: default_duration(),
default_duration: default_duration(),
out_file: None,
language: None,
auto_respond: false,
mock_scenario: default_mock_scenario(),
event_delay_ms: default_event_delay_ms(),
speak_responses: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_plugins_dir")]
pub custom_tools_dir: PathBuf,
#[serde(default = "default_continue_on_error")]
pub continue_on_error: bool,
#[serde(default)]
pub allow_override_builtin: bool,
}
fn default_plugins_dir() -> PathBuf {
PathBuf::from("~/.spec-ai/tools")
}
fn default_continue_on_error() -> bool {
true
}
impl Default for PluginConfig {
fn default() -> Self {
Self {
enabled: false,
custom_tools_dir: default_plugins_dir(),
continue_on_error: true,
allow_override_builtin: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub credentials_file: Option<PathBuf>,
#[serde(default = "default_token_expiry")]
pub token_expiry_secs: u64,
#[serde(default)]
pub token_secret: Option<String>,
}
fn default_token_expiry() -> u64 {
86400 }
impl Default for AuthConfig {
fn default() -> Self {
Self {
enabled: false,
credentials_file: None,
token_expiry_secs: default_token_expiry(),
token_secret: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_sync_interval")]
pub interval_secs: u64,
#[serde(default = "default_max_concurrent_syncs")]
pub max_concurrent_syncs: usize,
#[serde(default = "default_retry_interval")]
pub retry_interval_secs: u64,
#[serde(default = "default_max_retries")]
pub max_retries: usize,
#[serde(default)]
pub namespaces: Vec<SyncNamespace>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncNamespace {
pub session_id: String,
#[serde(default = "default_graph_name")]
pub graph_name: String,
}
fn default_sync_interval() -> u64 {
60
}
fn default_max_concurrent_syncs() -> usize {
3
}
fn default_retry_interval() -> u64 {
300
}
fn default_max_retries() -> usize {
3
}
fn default_graph_name() -> String {
"default".to_string()
}
impl Default for SyncConfig {
fn default() -> Self {
Self {
enabled: false,
interval_secs: default_sync_interval(),
max_concurrent_syncs: default_max_concurrent_syncs(),
retry_interval_secs: default_retry_interval(),
max_retries: default_max_retries(),
namespaces: Vec::new(),
}
}
}