use std::path::PathBuf;
use config::{Config, Environment, File};
use serde::{Deserialize, Serialize};
use crate::error::AptuError;
pub const DEFAULT_OPENROUTER_MODEL: &str = "mistralai/mistral-small-2603";
pub const DEFAULT_GEMINI_MODEL: &str = "gemini-3.1-flash-lite-preview";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskType {
Triage,
Review,
Create,
}
#[derive(Debug, Default, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct AppConfig {
pub user: UserConfig,
pub ai: AiConfig,
pub github: GitHubConfig,
pub ui: UiConfig,
pub cache: CacheConfig,
pub repos: ReposConfig,
}
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
#[serde(default)]
pub struct UserConfig {
pub default_repo: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
#[serde(default)]
pub struct TaskOverride {
pub provider: Option<String>,
pub model: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
#[serde(default)]
pub struct TasksConfig {
pub triage: Option<TaskOverride>,
pub review: Option<TaskOverride>,
pub create: Option<TaskOverride>,
}
#[derive(Debug, Clone, Serialize)]
pub struct FallbackEntry {
pub provider: String,
pub model: Option<String>,
}
impl<'de> Deserialize<'de> for FallbackEntry {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum EntryVariant {
String(String),
Struct {
provider: String,
model: Option<String>,
},
}
match EntryVariant::deserialize(deserializer)? {
EntryVariant::String(provider) => Ok(FallbackEntry {
provider,
model: None,
}),
EntryVariant::Struct { provider, model } => Ok(FallbackEntry { provider, model }),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[serde(default)]
pub struct FallbackConfig {
pub chain: Vec<FallbackEntry>,
}
fn default_retry_max_attempts() -> u32 {
3
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct AiConfig {
pub provider: String,
pub model: String,
pub timeout_seconds: u64,
pub allow_paid_models: bool,
pub max_tokens: u32,
pub temperature: f32,
pub circuit_breaker_threshold: u32,
pub circuit_breaker_reset_seconds: u64,
#[serde(default = "default_retry_max_attempts")]
pub retry_max_attempts: u32,
pub tasks: Option<TasksConfig>,
pub fallback: Option<FallbackConfig>,
pub custom_guidance: Option<String>,
pub validation_enabled: bool,
}
impl Default for AiConfig {
fn default() -> Self {
Self {
provider: "openrouter".to_string(),
model: DEFAULT_OPENROUTER_MODEL.to_string(),
timeout_seconds: 30,
allow_paid_models: true,
max_tokens: 4096,
temperature: 0.3,
circuit_breaker_threshold: 3,
circuit_breaker_reset_seconds: 60,
retry_max_attempts: default_retry_max_attempts(),
tasks: None,
fallback: None,
custom_guidance: None,
validation_enabled: true,
}
}
}
impl AiConfig {
#[must_use]
pub fn resolve_for_task(&self, task: TaskType) -> (String, String) {
let task_override = match task {
TaskType::Triage => self.tasks.as_ref().and_then(|t| t.triage.as_ref()),
TaskType::Review => self.tasks.as_ref().and_then(|t| t.review.as_ref()),
TaskType::Create => self.tasks.as_ref().and_then(|t| t.create.as_ref()),
};
let provider = task_override
.and_then(|o| o.provider.clone())
.unwrap_or_else(|| self.provider.clone());
let model = task_override
.and_then(|o| o.model.clone())
.unwrap_or_else(|| self.model.clone());
(provider, model)
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct GitHubConfig {
pub api_timeout_seconds: u64,
}
impl Default for GitHubConfig {
fn default() -> Self {
Self {
api_timeout_seconds: 10,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct UiConfig {
pub color: bool,
pub progress_bars: bool,
pub confirm_before_post: bool,
}
impl Default for UiConfig {
fn default() -> Self {
Self {
color: true,
progress_bars: true,
confirm_before_post: true,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct CacheConfig {
pub issue_ttl_minutes: i64,
pub repo_ttl_hours: i64,
pub curated_repos_url: String,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
issue_ttl_minutes: crate::cache::DEFAULT_ISSUE_TTL_MINS,
repo_ttl_hours: crate::cache::DEFAULT_REPO_TTL_HOURS,
curated_repos_url:
"https://raw.githubusercontent.com/clouatre-labs/aptu/main/data/curated-repos.json"
.to_string(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct ReposConfig {
pub curated: bool,
}
impl Default for ReposConfig {
fn default() -> Self {
Self { curated: true }
}
}
#[must_use]
pub fn config_dir() -> PathBuf {
if let Ok(xdg_config) = std::env::var("XDG_CONFIG_HOME")
&& !xdg_config.is_empty()
{
return PathBuf::from(xdg_config).join("aptu");
}
dirs::home_dir()
.expect("Could not determine home directory - is HOME set?")
.join(".config")
.join("aptu")
}
#[must_use]
pub fn data_dir() -> PathBuf {
if let Ok(xdg_data) = std::env::var("XDG_DATA_HOME")
&& !xdg_data.is_empty()
{
return PathBuf::from(xdg_data).join("aptu");
}
dirs::home_dir()
.expect("Could not determine home directory - is HOME set?")
.join(".local")
.join("share")
.join("aptu")
}
#[must_use]
pub fn prompts_dir() -> PathBuf {
config_dir().join("prompts")
}
#[must_use]
pub fn config_file_path() -> PathBuf {
config_dir().join("config.toml")
}
pub fn load_config() -> Result<AppConfig, AptuError> {
let config_path = config_file_path();
let config = Config::builder()
.add_source(File::with_name(config_path.to_string_lossy().as_ref()).required(false))
.add_source(
Environment::with_prefix("APTU")
.prefix_separator("_")
.separator("__")
.try_parsing(true),
)
.build()?;
let app_config: AppConfig = config.try_deserialize()?;
Ok(app_config)
}
#[cfg(test)]
mod tests {
#![allow(unsafe_code)]
use super::*;
use serial_test::serial;
#[test]
#[serial]
fn test_load_config_defaults() {
let tmp_dir = std::env::temp_dir().join("aptu_test_defaults_no_config");
std::fs::create_dir_all(&tmp_dir).expect("create tmp dir");
unsafe {
std::env::set_var("XDG_CONFIG_HOME", &tmp_dir);
}
let config = load_config().expect("should load with defaults");
unsafe {
std::env::remove_var("XDG_CONFIG_HOME");
}
assert_eq!(config.ai.provider, "openrouter");
assert_eq!(config.ai.model, DEFAULT_OPENROUTER_MODEL);
assert_eq!(config.ai.timeout_seconds, 30);
assert_eq!(config.ai.max_tokens, 4096);
assert_eq!(config.ai.allow_paid_models, true);
#[allow(clippy::float_cmp)]
{
assert_eq!(config.ai.temperature, 0.3);
}
assert_eq!(config.github.api_timeout_seconds, 10);
assert!(config.ui.color);
assert!(config.ui.confirm_before_post);
assert_eq!(config.cache.issue_ttl_minutes, 60);
}
#[test]
fn test_config_dir_exists() {
let dir = config_dir();
assert!(dir.ends_with("aptu"));
}
#[test]
fn test_data_dir_exists() {
let dir = data_dir();
assert!(dir.ends_with("aptu"));
}
#[test]
fn test_config_file_path() {
let path = config_file_path();
assert!(path.ends_with("config.toml"));
}
#[test]
fn test_config_with_task_triage_override() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.tasks.triage]
model = "gemini-3.1-flash-lite-preview"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
assert_eq!(app_config.ai.provider, "gemini");
assert_eq!(app_config.ai.model, DEFAULT_GEMINI_MODEL);
assert!(app_config.ai.tasks.is_some());
let tasks = app_config.ai.tasks.unwrap();
assert!(tasks.triage.is_some());
assert!(tasks.review.is_none());
assert!(tasks.create.is_none());
let triage = tasks.triage.unwrap();
assert_eq!(triage.provider, None);
assert_eq!(triage.model, Some(DEFAULT_GEMINI_MODEL.to_string()));
}
#[test]
fn test_config_with_multiple_task_overrides() {
let config_str = r#"
[ai]
provider = "openrouter"
model = "mistralai/mistral-small-2603"
[ai.tasks.triage]
model = "mistralai/mistral-small-2603"
[ai.tasks.review]
provider = "openrouter"
model = "anthropic/claude-haiku-4.5"
[ai.tasks.create]
model = "anthropic/claude-sonnet-4.6"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
let tasks = app_config.ai.tasks.expect("tasks should exist");
let triage = tasks.triage.expect("triage should exist");
assert_eq!(triage.provider, None);
assert_eq!(triage.model, Some(DEFAULT_OPENROUTER_MODEL.to_string()));
let review = tasks.review.expect("review should exist");
assert_eq!(review.provider, Some("openrouter".to_string()));
assert_eq!(review.model, Some("anthropic/claude-haiku-4.5".to_string()));
let create = tasks.create.expect("create should exist");
assert_eq!(create.provider, None);
assert_eq!(
create.model,
Some("anthropic/claude-sonnet-4.6".to_string())
);
}
#[test]
fn test_config_with_partial_task_overrides() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.tasks.triage]
provider = "gemini"
[ai.tasks.review]
model = "gemini-3.1-flash-lite-preview"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
let tasks = app_config.ai.tasks.expect("tasks should exist");
let triage = tasks.triage.expect("triage should exist");
assert_eq!(triage.provider, Some("gemini".to_string()));
assert_eq!(triage.model, None);
let review = tasks.review.expect("review should exist");
assert_eq!(review.provider, None);
assert_eq!(review.model, Some(DEFAULT_GEMINI_MODEL.to_string()));
}
#[test]
fn test_config_without_tasks_section() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
assert_eq!(app_config.ai.provider, "gemini");
assert_eq!(app_config.ai.model, DEFAULT_GEMINI_MODEL);
assert!(app_config.ai.tasks.is_none());
}
#[test]
fn test_resolve_for_task_with_defaults() {
let ai_config = AiConfig::default();
let (provider, model) = ai_config.resolve_for_task(TaskType::Triage);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_OPENROUTER_MODEL);
assert_eq!(ai_config.allow_paid_models, true);
let (provider, model) = ai_config.resolve_for_task(TaskType::Review);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_OPENROUTER_MODEL);
assert_eq!(ai_config.allow_paid_models, true);
let (provider, model) = ai_config.resolve_for_task(TaskType::Create);
assert_eq!(provider, "openrouter");
assert_eq!(model, "mistralai/mistral-small-2603");
assert_eq!(ai_config.allow_paid_models, true);
}
#[test]
fn test_resolve_for_task_with_triage_override() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.tasks.triage]
model = "gemini-3.1-flash-lite-preview"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Triage);
assert_eq!(provider, "gemini");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Review);
assert_eq!(provider, "gemini");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Create);
assert_eq!(provider, "gemini");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
}
#[test]
fn test_resolve_for_task_with_provider_override() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.tasks.review]
provider = "openrouter"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Review);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Triage);
assert_eq!(provider, "gemini");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Create);
assert_eq!(provider, "gemini");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
}
#[test]
fn test_resolve_for_task_with_full_overrides() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.tasks.triage]
provider = "openrouter"
model = "mistralai/mistral-small-2603"
[ai.tasks.review]
provider = "openrouter"
model = "anthropic/claude-haiku-4.5"
[ai.tasks.create]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Triage);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_OPENROUTER_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Review);
assert_eq!(provider, "openrouter");
assert_eq!(model, "anthropic/claude-haiku-4.5");
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Create);
assert_eq!(provider, "gemini");
assert_eq!(model, DEFAULT_GEMINI_MODEL);
}
#[test]
fn test_resolve_for_task_partial_overrides() {
let config_str = r#"
[ai]
provider = "openrouter"
model = "mistralai/mistral-small-2603"
[ai.tasks.triage]
model = "mistralai/mistral-small-2603"
[ai.tasks.review]
provider = "openrouter"
[ai.tasks.create]
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Triage);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_OPENROUTER_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Review);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_OPENROUTER_MODEL);
let (provider, model) = app_config.ai.resolve_for_task(TaskType::Create);
assert_eq!(provider, "openrouter");
assert_eq!(model, DEFAULT_OPENROUTER_MODEL);
}
#[test]
fn test_fallback_config_toml_parsing() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.fallback]
chain = ["openrouter", "anthropic"]
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
assert_eq!(app_config.ai.provider, "gemini");
assert_eq!(app_config.ai.model, "gemini-3.1-flash-lite-preview");
assert!(app_config.ai.fallback.is_some());
let fallback = app_config.ai.fallback.unwrap();
assert_eq!(fallback.chain.len(), 2);
assert_eq!(fallback.chain[0].provider, "openrouter");
assert_eq!(fallback.chain[1].provider, "anthropic");
}
#[test]
fn test_fallback_config_empty_chain() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.fallback]
chain = []
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
assert!(app_config.ai.fallback.is_some());
let fallback = app_config.ai.fallback.unwrap();
assert_eq!(fallback.chain.len(), 0);
}
#[test]
fn test_fallback_config_single_provider() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
[ai.fallback]
chain = ["openrouter"]
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
assert!(app_config.ai.fallback.is_some());
let fallback = app_config.ai.fallback.unwrap();
assert_eq!(fallback.chain.len(), 1);
assert_eq!(fallback.chain[0].provider, "openrouter");
}
#[test]
fn test_fallback_config_without_fallback_section() {
let config_str = r#"
[ai]
provider = "gemini"
model = "gemini-3.1-flash-lite-preview"
"#;
let config = Config::builder()
.add_source(config::File::from_str(config_str, config::FileFormat::Toml))
.build()
.expect("should build config");
let app_config: AppConfig = config.try_deserialize().expect("should deserialize");
assert!(app_config.ai.fallback.is_none());
}
#[test]
fn test_fallback_config_default() {
let ai_config = AiConfig::default();
assert!(ai_config.fallback.is_none());
}
#[test]
#[serial]
fn test_load_config_env_var_override() {
let tmp_dir = std::env::temp_dir().join("aptu_test_env_override");
std::fs::create_dir_all(&tmp_dir).expect("create tmp dir");
unsafe {
std::env::set_var("XDG_CONFIG_HOME", &tmp_dir);
std::env::set_var("APTU_AI__MODEL", "test-model-override");
std::env::set_var("APTU_AI__PROVIDER", "openrouter");
}
let config = load_config().expect("should load with env overrides");
unsafe {
std::env::remove_var("XDG_CONFIG_HOME");
std::env::remove_var("APTU_AI__MODEL");
std::env::remove_var("APTU_AI__PROVIDER");
}
assert_eq!(config.ai.model, "test-model-override");
assert_eq!(config.ai.provider, "openrouter");
}
}