nu_lint/
engine.rs

1use std::{path::Path, sync::OnceLock};
2
3use nu_parser::parse;
4use nu_protocol::{
5    ast::Block,
6    engine::{EngineState, StateWorkingSet},
7};
8
9use crate::{
10    LintError, config::Config, context::LintContext, lint::Violation, rules::RuleRegistry,
11};
12
13/// Parse Nushell source code into an AST and return both the Block and
14/// `StateWorkingSet`.
15fn parse_source<'a>(engine_state: &'a EngineState, source: &[u8]) -> (Block, StateWorkingSet<'a>) {
16    let mut working_set = StateWorkingSet::new(engine_state);
17    let block = parse(&mut working_set, None, source, false);
18
19    ((*block).clone(), working_set)
20}
21
22pub struct LintEngine {
23    registry: RuleRegistry,
24    config: Config,
25    engine_state: &'static EngineState,
26}
27
28impl LintEngine {
29    /// Get or initialize the default engine state
30    fn default_engine_state() -> &'static EngineState {
31        static ENGINE: OnceLock<EngineState> = OnceLock::new();
32        ENGINE.get_or_init(|| {
33            let engine_state = nu_cmd_lang::create_default_context();
34            nu_command::add_shell_command_context(engine_state)
35        })
36    }
37
38    #[must_use]
39    pub fn new(config: Config) -> Self {
40        Self {
41            registry: RuleRegistry::with_default_rules(),
42            config,
43            engine_state: Self::default_engine_state(),
44        }
45    }
46
47    /// Lint a file at the given path.
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if the file cannot be read.
52    pub fn lint_file(&self, path: &Path) -> Result<Vec<Violation>, LintError> {
53        let source = std::fs::read_to_string(path)?;
54        Ok(self.lint_source(&source, Some(path)))
55    }
56
57    #[must_use]
58    pub fn lint_source(&self, source: &str, path: Option<&Path>) -> Vec<Violation> {
59        let (block, working_set) = parse_source(self.engine_state, source.as_bytes());
60
61        let context = LintContext {
62            source,
63            file_path: path,
64            ast: &block,
65            engine_state: self.engine_state,
66            working_set: &working_set,
67        };
68
69        let mut violations = self.collect_violations(&context);
70        Self::attach_file_path(&mut violations, path);
71        Self::sort_violations(&mut violations);
72        violations
73    }
74
75    /// Collect violations from all enabled rules
76    fn collect_violations(&self, context: &LintContext) -> Vec<Violation> {
77        let enabled_rules = self.get_enabled_rules();
78
79        enabled_rules
80            .flat_map(|rule| (rule.check)(context))
81            .collect()
82    }
83
84    /// Get all rules that are enabled according to the configuration
85    fn get_enabled_rules(&self) -> impl Iterator<Item = &crate::rule::Rule> {
86        self.registry.all_rules().filter(|rule| {
87            // If not in config, use default (enabled). If in config, check if it's not
88            // turned off.
89            !matches!(
90                self.config.rules.get(rule.id),
91                Some(&crate::config::RuleSeverity::Off)
92            )
93        })
94    }
95
96    /// Attach file path to all violations
97    fn attach_file_path(violations: &mut [Violation], path: Option<&Path>) {
98        if let Some(file_path_str) = path.and_then(|p| p.to_str()) {
99            use std::borrow::Cow;
100            let file_path: Cow<'static, str> = file_path_str.to_owned().into();
101            for violation in violations {
102                violation.file = Some(file_path.clone());
103            }
104        }
105    }
106
107    /// Sort violations by span start position, then by severity
108    fn sort_violations(violations: &mut [Violation]) {
109        violations.sort_by(|a, b| {
110            a.span
111                .start
112                .cmp(&b.span.start)
113                .then(a.severity.cmp(&b.severity))
114        });
115    }
116
117    #[must_use]
118    pub fn registry(&self) -> &RuleRegistry {
119        &self.registry
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_lint_valid_code() {
129        let engine = LintEngine::new(Config::default());
130        let source = "let my_variable = 5";
131        let violations = engine.lint_source(source, None);
132        assert_eq!(violations.len(), 0);
133    }
134
135    #[test]
136    fn test_lint_invalid_snake_case() {
137        let engine = LintEngine::new(Config::default());
138        let source = "let myVariable = 5";
139        let violations = engine.lint_source(source, None);
140        assert!(!violations.is_empty());
141        assert_eq!(violations[0].rule_id, "snake_case_variables");
142    }
143}