use std::collections::HashMap;
use camino::{Utf8Path, Utf8PathBuf};
use figment::Figment;
use figment::providers::{Env, Format, Json, Serialized, Toml, Yaml};
use serde::{Deserialize, Serialize};
use crate::error::{ConfigError, ConfigResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
pub enum Dialect {
#[cfg_attr(feature = "clap", value(name = "en-us"))]
EnUs,
#[cfg_attr(feature = "clap", value(name = "en-gb"))]
EnGb,
#[cfg_attr(feature = "clap", value(name = "en-ca"))]
EnCa,
#[cfg_attr(feature = "clap", value(name = "en-au"))]
EnAu,
}
impl Dialect {
pub const fn as_str(&self) -> &'static str {
match self {
Self::EnUs => "en-us",
Self::EnGb => "en-gb",
Self::EnCa => "en-ca",
Self::EnAu => "en-au",
}
}
}
impl std::fmt::Display for Dialect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
#[serde(default)]
pub struct Config {
pub log_level: LogLevel,
pub log_dir: Option<Utf8PathBuf>,
pub token_budget: Option<usize>,
pub max_grade: Option<f64>,
pub passive_max_percent: Option<f64>,
pub style_min_score: Option<i32>,
pub dialect: Option<Dialect>,
pub templates: Option<HashMap<String, Vec<String>>>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
Debug,
#[default]
Info,
Warn,
Error,
}
impl LogLevel {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Debug => "debug",
Self::Info => "info",
Self::Warn => "warn",
Self::Error => "error",
}
}
}
const CONFIG_EXTENSIONS: &[&str] = &["toml", "yaml", "yml", "json"];
const APP_NAME: &str = "bito-lint";
#[derive(Debug, Default)]
pub struct ConfigLoader {
project_search_root: Option<Utf8PathBuf>,
include_user_config: bool,
boundary_marker: Option<String>,
explicit_files: Vec<Utf8PathBuf>,
}
impl ConfigLoader {
pub fn new() -> Self {
Self {
project_search_root: None,
include_user_config: true,
boundary_marker: Some(".git".to_string()),
explicit_files: Vec::new(),
}
}
pub fn with_project_search<P: AsRef<Utf8Path>>(mut self, path: P) -> Self {
self.project_search_root = Some(path.as_ref().to_path_buf());
self
}
pub const fn with_user_config(mut self, include: bool) -> Self {
self.include_user_config = include;
self
}
pub fn with_boundary_marker<S: Into<String>>(mut self, marker: S) -> Self {
self.boundary_marker = Some(marker.into());
self
}
pub fn without_boundary_marker(mut self) -> Self {
self.boundary_marker = None;
self
}
pub fn with_file<P: AsRef<Utf8Path>>(mut self, path: P) -> Self {
self.explicit_files.push(path.as_ref().to_path_buf());
self
}
#[tracing::instrument(skip(self), fields(search_root = ?self.project_search_root))]
pub fn load(self) -> ConfigResult<Config> {
tracing::debug!("loading configuration");
let mut figment = Figment::new().merge(Serialized::defaults(Config::default()));
if self.include_user_config
&& let Some(user_config) = self.find_user_config()
{
figment = Self::merge_file(figment, &user_config);
}
if let Some(ref root) = self.project_search_root
&& let Some(project_config) = self.find_project_config(root)
{
figment = Self::merge_file(figment, &project_config);
}
for file in &self.explicit_files {
figment = Self::merge_file(figment, file);
}
figment = figment.merge(Env::prefixed("BITO_LINT_").lowercase(true));
let config: Config = figment
.extract()
.map_err(|e| ConfigError::Deserialize(Box::new(e)))?;
tracing::info!(
log_level = config.log_level.as_str(),
"configuration loaded"
);
Ok(config)
}
pub fn load_or_error(self) -> ConfigResult<Config> {
let has_user = self.include_user_config && self.find_user_config().is_some();
let has_project = self
.project_search_root
.as_ref()
.and_then(|root| self.find_project_config(root))
.is_some();
let has_explicit = !self.explicit_files.is_empty();
if !has_user && !has_project && !has_explicit {
return Err(ConfigError::NotFound);
}
self.load()
}
fn find_project_config(&self, start: &Utf8Path) -> Option<Utf8PathBuf> {
let mut current = Some(start.to_path_buf());
while let Some(dir) = current {
if let Some(ref marker) = self.boundary_marker {
let marker_path = dir.join(marker);
if marker_path.exists() && dir != start {
break;
}
}
for ext in CONFIG_EXTENSIONS {
let dotfile = dir.join(format!(".{APP_NAME}.{ext}"));
if dotfile.is_file() {
return Some(dotfile);
}
let regular = dir.join(format!("{APP_NAME}.{ext}"));
if regular.is_file() {
return Some(regular);
}
}
current = dir.parent().map(Utf8Path::to_path_buf);
}
None
}
fn find_user_config(&self) -> Option<Utf8PathBuf> {
let proj_dirs = directories::ProjectDirs::from("", "", APP_NAME)?;
let config_dir = proj_dirs.config_dir();
for ext in CONFIG_EXTENSIONS {
let config_path = config_dir.join(format!("config.{ext}"));
if config_path.is_file() {
return Utf8PathBuf::from_path_buf(config_path).ok();
}
}
None
}
fn merge_file(figment: Figment, path: &Utf8Path) -> Figment {
match path.extension() {
Some("toml") => figment.merge(Toml::file_exact(path.as_str())),
Some("yaml" | "yml") => figment.merge(Yaml::file_exact(path.as_str())),
Some("json") => figment.merge(Json::file_exact(path.as_str())),
_ => figment.merge(Toml::file_exact(path.as_str())),
}
}
}
pub fn find_project_config<P: AsRef<Utf8Path>>(start: P) -> Option<Utf8PathBuf> {
ConfigLoader::new()
.with_project_search(start.as_ref())
.without_boundary_marker()
.find_project_config(start.as_ref())
}
fn project_dirs() -> Option<directories::ProjectDirs> {
directories::ProjectDirs::from("", "", APP_NAME)
}
pub fn user_config_dir() -> Option<Utf8PathBuf> {
let proj_dirs = project_dirs()?;
Utf8PathBuf::from_path_buf(proj_dirs.config_dir().to_path_buf()).ok()
}
pub fn user_cache_dir() -> Option<Utf8PathBuf> {
let proj_dirs = project_dirs()?;
Utf8PathBuf::from_path_buf(proj_dirs.cache_dir().to_path_buf()).ok()
}
pub fn user_data_dir() -> Option<Utf8PathBuf> {
let proj_dirs = project_dirs()?;
Utf8PathBuf::from_path_buf(proj_dirs.data_dir().to_path_buf()).ok()
}
pub fn user_data_local_dir() -> Option<Utf8PathBuf> {
let proj_dirs = project_dirs()?;
Utf8PathBuf::from_path_buf(proj_dirs.data_local_dir().to_path_buf()).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.log_level, LogLevel::Info);
assert!(config.log_dir.is_none());
}
#[test]
fn test_loader_builds_with_defaults() {
let loader = ConfigLoader::new()
.with_user_config(false)
.without_boundary_marker();
let config = loader.load().unwrap();
assert_eq!(config.log_level, LogLevel::Info);
}
#[test]
fn test_single_file_overrides_default() {
let tmp = TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
fs::write(
&config_path,
r#"log_level = "debug"
log_dir = "/tmp/bito-lint"
"#,
)
.unwrap();
let config_path = Utf8PathBuf::try_from(config_path).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.with_file(&config_path)
.load()
.unwrap();
assert_eq!(config.log_level, LogLevel::Debug);
assert_eq!(
config.log_dir.as_ref().map(|dir| dir.as_str()),
Some("/tmp/bito-lint")
);
}
#[test]
fn test_later_file_overrides_earlier() {
let tmp = TempDir::new().unwrap();
let base_config = tmp.path().join("base.toml");
fs::write(&base_config, r#"log_level = "warn""#).unwrap();
let override_config = tmp.path().join("override.toml");
fs::write(&override_config, r#"log_level = "error""#).unwrap();
let base_config = Utf8PathBuf::try_from(base_config).unwrap();
let override_config = Utf8PathBuf::try_from(override_config).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.with_file(&base_config)
.with_file(&override_config)
.load()
.unwrap();
assert_eq!(config.log_level, LogLevel::Error);
}
#[test]
fn test_project_config_discovery() {
let tmp = TempDir::new().unwrap();
let project_dir = tmp.path().join("project");
let sub_dir = project_dir.join("src").join("deep");
fs::create_dir_all(&sub_dir).unwrap();
let config_path = project_dir.join(".bito-lint.toml");
fs::write(&config_path, r#"log_level = "debug""#).unwrap();
let sub_dir = Utf8PathBuf::try_from(sub_dir).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.without_boundary_marker()
.with_project_search(&sub_dir)
.load()
.unwrap();
assert_eq!(config.log_level, LogLevel::Debug);
}
#[test]
fn test_boundary_marker_stops_search() {
let tmp = TempDir::new().unwrap();
let parent = tmp.path().join("parent");
let child = parent.join("child");
let work = child.join("work");
fs::create_dir_all(&work).unwrap();
fs::write(parent.join(".bito-lint.toml"), r#"log_level = "warn""#).unwrap();
fs::create_dir(child.join(".git")).unwrap();
let work = Utf8PathBuf::try_from(work).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.with_boundary_marker(".git")
.with_project_search(&work)
.load()
.unwrap();
assert_eq!(config.log_level, LogLevel::Info);
}
#[test]
fn test_explicit_file_overrides_project_config() {
let tmp = TempDir::new().unwrap();
let project_config = tmp.path().join(".bito-lint.toml");
fs::write(&project_config, r#"log_level = "warn""#).unwrap();
let override_config = tmp.path().join("override.toml");
fs::write(&override_config, r#"log_level = "error""#).unwrap();
let tmp_path = Utf8PathBuf::try_from(tmp.path().to_path_buf()).unwrap();
let override_config = Utf8PathBuf::try_from(override_config).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.without_boundary_marker()
.with_project_search(&tmp_path)
.with_file(&override_config)
.load()
.unwrap();
assert_eq!(config.log_level, LogLevel::Error);
}
#[test]
fn test_load_or_error_fails_when_no_config() {
let result = ConfigLoader::new()
.with_user_config(false)
.without_boundary_marker()
.load_or_error();
assert!(matches!(result, Err(ConfigError::NotFound)));
}
#[test]
fn test_load_or_error_succeeds_with_explicit_file() {
let tmp = TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
fs::write(&config_path, r#"log_level = "debug""#).unwrap();
let config_path = Utf8PathBuf::try_from(config_path).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.with_file(&config_path)
.load_or_error()
.unwrap();
assert_eq!(config.log_level, LogLevel::Debug);
}
#[test]
fn test_user_config_dir() {
let dir = user_config_dir();
if let Some(path) = dir {
assert!(path.as_str().contains("bito-lint"));
}
}
#[test]
fn test_dialect_deserialization_toml() {
let tmp = TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
fs::write(&config_path, "dialect = \"en-gb\"\n").unwrap();
let config_path = Utf8PathBuf::try_from(config_path).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.with_file(&config_path)
.load()
.unwrap();
assert_eq!(config.dialect, Some(Dialect::EnGb));
}
#[test]
fn test_dialect_deserialization_all_variants() {
for (input, expected) in [
("en-us", Dialect::EnUs),
("en-gb", Dialect::EnGb),
("en-ca", Dialect::EnCa),
("en-au", Dialect::EnAu),
] {
let tmp = TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
fs::write(&config_path, format!("dialect = \"{input}\"\n")).unwrap();
let config_path = Utf8PathBuf::try_from(config_path).unwrap();
let config = ConfigLoader::new()
.with_user_config(false)
.with_file(&config_path)
.load()
.unwrap();
assert_eq!(config.dialect, Some(expected), "failed for {input}");
}
}
#[test]
fn test_dialect_default_is_none() {
let config = Config::default();
assert!(config.dialect.is_none());
}
#[test]
fn test_dialect_as_str() {
assert_eq!(Dialect::EnUs.as_str(), "en-us");
assert_eq!(Dialect::EnGb.as_str(), "en-gb");
assert_eq!(Dialect::EnCa.as_str(), "en-ca");
assert_eq!(Dialect::EnAu.as_str(), "en-au");
}
#[test]
#[allow(unsafe_code)]
fn test_env_var_override_dialect() {
unsafe {
std::env::set_var("BITO_LINT_DIALECT", "en-ca");
}
let config = ConfigLoader::new()
.with_user_config(false)
.without_boundary_marker()
.load()
.unwrap();
assert_eq!(config.dialect, Some(Dialect::EnCa));
unsafe {
std::env::remove_var("BITO_LINT_DIALECT");
}
}
#[test]
#[allow(unsafe_code)]
fn test_env_var_overrides_file_config() {
let tmp = TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
fs::write(&config_path, "dialect = \"en-us\"\n").unwrap();
let config_path = Utf8PathBuf::try_from(config_path).unwrap();
unsafe {
std::env::set_var("BITO_LINT_DIALECT", "en-au");
}
let config = ConfigLoader::new()
.with_user_config(false)
.with_file(&config_path)
.load()
.unwrap();
assert_eq!(config.dialect, Some(Dialect::EnAu));
unsafe {
std::env::remove_var("BITO_LINT_DIALECT");
}
}
}