Skip to main content

nu_lint/
context.rs

1use std::{collections::BTreeSet, ops::ControlFlow, str::from_utf8, vec::Vec};
2
3use nu_protocol::{
4    Span,
5    ast::{Block, Expr, Expression, Traverse},
6    engine::{EngineState, StateWorkingSet},
7};
8
9#[cfg(test)]
10use crate::violation;
11use crate::{
12    Config,
13    ast::{call::CallExt, declaration::CustomCommandDef, string::StringFormat},
14    span::FileSpan,
15    violation::Detection,
16};
17
18/// Fix data for external command alternatives
19pub struct ExternalCmdFixData<'a> {
20    /// Argument expressions from the external call
21    pub args: Box<[&'a Expression]>,
22    pub expr_span: Span,
23}
24
25impl ExternalCmdFixData<'_> {
26    /// Get argument text content for each argument.
27    ///
28    /// For string literals, returns the unquoted content.
29    /// For other expressions (variables, subexpressions), returns the source
30    /// text.
31    ///
32    /// This is the primary API for parsing command arguments.
33    pub fn arg_texts<'b>(&'b self, context: &'b LintContext<'b>) -> impl Iterator<Item = &'b str> {
34        self.args.iter().map(move |expr| match &expr.expr {
35            Expr::String(s) | Expr::RawString(s) => s.as_str(),
36            _ => context.expr_text(expr),
37        })
38    }
39
40    /// Get string format information for arguments that need quote
41    /// preservation.
42    ///
43    /// Returns `Some(StringFormat)` for string literals (with quote type info).
44    /// Returns `None` for non-string expressions (variables, subexpressions,
45    /// etc.).
46    ///
47    /// Use this when generating replacement text that must preserve quote
48    /// styles.
49    pub fn arg_formats(&self, context: &LintContext) -> Vec<Option<StringFormat>> {
50        self.args
51            .iter()
52            .map(|expr| StringFormat::from_expression(expr, context))
53            .collect()
54    }
55
56    /// Check if an argument is a string literal (safe to extract unquoted
57    /// content).
58    pub fn arg_is_string(&self, index: usize) -> bool {
59        self.args.get(index).is_some_and(|expr| {
60            matches!(
61                &expr.expr,
62                Expr::String(_) | Expr::RawString(_) | Expr::StringInterpolation(_)
63            )
64        })
65    }
66}
67
68/// Context containing all lint information (source, AST, and engine state)
69pub struct LintContext<'a> {
70    /// Raw source string of the file being linted (file-relative coordinates)
71    source: &'a str,
72    pub ast: &'a Block,
73    pub engine_state: &'a EngineState,
74    pub working_set: &'a StateWorkingSet<'a>,
75    /// Byte offset where this file starts in the global span space
76    file_offset: usize,
77    pub config: &'a Config,
78}
79
80impl<'a> LintContext<'a> {
81    /// Create a new `LintContext`
82    pub(crate) const fn new(
83        source: &'a str,
84        ast: &'a Block,
85        engine_state: &'a EngineState,
86        working_set: &'a StateWorkingSet<'a>,
87        file_offset: usize,
88        config: &'a Config,
89    ) -> Self {
90        Self {
91            source,
92            ast,
93            engine_state,
94            working_set,
95            file_offset,
96            config,
97        }
98    }
99
100    /// Create a new `LintContext` using the default configuration.
101    #[cfg(test)]
102    pub(crate) fn with_default_config(
103        source: &'a str,
104        ast: &'a Block,
105        engine_state: &'a EngineState,
106        working_set: &'a StateWorkingSet<'a>,
107        file_offset: usize,
108    ) -> Self {
109        Self::new(
110            source,
111            ast,
112            engine_state,
113            working_set,
114            file_offset,
115            Config::default_static(),
116        )
117    }
118
119    #[must_use]
120    pub const unsafe fn source(&self) -> &str {
121        self.source
122    }
123
124    /// Check if a global span is within the user's file bounds
125    #[must_use]
126    pub const fn span_in_user_file(&self, span: Span) -> bool {
127        let file_end = self.file_offset + self.source.len();
128        span.start >= self.file_offset && span.end <= file_end
129    }
130
131    /// Get the source length of the user's file
132    #[must_use]
133    pub const fn source_len(&self) -> usize {
134        self.source.len()
135    }
136
137    /// Get text for an AST span
138    #[must_use]
139    pub fn span_text(&self, span: Span) -> &str {
140        from_utf8(self.working_set.get_span_contents(span))
141            .expect("span contents should be valid UTF-8")
142    }
143
144    #[must_use]
145    pub fn expr_text(&self, expr: &Expression) -> &str {
146        self.span_text(expr.span)
147    }
148
149    /// Get source text before an AST span
150    #[must_use]
151    pub fn source_before_span(&self, span: Span) -> &str {
152        let file_pos = span.start.saturating_sub(self.file_offset);
153        self.source
154            .get(..file_pos)
155            .expect("file position should be within source bounds")
156    }
157
158    /// Get source text after an AST span
159    #[must_use]
160    pub fn source_after_span(&self, span: Span) -> &str {
161        let file_pos = span.end.saturating_sub(self.file_offset);
162        self.source
163            .get(file_pos..)
164            .expect("file position should be within source bounds")
165    }
166
167    /// Get source text between two span endpoints (from end of first to start
168    /// of second) Returns empty string if the range is invalid
169    #[must_use]
170    pub fn source_between_span_ends(&self, end_span: Span, start_span: Span) -> &str {
171        let file_start = end_span.end.saturating_sub(self.file_offset);
172        let file_end = start_span.start.saturating_sub(self.file_offset);
173
174        if file_start >= file_end || file_end > self.source.len() {
175            return "";
176        }
177
178        &self.source[file_start..file_end]
179    }
180
181    /// Count newlines up to a file-relative offset
182    #[must_use]
183    pub fn count_newlines_before(&self, offset: usize) -> usize {
184        let safe_offset = offset.min(self.source.len());
185        self.source[..safe_offset]
186            .bytes()
187            .filter(|&b| b == b'\n')
188            .count()
189    }
190
191    /// Convert an AST span to file-relative positions for `Replacement` spans
192    #[must_use]
193    pub const fn normalize_span(&self, span: Span) -> FileSpan {
194        FileSpan::new(
195            span.start.saturating_sub(self.file_offset),
196            span.end.saturating_sub(self.file_offset),
197        )
198    }
199
200    #[must_use]
201    pub fn source_contains(&self, pattern: &str) -> bool {
202        self.source.contains(pattern)
203    }
204
205    /// Get the format name for a file extension based on available `from`
206    /// commands.
207    ///
208    /// This dynamically queries the engine state for `from <format>` commands
209    /// and maps file extensions to their corresponding format names.
210    ///
211    /// Returns `None` if the extension doesn't have a corresponding `from`
212    /// command.
213    #[must_use]
214    pub fn format_for_extension(&self, filename: &str) -> Option<String> {
215        let lower = filename.to_lowercase();
216
217        // Extract extension from filename
218        let ext = lower.rsplit('.').next()?;
219
220        // Handle .yml -> yaml alias
221        let format = if ext == "yml" { "yaml" } else { ext };
222
223        // Check if `from <format>` command exists
224        let from_cmd_name = format!("from {format}");
225        self.working_set
226            .find_decl(from_cmd_name.as_bytes())
227            .is_some()
228            .then(|| format.to_string())
229    }
230
231    /// Byte offset where this file starts in the global span space
232    #[must_use]
233    pub const fn file_offset(&self) -> usize {
234        self.file_offset
235    }
236
237    /// Collect spans of all calls to the specified commands
238    #[must_use]
239    pub fn collect_command_spans(&self, commands: &[&str]) -> Vec<Span> {
240        let mut spans = Vec::new();
241        self.ast.flat_map(
242            self.working_set,
243            &|expr| {
244                if let Expr::Call(call) = &expr.expr {
245                    let cmd_name = call.get_call_name(self);
246                    if commands.iter().any(|&cmd| cmd == cmd_name) {
247                        return vec![expr.span];
248                    }
249                }
250                vec![]
251            },
252            &mut spans,
253        );
254        spans
255    }
256
257    /// Expand a span to include the full line(s) it occupies
258    /// Takes a global AST span and returns a global span
259    #[must_use]
260    pub fn expand_span_to_full_lines(&self, span: Span) -> Span {
261        let bytes = self.source.as_bytes();
262
263        let file_start = span.start.saturating_sub(self.file_offset);
264        let file_end = span.end.saturating_sub(self.file_offset);
265
266        let start = bytes[..file_start]
267            .iter()
268            .rposition(|&b| b == b'\n')
269            .map_or(0, |pos| pos + 1);
270
271        let end = bytes[file_end..]
272            .iter()
273            .position(|&b| b == b'\n')
274            .map_or(self.source.len(), |pos| file_end + pos + 1);
275
276        Span::new(start + self.file_offset, end + self.file_offset)
277    }
278
279    /// Expand a statement span to include its separator (semicolon or newline).
280    /// Uses AST pipeline boundaries - no string parsing needed.
281    ///
282    /// - If there's a next pipeline: span extends to next pipeline's start
283    /// - If last pipeline but has previous: span starts from previous
284    ///   pipeline's end
285    /// - If only pipeline: expand to full line
286    #[must_use]
287    pub fn expand_span_to_statement(&self, span: Span) -> Span {
288        let pipelines = &self.ast.pipelines;
289
290        // Find which pipeline contains this span
291        let Some(idx) = pipelines.iter().position(|p| {
292            p.elements
293                .first()
294                .is_some_and(|e| e.expr.span.start <= span.start)
295                && p.elements
296                    .last()
297                    .is_some_and(|e| e.expr.span.end >= span.end)
298        }) else {
299            return self.expand_span_to_full_lines(span);
300        };
301
302        // If there's a next pipeline, remove from span.start to next pipeline's
303        // start
304        if let Some(next) = pipelines.get(idx + 1)
305            && let Some(first_elem) = next.elements.first()
306        {
307            return Span::new(span.start, first_elem.expr.span.start);
308        }
309
310        // If there's a previous pipeline, remove from previous pipeline's end to
311        // span.end
312        if idx > 0
313            && let Some(prev) = pipelines.get(idx - 1)
314            && let Some(last_elem) = prev.elements.last()
315        {
316            return Span::new(last_elem.expr.span.end, span.end);
317        }
318
319        // Only pipeline - expand to full line
320        self.expand_span_to_full_lines(span)
321    }
322
323    /// Collect detected violations with associated fix data using a closure
324    /// over expressions
325    pub(crate) fn detect_with_fix_data<F, D>(&self, collector: F) -> Vec<(Detection, D)>
326    where
327        F: Fn(&Expression, &Self) -> Vec<(Detection, D)>,
328        D: 'a,
329    {
330        let mut results = Vec::new();
331        let f = |expr: &Expression| collector(expr, self);
332        self.ast.flat_map(self.working_set, &f, &mut results);
333        results
334    }
335
336    /// Collect detected violations without fix data (convenience for rules with
337    /// `FixData = ()`)
338    pub(crate) fn detect<F>(&self, fix_data_collector: F) -> Vec<Detection>
339    where
340        F: Fn(&Expression, &Self) -> Vec<Detection>,
341    {
342        let mut violations = Vec::new();
343        let f = |expr: &Expression| fix_data_collector(expr, self);
344        self.ast.flat_map(self.working_set, &f, &mut violations);
345        violations
346    }
347
348    pub(crate) fn detect_single<F>(&self, detector: F) -> Vec<Detection>
349    where
350        F: Fn(&Expression, &Self) -> Option<Detection>,
351    {
352        let mut violations = Vec::new();
353        let f = |expr: &Expression| {
354            detector(expr, self).map_or_else(Vec::new, |detection| vec![detection])
355        };
356        self.ast.flat_map(self.working_set, &f, &mut violations);
357        violations
358    }
359
360    /// Traverse the AST with parent context, calling the callback for each
361    /// expression with its parent expression (if any).
362    ///
363    /// This builds on top of the `Traverse` trait but adds parent tracking,
364    /// which is useful for rules that need to know the context of an
365    /// expression (e.g., whether a string is in command position).
366    ///
367    /// The callback returns `ControlFlow::Continue(())` to recurse into
368    /// children, or `ControlFlow::Break(())` to skip this expression's
369    /// children.
370    pub(crate) fn traverse_with_parent<F>(&self, mut callback: F)
371    where
372        F: FnMut(&Expression, Option<&Expression>) -> ControlFlow<()>,
373    {
374        use crate::ast::block::BlockExt;
375
376        self.ast.traverse_with_parent(self, None, &mut callback);
377    }
378
379    /// Range of declaration IDs added during parsing: `base..total`
380    #[must_use]
381    pub fn new_decl_range(&self) -> (usize, usize) {
382        let base_count = self.engine_state.num_decls();
383        let total_count = self.working_set.num_decls();
384        (base_count, total_count)
385    }
386
387    /// Collect all function definitions
388    #[must_use]
389    pub fn custom_commands(&self) -> BTreeSet<CustomCommandDef> {
390        let mut functions = Vec::new();
391        self.ast.flat_map(
392            self.working_set,
393            &|expr| {
394                let Expr::Call(call) = &expr.expr else {
395                    return vec![];
396                };
397                call.custom_command_def(self).into_iter().collect()
398            },
399            &mut functions,
400        );
401        functions.into_iter().collect()
402    }
403
404    /// Detect external command invocations with custom validation.
405    /// This allows rules to check if the arguments can be reliably translated
406    /// before reporting a violation.
407    ///
408    /// The validator function receives the command name, fix data, and context,
409    /// and should return `Some(note)` if the invocation should be reported,
410    /// or `None` if it should be ignored.
411    #[must_use]
412    pub fn detect_external_with_validation<'context, F>(
413        &'context self,
414        external_cmd: &'static str,
415        validator: F,
416    ) -> Vec<(Detection, ExternalCmdFixData<'context>)>
417    where
418        F: Fn(&str, &ExternalCmdFixData<'context>, &'context Self) -> Option<&'static str>,
419    {
420        use nu_protocol::ast::{Expr, ExternalArgument, Traverse};
421
422        let mut results = Vec::new();
423
424        self.ast.flat_map(
425            self.working_set,
426            &|expr| {
427                let Expr::ExternalCall(head, args) = &expr.expr else {
428                    return vec![];
429                };
430
431                let cmd_text = self.span_text(head.span);
432                if cmd_text != external_cmd {
433                    return vec![];
434                }
435
436                let arg_exprs: Vec<&Expression> = args
437                    .iter()
438                    .map(|arg| match arg {
439                        ExternalArgument::Regular(expr) | ExternalArgument::Spread(expr) => expr,
440                    })
441                    .collect();
442
443                let fix_data = ExternalCmdFixData {
444                    args: arg_exprs.into_boxed_slice(),
445                    expr_span: expr.span,
446                };
447
448                // Validate if this invocation should be reported
449                let Some(note) = validator(cmd_text, &fix_data, self) else {
450                    return vec![];
451                };
452
453                let detected = Detection::from_global_span(note, expr.span)
454                    .with_primary_label(format!("external '{cmd_text}'"));
455
456                vec![(detected, fix_data)]
457            },
458            &mut results,
459        );
460
461        results
462    }
463}
464
465#[cfg(test)]
466impl LintContext<'_> {
467    /// Helper to create a test context with stdlib commands loaded.
468    ///
469    /// Always uses the default configuration so that user-specific overrides
470    /// (e.g. `~/.nu-lint.toml`) do not interfere with test results.
471    #[track_caller]
472    pub fn test_with_parsed_source<F, R>(source: &str, f: F) -> R
473    where
474        F: for<'b> FnOnce(LintContext<'b>) -> R,
475    {
476        use crate::engine::{LintEngine, parse_source};
477
478        let engine_state = LintEngine::new_state();
479        let (block, working_set, file_offset) = parse_source(engine_state, source.as_bytes(), None);
480
481        let context = LintContext::with_default_config(
482            source,
483            &block,
484            engine_state,
485            &working_set,
486            file_offset,
487        );
488
489        f(context)
490    }
491
492    /// Helper to get normalized violations from source code (matches production
493    /// behavior)
494    #[track_caller]
495    pub fn test_get_violations<F>(source: &str, f: F) -> Vec<violation::Violation>
496    where
497        F: for<'b> FnOnce(&LintContext<'b>) -> Vec<violation::Violation>,
498    {
499        Self::test_with_parsed_source(source, |context| {
500            let file_offset = context.file_offset();
501            let mut violations = f(&context);
502            for v in &mut violations {
503                v.normalize_spans(file_offset);
504            }
505            violations
506        })
507    }
508}