Skip to main content

diffguard_lsp/
config.rs

1use std::collections::{BTreeSet, HashSet};
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result, bail};
5use diffguard_domain::DirectoryRuleOverride;
6use diffguard_types::{ConfigFile, DirectoryOverrideConfig, MatchMode, RuleConfig, Severity};
7use lsp_types::{Diagnostic, NumberOrString};
8use regex::Regex;
9
10const DIRECTORY_OVERRIDE_NAME: &str = ".diffguard.toml";
11const MAX_INCLUDE_DEPTH: usize = 10;
12
13pub fn load_effective_config(path: Option<&Path>, no_default_rules: bool) -> Result<ConfigFile> {
14    let Some(path) = path else {
15        return Ok(ConfigFile::built_in());
16    };
17
18    let parsed = load_config_with_includes(path)?;
19    if no_default_rules {
20        Ok(parsed)
21    } else {
22        Ok(merge_with_built_in(parsed))
23    }
24}
25
26pub fn resolve_config_path(
27    workspace_root: Option<&Path>,
28    override_path: Option<String>,
29    default_name: &str,
30) -> Option<PathBuf> {
31    if let Some(raw) = override_path {
32        let candidate = PathBuf::from(raw);
33        if candidate.is_absolute() {
34            return Some(candidate);
35        }
36        return Some(
37            workspace_root
38                .map(|root| root.join(candidate.clone()))
39                .unwrap_or(candidate),
40        );
41    }
42
43    if let Some(root) = workspace_root {
44        let candidate = root.join(default_name);
45        if candidate.is_file() {
46            return Some(candidate);
47        }
48        return None;
49    }
50
51    let candidate = PathBuf::from(default_name);
52    if candidate.is_file() {
53        Some(candidate)
54    } else {
55        None
56    }
57}
58
59pub fn paths_match(left: &Path, right: &Path) -> bool {
60    let left_canonical = left.canonicalize().ok();
61    let right_canonical = right.canonicalize().ok();
62    if let (Some(left), Some(right)) = (left_canonical, right_canonical) {
63        return left == right;
64    }
65    normalize_path(left) == normalize_path(right)
66}
67
68pub fn normalize_path(path: &Path) -> String {
69    path.to_string_lossy().replace('\\', "/")
70}
71
72pub fn to_workspace_relative_path(workspace_root: Option<&Path>, file_path: &Path) -> String {
73    let normalized = if let Some(root) = workspace_root {
74        if let Ok(stripped) = file_path.strip_prefix(root) {
75            normalize_path(stripped)
76        } else {
77            normalize_path(file_path)
78        }
79    } else {
80        normalize_path(file_path)
81    };
82
83    normalized.trim_start_matches("./").to_string()
84}
85
86pub fn extract_rule_id(diagnostic: &Diagnostic) -> Option<String> {
87    if let Some(NumberOrString::String(rule_id)) = diagnostic.code.as_ref() {
88        return Some(rule_id.clone());
89    }
90
91    diagnostic
92        .data
93        .as_ref()
94        .and_then(|value| value.get("ruleId"))
95        .and_then(|value| value.as_str())
96        .map(|s| s.to_string())
97}
98
99pub fn find_rule<'a>(config: &'a ConfigFile, rule_id: &str) -> Option<&'a RuleConfig> {
100    config.rule.iter().find(|rule| rule.id == rule_id)
101}
102
103pub fn format_rule_explanation(rule: &RuleConfig) -> String {
104    let mut output = String::new();
105    output.push_str(&format!("Rule: {}\n", rule.id));
106    output.push_str(&format!("Severity: {}\n", rule.severity.as_str()));
107    output.push_str(&format!("Message: {}\n", rule.message));
108    output.push_str("Patterns:\n");
109    for pattern in &rule.patterns {
110        output.push_str(&format!("- {}\n", pattern));
111    }
112    output.push_str("Semantics:\n");
113    let match_mode = match rule.match_mode {
114        MatchMode::Any => "any",
115        MatchMode::Absent => "absent",
116    };
117    output.push_str(&format!("- Match mode: {}\n", match_mode));
118    output.push_str(&format!(
119        "- Multiline: {}{}\n",
120        if rule.multiline { "yes" } else { "no" },
121        rule.multiline_window
122            .map(|window| format!(" (window={})", window))
123            .unwrap_or_default()
124    ));
125    if !rule.context_patterns.is_empty() {
126        output.push_str(&format!(
127            "- Context patterns (window={}): {}\n",
128            rule.context_window.unwrap_or(3),
129            rule.context_patterns.join(", ")
130        ));
131    }
132    if !rule.escalate_patterns.is_empty() {
133        output.push_str(&format!(
134            "- Escalate to {} (window={}): {}\n",
135            rule.escalate_to.unwrap_or(Severity::Error).as_str(),
136            rule.escalate_window.unwrap_or(0),
137            rule.escalate_patterns.join(", ")
138        ));
139    }
140    if !rule.depends_on.is_empty() {
141        output.push_str(&format!("- Depends on: {}\n", rule.depends_on.join(", ")));
142    }
143    if !rule.languages.is_empty() {
144        output.push_str(&format!("Languages: {}\n", rule.languages.join(", ")));
145    }
146    if !rule.paths.is_empty() {
147        output.push_str(&format!("Paths: {}\n", rule.paths.join(", ")));
148    }
149    if !rule.exclude_paths.is_empty() {
150        output.push_str(&format!("Excludes: {}\n", rule.exclude_paths.join(", ")));
151    }
152    output.push_str(&format!(
153        "Ignore comments: {}\n",
154        if rule.ignore_comments { "yes" } else { "no" }
155    ));
156    output.push_str(&format!(
157        "Ignore strings: {}\n",
158        if rule.ignore_strings { "yes" } else { "no" }
159    ));
160    if let Some(help) = &rule.help {
161        output.push_str("Help:\n");
162        for line in help.lines() {
163            output.push_str(&format!("{}\n", line));
164        }
165    }
166    if let Some(url) = &rule.url {
167        output.push_str(&format!("URL: {}\n", url));
168    }
169    output
170}
171
172pub fn find_similar_rules(rule_id: &str, rules: &[RuleConfig]) -> Vec<String> {
173    let rule_id_lower = rule_id.to_lowercase();
174    let mut candidates: Vec<(String, usize)> = Vec::new();
175
176    for rule in rules {
177        let id_lower = rule.id.to_lowercase();
178        if id_lower.starts_with(&rule_id_lower) || rule_id_lower.starts_with(&id_lower) {
179            candidates.push((rule.id.clone(), 0));
180            continue;
181        }
182        if id_lower.contains(&rule_id_lower) || rule_id_lower.contains(&id_lower) {
183            candidates.push((rule.id.clone(), 1));
184            continue;
185        }
186        let distance = simple_edit_distance(&rule_id_lower, &id_lower);
187        if distance <= 3 {
188            candidates.push((rule.id.clone(), distance + 2));
189        }
190    }
191
192    candidates.sort_by_key(|(_, score)| *score);
193    candidates.truncate(5);
194    candidates.into_iter().map(|(id, _)| id).collect()
195}
196
197pub fn load_directory_overrides_for_file(
198    workspace_root: &Path,
199    relative_file_path: &str,
200) -> Result<Vec<DirectoryRuleOverride>> {
201    let mut candidates = BTreeSet::<PathBuf>::new();
202    collect_override_candidates_for_path(relative_file_path, &mut candidates);
203
204    let mut ordered_candidates: Vec<PathBuf> = candidates.into_iter().collect();
205    ordered_candidates.sort_by(|left, right| {
206        let left_parent = left.parent().unwrap_or_else(|| Path::new(""));
207        let right_parent = right.parent().unwrap_or_else(|| Path::new(""));
208        directory_depth(left_parent)
209            .cmp(&directory_depth(right_parent))
210            .then_with(|| left.to_string_lossy().cmp(&right.to_string_lossy()))
211    });
212
213    let mut overrides = Vec::new();
214    for candidate in ordered_candidates {
215        let full_path = workspace_root.join(&candidate);
216        if !full_path.is_file() {
217            continue;
218        }
219
220        let content = std::fs::read_to_string(&full_path)
221            .with_context(|| format!("read directory override '{}'", full_path.display()))?;
222        let expanded = expand_env_vars(&content).with_context(|| {
223            format!(
224                "expand env vars in directory override '{}'",
225                full_path.display()
226            )
227        })?;
228
229        let parsed: DirectoryOverrideConfig = toml::from_str(&expanded)
230            .with_context(|| format!("parse directory override '{}'", full_path.display()))?;
231
232        let directory =
233            normalize_override_directory(candidate.parent().unwrap_or_else(|| Path::new("")));
234        for rule in parsed.rules {
235            overrides.push(DirectoryRuleOverride {
236                directory: directory.clone(),
237                rule_id: rule.id,
238                enabled: rule.enabled,
239                severity: rule.severity,
240                exclude_paths: rule.exclude_paths,
241            });
242        }
243    }
244
245    Ok(overrides)
246}
247
248fn load_config_with_includes(path: &Path) -> Result<ConfigFile> {
249    let mut visited = HashSet::new();
250    load_config_recursive(path, &mut visited, 0)
251}
252
253fn load_config_recursive(
254    path: &Path,
255    visited: &mut HashSet<PathBuf>,
256    depth: usize,
257) -> Result<ConfigFile> {
258    if depth > MAX_INCLUDE_DEPTH {
259        bail!(
260            "include depth exceeded maximum of {} at '{}'",
261            MAX_INCLUDE_DEPTH,
262            path.display()
263        );
264    }
265
266    let canonical = path
267        .canonicalize()
268        .with_context(|| format!("canonicalize config path '{}'", path.display()))?;
269    if !visited.insert(canonical.clone()) {
270        bail!("circular include detected at '{}'", path.display());
271    }
272
273    let content = std::fs::read_to_string(path)
274        .with_context(|| format!("read config '{}'", path.display()))?;
275    let expanded = expand_env_vars(&content)?;
276    let parsed: ConfigFile =
277        toml::from_str(&expanded).with_context(|| format!("parse config '{}'", path.display()))?;
278
279    if parsed.includes.is_empty() {
280        return Ok(parsed);
281    }
282
283    let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
284    let mut merged = ConfigFile {
285        includes: vec![],
286        defaults: diffguard_types::Defaults::default(),
287        rule: vec![],
288    };
289
290    for include in &parsed.includes {
291        let include_path = base_dir.join(include);
292        if !include_path.exists() {
293            bail!(
294                "included config file not found: '{}' (from '{}')",
295                include_path.display(),
296                include
297            );
298        }
299
300        let included = load_config_recursive(&include_path, visited, depth + 1)?;
301        merged = merge_configs(merged, included);
302    }
303
304    let current = ConfigFile {
305        includes: vec![],
306        defaults: parsed.defaults,
307        rule: parsed.rule,
308    };
309    Ok(merge_configs(merged, current))
310}
311
312fn merge_configs(base: ConfigFile, other: ConfigFile) -> ConfigFile {
313    let defaults = if other.defaults != diffguard_types::Defaults::default() {
314        other.defaults
315    } else {
316        base.defaults
317    };
318
319    let mut rules = std::collections::BTreeMap::new();
320    for rule in base.rule {
321        rules.insert(rule.id.clone(), rule);
322    }
323    for rule in other.rule {
324        rules.insert(rule.id.clone(), rule);
325    }
326
327    ConfigFile {
328        includes: vec![],
329        defaults,
330        rule: rules.into_values().collect(),
331    }
332}
333
334fn merge_with_built_in(user: ConfigFile) -> ConfigFile {
335    let mut built_in = ConfigFile::built_in();
336    built_in.defaults = user.defaults;
337
338    let mut rules = std::collections::BTreeMap::<String, RuleConfig>::new();
339    for rule in built_in.rule {
340        rules.insert(rule.id.clone(), rule);
341    }
342    for rule in user.rule {
343        rules.insert(rule.id.clone(), rule);
344    }
345
346    built_in.rule = rules.into_values().collect();
347    built_in
348}
349
350fn expand_env_vars(content: &str) -> Result<String> {
351    let regex = Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::-([^}]*))?\}")
352        .expect("env var regex must compile");
353    let mut result = String::with_capacity(content.len());
354    let mut last_end = 0usize;
355
356    for capture in regex.captures_iter(content) {
357        let full = capture
358            .get(0)
359            .expect("full regex match should always be present");
360        let variable = capture
361            .get(1)
362            .expect("variable capture should always be present")
363            .as_str();
364        let default = capture.get(2).map(|m| m.as_str());
365
366        result.push_str(&content[last_end..full.start()]);
367        match std::env::var(variable) {
368            Ok(value) => result.push_str(&value),
369            Err(_) => {
370                if let Some(default) = default {
371                    result.push_str(default);
372                } else {
373                    bail!(
374                        "environment variable '{}' is not set and no default was provided",
375                        variable
376                    );
377                }
378            }
379        }
380        last_end = full.end();
381    }
382
383    result.push_str(&content[last_end..]);
384    Ok(result)
385}
386
387fn collect_override_candidates_for_path(file_path: &str, output: &mut BTreeSet<PathBuf>) {
388    let path = Path::new(file_path);
389    let mut current = path.parent();
390
391    if current.is_none() {
392        output.insert(PathBuf::from(DIRECTORY_OVERRIDE_NAME));
393        return;
394    }
395
396    while let Some(directory) = current {
397        let mut candidate = PathBuf::new();
398        if !directory.as_os_str().is_empty() {
399            candidate.push(directory);
400        }
401        candidate.push(DIRECTORY_OVERRIDE_NAME);
402        output.insert(candidate);
403
404        if directory.as_os_str().is_empty() {
405            break;
406        }
407        current = directory.parent();
408    }
409}
410
411fn normalize_override_directory(path: &Path) -> String {
412    let normalized = normalize_path(path);
413    let trimmed = normalized.trim_matches('/');
414    if trimmed.is_empty() || trimmed == "." {
415        String::new()
416    } else {
417        trimmed.to_string()
418    }
419}
420
421fn directory_depth(path: &Path) -> usize {
422    path.components().count()
423}
424
425fn simple_edit_distance(left: &str, right: &str) -> usize {
426    let left_chars: Vec<char> = left.chars().collect();
427    let right_chars: Vec<char> = right.chars().collect();
428
429    let left_len = left_chars.len();
430    let right_len = right_chars.len();
431    if left_len == 0 {
432        return right_len;
433    }
434    if right_len == 0 {
435        return left_len;
436    }
437
438    let mut previous: Vec<usize> = (0..=right_len).collect();
439    let mut current: Vec<usize> = vec![0; right_len + 1];
440    for i in 1..=left_len {
441        current[0] = i;
442        for j in 1..=right_len {
443            let cost = if left_chars[i - 1] == right_chars[j - 1] {
444                0
445            } else {
446                1
447            };
448            current[j] = (previous[j] + 1)
449                .min(current[j - 1] + 1)
450                .min(previous[j - 1] + cost);
451        }
452        std::mem::swap(&mut previous, &mut current);
453    }
454    previous[right_len]
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use tempfile::TempDir;
461
462    #[test]
463    fn extract_rule_id_from_code_or_data() {
464        let diagnostic_with_code = Diagnostic {
465            code: Some(NumberOrString::String("rust.no_unwrap".to_string())),
466            ..Diagnostic::default()
467        };
468        assert_eq!(
469            extract_rule_id(&diagnostic_with_code),
470            Some("rust.no_unwrap".to_string())
471        );
472
473        let diagnostic_with_data = Diagnostic {
474            data: Some(serde_json::json!({ "ruleId": "security.no_eval" })),
475            ..Diagnostic::default()
476        };
477        assert_eq!(
478            extract_rule_id(&diagnostic_with_data),
479            Some("security.no_eval".to_string())
480        );
481    }
482
483    #[test]
484    fn format_rule_explanation_contains_semantics() {
485        let rule = RuleConfig {
486            id: "rust.no_unwrap".to_string(),
487            severity: Severity::Error,
488            message: "Avoid unwrap".to_string(),
489            languages: vec!["rust".to_string()],
490            patterns: vec![r"\.unwrap\(".to_string()],
491            paths: vec!["**/*.rs".to_string()],
492            exclude_paths: vec!["**/tests/**".to_string()],
493            ignore_comments: true,
494            ignore_strings: true,
495            match_mode: MatchMode::Any,
496            multiline: false,
497            multiline_window: None,
498            context_patterns: vec![],
499            context_window: None,
500            escalate_patterns: vec![],
501            escalate_window: None,
502            escalate_to: None,
503            depends_on: vec![],
504            help: Some("Use pattern matching instead.".to_string()),
505            url: Some("https://example.com/rules/no_unwrap".to_string()),
506            tags: vec!["safety".to_string()],
507            test_cases: vec![],
508        };
509
510        let explanation = format_rule_explanation(&rule);
511        assert!(explanation.contains("Rule: rust.no_unwrap"));
512        assert!(explanation.contains("Severity: error"));
513        assert!(explanation.contains("Use pattern matching instead."));
514        assert!(explanation.contains("URL: https://example.com/rules/no_unwrap"));
515    }
516
517    #[test]
518    fn find_similar_rules_prefers_close_matches() {
519        let rules = vec![
520            RuleConfig {
521                id: "rust.no_unwrap".to_string(),
522                severity: Severity::Warn,
523                message: "msg".to_string(),
524                languages: vec![],
525                patterns: vec!["a".to_string()],
526                paths: vec![],
527                exclude_paths: vec![],
528                ignore_comments: false,
529                ignore_strings: false,
530                match_mode: MatchMode::Any,
531                multiline: false,
532                multiline_window: None,
533                context_patterns: vec![],
534                context_window: None,
535                escalate_patterns: vec![],
536                escalate_window: None,
537                escalate_to: None,
538                depends_on: vec![],
539                help: None,
540                url: None,
541                tags: vec![],
542                test_cases: vec![],
543            },
544            RuleConfig {
545                id: "security.no_eval".to_string(),
546                severity: Severity::Warn,
547                message: "msg".to_string(),
548                languages: vec![],
549                patterns: vec!["a".to_string()],
550                paths: vec![],
551                exclude_paths: vec![],
552                ignore_comments: false,
553                ignore_strings: false,
554                match_mode: MatchMode::Any,
555                multiline: false,
556                multiline_window: None,
557                context_patterns: vec![],
558                context_window: None,
559                escalate_patterns: vec![],
560                escalate_window: None,
561                escalate_to: None,
562                depends_on: vec![],
563                help: None,
564                url: None,
565                tags: vec![],
566                test_cases: vec![],
567            },
568        ];
569
570        let suggestions = find_similar_rules("rust.no_unwra", &rules);
571        assert!(suggestions.contains(&"rust.no_unwrap".to_string()));
572    }
573
574    #[test]
575    fn load_config_with_includes_merges_rules() {
576        let temp = TempDir::new().expect("temp dir");
577        let base = temp.path().join("base.toml");
578        let main = temp.path().join("main.toml");
579
580        std::fs::write(
581            &base,
582            r#"
583[[rule]]
584id = "base.rule"
585severity = "warn"
586message = "Base rule"
587patterns = ["base"]
588"#,
589        )
590        .expect("write base");
591
592        std::fs::write(
593            &main,
594            r#"
595includes = ["base.toml"]
596[[rule]]
597id = "main.rule"
598severity = "error"
599message = "Main rule"
600patterns = ["main"]
601"#,
602        )
603        .expect("write main");
604
605        let loaded = load_effective_config(Some(&main), true).expect("load config");
606        let ids: BTreeSet<String> = loaded.rule.into_iter().map(|rule| rule.id).collect();
607        assert!(ids.contains("base.rule"));
608        assert!(ids.contains("main.rule"));
609    }
610}