1use std::path::Path;
2
3use nu_protocol::{
4 DeclId, Span,
5 ast::Block,
6 engine::{Command, EngineState, StateWorkingSet},
7};
8
9use crate::{
10 lint::{Fix, Replacement, Severity, Violation},
11 visitor::{AstVisitor, VisitContext},
12};
13
14pub struct LintContext<'a> {
15 pub source: &'a str,
16 pub ast: &'a Block,
17 pub engine_state: &'a EngineState,
18 pub working_set: &'a StateWorkingSet<'a>,
19 pub file_path: Option<&'a Path>,
20}
21
22impl LintContext<'_> {
23 #[must_use]
27 pub fn new_decl_range(&self) -> (usize, usize) {
28 let base_count = self.engine_state.num_decls();
29 let total_count = self.working_set.num_decls();
30 (base_count, total_count)
31 }
32
33 pub fn new_user_functions(&self) -> impl Iterator<Item = (usize, &dyn Command)> + '_ {
36 let (base_count, total_count) = self.new_decl_range();
37 (base_count..total_count)
38 .map(|decl_id| (decl_id, self.working_set.get_decl(DeclId::new(decl_id))))
39 .filter(|(_, decl)| {
40 let name = &decl.signature().name;
41 !name.contains(' ') && !name.starts_with('_')
42 })
43 }
44
45 pub fn find_declaration_span(&self, name: &str) -> Span {
49 if let Some(name_pos) = self.source.find(name) {
50 Span::new(name_pos, name_pos + name.len())
51 } else {
52 self.ast.span.unwrap_or_else(Span::unknown)
53 }
54 }
55
56 pub fn violations_from_regex_if<F>(
70 &self,
71 pattern: ®ex::Regex,
72 rule_id: &str,
73 severity: Severity,
74 predicate: F,
75 ) -> Vec<Violation>
76 where
77 F: Fn(regex::Match) -> Option<(String, Option<String>)>,
78 {
79 pattern
80 .find_iter(self.source)
81 .filter_map(|mat| {
82 predicate(mat).map(|(message, suggestion)| Violation {
83 rule_id: rule_id.to_string(),
84 severity,
85 message,
86 span: Span::new(mat.start(), mat.end()),
87 suggestion,
88 fix: None,
89 file: None,
90 })
91 })
92 .collect()
93 }
94
95 pub fn walk_ast<V: AstVisitor>(&self, visitor: &mut V) {
101 let visit_context = VisitContext::new(self.working_set, self.source);
102
103 visitor.visit_block(self.ast, &visit_context);
105
106 for (_decl_id, decl) in self.new_user_functions() {
108 if let Some(block_id) = decl.block_id() {
109 let block = self.working_set.get_block(block_id);
110 visitor.visit_block(block, &visit_context);
111 }
112 }
113 }
114
115 #[must_use]
117 pub fn get_span_contents(&self, span: Span) -> &str {
118 let start = span.start.min(self.source.len());
119 let end = span.end.min(self.source.len());
120 &self.source[start..end]
121 }
122
123 pub fn create_simple_fix(
127 &self,
128 description: impl Into<String>,
129 span: Span,
130 new_text: impl Into<String>,
131 ) -> Fix {
132 Fix {
133 description: description.into(),
134 replacements: vec![Replacement {
135 span,
136 new_text: new_text.into(),
137 }],
138 }
139 }
140}
141
142#[cfg(test)]
143impl LintContext<'_> {
144 pub fn test_with_parsed_source<F, R>(source: &str, f: F) -> R
146 where
147 F: for<'b> FnOnce(LintContext<'b>) -> R,
148 {
149 use nu_parser::parse;
150 use nu_protocol::engine::StateWorkingSet;
151
152 fn create_engine_with_stdlib() -> nu_protocol::engine::EngineState {
153 let engine_state = nu_cmd_lang::create_default_context();
154 nu_command::add_shell_command_context(engine_state)
155 }
156
157 let engine_state = create_engine_with_stdlib();
158 let mut working_set = StateWorkingSet::new(&engine_state);
159 let block = parse(&mut working_set, None, source.as_bytes(), false);
160
161 let context = LintContext {
162 source,
163 ast: &block,
164 engine_state: &engine_state,
165 working_set: &working_set,
166 file_path: None,
167 };
168
169 f(context)
170 }
171}