Skip to main content

rigsql_config/
lib.rs

1use std::collections::HashMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum ConfigError {
9    #[error("Failed to read config file {path}: {source}")]
10    ReadError {
11        path: PathBuf,
12        source: std::io::Error,
13    },
14}
15
16/// Parsed rigsql / sqlfluff configuration.
17#[derive(Debug, Clone, Default)]
18pub struct Config {
19    /// SQL dialect name (e.g. "ansi", "tsql", "postgres").
20    pub dialect: Option<String>,
21    /// Maximum line length for LT05.
22    pub max_line_length: Option<usize>,
23    /// Exclude rules (comma-separated codes).
24    pub exclude_rules: Vec<String>,
25    /// Per-rule settings: rule_name -> key -> value.
26    pub rules: HashMap<String, HashMap<String, String>>,
27}
28
29/// Which config file was found.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31enum ConfigKind {
32    RigsqlToml,
33    Sqlfluff,
34}
35
36impl Config {
37    /// Load config by searching upward from the given file/directory path.
38    ///
39    /// Priority: `rigsql.toml` > `.sqlfluff`.
40    /// At each directory level, if `rigsql.toml` exists it is used; otherwise `.sqlfluff`.
41    /// Files are merged bottom-up (closest file wins).
42    pub fn load_for_path(path: &Path) -> Self {
43        let search_dir = if path.is_file() {
44            path.parent().unwrap_or(path)
45        } else {
46            path
47        };
48
49        let mut config_files: Vec<(PathBuf, ConfigKind)> = Vec::new();
50        let mut dir = Some(search_dir);
51        while let Some(d) = dir {
52            if let Some(found) = find_config_in_dir(d) {
53                config_files.push(found);
54            }
55            dir = d.parent();
56        }
57
58        // Also check home directory (if not already found via traversal)
59        if let Some(home) = dirs_home() {
60            if !config_files.iter().any(|(p, _)| p.parent() == Some(&home)) {
61                if let Some(found) = find_config_in_dir(&home) {
62                    config_files.push(found);
63                }
64            }
65        }
66
67        // Reverse so that furthest (most general) is first, closest (most specific) last
68        config_files.reverse();
69
70        let mut config = Config::default();
71        for (path, kind) in &config_files {
72            let parsed = match kind {
73                ConfigKind::RigsqlToml => parse_rigsql_toml(path),
74                ConfigKind::Sqlfluff => parse_sqlfluff_file(path),
75            };
76            if let Ok(file_config) = parsed {
77                config.merge(file_config);
78            }
79        }
80
81        config
82    }
83
84    /// Merge another config into this one. `other` takes precedence.
85    fn merge(&mut self, other: Config) {
86        if other.dialect.is_some() {
87            self.dialect = other.dialect;
88        }
89        if other.max_line_length.is_some() {
90            self.max_line_length = other.max_line_length;
91        }
92        if !other.exclude_rules.is_empty() {
93            self.exclude_rules = other.exclude_rules;
94        }
95        for (rule_name, settings) in other.rules {
96            let entry = self.rules.entry(rule_name).or_default();
97            for (k, v) in settings {
98                entry.insert(k, v);
99            }
100        }
101    }
102
103    /// Get a rule-specific setting by rule name (e.g. "capitalisation.keywords") and key.
104    pub fn rule_setting(&self, rule_name: &str, key: &str) -> Option<&str> {
105        self.rules
106            .get(rule_name)
107            .and_then(|m| m.get(key))
108            .map(|s| s.as_str())
109    }
110}
111
112/// Check for rigsql.toml or .sqlfluff in a directory (rigsql.toml takes priority).
113fn find_config_in_dir(dir: &Path) -> Option<(PathBuf, ConfigKind)> {
114    let toml_path = dir.join("rigsql.toml");
115    if toml_path.is_file() {
116        return Some((toml_path, ConfigKind::RigsqlToml));
117    }
118    let sqlfluff_path = dir.join(".sqlfluff");
119    if sqlfluff_path.is_file() {
120        return Some((sqlfluff_path, ConfigKind::Sqlfluff));
121    }
122    None
123}
124
125/// Read a config file's content, mapping IO errors to ConfigError.
126fn read_config_file(path: &Path) -> Result<String, ConfigError> {
127    fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
128        path: path.to_path_buf(),
129        source: e,
130    })
131}
132
133// ── rigsql.toml parser ──────────────────────────────────────────────────
134
135/// Parse a `rigsql.toml` configuration file.
136///
137/// Expected format:
138/// ```toml
139/// [core]
140/// dialect = "tsql"
141/// max_line_length = 120
142/// exclude_rules = ["LT09", "CV06"]
143///
144/// [rules."capitalisation.keywords"]
145/// capitalisation_policy = "lower"
146/// ```
147fn parse_rigsql_toml(path: &Path) -> Result<Config, ConfigError> {
148    let content = read_config_file(path)?;
149
150    let table: toml::Table = match content.parse() {
151        Ok(t) => t,
152        Err(e) => {
153            eprintln!("Warning: failed to parse {}: {e}", path.display());
154            return Ok(Config::default());
155        }
156    };
157
158    let mut config = Config::default();
159
160    // [core] section
161    if let Some(core) = table.get("core").and_then(|v| v.as_table()) {
162        if let Some(dialect) = core.get("dialect").and_then(|v| v.as_str()) {
163            config.dialect = Some(dialect.to_string());
164        }
165        if let Some(len) = core.get("max_line_length").and_then(|v| v.as_integer()) {
166            config.max_line_length = Some(len as usize);
167        }
168        if let Some(arr) = core.get("exclude_rules").and_then(|v| v.as_array()) {
169            config.exclude_rules = arr
170                .iter()
171                .filter_map(|v| v.as_str())
172                .map(|s| s.to_string())
173                .collect();
174        }
175    }
176
177    // [rules.*] sections
178    if let Some(rules) = table.get("rules").and_then(|v| v.as_table()) {
179        for (rule_name, rule_value) in rules {
180            if let Some(rule_table) = rule_value.as_table() {
181                let mut settings = HashMap::new();
182                for (k, v) in rule_table {
183                    let val = match v {
184                        toml::Value::String(s) => s.clone(),
185                        toml::Value::Integer(i) => i.to_string(),
186                        toml::Value::Float(f) => f.to_string(),
187                        toml::Value::Boolean(b) => b.to_string(),
188                        _ => continue,
189                    };
190                    settings.insert(k.clone(), val);
191                }
192                if !settings.is_empty() {
193                    config.rules.insert(rule_name.clone(), settings);
194                }
195            }
196        }
197    }
198
199    Ok(config)
200}
201
202// ── .sqlfluff INI parser ────────────────────────────────────────────────
203
204/// Parse a .sqlfluff INI-style config file.
205fn parse_sqlfluff_file(path: &Path) -> Result<Config, ConfigError> {
206    let content = read_config_file(path)?;
207
208    let mut config = Config::default();
209    let mut current_section = String::new();
210
211    for line in content.lines() {
212        let line = line.trim();
213
214        // Skip empty lines and comments
215        if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
216            continue;
217        }
218
219        // Section header
220        if line.starts_with('[') && line.ends_with(']') {
221            current_section = line[1..line.len() - 1].trim().to_string();
222            continue;
223        }
224
225        // Key = value
226        if let Some((key, value)) = line.split_once('=') {
227            let key = key.trim().to_lowercase();
228            let value = value.trim().to_string();
229
230            match current_section.as_str() {
231                "sqlfluff" => match key.as_str() {
232                    "dialect" => config.dialect = Some(value),
233                    "max_line_length" => {
234                        config.max_line_length = value.parse().ok();
235                    }
236                    "exclude_rules" => {
237                        config.exclude_rules = value
238                            .split(',')
239                            .map(|s| s.trim().to_string())
240                            .filter(|s| !s.is_empty())
241                            .collect();
242                    }
243                    _ => {}
244                },
245                section if section.starts_with("sqlfluff:rules:") => {
246                    let rule_name = section.strip_prefix("sqlfluff:rules:").unwrap();
247                    config
248                        .rules
249                        .entry(rule_name.to_string())
250                        .or_default()
251                        .insert(key, value);
252                }
253                _ => {}
254            }
255        }
256    }
257
258    Ok(config)
259}
260
261fn dirs_home() -> Option<PathBuf> {
262    std::env::var_os("HOME").map(PathBuf::from)
263}
264
265/// Filter out violations on lines that have `-- noqa` comments.
266pub fn filter_noqa(source: &str, violations: &mut Vec<rigsql_rules::LintViolation>) {
267    if violations.is_empty() {
268        return;
269    }
270
271    // Build a map of line_number -> noqa spec
272    let noqa_lines: HashMap<usize, NoqaSpec> = source
273        .lines()
274        .enumerate()
275        .filter_map(|(i, line)| parse_noqa_comment(line).map(|spec| (i + 1, spec)))
276        .collect();
277
278    if noqa_lines.is_empty() {
279        return;
280    }
281
282    violations.retain(|v| {
283        let (line, _) = v.line_col(source);
284        match noqa_lines.get(&line) {
285            None => true,
286            Some(NoqaSpec::All) => false,
287            Some(NoqaSpec::Rules(codes)) => !codes.iter().any(|c| c == v.rule_code),
288        }
289    });
290}
291
292#[derive(Debug)]
293enum NoqaSpec {
294    /// `-- noqa` — suppress all rules on this line.
295    All,
296    /// `-- noqa: CP01,LT01` — suppress specific rules.
297    Rules(Vec<String>),
298}
299
300/// Parse a noqa comment from a source line.
301fn parse_noqa_comment(line: &str) -> Option<NoqaSpec> {
302    // Case-insensitive search without allocating a new string
303    let bytes = line.as_bytes();
304    let pattern = b"-- noqa";
305    let idx = bytes
306        .windows(pattern.len())
307        .position(|w| w.eq_ignore_ascii_case(pattern))?;
308    let after = line[idx + 7..].trim_start();
309
310    if after.is_empty() || after.starts_with("--") {
311        return Some(NoqaSpec::All);
312    }
313
314    if let Some(rest) = after.strip_prefix(':') {
315        let codes: Vec<String> = rest
316            .split(',')
317            .map(|s| s.trim().to_uppercase())
318            .filter(|s| !s.is_empty())
319            .collect();
320        if codes.is_empty() {
321            Some(NoqaSpec::All)
322        } else {
323            Some(NoqaSpec::Rules(codes))
324        }
325    } else {
326        Some(NoqaSpec::All)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_parse_noqa_all() {
336        assert!(matches!(
337            parse_noqa_comment("SELECT 1 -- noqa"),
338            Some(NoqaSpec::All)
339        ));
340    }
341
342    #[test]
343    fn test_parse_noqa_specific() {
344        match parse_noqa_comment("SELECT 1 -- noqa: CP01, LT01") {
345            Some(NoqaSpec::Rules(codes)) => {
346                assert_eq!(codes, vec!["CP01", "LT01"]);
347            }
348            _ => panic!("Expected NoqaSpec::Rules"),
349        }
350    }
351
352    #[test]
353    fn test_parse_noqa_none() {
354        assert!(parse_noqa_comment("SELECT 1").is_none());
355    }
356
357    #[test]
358    fn test_parse_sqlfluff_config() {
359        let content = "\
360[sqlfluff]
361dialect = tsql
362max_line_length = 120
363
364[sqlfluff:rules:capitalisation.keywords]
365capitalisation_policy = lower
366";
367        let dir = std::env::temp_dir().join("rigsql_test_sqlfluff_config");
368        let _ = fs::create_dir_all(&dir);
369        let path = dir.join(".sqlfluff");
370        fs::write(&path, content).unwrap();
371
372        let config = parse_sqlfluff_file(&path).unwrap();
373        assert_eq!(config.dialect.as_deref(), Some("tsql"));
374        assert_eq!(config.max_line_length, Some(120));
375        assert_eq!(
376            config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
377            Some("lower")
378        );
379
380        let _ = fs::remove_dir_all(&dir);
381    }
382
383    #[test]
384    fn test_parse_rigsql_toml() {
385        let content = r#"
386[core]
387dialect = "tsql"
388max_line_length = 120
389exclude_rules = ["LT09", "CV06"]
390
391[rules."capitalisation.keywords"]
392capitalisation_policy = "lower"
393"#;
394        let dir = std::env::temp_dir().join("rigsql_test_toml_config");
395        let _ = fs::create_dir_all(&dir);
396        let path = dir.join("rigsql.toml");
397        fs::write(&path, content).unwrap();
398
399        let config = parse_rigsql_toml(&path).unwrap();
400        assert_eq!(config.dialect.as_deref(), Some("tsql"));
401        assert_eq!(config.max_line_length, Some(120));
402        assert_eq!(config.exclude_rules, vec!["LT09", "CV06"]);
403        assert_eq!(
404            config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
405            Some("lower")
406        );
407
408        let _ = fs::remove_dir_all(&dir);
409    }
410
411    #[test]
412    fn test_rigsql_toml_priority_over_sqlfluff() {
413        let dir = std::env::temp_dir().join("rigsql_test_priority");
414        let _ = fs::create_dir_all(&dir);
415
416        // Write both config files
417        fs::write(
418            dir.join(".sqlfluff"),
419            "[sqlfluff]\ndialect = postgres\nmax_line_length = 80\n",
420        )
421        .unwrap();
422        fs::write(
423            dir.join("rigsql.toml"),
424            "[core]\ndialect = \"tsql\"\nmax_line_length = 120\n",
425        )
426        .unwrap();
427
428        let config = Config::load_for_path(&dir);
429        // rigsql.toml should win
430        assert_eq!(config.dialect.as_deref(), Some("tsql"));
431        assert_eq!(config.max_line_length, Some(120));
432
433        let _ = fs::remove_dir_all(&dir);
434    }
435}