1use std::collections::HashMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum ConfigError {
9 #[error("Failed to read config file {path}: {source}")]
10 ReadError {
11 path: PathBuf,
12 source: std::io::Error,
13 },
14}
15
16#[derive(Debug, Clone, Default)]
18pub struct Config {
19 pub dialect: Option<String>,
21 pub locale: Option<String>,
23 pub max_line_length: Option<usize>,
25 pub exclude_rules: Vec<String>,
27 pub rules: HashMap<String, HashMap<String, String>>,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33enum ConfigKind {
34 RigsqlToml,
35 Sqlfluff,
36}
37
38impl Config {
39 pub fn load_for_path(path: &Path) -> Self {
45 let search_dir = if path.is_file() {
46 path.parent().unwrap_or(path)
47 } else {
48 path
49 };
50
51 let mut config_files: Vec<(PathBuf, ConfigKind)> = Vec::new();
52 let mut dir = Some(search_dir);
53 while let Some(d) = dir {
54 if let Some(found) = find_config_in_dir(d) {
55 config_files.push(found);
56 }
57 dir = d.parent();
58 }
59
60 if let Some(home) = dirs_home() {
62 if !config_files.iter().any(|(p, _)| p.parent() == Some(&home)) {
63 if let Some(found) = find_config_in_dir(&home) {
64 config_files.push(found);
65 }
66 }
67 }
68
69 config_files.reverse();
71
72 let mut config = Config::default();
73 for (path, kind) in &config_files {
74 let parsed = match kind {
75 ConfigKind::RigsqlToml => parse_rigsql_toml(path),
76 ConfigKind::Sqlfluff => parse_sqlfluff_file(path),
77 };
78 if let Ok(file_config) = parsed {
79 config.merge(file_config);
80 }
81 }
82
83 config
84 }
85
86 fn merge(&mut self, other: Config) {
88 if other.dialect.is_some() {
89 self.dialect = other.dialect;
90 }
91 if other.locale.is_some() {
92 self.locale = other.locale;
93 }
94 if other.max_line_length.is_some() {
95 self.max_line_length = other.max_line_length;
96 }
97 if !other.exclude_rules.is_empty() {
98 self.exclude_rules = other.exclude_rules;
99 }
100 for (rule_name, settings) in other.rules {
101 let entry = self.rules.entry(rule_name).or_default();
102 for (k, v) in settings {
103 entry.insert(k, v);
104 }
105 }
106 }
107
108 pub fn rule_setting(&self, rule_name: &str, key: &str) -> Option<&str> {
110 self.rules
111 .get(rule_name)
112 .and_then(|m| m.get(key))
113 .map(|s| s.as_str())
114 }
115}
116
117fn find_config_in_dir(dir: &Path) -> Option<(PathBuf, ConfigKind)> {
119 let toml_path = dir.join("rigsql.toml");
120 if toml_path.is_file() {
121 return Some((toml_path, ConfigKind::RigsqlToml));
122 }
123 let sqlfluff_path = dir.join(".sqlfluff");
124 if sqlfluff_path.is_file() {
125 return Some((sqlfluff_path, ConfigKind::Sqlfluff));
126 }
127 None
128}
129
130fn read_config_file(path: &Path) -> Result<String, ConfigError> {
132 fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
133 path: path.to_path_buf(),
134 source: e,
135 })
136}
137
138fn parse_rigsql_toml(path: &Path) -> Result<Config, ConfigError> {
153 let content = read_config_file(path)?;
154
155 let table: toml::Table = match content.parse() {
156 Ok(t) => t,
157 Err(e) => {
158 eprintln!("Warning: failed to parse {}: {e}", path.display());
159 return Ok(Config::default());
160 }
161 };
162
163 let mut config = Config::default();
164
165 if let Some(core) = table.get("core").and_then(|v| v.as_table()) {
167 if let Some(dialect) = core.get("dialect").and_then(|v| v.as_str()) {
168 config.dialect = Some(dialect.to_string());
169 }
170 if let Some(locale) = core.get("locale").and_then(|v| v.as_str()) {
171 config.locale = Some(locale.to_string());
172 }
173 if let Some(len) = core.get("max_line_length").and_then(|v| v.as_integer()) {
174 config.max_line_length = Some(len as usize);
175 }
176 if let Some(arr) = core.get("exclude_rules").and_then(|v| v.as_array()) {
177 config.exclude_rules = arr
178 .iter()
179 .filter_map(|v| v.as_str())
180 .map(|s| s.to_string())
181 .collect();
182 }
183 }
184
185 if let Some(rules) = table.get("rules").and_then(|v| v.as_table()) {
187 for (rule_name, rule_value) in rules {
188 if let Some(rule_table) = rule_value.as_table() {
189 let mut settings = HashMap::new();
190 for (k, v) in rule_table {
191 let val = match v {
192 toml::Value::String(s) => s.clone(),
193 toml::Value::Integer(i) => i.to_string(),
194 toml::Value::Float(f) => f.to_string(),
195 toml::Value::Boolean(b) => b.to_string(),
196 _ => continue,
197 };
198 settings.insert(k.clone(), val);
199 }
200 if !settings.is_empty() {
201 config.rules.insert(rule_name.clone(), settings);
202 }
203 }
204 }
205 }
206
207 Ok(config)
208}
209
210fn parse_sqlfluff_file(path: &Path) -> Result<Config, ConfigError> {
214 let content = read_config_file(path)?;
215
216 let mut config = Config::default();
217 let mut current_section = String::new();
218
219 for line in content.lines() {
220 let line = line.trim();
221
222 if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
224 continue;
225 }
226
227 if line.starts_with('[') && line.ends_with(']') {
229 current_section = line[1..line.len() - 1].trim().to_string();
230 continue;
231 }
232
233 if let Some((key, value)) = line.split_once('=') {
235 let key = key.trim().to_lowercase();
236 let value = value.trim().to_string();
237
238 match current_section.as_str() {
239 "sqlfluff" => match key.as_str() {
240 "dialect" => config.dialect = Some(value),
241 "locale" => config.locale = Some(value),
242 "max_line_length" => {
243 config.max_line_length = value.parse().ok();
244 }
245 "exclude_rules" => {
246 config.exclude_rules = value
247 .split(',')
248 .map(|s| s.trim().to_string())
249 .filter(|s| !s.is_empty())
250 .collect();
251 }
252 _ => {}
253 },
254 section if section.starts_with("sqlfluff:rules:") => {
255 let rule_name = section.strip_prefix("sqlfluff:rules:").unwrap();
256 config
257 .rules
258 .entry(rule_name.to_string())
259 .or_default()
260 .insert(key, value);
261 }
262 _ => {}
263 }
264 }
265 }
266
267 Ok(config)
268}
269
270fn dirs_home() -> Option<PathBuf> {
271 std::env::var_os("HOME").map(PathBuf::from)
272}
273
274pub fn filter_noqa(source: &str, violations: &mut Vec<rigsql_rules::LintViolation>) {
276 if violations.is_empty() {
277 return;
278 }
279
280 let noqa_lines: HashMap<usize, NoqaSpec> = source
282 .lines()
283 .enumerate()
284 .filter_map(|(i, line)| parse_noqa_comment(line).map(|spec| (i + 1, spec)))
285 .collect();
286
287 if noqa_lines.is_empty() {
288 return;
289 }
290
291 violations.retain(|v| {
292 let (line, _) = v.line_col(source);
293 match noqa_lines.get(&line) {
294 None => true,
295 Some(NoqaSpec::All) => false,
296 Some(NoqaSpec::Rules(codes)) => !codes.iter().any(|c| c == v.rule_code),
297 }
298 });
299}
300
301#[derive(Debug)]
302enum NoqaSpec {
303 All,
305 Rules(Vec<String>),
307}
308
309fn parse_noqa_comment(line: &str) -> Option<NoqaSpec> {
311 let bytes = line.as_bytes();
313 let pattern = b"-- noqa";
314 let idx = bytes
315 .windows(pattern.len())
316 .position(|w| w.eq_ignore_ascii_case(pattern))?;
317 let after = line[idx + 7..].trim_start();
318
319 if after.is_empty() || after.starts_with("--") {
320 return Some(NoqaSpec::All);
321 }
322
323 if let Some(rest) = after.strip_prefix(':') {
324 let codes: Vec<String> = rest
325 .split(',')
326 .map(|s| s.trim().to_uppercase())
327 .filter(|s| !s.is_empty())
328 .collect();
329 if codes.is_empty() {
330 Some(NoqaSpec::All)
331 } else {
332 Some(NoqaSpec::Rules(codes))
333 }
334 } else {
335 Some(NoqaSpec::All)
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_parse_noqa_all() {
345 assert!(matches!(
346 parse_noqa_comment("SELECT 1 -- noqa"),
347 Some(NoqaSpec::All)
348 ));
349 }
350
351 #[test]
352 fn test_parse_noqa_specific() {
353 match parse_noqa_comment("SELECT 1 -- noqa: CP01, LT01") {
354 Some(NoqaSpec::Rules(codes)) => {
355 assert_eq!(codes, vec!["CP01", "LT01"]);
356 }
357 _ => panic!("Expected NoqaSpec::Rules"),
358 }
359 }
360
361 #[test]
362 fn test_parse_noqa_none() {
363 assert!(parse_noqa_comment("SELECT 1").is_none());
364 }
365
366 #[test]
367 fn test_parse_sqlfluff_config() {
368 let content = "\
369[sqlfluff]
370dialect = tsql
371max_line_length = 120
372
373[sqlfluff:rules:capitalisation.keywords]
374capitalisation_policy = lower
375";
376 let dir = std::env::temp_dir().join("rigsql_test_sqlfluff_config");
377 let _ = fs::create_dir_all(&dir);
378 let path = dir.join(".sqlfluff");
379 fs::write(&path, content).unwrap();
380
381 let config = parse_sqlfluff_file(&path).unwrap();
382 assert_eq!(config.dialect.as_deref(), Some("tsql"));
383 assert_eq!(config.max_line_length, Some(120));
384 assert_eq!(
385 config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
386 Some("lower")
387 );
388
389 let _ = fs::remove_dir_all(&dir);
390 }
391
392 #[test]
393 fn test_parse_rigsql_toml() {
394 let content = r#"
395[core]
396dialect = "tsql"
397max_line_length = 120
398exclude_rules = ["LT09", "CV06"]
399
400[rules."capitalisation.keywords"]
401capitalisation_policy = "lower"
402"#;
403 let dir = std::env::temp_dir().join("rigsql_test_toml_config");
404 let _ = fs::create_dir_all(&dir);
405 let path = dir.join("rigsql.toml");
406 fs::write(&path, content).unwrap();
407
408 let config = parse_rigsql_toml(&path).unwrap();
409 assert_eq!(config.dialect.as_deref(), Some("tsql"));
410 assert_eq!(config.max_line_length, Some(120));
411 assert_eq!(config.exclude_rules, vec!["LT09", "CV06"]);
412 assert_eq!(
413 config.rule_setting("capitalisation.keywords", "capitalisation_policy"),
414 Some("lower")
415 );
416
417 let _ = fs::remove_dir_all(&dir);
418 }
419
420 #[test]
421 fn test_rigsql_toml_priority_over_sqlfluff() {
422 let dir = std::env::temp_dir().join("rigsql_test_priority");
423 let _ = fs::create_dir_all(&dir);
424
425 fs::write(
427 dir.join(".sqlfluff"),
428 "[sqlfluff]\ndialect = postgres\nmax_line_length = 80\n",
429 )
430 .unwrap();
431 fs::write(
432 dir.join("rigsql.toml"),
433 "[core]\ndialect = \"tsql\"\nmax_line_length = 120\n",
434 )
435 .unwrap();
436
437 let config = Config::load_for_path(&dir);
438 assert_eq!(config.dialect.as_deref(), Some("tsql"));
440 assert_eq!(config.max_line_length, Some(120));
441
442 let _ = fs::remove_dir_all(&dir);
443 }
444}