Skip to main content

rigsql_config/
lib.rs

1use std::collections::HashMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum ConfigError {
9    #[error("Failed to read config file {path}: {source}")]
10    ReadError {
11        path: PathBuf,
12        source: std::io::Error,
13    },
14}
15
16/// Parsed rigsql / sqlfluff configuration.
17#[derive(Debug, Clone, Default)]
18pub struct Config {
19    /// SQL dialect name (e.g. "ansi", "tsql", "postgres").
20    pub dialect: Option<String>,
21    /// Locale for output messages (e.g. "en", "ja").
22    pub locale: Option<String>,
23    /// Maximum line length for LT05.
24    pub max_line_length: Option<usize>,
25    /// Exclude rules (comma-separated codes).
26    pub exclude_rules: Vec<String>,
27    /// Per-rule settings: rule_name -> key -> value.
28    pub rules: HashMap<String, HashMap<String, String>>,
29}
30
31/// Which config file was found.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33enum ConfigKind {
34    RigsqlToml,
35    Sqlfluff,
36}
37
38impl Config {
39    /// Load config by searching upward from the given file/directory path.
40    ///
41    /// Priority: `rigsql.toml` > `.sqlfluff`.
42    /// At each directory level, if `rigsql.toml` exists it is used; otherwise `.sqlfluff`.
43    /// Files are merged bottom-up (closest file wins).
44    pub fn load_for_path(path: &Path) -> Self {
45        let search_dir = if path.is_file() {
46            path.parent().unwrap_or(path)
47        } else {
48            path
49        };
50
51        let mut config_files: Vec<(PathBuf, ConfigKind)> = Vec::new();
52        let mut dir = Some(search_dir);
53        while let Some(d) = dir {
54            if let Some(found) = find_config_in_dir(d) {
55                config_files.push(found);
56            }
57            dir = d.parent();
58        }
59
60        // Also check home directory (if not already found via traversal)
61        if let Some(home) = dirs_home() {
62            if !config_files.iter().any(|(p, _)| p.parent() == Some(&home)) {
63                if let Some(found) = find_config_in_dir(&home) {
64                    config_files.push(found);
65                }
66            }
67        }
68
69        // Reverse so that furthest (most general) is first, closest (most specific) last
70        config_files.reverse();
71
72        let mut config = Config::default();
73        for (path, kind) in &config_files {
74            let parsed = match kind {
75                ConfigKind::RigsqlToml => parse_rigsql_toml(path),
76                ConfigKind::Sqlfluff => parse_sqlfluff_file(path),
77            };
78            if let Ok(file_config) = parsed {
79                config.merge(file_config);
80            }
81        }
82
83        config
84    }
85
86    /// Merge another config into this one. `other` takes precedence.
87    fn merge(&mut self, other: Config) {
88        if other.dialect.is_some() {
89            self.dialect = other.dialect;
90        }
91        if other.locale.is_some() {
92            self.locale = other.locale;
93        }
94        if other.max_line_length.is_some() {
95            self.max_line_length = other.max_line_length;
96        }
97        if !other.exclude_rules.is_empty() {
98            self.exclude_rules = other.exclude_rules;
99        }
100        for (rule_name, settings) in other.rules {
101            let entry = self.rules.entry(rule_name).or_default();
102            for (k, v) in settings {
103                entry.insert(k, v);
104            }
105        }
106    }
107
108    /// Get a rule-specific setting by rule name (e.g. "capitalisation.keywords") and key.
109    pub fn rule_setting(&self, rule_name: &str, key: &str) -> Option<&str> {
110        self.rules
111            .get(rule_name)
112            .and_then(|m| m.get(key))
113            .map(|s| s.as_str())
114    }
115}
116
117/// Check for rigsql.toml or .sqlfluff in a directory (rigsql.toml takes priority).
118fn find_config_in_dir(dir: &Path) -> Option<(PathBuf, ConfigKind)> {
119    let toml_path = dir.join("rigsql.toml");
120    if toml_path.is_file() {
121        return Some((toml_path, ConfigKind::RigsqlToml));
122    }
123    let sqlfluff_path = dir.join(".sqlfluff");
124    if sqlfluff_path.is_file() {
125        return Some((sqlfluff_path, ConfigKind::Sqlfluff));
126    }
127    None
128}
129
130/// Read a config file's content, mapping IO errors to ConfigError.
131fn read_config_file(path: &Path) -> Result<String, ConfigError> {
132    fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
133        path: path.to_path_buf(),
134        source: e,
135    })
136}
137
138// ── rigsql.toml parser ──────────────────────────────────────────────────
139
140/// Parse a `rigsql.toml` configuration file.
141///
142/// Expected format:
143/// ```toml
144/// [core]
145/// dialect = "tsql"
146/// max_line_length = 120
147/// exclude_rules = ["LT09", "CV06"]
148///
149/// [rules."capitalisation.keywords"]
150/// capitalisation_policy = "lower"
151/// ```
152fn parse_rigsql_toml(path: &Path) -> Result<Config, ConfigError> {
153    let content = read_config_file(path)?;
154
155    let table: toml::Table = match content.parse() {
156        Ok(t) => t,
157        Err(e) => {
158            eprintln!("Warning: failed to parse {}: {e}", path.display());
159            return Ok(Config::default());
160        }
161    };
162
163    let mut config = Config::default();
164
165    // [core] section
166    if let Some(core) = table.get("core").and_then(|v| v.as_table()) {
167        if let Some(dialect) = core.get("dialect").and_then(|v| v.as_str()) {
168            config.dialect = Some(dialect.to_string());
169        }
170        if let Some(locale) = core.get("locale").and_then(|v| v.as_str()) {
171            config.locale = Some(locale.to_string());
172        }
173        if let Some(len) = core.get("max_line_length").and_then(|v| v.as_integer()) {
174            config.max_line_length = Some(len as usize);
175        }
176        if let Some(arr) = core.get("exclude_rules").and_then(|v| v.as_array()) {
177            config.exclude_rules = arr
178                .iter()
179                .filter_map(|v| v.as_str())
180                .map(|s| s.to_string())
181                .collect();
182        }
183    }
184
185    // [rules.*] sections
186    if let Some(rules) = table.get("rules").and_then(|v| v.as_table()) {
187        for (rule_name, rule_value) in rules {
188            if let Some(rule_table) = rule_value.as_table() {
189                let mut settings = HashMap::new();
190                for (k, v) in rule_table {
191                    let val = match v {
192                        toml::Value::String(s) => s.clone(),
193                        toml::Value::Integer(i) => i.to_string(),
194                        toml::Value::Float(f) => f.to_string(),
195                        toml::Value::Boolean(b) => b.to_string(),
196                        _ => continue,
197                    };
198                    settings.insert(k.clone(), val);
199                }
200                if !settings.is_empty() {
201                    config.rules.insert(rule_name.clone(), settings);
202                }
203            }
204        }
205    }
206
207    Ok(config)
208}
209
210// ── .sqlfluff INI parser ────────────────────────────────────────────────
211
212/// Parse a .sqlfluff INI-style config file.
213fn parse_sqlfluff_file(path: &Path) -> Result<Config, ConfigError> {
214    let content = read_config_file(path)?;
215
216    let mut config = Config::default();
217    let mut current_section = String::new();
218
219    for line in content.lines() {
220        let line = line.trim();
221
222        // Skip empty lines and comments
223        if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
224            continue;
225        }
226
227        // Section header
228        if line.starts_with('[') && line.ends_with(']') {
229            current_section = line[1..line.len() - 1].trim().to_string();
230            continue;
231        }
232
233        // Key = value
234        if let Some((key, value)) = line.split_once('=') {
235            let key = key.trim().to_lowercase();
236            let value = value.trim().to_string();
237
238            match current_section.as_str() {
239                "sqlfluff" => match key.as_str() {
240                    "dialect" => config.dialect = Some(value),
241                    "locale" => config.locale = Some(value),
242                    "max_line_length" => {
243                        config.max_line_length = value.parse().ok();
244                    }
245                    "exclude_rules" => {
246                        config.exclude_rules = value
247                            .split(',')
248                            .map(|s| s.trim().to_string())
249                            .filter(|s| !s.is_empty())
250                            .collect();
251                    }
252                    _ => {}
253                },
254                section if section.starts_with("sqlfluff:rules:") => {
255                    let rule_name = section.strip_prefix("sqlfluff:rules:").unwrap();
256                    config
257                        .rules
258                        .entry(rule_name.to_string())
259                        .or_default()
260                        .insert(key, value);
261                }
262                _ => {}
263            }
264        }
265    }
266
267    Ok(config)
268}
269
270fn dirs_home() -> Option<PathBuf> {
271    std::env::var_os("HOME").map(PathBuf::from)
272}
273
274/// Filter out violations on lines that have `-- noqa` comments.
275pub fn filter_noqa(source: &str, violations: &mut Vec<rigsql_rules::LintViolation>) {
276    if violations.is_empty() {
277        return;
278    }
279
280    // Build a map of line_number -> noqa spec
281    let noqa_lines: HashMap<usize, NoqaSpec> = source
282        .lines()
283        .enumerate()
284        .filter_map(|(i, line)| parse_noqa_comment(line).map(|spec| (i + 1, spec)))
285        .collect();
286
287    if noqa_lines.is_empty() {
288        return;
289    }
290
291    violations.retain(|v| {
292        let (line, _) = v.line_col(source);
293        match noqa_lines.get(&line) {
294            None => true,
295            Some(NoqaSpec::All) => false,
296            Some(NoqaSpec::Rules(codes)) => !codes.iter().any(|c| c == v.rule_code),
297        }
298    });
299}
300
301#[derive(Debug)]
302enum NoqaSpec {
303    /// `-- noqa` — suppress all rules on this line.
304    All,
305    /// `-- noqa: CP01,LT01` — suppress specific rules.
306    Rules(Vec<String>),
307}
308
309/// Parse a noqa comment from a source line.
310fn parse_noqa_comment(line: &str) -> Option<NoqaSpec> {
311    // Case-insensitive search without allocating a new string
312    let bytes = line.as_bytes();
313    let pattern = b"-- noqa";
314    let idx = bytes
315        .windows(pattern.len())
316        .position(|w| w.eq_ignore_ascii_case(pattern))?;
317    let after = line[idx + 7..].trim_start();
318
319    if after.is_empty() || after.starts_with("--") {
320        return Some(NoqaSpec::All);
321    }
322
323    if let Some(rest) = after.strip_prefix(':') {
324        let codes: Vec<String> = rest
325            .split(',')
326            .map(|s| s.trim().to_uppercase())
327            .filter(|s| !s.is_empty())
328            .collect();
329        if codes.is_empty() {
330            Some(NoqaSpec::All)
331        } else {
332            Some(NoqaSpec::Rules(codes))
333        }
334    } else {
335        Some(NoqaSpec::All)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_parse_noqa_all() {
345        assert!(matches!(
346            parse_noqa_comment("SELECT 1 -- noqa"),
347            Some(NoqaSpec::All)
348        ));
349    }
350
351    #[test]
352    fn test_parse_noqa_specific() {
353        match parse_noqa_comment("SELECT 1 -- noqa: CP01, LT01") {
354            Some(NoqaSpec::Rules(codes)) => {
355                assert_eq!(codes, vec!["CP01", "LT01"]);
356            }
357            _ => panic!("Expected NoqaSpec::Rules"),
358        }
359    }
360
361    #[test]
362    fn test_parse_noqa_none() {
363        assert!(parse_noqa_comment("SELECT 1").is_none());
364    }
365
366    #[test]
367    fn test_parse_sqlfluff_config() {
368        let content = "\
369[sqlfluff]
370dialect = tsql
371max_line_length = 120
372
373[sqlfluff:rules:capitalisation.keywords]
374capitalisation_policy = lower
375";
376        let dir = std::env::temp_dir().join("rigsql_test_sqlfluff_config");
377        let _ = fs::create_dir_all(&dir);
378        let path = dir.join(".sqlfluff");
379        fs::write(&path, content).unwrap();
380
381        let config = parse_sqlfluff_file(&path).unwrap();
382        assert_eq!(config.dialect.as_deref(), Some("tsql"));
383        assert_eq!(config.max_line_length, Some(120));
384        assert_eq!(
385            config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
386            Some("lower")
387        );
388
389        let _ = fs::remove_dir_all(&dir);
390    }
391
392    #[test]
393    fn test_parse_rigsql_toml() {
394        let content = r#"
395[core]
396dialect = "tsql"
397max_line_length = 120
398exclude_rules = ["LT09", "CV06"]
399
400[rules."capitalisation.keywords"]
401capitalisation_policy = "lower"
402"#;
403        let dir = std::env::temp_dir().join("rigsql_test_toml_config");
404        let _ = fs::create_dir_all(&dir);
405        let path = dir.join("rigsql.toml");
406        fs::write(&path, content).unwrap();
407
408        let config = parse_rigsql_toml(&path).unwrap();
409        assert_eq!(config.dialect.as_deref(), Some("tsql"));
410        assert_eq!(config.max_line_length, Some(120));
411        assert_eq!(config.exclude_rules, vec!["LT09", "CV06"]);
412        assert_eq!(
413            config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
414            Some("lower")
415        );
416
417        let _ = fs::remove_dir_all(&dir);
418    }
419
420    #[test]
421    fn test_rigsql_toml_priority_over_sqlfluff() {
422        let dir = std::env::temp_dir().join("rigsql_test_priority");
423        let _ = fs::create_dir_all(&dir);
424
425        // Write both config files
426        fs::write(
427            dir.join(".sqlfluff"),
428            "[sqlfluff]\ndialect = postgres\nmax_line_length = 80\n",
429        )
430        .unwrap();
431        fs::write(
432            dir.join("rigsql.toml"),
433            "[core]\ndialect = \"tsql\"\nmax_line_length = 120\n",
434        )
435        .unwrap();
436
437        let config = Config::load_for_path(&dir);
438        // rigsql.toml should win
439        assert_eq!(config.dialect.as_deref(), Some("tsql"));
440        assert_eq!(config.max_line_length, Some(120));
441
442        let _ = fs::remove_dir_all(&dir);
443    }
444}