1use std::collections::HashMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum ConfigError {
9 #[error("Failed to read config file {path}: {source}")]
10 ReadError {
11 path: PathBuf,
12 source: std::io::Error,
13 },
14}
15
16#[derive(Debug, Clone, Default)]
18pub struct Config {
19 pub dialect: Option<String>,
21 pub max_line_length: Option<usize>,
23 pub exclude_rules: Vec<String>,
25 pub rules: HashMap<String, HashMap<String, String>>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31enum ConfigKind {
32 RigsqlToml,
33 Sqlfluff,
34}
35
36impl Config {
37 pub fn load_for_path(path: &Path) -> Self {
43 let search_dir = if path.is_file() {
44 path.parent().unwrap_or(path)
45 } else {
46 path
47 };
48
49 let mut config_files: Vec<(PathBuf, ConfigKind)> = Vec::new();
50 let mut dir = Some(search_dir);
51 while let Some(d) = dir {
52 if let Some(found) = find_config_in_dir(d) {
53 config_files.push(found);
54 }
55 dir = d.parent();
56 }
57
58 if let Some(home) = dirs_home() {
60 if !config_files.iter().any(|(p, _)| p.parent() == Some(&home)) {
61 if let Some(found) = find_config_in_dir(&home) {
62 config_files.push(found);
63 }
64 }
65 }
66
67 config_files.reverse();
69
70 let mut config = Config::default();
71 for (path, kind) in &config_files {
72 let parsed = match kind {
73 ConfigKind::RigsqlToml => parse_rigsql_toml(path),
74 ConfigKind::Sqlfluff => parse_sqlfluff_file(path),
75 };
76 if let Ok(file_config) = parsed {
77 config.merge(file_config);
78 }
79 }
80
81 config
82 }
83
84 fn merge(&mut self, other: Config) {
86 if other.dialect.is_some() {
87 self.dialect = other.dialect;
88 }
89 if other.max_line_length.is_some() {
90 self.max_line_length = other.max_line_length;
91 }
92 if !other.exclude_rules.is_empty() {
93 self.exclude_rules = other.exclude_rules;
94 }
95 for (rule_name, settings) in other.rules {
96 let entry = self.rules.entry(rule_name).or_default();
97 for (k, v) in settings {
98 entry.insert(k, v);
99 }
100 }
101 }
102
103 pub fn rule_setting(&self, rule_name: &str, key: &str) -> Option<&str> {
105 self.rules
106 .get(rule_name)
107 .and_then(|m| m.get(key))
108 .map(|s| s.as_str())
109 }
110}
111
112fn find_config_in_dir(dir: &Path) -> Option<(PathBuf, ConfigKind)> {
114 let toml_path = dir.join("rigsql.toml");
115 if toml_path.is_file() {
116 return Some((toml_path, ConfigKind::RigsqlToml));
117 }
118 let sqlfluff_path = dir.join(".sqlfluff");
119 if sqlfluff_path.is_file() {
120 return Some((sqlfluff_path, ConfigKind::Sqlfluff));
121 }
122 None
123}
124
125fn read_config_file(path: &Path) -> Result<String, ConfigError> {
127 fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
128 path: path.to_path_buf(),
129 source: e,
130 })
131}
132
133fn parse_rigsql_toml(path: &Path) -> Result<Config, ConfigError> {
148 let content = read_config_file(path)?;
149
150 let table: toml::Table = match content.parse() {
151 Ok(t) => t,
152 Err(e) => {
153 eprintln!("Warning: failed to parse {}: {e}", path.display());
154 return Ok(Config::default());
155 }
156 };
157
158 let mut config = Config::default();
159
160 if let Some(core) = table.get("core").and_then(|v| v.as_table()) {
162 if let Some(dialect) = core.get("dialect").and_then(|v| v.as_str()) {
163 config.dialect = Some(dialect.to_string());
164 }
165 if let Some(len) = core.get("max_line_length").and_then(|v| v.as_integer()) {
166 config.max_line_length = Some(len as usize);
167 }
168 if let Some(arr) = core.get("exclude_rules").and_then(|v| v.as_array()) {
169 config.exclude_rules = arr
170 .iter()
171 .filter_map(|v| v.as_str())
172 .map(|s| s.to_string())
173 .collect();
174 }
175 }
176
177 if let Some(rules) = table.get("rules").and_then(|v| v.as_table()) {
179 for (rule_name, rule_value) in rules {
180 if let Some(rule_table) = rule_value.as_table() {
181 let mut settings = HashMap::new();
182 for (k, v) in rule_table {
183 let val = match v {
184 toml::Value::String(s) => s.clone(),
185 toml::Value::Integer(i) => i.to_string(),
186 toml::Value::Float(f) => f.to_string(),
187 toml::Value::Boolean(b) => b.to_string(),
188 _ => continue,
189 };
190 settings.insert(k.clone(), val);
191 }
192 if !settings.is_empty() {
193 config.rules.insert(rule_name.clone(), settings);
194 }
195 }
196 }
197 }
198
199 Ok(config)
200}
201
202fn parse_sqlfluff_file(path: &Path) -> Result<Config, ConfigError> {
206 let content = read_config_file(path)?;
207
208 let mut config = Config::default();
209 let mut current_section = String::new();
210
211 for line in content.lines() {
212 let line = line.trim();
213
214 if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
216 continue;
217 }
218
219 if line.starts_with('[') && line.ends_with(']') {
221 current_section = line[1..line.len() - 1].trim().to_string();
222 continue;
223 }
224
225 if let Some((key, value)) = line.split_once('=') {
227 let key = key.trim().to_lowercase();
228 let value = value.trim().to_string();
229
230 match current_section.as_str() {
231 "sqlfluff" => match key.as_str() {
232 "dialect" => config.dialect = Some(value),
233 "max_line_length" => {
234 config.max_line_length = value.parse().ok();
235 }
236 "exclude_rules" => {
237 config.exclude_rules = value
238 .split(',')
239 .map(|s| s.trim().to_string())
240 .filter(|s| !s.is_empty())
241 .collect();
242 }
243 _ => {}
244 },
245 section if section.starts_with("sqlfluff:rules:") => {
246 let rule_name = section.strip_prefix("sqlfluff:rules:").unwrap();
247 config
248 .rules
249 .entry(rule_name.to_string())
250 .or_default()
251 .insert(key, value);
252 }
253 _ => {}
254 }
255 }
256 }
257
258 Ok(config)
259}
260
261fn dirs_home() -> Option<PathBuf> {
262 std::env::var_os("HOME").map(PathBuf::from)
263}
264
265pub fn filter_noqa(source: &str, violations: &mut Vec<rigsql_rules::LintViolation>) {
267 if violations.is_empty() {
268 return;
269 }
270
271 let noqa_lines: HashMap<usize, NoqaSpec> = source
273 .lines()
274 .enumerate()
275 .filter_map(|(i, line)| parse_noqa_comment(line).map(|spec| (i + 1, spec)))
276 .collect();
277
278 if noqa_lines.is_empty() {
279 return;
280 }
281
282 violations.retain(|v| {
283 let (line, _) = v.line_col(source);
284 match noqa_lines.get(&line) {
285 None => true,
286 Some(NoqaSpec::All) => false,
287 Some(NoqaSpec::Rules(codes)) => !codes.iter().any(|c| c == v.rule_code),
288 }
289 });
290}
291
292#[derive(Debug)]
293enum NoqaSpec {
294 All,
296 Rules(Vec<String>),
298}
299
300fn parse_noqa_comment(line: &str) -> Option<NoqaSpec> {
302 let bytes = line.as_bytes();
304 let pattern = b"-- noqa";
305 let idx = bytes
306 .windows(pattern.len())
307 .position(|w| w.eq_ignore_ascii_case(pattern))?;
308 let after = line[idx + 7..].trim_start();
309
310 if after.is_empty() || after.starts_with("--") {
311 return Some(NoqaSpec::All);
312 }
313
314 if let Some(rest) = after.strip_prefix(':') {
315 let codes: Vec<String> = rest
316 .split(',')
317 .map(|s| s.trim().to_uppercase())
318 .filter(|s| !s.is_empty())
319 .collect();
320 if codes.is_empty() {
321 Some(NoqaSpec::All)
322 } else {
323 Some(NoqaSpec::Rules(codes))
324 }
325 } else {
326 Some(NoqaSpec::All)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_parse_noqa_all() {
336 assert!(matches!(
337 parse_noqa_comment("SELECT 1 -- noqa"),
338 Some(NoqaSpec::All)
339 ));
340 }
341
342 #[test]
343 fn test_parse_noqa_specific() {
344 match parse_noqa_comment("SELECT 1 -- noqa: CP01, LT01") {
345 Some(NoqaSpec::Rules(codes)) => {
346 assert_eq!(codes, vec!["CP01", "LT01"]);
347 }
348 _ => panic!("Expected NoqaSpec::Rules"),
349 }
350 }
351
352 #[test]
353 fn test_parse_noqa_none() {
354 assert!(parse_noqa_comment("SELECT 1").is_none());
355 }
356
357 #[test]
358 fn test_parse_sqlfluff_config() {
359 let content = "\
360[sqlfluff]
361dialect = tsql
362max_line_length = 120
363
364[sqlfluff:rules:capitalisation.keywords]
365capitalisation_policy = lower
366";
367 let dir = std::env::temp_dir().join("rigsql_test_sqlfluff_config");
368 let _ = fs::create_dir_all(&dir);
369 let path = dir.join(".sqlfluff");
370 fs::write(&path, content).unwrap();
371
372 let config = parse_sqlfluff_file(&path).unwrap();
373 assert_eq!(config.dialect.as_deref(), Some("tsql"));
374 assert_eq!(config.max_line_length, Some(120));
375 assert_eq!(
376 config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
377 Some("lower")
378 );
379
380 let _ = fs::remove_dir_all(&dir);
381 }
382
383 #[test]
384 fn test_parse_rigsql_toml() {
385 let content = r#"
386[core]
387dialect = "tsql"
388max_line_length = 120
389exclude_rules = ["LT09", "CV06"]
390
391[rules."capitalisation.keywords"]
392capitalisation_policy = "lower"
393"#;
394 let dir = std::env::temp_dir().join("rigsql_test_toml_config");
395 let _ = fs::create_dir_all(&dir);
396 let path = dir.join("rigsql.toml");
397 fs::write(&path, content).unwrap();
398
399 let config = parse_rigsql_toml(&path).unwrap();
400 assert_eq!(config.dialect.as_deref(), Some("tsql"));
401 assert_eq!(config.max_line_length, Some(120));
402 assert_eq!(config.exclude_rules, vec!["LT09", "CV06"]);
403 assert_eq!(
404 config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
405 Some("lower")
406 );
407
408 let _ = fs::remove_dir_all(&dir);
409 }
410
411 #[test]
412 fn test_rigsql_toml_priority_over_sqlfluff() {
413 let dir = std::env::temp_dir().join("rigsql_test_priority");
414 let _ = fs::create_dir_all(&dir);
415
416 fs::write(
418 dir.join(".sqlfluff"),
419 "[sqlfluff]\ndialect = postgres\nmax_line_length = 80\n",
420 )
421 .unwrap();
422 fs::write(
423 dir.join("rigsql.toml"),
424 "[core]\ndialect = \"tsql\"\nmax_line_length = 120\n",
425 )
426 .unwrap();
427
428 let config = Config::load_for_path(&dir);
429 assert_eq!(config.dialect.as_deref(), Some("tsql"));
431 assert_eq!(config.max_line_length, Some(120));
432
433 let _ = fs::remove_dir_all(&dir);
434 }
435}