1use std::collections::BTreeMap;
12
13use harn_hostlib::ast::Language;
14
15use crate::constraint::CompiledConstraint;
16use crate::error::RulesError;
17use crate::evaluator::CompiledRuleTree;
18use crate::fix::{interpolate, splice, AppliedEdit};
19use crate::model::{Applicability, Rule, Safety, Severity};
20use crate::transform::CompiledTransform;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct Span {
26 pub start_byte: usize,
28 pub end_byte: usize,
30 pub start_row: usize,
32 pub start_col: usize,
34 pub end_row: usize,
36 pub end_col: usize,
38}
39
40impl Span {
41 pub(crate) fn of(node: tree_sitter::Node<'_>) -> Self {
42 let start = node.start_position();
43 let end = node.end_position();
44 Span {
45 start_byte: node.start_byte(),
46 end_byte: node.end_byte(),
47 start_row: start.row,
48 start_col: start.column,
49 end_row: end.row,
50 end_col: end.column,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct Binding {
58 pub text: String,
60 pub span: Span,
62}
63
64#[derive(Debug, Clone)]
66pub struct RuleMatch {
67 pub rule_id: String,
69 pub span: Span,
71 pub text: String,
73 pub bindings: BTreeMap<String, Binding>,
76}
77
78#[derive(Debug, Clone)]
80pub struct CodemodResult {
81 pub rewritten: String,
83 pub edits: Vec<AppliedEdit>,
85 pub changed: bool,
87 pub safety: Safety,
89 pub applicability: Applicability,
91 pub idempotent: bool,
94}
95
96pub struct CompiledRule {
98 rule_id: String,
99 language: Language,
100 execution: Execution,
101 constraints: Vec<CompiledConstraint>,
103 transforms: Vec<(String, CompiledTransform)>,
105 fix: Option<String>,
107 safety: Safety,
109 message: String,
111 severity: Severity,
113}
114
115#[derive(Debug, Clone)]
118pub struct Diagnostic {
119 pub rule_id: String,
121 pub message: String,
123 pub severity: Severity,
125 pub span: Span,
127 pub applicability: Applicability,
129 pub fix: Option<String>,
132}
133
134enum Execution {
135 SourceRegex(regex::Regex),
138 Tree(Box<CompiledRuleTree>),
140}
141
142impl CompiledRule {
143 pub fn compile(rule: &Rule) -> Result<Self, RulesError> {
145 let language =
146 Language::from_name(&rule.language).ok_or_else(|| RulesError::UnknownLanguage {
147 rule: rule.id.clone(),
148 language: rule.language.clone(),
149 })?;
150
151 let execution = if rule.rule.is_pure_regex() {
155 let pattern = rule.rule.regex.as_ref().expect("pure regex");
156 Execution::SourceRegex(regex::Regex::new(pattern).map_err(|err| {
157 RulesError::PatternCompile {
158 rule: rule.id.clone(),
159 message: format!("invalid regex `{pattern}`: {err}"),
160 }
161 })?)
162 } else {
163 Execution::Tree(Box::new(CompiledRuleTree::compile(
164 &rule.id,
165 language,
166 &rule.rule,
167 &rule.utils,
168 )?))
169 };
170
171 let constraints = rule
172 .where_constraints
173 .iter()
174 .map(|c| CompiledConstraint::compile(&rule.id, language, c))
175 .collect::<Result<Vec<_>, _>>()?;
176
177 let transforms = rule
178 .transform
179 .iter()
180 .map(|(name, t)| {
181 CompiledTransform::compile(&rule.id, name, t).map(|c| (name.clone(), c))
182 })
183 .collect::<Result<Vec<_>, _>>()?;
184
185 Ok(CompiledRule {
186 rule_id: rule.id.clone(),
187 language,
188 execution,
189 constraints,
190 transforms,
191 fix: rule.fix.clone(),
192 safety: rule.safety,
193 message: rule.message.clone(),
194 severity: rule.severity,
195 })
196 }
197
198 pub fn language(&self) -> Language {
200 self.language
201 }
202
203 pub fn safety(&self) -> Safety {
205 self.safety
206 }
207
208 pub fn applicability(&self) -> Applicability {
211 self.safety.applicability()
212 }
213
214 pub fn id(&self) -> &str {
216 &self.rule_id
217 }
218
219 pub fn severity(&self) -> Severity {
222 self.severity
223 }
224
225 pub fn message(&self) -> &str {
227 &self.message
228 }
229
230 pub fn run(&self, source: &str) -> Result<Vec<RuleMatch>, RulesError> {
233 let mut matches = match &self.execution {
234 Execution::SourceRegex(regex) => self.run_regex(regex, source),
235 Execution::Tree(tree) => tree
236 .find(&self.rule_id, self.language, source)?
237 .into_iter()
238 .map(|m| RuleMatch {
239 rule_id: self.rule_id.clone(),
240 span: m.span,
241 text: m.text,
242 bindings: m.bindings,
243 })
244 .collect(),
245 };
246 if !self.constraints.is_empty() {
247 matches.retain(|m| self.satisfies_constraints(m));
248 }
249 Ok(matches)
250 }
251
252 fn satisfies_constraints(&self, m: &RuleMatch) -> bool {
255 self.constraints.iter().all(|c| {
256 m.bindings
257 .get(&c.metavar)
258 .is_some_and(|b| c.evaluate(&b.text))
259 })
260 }
261
262 pub fn apply(&self, source: &str) -> Result<CodemodResult, RulesError> {
271 let (rewritten, edits) = self.rewrite(source)?;
272 let changed = rewritten != source;
273 let (twice, _) = self.rewrite(&rewritten)?;
276 let idempotent = twice == rewritten;
277 Ok(CodemodResult {
278 rewritten,
279 edits,
280 changed,
281 safety: self.safety,
282 applicability: self.applicability(),
283 idempotent,
284 })
285 }
286
287 pub fn auto_apply(&self, source: &str) -> Result<CodemodResult, RulesError> {
291 if !self.safety.is_auto_applicable() {
292 return Err(RulesError::NotAutoApplicable {
293 rule: self.rule_id.clone(),
294 safety: format!("{:?}", self.safety),
295 });
296 }
297 self.apply(source)
298 }
299
300 pub fn apply_checked(&self, source: &str) -> Result<CodemodResult, RulesError> {
304 let result = self.apply(source)?;
305 if !result.idempotent {
306 return Err(RulesError::NotIdempotent {
307 rule: self.rule_id.clone(),
308 });
309 }
310 Ok(result)
311 }
312
313 pub fn diagnostics(&self, source: &str) -> Result<Vec<Diagnostic>, RulesError> {
318 let applicability = self.applicability();
319 let matches = self.run(source)?;
320 Ok(matches
321 .iter()
322 .map(|m| Diagnostic {
323 rule_id: self.rule_id.clone(),
324 message: self.message.clone(),
325 severity: self.severity,
326 span: m.span,
327 applicability,
328 fix: self.fix.as_ref().map(|template| {
329 let vars = self.metavars_for(m);
330 interpolate(template, &vars)
331 }),
332 })
333 .collect())
334 }
335
336 fn rewrite(&self, source: &str) -> Result<(String, Vec<AppliedEdit>), RulesError> {
339 let template = self
340 .fix
341 .as_ref()
342 .ok_or_else(|| RulesError::PatternCompile {
343 rule: self.rule_id.clone(),
344 message: "apply requires a `fix` template; this rule has none".into(),
345 })?;
346
347 let matches = dedupe_overlapping(self.run(source)?);
348 let edits: Vec<AppliedEdit> = matches
349 .iter()
350 .map(|m| {
351 let vars = self.metavars_for(m);
352 AppliedEdit {
353 span: m.span,
354 before: m.text.clone(),
355 replacement: interpolate(template, &vars),
356 }
357 })
358 .collect();
359 Ok((splice(source, &edits), edits))
360 }
361
362 fn metavars_for(&self, m: &RuleMatch) -> BTreeMap<String, String> {
365 let mut vars: BTreeMap<String, String> = m
366 .bindings
367 .iter()
368 .map(|(name, binding)| (name.clone(), binding.text.clone()))
369 .collect();
370 for (name, transform) in &self.transforms {
371 let input = m
372 .bindings
373 .get(&transform.source)
374 .map(|b| b.text.as_str())
375 .unwrap_or("");
376 vars.insert(name.clone(), transform.apply(input));
377 }
378 vars
379 }
380
381 fn run_regex(&self, regex: ®ex::Regex, source: &str) -> Vec<RuleMatch> {
382 let mut matches = Vec::new();
383 for m in regex.find_iter(source) {
384 let span = byte_span(source, m.start(), m.end());
385 matches.push(RuleMatch {
386 rule_id: self.rule_id.clone(),
387 span,
388 text: m.as_str().to_string(),
389 bindings: BTreeMap::new(),
390 });
391 }
392 matches
393 }
394}
395
396fn dedupe_overlapping(mut matches: Vec<RuleMatch>) -> Vec<RuleMatch> {
404 matches.sort_by(|a, b| {
407 a.span
408 .start_byte
409 .cmp(&b.span.start_byte)
410 .then(b.span.end_byte.cmp(&a.span.end_byte))
411 });
412 let mut kept: Vec<RuleMatch> = Vec::with_capacity(matches.len());
413 let mut covered_to = 0usize; for m in matches {
415 if m.span.start_byte >= covered_to {
419 covered_to = m.span.end_byte.max(covered_to);
420 kept.push(m);
421 }
422 }
423 kept
424}
425
426fn byte_span(source: &str, start: usize, end: usize) -> Span {
429 let (start_row, start_col) = row_col(source, start);
430 let (end_row, end_col) = row_col(source, end);
431 Span {
432 start_byte: start,
433 end_byte: end,
434 start_row,
435 start_col,
436 end_row,
437 end_col,
438 }
439}
440
441fn row_col(source: &str, byte: usize) -> (usize, usize) {
442 let mut row = 0;
443 let mut col = 0;
444 for (i, ch) in source.char_indices() {
445 if i >= byte {
446 break;
447 }
448 if ch == '\n' {
449 row += 1;
450 col = 0;
451 } else {
452 col += 1;
453 }
454 }
455 (row, col)
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::model::Rule;
462
463 fn rule(toml: &str) -> CompiledRule {
464 let parsed = Rule::from_toml_str(toml).expect("rule parses");
465 CompiledRule::compile(&parsed).expect("rule compiles")
466 }
467
468 #[test]
469 fn pattern_rule_binds_metavars() {
470 let compiled = rule(
471 r#"
472 id = "destructure-default"
473 language = "typescript"
474 fix = "{ $KEY: $SRC }"
475 [rule]
476 pattern = "$SRC?.$KEY ?? $DEFAULT"
477 "#,
478 );
479 let matches = compiled
480 .run("const a = cfg?.timeout ?? 30;\nconst b = opts?.retries ?? 3;\n")
481 .unwrap();
482 assert_eq!(matches.len(), 2);
483 assert_eq!(matches[0].bindings["SRC"].text, "cfg");
484 assert_eq!(matches[0].bindings["KEY"].text, "timeout");
485 assert_eq!(matches[0].bindings["DEFAULT"].text, "30");
486 assert_eq!(matches[1].bindings["SRC"].text, "opts");
487 assert_eq!(matches[0].text, "cfg?.timeout ?? 30");
489 assert_eq!(matches[0].span.start_row, 0);
490 assert_eq!(matches[1].span.start_row, 1);
491 }
492
493 #[test]
494 fn nested_matches_do_not_corrupt_or_panic_on_apply() {
495 let compiled = rule(
500 r#"
501 id = "sum-binop"
502 language = "typescript"
503 fix = "sum($X, $Y)"
504 [rule]
505 pattern = "$X + $Y"
506 "#,
507 );
508 assert!(compiled.run("const z = a + b + c;\n").unwrap().len() >= 2);
510 let result = compiled.apply("const z = a + b + c;\n").unwrap();
511 assert_eq!(result.rewritten, "const z = sum(a + b, c);\n");
513 assert_eq!(result.edits.len(), 1);
514 assert!(result.changed);
515 }
516
517 #[test]
518 fn dedupe_overlapping_keeps_outermost_in_document_order() {
519 let span = |s: usize, e: usize| Span {
520 start_byte: s,
521 end_byte: e,
522 start_row: 0,
523 start_col: s,
524 end_row: 0,
525 end_col: e,
526 };
527 let m = |s: usize, e: usize| RuleMatch {
528 rule_id: "r".into(),
529 span: span(s, e),
530 text: String::new(),
531 bindings: BTreeMap::new(),
532 };
533 let kept = dedupe_overlapping(vec![m(0, 5), m(0, 9), m(10, 14)]);
535 let spans: Vec<_> = kept
536 .iter()
537 .map(|m| (m.span.start_byte, m.span.end_byte))
538 .collect();
539 assert_eq!(spans, vec![(0, 9), (10, 14)]);
540 }
541
542 #[test]
543 fn kind_rule_matches_node_kind() {
544 let compiled = rule(
545 r#"
546 id = "find-calls"
547 language = "python"
548 [rule]
549 kind = "call"
550 "#,
551 );
552 let matches = compiled.run("print(x)\nlog(y)\n").unwrap();
553 assert_eq!(matches.len(), 2);
554 assert_eq!(matches[0].text, "print(x)");
555 assert!(matches[0].bindings.is_empty());
556 }
557
558 #[test]
559 fn regex_rule_matches_text() {
560 let compiled = rule(
561 r#"
562 id = "todo"
563 language = "rust"
564 message = "Found a TODO"
565 [rule]
566 regex = "TODO\\(\\w+\\)"
567 "#,
568 );
569 let matches = compiled
570 .run("fn f() {\n // TODO(ken) fix\n // todo lower\n}\n")
571 .unwrap();
572 assert_eq!(matches.len(), 1);
573 assert_eq!(matches[0].text, "TODO(ken)");
574 assert_eq!(matches[0].span.start_row, 1);
575 }
576
577 #[test]
578 fn unknown_language_is_an_error() {
579 let parsed = Rule::from_toml_str(
580 r#"
581 id = "x"
582 language = "cobol"
583 [rule]
584 kind = "foo"
585 "#,
586 )
587 .unwrap();
588 assert!(matches!(
589 CompiledRule::compile(&parsed),
590 Err(RulesError::UnknownLanguage { .. })
591 ));
592 }
593
594 #[test]
595 fn invalid_pattern_surfaces_compile_error() {
596 let parsed = Rule::from_toml_str(
597 r#"
598 id = "x"
599 language = "typescript"
600 [rule]
601 pattern = "foo($$$ARGS)"
602 "#,
603 )
604 .unwrap();
605 assert!(matches!(
606 CompiledRule::compile(&parsed),
607 Err(RulesError::PatternCompile { .. })
608 ));
609 }
610}