use serde::Deserialize;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Deserialize, Default, Debug)]
#[serde(deny_unknown_fields)]
pub struct Config {
#[serde(default)]
pub sqrust: SqrustConfig,
#[serde(default)]
pub rules: RulesConfig,
}
#[derive(Deserialize, Default, Debug)]
#[serde(deny_unknown_fields)]
pub struct SqrustConfig {
pub dialect: Option<String>,
#[serde(default)]
pub include: Vec<String>,
#[serde(default)]
pub exclude: Vec<String>,
}
#[derive(Deserialize, Default, Debug)]
#[serde(deny_unknown_fields)]
pub struct RulesConfig {
#[serde(default)]
pub disable: Vec<String>,
}
impl Config {
pub fn load(start: &Path) -> Result<Self, String> {
if let Some(path) = find_config(start) {
let content = fs::read_to_string(&path)
.map_err(|e| format!("Cannot read {}: {}", path.display(), e))?;
toml::from_str(&content)
.map_err(|e| format!("Invalid sqrust.toml: {}", e))
} else {
Ok(Config::default())
}
}
pub fn rule_enabled(&self, name: &str) -> bool {
!self.rules.disable.iter().any(|d| d == name)
}
}
fn find_config(start: &Path) -> Option<PathBuf> {
let mut dir = if start.is_file() {
start.parent()?.to_path_buf()
} else {
start.to_path_buf()
};
loop {
let candidate = dir.join("sqrust.toml");
if candidate.exists() {
return Some(candidate);
}
match dir.parent() {
Some(parent) => dir = parent.to_path_buf(),
None => return None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse(toml: &str) -> Config {
toml::from_str(toml).expect("valid toml")
}
#[test]
fn empty_config_is_default() {
let cfg = parse("");
assert!(cfg.rules.disable.is_empty());
assert!(cfg.sqrust.exclude.is_empty());
}
#[test]
fn disable_list_parsed() {
let cfg = parse(r#"
[rules]
disable = ["Convention/SelectStar", "Layout/LongLines"]
"#);
assert_eq!(cfg.rules.disable.len(), 2);
assert!(cfg.rules.disable.contains(&"Convention/SelectStar".to_string()));
}
#[test]
fn rule_enabled_respects_disable() {
let cfg = parse(r#"
[rules]
disable = ["Convention/SelectStar"]
"#);
assert!(!cfg.rule_enabled("Convention/SelectStar"));
assert!(cfg.rule_enabled("Layout/LongLines"));
}
#[test]
fn exclude_patterns_parsed() {
let cfg = parse(r#"
[sqrust]
exclude = ["dbt_packages/**", "target/**"]
"#);
assert_eq!(cfg.sqrust.exclude.len(), 2);
}
#[test]
fn dialect_parsed() {
let cfg = parse(r#"
[sqrust]
dialect = "bigquery"
"#);
assert_eq!(cfg.sqrust.dialect.as_deref(), Some("bigquery"));
}
#[test]
fn unknown_field_rejected() {
let result: Result<Config, _> = toml::from_str(r#"
[rules]
select = ["Convention"]
"#);
assert!(result.is_err(), "select is not yet supported and should be rejected");
}
}