1use std::collections::BTreeMap;
12
13use harn_hostlib::ast::Language;
14
15use crate::constraint::CompiledConstraint;
16use crate::error::RulesError;
17use crate::evaluator::CompiledRuleTree;
18use crate::fix::{interpolate, splice, AppliedEdit};
19use crate::model::{Applicability, Rule, Safety, Severity};
20use crate::transform::CompiledTransform;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct Span {
26 pub start_byte: usize,
28 pub end_byte: usize,
30 pub start_row: usize,
32 pub start_col: usize,
34 pub end_row: usize,
36 pub end_col: usize,
38}
39
40impl Span {
41 pub(crate) fn of(node: tree_sitter::Node<'_>) -> Self {
42 let start = node.start_position();
43 let end = node.end_position();
44 Span {
45 start_byte: node.start_byte(),
46 end_byte: node.end_byte(),
47 start_row: start.row,
48 start_col: start.column,
49 end_row: end.row,
50 end_col: end.column,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct Binding {
58 pub text: String,
60 pub span: Span,
62}
63
64#[derive(Debug, Clone)]
66pub struct RuleMatch {
67 pub rule_id: String,
69 pub span: Span,
71 pub text: String,
73 pub bindings: BTreeMap<String, Binding>,
76}
77
78#[derive(Debug, Clone)]
80pub struct CodemodResult {
81 pub rewritten: String,
83 pub edits: Vec<AppliedEdit>,
85 pub changed: bool,
87 pub safety: Safety,
89 pub applicability: Applicability,
91 pub idempotent: bool,
94}
95
96pub struct CompiledRule {
98 rule_id: String,
99 language: Language,
100 execution: Execution,
101 constraints: Vec<CompiledConstraint>,
103 transforms: Vec<(String, CompiledTransform)>,
105 fix: Option<String>,
107 safety: Safety,
109 message: String,
111 severity: Severity,
113}
114
115#[derive(Debug, Clone)]
118pub struct Diagnostic {
119 pub rule_id: String,
121 pub message: String,
123 pub severity: Severity,
125 pub span: Span,
127 pub applicability: Applicability,
129 pub fix: Option<String>,
132}
133
134enum Execution {
135 SourceRegex(regex::Regex),
138 Tree(Box<CompiledRuleTree>),
140}
141
142impl CompiledRule {
143 pub fn compile(rule: &Rule) -> Result<Self, RulesError> {
145 let language =
146 Language::from_name(&rule.language).ok_or_else(|| RulesError::UnknownLanguage {
147 rule: rule.id.clone(),
148 language: rule.language.clone(),
149 })?;
150
151 let execution = if rule.rule.is_pure_regex() {
155 let pattern = rule.rule.regex.as_ref().expect("pure regex");
156 Execution::SourceRegex(regex::Regex::new(pattern).map_err(|err| {
157 RulesError::PatternCompile {
158 rule: rule.id.clone(),
159 message: format!("invalid regex `{pattern}`: {err}"),
160 }
161 })?)
162 } else {
163 Execution::Tree(Box::new(CompiledRuleTree::compile(
164 &rule.id,
165 language,
166 &rule.rule,
167 &rule.utils,
168 )?))
169 };
170
171 let constraints = rule
172 .where_constraints
173 .iter()
174 .map(|c| CompiledConstraint::compile(&rule.id, language, c))
175 .collect::<Result<Vec<_>, _>>()?;
176
177 let transforms = rule
178 .transform
179 .iter()
180 .map(|(name, t)| {
181 CompiledTransform::compile(&rule.id, name, t).map(|c| (name.clone(), c))
182 })
183 .collect::<Result<Vec<_>, _>>()?;
184
185 Ok(CompiledRule {
186 rule_id: rule.id.clone(),
187 language,
188 execution,
189 constraints,
190 transforms,
191 fix: rule.fix.clone(),
192 safety: rule.safety,
193 message: rule.message.clone(),
194 severity: rule.severity,
195 })
196 }
197
198 pub fn language(&self) -> Language {
200 self.language
201 }
202
203 pub fn safety(&self) -> Safety {
205 self.safety
206 }
207
208 pub fn applicability(&self) -> Applicability {
211 self.safety.applicability()
212 }
213
214 pub fn id(&self) -> &str {
216 &self.rule_id
217 }
218
219 pub fn run(&self, source: &str) -> Result<Vec<RuleMatch>, RulesError> {
222 let mut matches = match &self.execution {
223 Execution::SourceRegex(regex) => self.run_regex(regex, source),
224 Execution::Tree(tree) => tree
225 .find(&self.rule_id, self.language, source)?
226 .into_iter()
227 .map(|m| RuleMatch {
228 rule_id: self.rule_id.clone(),
229 span: m.span,
230 text: m.text,
231 bindings: m.bindings,
232 })
233 .collect(),
234 };
235 if !self.constraints.is_empty() {
236 matches.retain(|m| self.satisfies_constraints(m));
237 }
238 Ok(matches)
239 }
240
241 fn satisfies_constraints(&self, m: &RuleMatch) -> bool {
244 self.constraints.iter().all(|c| {
245 m.bindings
246 .get(&c.metavar)
247 .is_some_and(|b| c.evaluate(&b.text))
248 })
249 }
250
251 pub fn apply(&self, source: &str) -> Result<CodemodResult, RulesError> {
260 let (rewritten, edits) = self.rewrite(source)?;
261 let changed = rewritten != source;
262 let (twice, _) = self.rewrite(&rewritten)?;
265 let idempotent = twice == rewritten;
266 Ok(CodemodResult {
267 rewritten,
268 edits,
269 changed,
270 safety: self.safety,
271 applicability: self.applicability(),
272 idempotent,
273 })
274 }
275
276 pub fn auto_apply(&self, source: &str) -> Result<CodemodResult, RulesError> {
280 if !self.safety.is_auto_applicable() {
281 return Err(RulesError::NotAutoApplicable {
282 rule: self.rule_id.clone(),
283 safety: format!("{:?}", self.safety),
284 });
285 }
286 self.apply(source)
287 }
288
289 pub fn apply_checked(&self, source: &str) -> Result<CodemodResult, RulesError> {
293 let result = self.apply(source)?;
294 if !result.idempotent {
295 return Err(RulesError::NotIdempotent {
296 rule: self.rule_id.clone(),
297 });
298 }
299 Ok(result)
300 }
301
302 pub fn diagnostics(&self, source: &str) -> Result<Vec<Diagnostic>, RulesError> {
307 let applicability = self.applicability();
308 let matches = self.run(source)?;
309 Ok(matches
310 .iter()
311 .map(|m| Diagnostic {
312 rule_id: self.rule_id.clone(),
313 message: self.message.clone(),
314 severity: self.severity,
315 span: m.span,
316 applicability,
317 fix: self.fix.as_ref().map(|template| {
318 let vars = self.metavars_for(m);
319 interpolate(template, &vars)
320 }),
321 })
322 .collect())
323 }
324
325 fn rewrite(&self, source: &str) -> Result<(String, Vec<AppliedEdit>), RulesError> {
328 let template = self
329 .fix
330 .as_ref()
331 .ok_or_else(|| RulesError::PatternCompile {
332 rule: self.rule_id.clone(),
333 message: "apply requires a `fix` template; this rule has none".into(),
334 })?;
335
336 let matches = self.run(source)?;
337 let edits: Vec<AppliedEdit> = matches
338 .iter()
339 .map(|m| {
340 let vars = self.metavars_for(m);
341 AppliedEdit {
342 span: m.span,
343 before: m.text.clone(),
344 replacement: interpolate(template, &vars),
345 }
346 })
347 .collect();
348 Ok((splice(source, &edits), edits))
349 }
350
351 fn metavars_for(&self, m: &RuleMatch) -> BTreeMap<String, String> {
354 let mut vars: BTreeMap<String, String> = m
355 .bindings
356 .iter()
357 .map(|(name, binding)| (name.clone(), binding.text.clone()))
358 .collect();
359 for (name, transform) in &self.transforms {
360 let input = m
361 .bindings
362 .get(&transform.source)
363 .map(|b| b.text.as_str())
364 .unwrap_or("");
365 vars.insert(name.clone(), transform.apply(input));
366 }
367 vars
368 }
369
370 fn run_regex(&self, regex: ®ex::Regex, source: &str) -> Vec<RuleMatch> {
371 let mut matches = Vec::new();
372 for m in regex.find_iter(source) {
373 let span = byte_span(source, m.start(), m.end());
374 matches.push(RuleMatch {
375 rule_id: self.rule_id.clone(),
376 span,
377 text: m.as_str().to_string(),
378 bindings: BTreeMap::new(),
379 });
380 }
381 matches
382 }
383}
384
385fn byte_span(source: &str, start: usize, end: usize) -> Span {
388 let (start_row, start_col) = row_col(source, start);
389 let (end_row, end_col) = row_col(source, end);
390 Span {
391 start_byte: start,
392 end_byte: end,
393 start_row,
394 start_col,
395 end_row,
396 end_col,
397 }
398}
399
400fn row_col(source: &str, byte: usize) -> (usize, usize) {
401 let mut row = 0;
402 let mut col = 0;
403 for (i, ch) in source.char_indices() {
404 if i >= byte {
405 break;
406 }
407 if ch == '\n' {
408 row += 1;
409 col = 0;
410 } else {
411 col += 1;
412 }
413 }
414 (row, col)
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::model::Rule;
421
422 fn rule(toml: &str) -> CompiledRule {
423 let parsed = Rule::from_toml_str(toml).expect("rule parses");
424 CompiledRule::compile(&parsed).expect("rule compiles")
425 }
426
427 #[test]
428 fn pattern_rule_binds_metavars() {
429 let compiled = rule(
430 r#"
431 id = "destructure-default"
432 language = "typescript"
433 fix = "{ $KEY: $SRC }"
434 [rule]
435 pattern = "$SRC?.$KEY ?? $DEFAULT"
436 "#,
437 );
438 let matches = compiled
439 .run("const a = cfg?.timeout ?? 30;\nconst b = opts?.retries ?? 3;\n")
440 .unwrap();
441 assert_eq!(matches.len(), 2);
442 assert_eq!(matches[0].bindings["SRC"].text, "cfg");
443 assert_eq!(matches[0].bindings["KEY"].text, "timeout");
444 assert_eq!(matches[0].bindings["DEFAULT"].text, "30");
445 assert_eq!(matches[1].bindings["SRC"].text, "opts");
446 assert_eq!(matches[0].text, "cfg?.timeout ?? 30");
448 assert_eq!(matches[0].span.start_row, 0);
449 assert_eq!(matches[1].span.start_row, 1);
450 }
451
452 #[test]
453 fn kind_rule_matches_node_kind() {
454 let compiled = rule(
455 r#"
456 id = "find-calls"
457 language = "python"
458 [rule]
459 kind = "call"
460 "#,
461 );
462 let matches = compiled.run("print(x)\nlog(y)\n").unwrap();
463 assert_eq!(matches.len(), 2);
464 assert_eq!(matches[0].text, "print(x)");
465 assert!(matches[0].bindings.is_empty());
466 }
467
468 #[test]
469 fn regex_rule_matches_text() {
470 let compiled = rule(
471 r#"
472 id = "todo"
473 language = "rust"
474 message = "Found a TODO"
475 [rule]
476 regex = "TODO\\(\\w+\\)"
477 "#,
478 );
479 let matches = compiled
480 .run("fn f() {\n // TODO(ken) fix\n // todo lower\n}\n")
481 .unwrap();
482 assert_eq!(matches.len(), 1);
483 assert_eq!(matches[0].text, "TODO(ken)");
484 assert_eq!(matches[0].span.start_row, 1);
485 }
486
487 #[test]
488 fn unknown_language_is_an_error() {
489 let parsed = Rule::from_toml_str(
490 r#"
491 id = "x"
492 language = "cobol"
493 [rule]
494 kind = "foo"
495 "#,
496 )
497 .unwrap();
498 assert!(matches!(
499 CompiledRule::compile(&parsed),
500 Err(RulesError::UnknownLanguage { .. })
501 ));
502 }
503
504 #[test]
505 fn invalid_pattern_surfaces_compile_error() {
506 let parsed = Rule::from_toml_str(
507 r#"
508 id = "x"
509 language = "typescript"
510 [rule]
511 pattern = "foo($$$ARGS)"
512 "#,
513 )
514 .unwrap();
515 assert!(matches!(
516 CompiledRule::compile(&parsed),
517 Err(RulesError::PatternCompile { .. })
518 ));
519 }
520}