1use std::{path::Path, sync::OnceLock};
2
3use nu_parser::parse;
4use nu_protocol::{
5 ast::Block,
6 engine::{EngineState, StateWorkingSet},
7};
8
9use crate::{
10 LintError, config::Config, context::LintContext, lint::Violation, rules::RuleRegistry,
11};
12
13fn parse_source<'a>(engine_state: &'a EngineState, source: &[u8]) -> (Block, StateWorkingSet<'a>) {
16 let mut working_set = StateWorkingSet::new(engine_state);
17 let block = parse(&mut working_set, None, source, false);
18
19 ((*block).clone(), working_set)
20}
21
22pub struct LintEngine {
23 registry: RuleRegistry,
24 config: Config,
25 engine_state: &'static EngineState,
26}
27
28impl LintEngine {
29 fn default_engine_state() -> &'static EngineState {
31 static ENGINE: OnceLock<EngineState> = OnceLock::new();
32 ENGINE.get_or_init(|| {
33 let engine_state = nu_cmd_lang::create_default_context();
34 nu_command::add_shell_command_context(engine_state)
35 })
36 }
37
38 #[must_use]
39 pub fn new(config: Config) -> Self {
40 Self {
41 registry: RuleRegistry::with_default_rules(),
42 config,
43 engine_state: Self::default_engine_state(),
44 }
45 }
46
47 pub fn lint_file(&self, path: &Path) -> Result<Vec<Violation>, LintError> {
53 let source = std::fs::read_to_string(path)?;
54 Ok(self.lint_source(&source, Some(path)))
55 }
56
57 #[must_use]
58 pub fn lint_source(&self, source: &str, path: Option<&Path>) -> Vec<Violation> {
59 let (block, working_set) = parse_source(self.engine_state, source.as_bytes());
60
61 let context = LintContext {
62 source,
63 file_path: path,
64 ast: &block,
65 engine_state: self.engine_state,
66 working_set: &working_set,
67 };
68
69 let mut violations = self.collect_violations(&context);
70 Self::attach_file_path(&mut violations, path);
71 Self::sort_violations(&mut violations);
72 violations
73 }
74
75 fn collect_violations(&self, context: &LintContext) -> Vec<Violation> {
77 let enabled_rules = self.get_enabled_rules();
78
79 enabled_rules
80 .flat_map(|rule| (rule.check)(context))
81 .collect()
82 }
83
84 fn get_enabled_rules(&self) -> impl Iterator<Item = &crate::rule::Rule> {
86 self.registry.all_rules().filter(|rule| {
87 !matches!(
90 self.config.rules.get(rule.id),
91 Some(&crate::config::RuleSeverity::Off)
92 )
93 })
94 }
95
96 fn attach_file_path(violations: &mut [Violation], path: Option<&Path>) {
98 if let Some(file_path_str) = path.and_then(|p| p.to_str()) {
99 use std::borrow::Cow;
100 let file_path: Cow<'static, str> = file_path_str.to_owned().into();
101 for violation in violations {
102 violation.file = Some(file_path.clone());
103 }
104 }
105 }
106
107 fn sort_violations(violations: &mut [Violation]) {
109 violations.sort_by(|a, b| {
110 a.span
111 .start
112 .cmp(&b.span.start)
113 .then(a.severity.cmp(&b.severity))
114 });
115 }
116
117 #[must_use]
118 pub fn registry(&self) -> &RuleRegistry {
119 &self.registry
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_lint_valid_code() {
129 let engine = LintEngine::new(Config::default());
130 let source = "let my_variable = 5";
131 let violations = engine.lint_source(source, None);
132 assert_eq!(violations.len(), 0);
133 }
134
135 #[test]
136 fn test_lint_invalid_snake_case() {
137 let engine = LintEngine::new(Config::default());
138 let source = "let myVariable = 5";
139 let violations = engine.lint_source(source, None);
140 assert!(!violations.is_empty());
141 assert_eq!(violations[0].rule_id, "snake_case_variables");
142 }
143}