nu_lint/
context.rs

1use std::path::Path;
2
3use nu_protocol::{
4    DeclId, Span,
5    ast::{Block, Expression, FindMapResult, Traverse},
6    engine::{Command, EngineState, StateWorkingSet},
7};
8
9use crate::lint::{RuleViolation, Violation};
10
11/// Context containing all lint information (source, AST, and engine state)
12/// Rules can use whatever they need from this context
13pub struct LintContext<'a> {
14    pub source: &'a str,
15    pub file_path: Option<&'a Path>,
16    pub ast: &'a Block,
17    pub engine_state: &'a EngineState,
18    pub working_set: &'a StateWorkingSet<'a>,
19}
20
21impl LintContext<'_> {
22    /// Find violations by applying a conditional predicate to regex matches
23    pub fn violations_from_regex<MatchPredicate>(
24        &self,
25        pattern: &regex::Regex,
26        rule_id: &'static str,
27        predicate: MatchPredicate,
28    ) -> Vec<RuleViolation>
29    where
30        MatchPredicate: Fn(regex::Match) -> Option<(String, Option<String>)>,
31    {
32        pattern
33            .find_iter(self.source)
34            .filter_map(|mat| {
35                predicate(mat).map(|(message, suggestion)| {
36                    let violation = RuleViolation::new_dynamic(
37                        rule_id,
38                        message,
39                        Span::new(mat.start(), mat.end()),
40                    );
41                    match suggestion {
42                        Some(sug) => violation.with_suggestion_dynamic(sug),
43                        None => violation,
44                    }
45                })
46            })
47            .collect()
48    }
49
50    /// Collect all violations using a closure over expressions (Traverse-based)
51    ///
52    /// This method uses Nushell's upstream `Traverse` trait to walk the AST
53    /// and collect violations. The collector function is called for each
54    /// expression in the AST and should return a vector of violations.
55    pub fn collect_violations<F>(&self, collector: F) -> Vec<Violation>
56    where
57        F: Fn(&Expression, &Self) -> Vec<Violation>,
58    {
59        let mut violations = Vec::new();
60
61        let f = |expr: &Expression| collector(expr, self);
62
63        // Visit main AST
64        self.ast.flat_map(self.working_set, &f, &mut violations);
65
66        violations
67    }
68
69    /// Collect all rule violations using a closure over expressions
70    /// (Traverse-based)
71    ///
72    /// This method uses Nushell's upstream `Traverse` trait to walk the AST
73    /// and collect rule violations. The collector function is called for each
74    /// expression in the AST and should return a vector of rule violations.
75    pub fn collect_rule_violations<F>(&self, collector: F) -> Vec<RuleViolation>
76    where
77        F: Fn(&Expression, &Self) -> Vec<RuleViolation>,
78    {
79        let mut violations = Vec::new();
80
81        let f = |expr: &Expression| collector(expr, self);
82
83        // Visit main AST
84        self.ast.flat_map(self.working_set, &f, &mut violations);
85
86        violations
87    }
88
89    /// Find first match using `find_map` (Traverse-based)
90    ///
91    /// This method uses Nushell's upstream `Traverse` trait to search the AST
92    /// for the first matching expression. The finder function should return
93    /// `FindMapResult::Found(value)` to return a value, `FindMapResult::Stop`
94    /// to stop searching, or `FindMapResult::Continue` to continue searching.
95    pub fn find_match<T, F>(&self, finder: F) -> Option<T>
96    where
97        F: Fn(&Expression) -> FindMapResult<T>,
98    {
99        self.ast.find_map(self.working_set, &finder)
100    }
101
102    /// Iterator over newly added user-defined function declarations
103    /// Filters out built-in functions (those with spaces or starting with '_')
104    pub fn new_user_functions(&self) -> impl Iterator<Item = (usize, &dyn Command)> + '_ {
105        let (base_count, total_count) = self.new_decl_range();
106        (base_count..total_count)
107            .map(|decl_id| (decl_id, self.working_set.get_decl(DeclId::new(decl_id))))
108            .filter(|(_, decl)| {
109                let name = &decl.signature().name;
110                !name.contains(' ') && !name.starts_with('_')
111            })
112    }
113
114    /// Find the span of a function/declaration name in the source code
115    /// Returns a span pointing to the first occurrence of the name, or a
116    /// fallback span
117    #[must_use]
118    pub fn find_declaration_span(&self, name: &str) -> Span {
119        // Use more efficient string search for function declarations
120        // Look for function declarations starting with "def " or "export def "
121
122        // Try "def <name>" first (most common case)
123        if let Some(pos) = self.source.find(&format!("def {name}")) {
124            let name_start = pos + 4; // "def ".len() == 4
125            return Span::new(name_start, name_start + name.len());
126        }
127
128        // Try "export def <name>"
129        if let Some(pos) = self.source.find(&format!("export def {name}")) {
130            let name_start = pos + 11; // "export def ".len() == 11
131            return Span::new(name_start, name_start + name.len());
132        }
133
134        // Fallback to simple name search
135        self.source.find(name).map_or_else(
136            || self.ast.span.unwrap_or_else(Span::unknown),
137            |name_pos| Span::new(name_pos, name_pos + name.len()),
138        )
139    }
140
141    /// Get the range of declaration IDs that were added during parsing (the
142    /// delta) Returns (`base_count`, `total_count`) for iterating:
143    /// `base_count..total_count`
144    #[must_use]
145    pub fn new_decl_range(&self) -> (usize, usize) {
146        let base_count = self.engine_state.num_decls();
147        let total_count = self.working_set.num_decls();
148        (base_count, total_count)
149    }
150}
151
152#[cfg(test)]
153impl LintContext<'_> {
154    /// Helper to create a test context with stdlib commands loaded
155    #[track_caller]
156    pub fn test_with_parsed_source<F, R>(source: &str, f: F) -> R
157    where
158        F: for<'b> FnOnce(LintContext<'b>) -> R,
159    {
160        use nu_parser::parse;
161        use nu_protocol::engine::StateWorkingSet;
162
163        fn create_engine_with_stdlib() -> nu_protocol::engine::EngineState {
164            let engine_state = nu_cmd_lang::create_default_context();
165            nu_command::add_shell_command_context(engine_state)
166        }
167
168        let engine_state = create_engine_with_stdlib();
169        let mut working_set = StateWorkingSet::new(&engine_state);
170        let block = parse(&mut working_set, None, source.as_bytes(), false);
171
172        let context = LintContext {
173            source,
174            file_path: None,
175            ast: &block,
176            engine_state: &engine_state,
177            working_set: &working_set,
178        };
179
180        f(context)
181    }
182}