1use std::collections::BTreeMap;
12
13use harn_hostlib::ast::{api, Language};
14use streaming_iterator::StreamingIterator;
15use tree_sitter::{Query, QueryCursor};
16
17use crate::error::RulesError;
18use crate::model::{AtomicMatcher, Rule};
19use crate::pattern::{compile_pattern, ROOT_CAPTURE};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct Span {
25 pub start_byte: usize,
27 pub end_byte: usize,
29 pub start_row: usize,
31 pub start_col: usize,
33 pub end_row: usize,
35 pub end_col: usize,
37}
38
39impl Span {
40 fn of(node: tree_sitter::Node<'_>) -> Self {
41 let start = node.start_position();
42 let end = node.end_position();
43 Span {
44 start_byte: node.start_byte(),
45 end_byte: node.end_byte(),
46 start_row: start.row,
47 start_col: start.column,
48 end_row: end.row,
49 end_col: end.column,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct Binding {
57 pub text: String,
59 pub span: Span,
61}
62
63#[derive(Debug, Clone)]
65pub struct RuleMatch {
66 pub rule_id: String,
68 pub span: Span,
70 pub text: String,
72 pub bindings: BTreeMap<String, Binding>,
75}
76
77pub struct CompiledRule {
79 rule_id: String,
80 language: Language,
81 matcher: CompiledMatcher,
82}
83
84enum CompiledMatcher {
85 Query { query: Query, metavars: Vec<String> },
88 Regex(regex::Regex),
90}
91
92impl CompiledRule {
93 pub fn compile(rule: &Rule) -> Result<Self, RulesError> {
95 let language =
96 Language::from_name(&rule.language).ok_or_else(|| RulesError::UnknownLanguage {
97 rule: rule.id.clone(),
98 language: rule.language.clone(),
99 })?;
100
101 let matcher = match rule
102 .rule
103 .resolve()
104 .map_err(|message| RulesError::PatternCompile {
105 rule: rule.id.clone(),
106 message,
107 })? {
108 AtomicMatcher::Pattern(snippet) => {
109 let ts_language =
110 language
111 .ts_language()
112 .ok_or_else(|| RulesError::GrammarUnavailable {
113 rule: rule.id.clone(),
114 language: language.name().to_string(),
115 })?;
116 let compiled = compile_pattern(&snippet, language).map_err(|message| {
117 RulesError::PatternCompile {
118 rule: rule.id.clone(),
119 message,
120 }
121 })?;
122 let query = Query::new(&ts_language, &compiled.query).map_err(|err| {
123 RulesError::QueryRejected {
124 rule: rule.id.clone(),
125 message: err.to_string(),
126 query: compiled.query.clone(),
127 }
128 })?;
129 CompiledMatcher::Query {
130 query,
131 metavars: compiled.metavars,
132 }
133 }
134 AtomicMatcher::Kind(kind) => {
135 let ts_language =
136 language
137 .ts_language()
138 .ok_or_else(|| RulesError::GrammarUnavailable {
139 rule: rule.id.clone(),
140 language: language.name().to_string(),
141 })?;
142 let query_text = format!("({kind}) @{ROOT_CAPTURE}");
143 let query = Query::new(&ts_language, &query_text).map_err(|err| {
144 RulesError::QueryRejected {
145 rule: rule.id.clone(),
146 message: err.to_string(),
147 query: query_text.clone(),
148 }
149 })?;
150 CompiledMatcher::Query {
151 query,
152 metavars: Vec::new(),
153 }
154 }
155 AtomicMatcher::Regex(pattern) => {
156 let regex =
157 regex::Regex::new(&pattern).map_err(|err| RulesError::PatternCompile {
158 rule: rule.id.clone(),
159 message: format!("invalid regex `{pattern}`: {err}"),
160 })?;
161 CompiledMatcher::Regex(regex)
162 }
163 };
164
165 Ok(CompiledRule {
166 rule_id: rule.id.clone(),
167 language,
168 matcher,
169 })
170 }
171
172 pub fn language(&self) -> Language {
174 self.language
175 }
176
177 pub fn run(&self, source: &str) -> Result<Vec<RuleMatch>, RulesError> {
180 match &self.matcher {
181 CompiledMatcher::Query { query, metavars } => self.run_query(query, metavars, source),
182 CompiledMatcher::Regex(regex) => Ok(self.run_regex(regex, source)),
183 }
184 }
185
186 fn run_query(
187 &self,
188 query: &Query,
189 metavars: &[String],
190 source: &str,
191 ) -> Result<Vec<RuleMatch>, RulesError> {
192 let tree =
193 api::parse_tree(source, self.language).map_err(|err| RulesError::SourceParse {
194 rule: self.rule_id.clone(),
195 message: err.to_string(),
196 })?;
197 let names: Vec<&str> = query.capture_names().to_vec();
198 let bytes = source.as_bytes();
199
200 let mut cursor = QueryCursor::new();
201 let mut it = cursor.matches(query, tree.root_node(), bytes);
202 let mut matches = Vec::new();
203 while let Some(m) = it.next() {
204 let mut root: Option<Span> = None;
205 let mut root_text = String::new();
206 let mut bindings: BTreeMap<String, Binding> = BTreeMap::new();
207 for cap in m.captures {
208 let name = names[cap.index as usize];
209 let span = Span::of(cap.node);
210 let text = source[cap.node.start_byte()..cap.node.end_byte()].to_string();
211 if name == ROOT_CAPTURE {
212 root = Some(span);
213 root_text = text;
214 } else if metavars.iter().any(|m| m == name) {
215 bindings
218 .entry(name.to_string())
219 .or_insert(Binding { text, span });
220 }
221 }
222 if let Some(span) = root {
223 matches.push(RuleMatch {
224 rule_id: self.rule_id.clone(),
225 span,
226 text: root_text,
227 bindings,
228 });
229 }
230 }
231 matches.sort_by_key(|m| (m.span.start_byte, m.span.end_byte));
234 Ok(matches)
235 }
236
237 fn run_regex(&self, regex: ®ex::Regex, source: &str) -> Vec<RuleMatch> {
238 let mut matches = Vec::new();
239 for m in regex.find_iter(source) {
240 let span = byte_span(source, m.start(), m.end());
241 matches.push(RuleMatch {
242 rule_id: self.rule_id.clone(),
243 span,
244 text: m.as_str().to_string(),
245 bindings: BTreeMap::new(),
246 });
247 }
248 matches
249 }
250}
251
252fn byte_span(source: &str, start: usize, end: usize) -> Span {
255 let (start_row, start_col) = row_col(source, start);
256 let (end_row, end_col) = row_col(source, end);
257 Span {
258 start_byte: start,
259 end_byte: end,
260 start_row,
261 start_col,
262 end_row,
263 end_col,
264 }
265}
266
267fn row_col(source: &str, byte: usize) -> (usize, usize) {
268 let mut row = 0;
269 let mut col = 0;
270 for (i, ch) in source.char_indices() {
271 if i >= byte {
272 break;
273 }
274 if ch == '\n' {
275 row += 1;
276 col = 0;
277 } else {
278 col += 1;
279 }
280 }
281 (row, col)
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::model::Rule;
288
289 fn rule(toml: &str) -> CompiledRule {
290 let parsed = Rule::from_toml_str(toml).expect("rule parses");
291 CompiledRule::compile(&parsed).expect("rule compiles")
292 }
293
294 #[test]
295 fn pattern_rule_binds_metavars() {
296 let compiled = rule(
297 r#"
298 id = "destructure-default"
299 language = "typescript"
300 fix = "{ $KEY: $SRC }"
301 [rule]
302 pattern = "$SRC?.$KEY ?? $DEFAULT"
303 "#,
304 );
305 let matches = compiled
306 .run("const a = cfg?.timeout ?? 30;\nconst b = opts?.retries ?? 3;\n")
307 .unwrap();
308 assert_eq!(matches.len(), 2);
309 assert_eq!(matches[0].bindings["SRC"].text, "cfg");
310 assert_eq!(matches[0].bindings["KEY"].text, "timeout");
311 assert_eq!(matches[0].bindings["DEFAULT"].text, "30");
312 assert_eq!(matches[1].bindings["SRC"].text, "opts");
313 assert_eq!(matches[0].text, "cfg?.timeout ?? 30");
315 assert_eq!(matches[0].span.start_row, 0);
316 assert_eq!(matches[1].span.start_row, 1);
317 }
318
319 #[test]
320 fn kind_rule_matches_node_kind() {
321 let compiled = rule(
322 r#"
323 id = "find-calls"
324 language = "python"
325 [rule]
326 kind = "call"
327 "#,
328 );
329 let matches = compiled.run("print(x)\nlog(y)\n").unwrap();
330 assert_eq!(matches.len(), 2);
331 assert_eq!(matches[0].text, "print(x)");
332 assert!(matches[0].bindings.is_empty());
333 }
334
335 #[test]
336 fn regex_rule_matches_text() {
337 let compiled = rule(
338 r#"
339 id = "todo"
340 language = "rust"
341 message = "Found a TODO"
342 [rule]
343 regex = "TODO\\(\\w+\\)"
344 "#,
345 );
346 let matches = compiled
347 .run("fn f() {\n // TODO(ken) fix\n // todo lower\n}\n")
348 .unwrap();
349 assert_eq!(matches.len(), 1);
350 assert_eq!(matches[0].text, "TODO(ken)");
351 assert_eq!(matches[0].span.start_row, 1);
352 }
353
354 #[test]
355 fn unknown_language_is_an_error() {
356 let parsed = Rule::from_toml_str(
357 r#"
358 id = "x"
359 language = "cobol"
360 [rule]
361 kind = "foo"
362 "#,
363 )
364 .unwrap();
365 assert!(matches!(
366 CompiledRule::compile(&parsed),
367 Err(RulesError::UnknownLanguage { .. })
368 ));
369 }
370
371 #[test]
372 fn invalid_pattern_surfaces_compile_error() {
373 let parsed = Rule::from_toml_str(
374 r#"
375 id = "x"
376 language = "typescript"
377 [rule]
378 pattern = "foo($$$ARGS)"
379 "#,
380 )
381 .unwrap();
382 assert!(matches!(
383 CompiledRule::compile(&parsed),
384 Err(RulesError::PatternCompile { .. })
385 ));
386 }
387}