use crate::git::GitRepo;
use crate::instruction_presets::get_instruction_preset_library;
use crate::log_debug;
use crate::providers::{Provider, ProviderConfig};
use anyhow::{Context, Result, anyhow};
use dirs::{config_dir, home_dir};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct Config {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub default_provider: String,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub providers: HashMap<String, ProviderConfig>,
#[serde(default = "default_true", skip_serializing_if = "is_true")]
pub use_gitmoji: bool,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub instructions: String,
#[serde(default = "default_preset", skip_serializing_if = "is_default_preset")]
pub instruction_preset: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub theme: String,
#[serde(
default = "default_subagent_timeout",
skip_serializing_if = "is_default_subagent_timeout"
)]
pub subagent_timeout_secs: u64,
#[serde(skip)]
pub temp_instructions: Option<String>,
#[serde(skip)]
pub temp_preset: Option<String>,
#[serde(skip)]
pub is_project_config: bool,
#[serde(skip)]
pub gitmoji_override: Option<bool>,
}
fn default_true() -> bool {
true
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_true(val: &bool) -> bool {
*val
}
fn default_preset() -> String {
"default".to_string()
}
fn is_default_preset(val: &str) -> bool {
val.is_empty() || val == "default"
}
fn default_subagent_timeout() -> u64 {
120 }
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_default_subagent_timeout(val: &u64) -> bool {
*val == 120
}
impl Default for Config {
fn default() -> Self {
let mut providers = HashMap::new();
for provider in Provider::ALL {
providers.insert(
provider.name().to_string(),
ProviderConfig::with_defaults(*provider),
);
}
Self {
default_provider: Provider::default().name().to_string(),
providers,
use_gitmoji: true,
instructions: String::new(),
instruction_preset: default_preset(),
theme: String::new(),
subagent_timeout_secs: default_subagent_timeout(),
temp_instructions: None,
temp_preset: None,
is_project_config: false,
gitmoji_override: None,
}
}
}
impl Config {
pub fn load() -> Result<Self> {
let config_path = Self::get_personal_config_path()?;
let mut config = if config_path.exists() {
let content = fs::read_to_string(&config_path)?;
let parsed: Self = toml::from_str(&content)?;
let (migrated, needs_save) = Self::migrate_if_needed(parsed);
if needs_save && let Err(e) = migrated.save() {
log_debug!("Failed to save migrated config: {}", e);
}
migrated
} else {
Self::default()
};
if let Ok((project_config, project_source)) = Self::load_project_config_with_source() {
config.merge_loaded_project_config(project_config, &project_source);
}
log_debug!(
"Configuration loaded (provider: {}, gitmoji: {})",
config.default_provider,
config.use_gitmoji
);
Ok(config)
}
pub fn load_project_config() -> Result<Self> {
let (config, _) = Self::load_project_config_with_source()?;
Ok(config)
}
fn load_project_config_with_source() -> Result<(Self, toml::Value)> {
let config_path = Self::get_project_config_path()?;
if !config_path.exists() {
return Err(anyhow!("Project configuration file not found"));
}
let content = fs::read_to_string(&config_path)
.with_context(|| format!("Failed to read {}", config_path.display()))?;
let project_source = toml::from_str(&content).with_context(|| {
format!(
"Invalid {} format. Check for syntax errors.",
PROJECT_CONFIG_FILENAME
)
})?;
let mut config: Self = toml::from_str(&content).with_context(|| {
format!(
"Invalid {} format. Check for syntax errors.",
PROJECT_CONFIG_FILENAME
)
})?;
config.is_project_config = true;
Ok((config, project_source))
}
pub fn get_project_config_path() -> Result<PathBuf> {
let repo_root = GitRepo::get_repo_root()?;
Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
}
pub fn merge_with_project_config(&mut self, project_config: Self) {
log_debug!("Merging with project configuration");
if !project_config.default_provider.is_empty()
&& project_config.default_provider != Provider::default().name()
{
self.default_provider = project_config.default_provider;
}
for (provider_name, proj_config) in project_config.providers {
let entry = self.providers.entry(provider_name).or_default();
if !proj_config.model.is_empty() {
entry.model = proj_config.model;
}
if proj_config.fast_model.is_some() {
entry.fast_model = proj_config.fast_model;
}
if proj_config.token_limit.is_some() {
entry.token_limit = proj_config.token_limit;
}
entry
.additional_params
.extend(proj_config.additional_params);
}
self.use_gitmoji = project_config.use_gitmoji;
self.instructions = project_config.instructions;
if project_config.instruction_preset != default_preset() {
self.instruction_preset = project_config.instruction_preset;
}
if !project_config.theme.is_empty() {
self.theme = project_config.theme;
}
if project_config.subagent_timeout_secs != default_subagent_timeout() {
self.subagent_timeout_secs = project_config.subagent_timeout_secs;
}
}
fn merge_loaded_project_config(&mut self, project_config: Self, project_source: &toml::Value) {
log_debug!("Merging loaded project configuration with explicit field tracking");
self.merge_project_provider_config(&project_config);
if Self::project_config_has_key(project_source, "default_provider") {
self.default_provider = project_config.default_provider;
}
if Self::project_config_has_key(project_source, "use_gitmoji") {
self.use_gitmoji = project_config.use_gitmoji;
}
if Self::project_config_has_key(project_source, "instructions") {
self.instructions = project_config.instructions;
}
if Self::project_config_has_key(project_source, "instruction_preset") {
self.instruction_preset = project_config.instruction_preset;
}
if Self::project_config_has_key(project_source, "theme") {
self.theme = project_config.theme;
}
if Self::project_config_has_key(project_source, "subagent_timeout_secs") {
self.subagent_timeout_secs = project_config.subagent_timeout_secs;
}
}
fn merge_project_provider_config(&mut self, project_config: &Self) {
for (provider_name, proj_config) in &project_config.providers {
let entry = self.providers.entry(provider_name.clone()).or_default();
if !proj_config.model.is_empty() {
proj_config.model.clone_into(&mut entry.model);
}
if proj_config.fast_model.is_some() {
entry.fast_model.clone_from(&proj_config.fast_model);
}
if proj_config.token_limit.is_some() {
entry.token_limit = proj_config.token_limit;
}
entry
.additional_params
.extend(proj_config.additional_params.clone());
}
}
fn project_config_has_key(project_source: &toml::Value, key: &str) -> bool {
project_source
.as_table()
.is_some_and(|table| table.contains_key(key))
}
fn migrate_if_needed(mut config: Self) -> (Self, bool) {
let mut migrated = false;
for (legacy, canonical) in [("claude", "anthropic"), ("gemini", "google")] {
if let Some(legacy_config) = config.providers.remove(legacy) {
log_debug!("Migrating '{legacy}' provider to '{canonical}'");
if config.providers.contains_key(canonical) {
log_debug!(
"Keeping existing '{canonical}' config and dropping legacy '{legacy}' entry"
);
} else {
config
.providers
.insert(canonical.to_string(), legacy_config);
}
migrated = true;
}
if config.default_provider.eq_ignore_ascii_case(legacy) {
config.default_provider = canonical.to_string();
migrated = true;
}
}
(config, migrated)
}
pub fn save(&self) -> Result<()> {
if self.is_project_config {
return Ok(());
}
let config_path = Self::get_personal_config_path()?;
let content = toml::to_string_pretty(self)?;
Self::write_config_file(&config_path, &content)?;
log_debug!("Configuration saved");
Ok(())
}
pub fn save_as_project_config(&self) -> Result<()> {
let config_path = Self::get_project_config_path()?;
let mut project_config = self.clone();
project_config.is_project_config = true;
for provider_config in project_config.providers.values_mut() {
provider_config.api_key.clear();
}
let content = toml::to_string_pretty(&project_config)?;
Self::write_config_file(&config_path, &content)?;
Ok(())
}
fn write_config_file(path: &Path, content: &str) -> Result<()> {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let tmp_path = path.with_extension("tmp");
fs::write(&tmp_path, content)?;
if let Err(e) = fs::set_permissions(&tmp_path, fs::Permissions::from_mode(0o600)) {
eprintln!(
"Warning: Could not restrict config permissions on {}: {e}",
tmp_path.display()
);
}
fs::rename(&tmp_path, path)?;
}
#[cfg(not(unix))]
{
fs::write(path, content)?;
}
Ok(())
}
fn resolve_personal_config_dir(
xdg_config_home: Option<PathBuf>,
home_dir: Option<PathBuf>,
platform_config_dir: Option<PathBuf>,
legacy_macos_config_exists: bool,
) -> Result<PathBuf> {
if let Some(xdg) = xdg_config_home.filter(|path| !path.as_os_str().is_empty()) {
return Ok(xdg.join("git-iris"));
}
if legacy_macos_config_exists && let Some(platform) = platform_config_dir.clone() {
return Ok(platform.join("git-iris"));
}
if let Some(home) = home_dir {
return Ok(home.join(".config").join("git-iris"));
}
platform_config_dir
.map(|p| p.join("git-iris"))
.ok_or_else(|| anyhow!("Unable to determine config directory"))
}
pub fn get_personal_config_path() -> Result<PathBuf> {
let platform_dir = config_dir();
let legacy_macos_config_exists = cfg!(target_os = "macos")
&& platform_dir
.as_ref()
.is_some_and(|dir| dir.join("git-iris").join("config.toml").exists());
let mut path = Self::resolve_personal_config_dir(
std::env::var_os("XDG_CONFIG_HOME").map(PathBuf::from),
home_dir(),
platform_dir,
legacy_macos_config_exists,
)?;
fs::create_dir_all(&path)?;
path.push("config.toml");
Ok(path)
}
pub fn check_environment(&self) -> Result<()> {
if !GitRepo::is_inside_work_tree()? {
return Err(anyhow!(
"Not in a Git repository. Please run this command from within a Git repository."
));
}
Ok(())
}
pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
self.temp_instructions = instructions;
}
pub fn set_temp_preset(&mut self, preset: Option<String>) {
self.temp_preset = preset;
}
#[must_use]
pub fn get_effective_preset_name(&self) -> &str {
self.temp_preset
.as_deref()
.unwrap_or(&self.instruction_preset)
}
#[must_use]
pub fn get_effective_instructions(&self) -> String {
let preset_library = get_instruction_preset_library();
let preset_instructions = self
.temp_preset
.as_ref()
.or(Some(&self.instruction_preset))
.and_then(|p| preset_library.get_preset(p))
.map(|p| p.instructions.clone())
.unwrap_or_default();
let custom = self
.temp_instructions
.as_ref()
.unwrap_or(&self.instructions);
format!("{preset_instructions}\n\n{custom}")
.trim()
.to_string()
}
#[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
pub fn update(
&mut self,
provider: Option<String>,
api_key: Option<String>,
model: Option<String>,
fast_model: Option<String>,
additional_params: Option<HashMap<String, String>>,
use_gitmoji: Option<bool>,
instructions: Option<String>,
token_limit: Option<usize>,
) -> Result<()> {
if let Some(ref provider_name) = provider {
let parsed: Provider = provider_name.parse().with_context(|| {
format!(
"Unknown provider '{}'. Supported: {}",
provider_name,
Provider::all_names().join(", ")
)
})?;
self.default_provider = parsed.name().to_string();
if !self.providers.contains_key(parsed.name()) {
self.providers.insert(
parsed.name().to_string(),
ProviderConfig::with_defaults(parsed),
);
}
}
let provider_config = self
.providers
.get_mut(&self.default_provider)
.context("Could not get default provider config")?;
if let Some(key) = api_key {
provider_config.api_key = key;
}
if let Some(m) = model {
provider_config.model = m;
}
if let Some(fm) = fast_model {
provider_config.fast_model = Some(fm);
}
if let Some(params) = additional_params {
provider_config.additional_params.extend(params);
}
if let Some(gitmoji) = use_gitmoji {
self.use_gitmoji = gitmoji;
}
if let Some(instr) = instructions {
self.instructions = instr;
}
if let Some(limit) = token_limit {
provider_config.token_limit = Some(limit);
}
log_debug!("Configuration updated");
Ok(())
}
#[must_use]
pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
let name = if provider.eq_ignore_ascii_case("claude") {
"anthropic"
} else if provider.eq_ignore_ascii_case("gemini") {
"google"
} else {
provider
};
self.providers
.get(name)
.or_else(|| self.providers.get(&name.to_lowercase()))
}
#[must_use]
pub fn provider(&self) -> Option<Provider> {
self.default_provider.parse().ok()
}
pub fn validate(&self) -> Result<()> {
let provider: Provider = self
.default_provider
.parse()
.with_context(|| format!("Invalid provider: {}", self.default_provider))?;
let config = self
.get_provider_config(provider.name())
.ok_or_else(|| anyhow!("No configuration found for provider: {}", provider.name()))?;
if !config.has_api_key() {
if std::env::var(provider.api_key_env()).is_err() {
return Err(anyhow!(
"API key required for {}. Set {} or configure in ~/.config/git-iris/config.toml",
provider.name(),
provider.api_key_env()
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests;