1use std::collections::BTreeMap;
48use std::path::{Path, PathBuf};
49use std::sync::Arc;
50
51use harn_hostlib::ast::Language;
52use harn_hostlib::tools::permissions::gated_handler;
53use harn_hostlib::{
54 BuiltinRegistry, HostlibCapability, HostlibError, HostlibRegistry, RegisteredBuiltin,
55};
56use harn_vm::{AsyncBuiltinCtx, Vm, VmError, VmValue};
57
58use harn_rules::{
59 data_table, Applicability, CompiledRule, Diagnostic, Rule, RuleMatch, Safety, Severity,
60 SourceFile, Span,
61};
62
63const SEARCH: &str = "hostlib_rules_search";
64const REPORT: &str = "hostlib_rules_report";
65const DIAGNOSTICS: &str = "hostlib_rules_diagnostics";
66const VISIT: &str = "hostlib_rules_visit";
67const APPLY: &str = "hostlib_rules_apply";
68
69#[derive(Default)]
71pub struct RulesCapability;
72
73impl HostlibCapability for RulesCapability {
74 fn module_name(&self) -> &'static str {
75 "rules"
76 }
77
78 fn register_builtins(&self, registry: &mut BuiltinRegistry) {
79 registry.register(RegisteredBuiltin {
80 name: SEARCH,
81 module: "rules",
82 method: "search",
83 handler: Arc::new(search_run),
84 });
85 registry.register(RegisteredBuiltin {
86 name: REPORT,
87 module: "rules",
88 method: "report",
89 handler: Arc::new(report_run),
90 });
91 registry.register(RegisteredBuiltin {
92 name: DIAGNOSTICS,
93 module: "rules",
94 method: "diagnostics",
95 handler: Arc::new(diagnostics_run),
96 });
97 registry.register(RegisteredBuiltin {
99 name: APPLY,
100 module: "rules",
101 method: "apply",
102 handler: gated_handler(APPLY, apply_run),
103 });
104 }
105}
106
107pub fn install(vm: &mut Vm) {
110 HostlibRegistry::new()
111 .with(RulesCapability)
112 .register_into_vm(vm);
113 vm.register_async_builtin(VISIT, visit_run);
117}
118
119fn search_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
124 let dict = first_dict(SEARCH, args)?;
125 let rule = compile_rule(SEARCH, &dict)?;
126 let files = load_files(SEARCH, &dict)?;
127
128 let mut matches = Vec::new();
129 for file in &files {
130 for m in rule.run(&file.source).map_err(|e| backend(SEARCH, &e))? {
131 matches.push(match_to_vm(&file.path, &m));
132 }
133 }
134 Ok(dict_vm([
135 ("result", str_vm("ok")),
136 ("match_count", VmValue::Int(matches.len() as i64)),
137 ("matches", VmValue::List(Arc::new(matches))),
138 ]))
139}
140
141fn report_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
142 let dict = first_dict(REPORT, args)?;
143 let rule = compile_rule(REPORT, &dict)?;
144 let files = load_files(REPORT, &dict)?;
145 let table = data_table(&rule, &files).map_err(|e| backend(REPORT, &e))?;
146 Ok(json_to_vm(&table.to_json_value()))
147}
148
149fn diagnostics_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
150 let dict = first_dict(DIAGNOSTICS, args)?;
151 let rule = compile_rule(DIAGNOSTICS, &dict)?;
152 let files = load_files(DIAGNOSTICS, &dict)?;
153
154 let mut diagnostics = Vec::new();
155 for file in &files {
156 for d in rule
157 .diagnostics(&file.source)
158 .map_err(|e| backend(DIAGNOSTICS, &e))?
159 {
160 diagnostics.push(diagnostic_vm(&file.path, &d));
161 }
162 }
163 Ok(dict_vm([
164 ("result", str_vm("ok")),
165 ("diagnostic_count", VmValue::Int(diagnostics.len() as i64)),
166 ("diagnostics", VmValue::List(Arc::new(diagnostics))),
167 ]))
168}
169
170async fn visit_run(ctx: AsyncBuiltinCtx, args: Vec<VmValue>) -> Result<VmValue, VmError> {
175 let dict = first_dict(VISIT, &args).map_err(host_err)?;
176 let rule = compile_rule(VISIT, &dict).map_err(host_err)?;
177 let files = load_files(VISIT, &dict).map_err(host_err)?;
178 let visitor = match dict.get("on_match") {
179 Some(VmValue::Closure(c)) => c.clone(),
180 _ => {
181 return Err(VmError::Runtime(format!(
182 "{VISIT}: `on_match` must be a function `fn(node, ctx)`"
183 )))
184 }
185 };
186
187 let default_severity = rule.severity();
188 let default_safety = rule.safety();
189 let rule_id = rule.id().to_string();
190
191 let mut vm = ctx.child_vm();
192 let mut diagnostics = Vec::new();
193 for file in &files {
194 let matches = rule
195 .run(&file.source)
196 .map_err(|e| host_err(backend(VISIT, &e)))?;
197 let file_ctx = ctx_vm(&file.path, file.language, &file.source, &rule_id);
198 for m in &matches {
199 let node = node_vm(m);
200 let ret = vm
201 .call_closure_pub(&visitor, &[node, file_ctx.clone()])
202 .await?;
203 ctx.forward_output(&vm.take_output());
204 for report in reports_from_return(ret) {
205 diagnostics.push(report_to_diagnostic_vm(
206 &file.path,
207 &rule_id,
208 m.span,
209 report,
210 default_severity,
211 default_safety,
212 ));
213 }
214 }
215 }
216 Ok(dict_vm([
217 ("result", str_vm("ok")),
218 ("diagnostic_count", VmValue::Int(diagnostics.len() as i64)),
219 ("diagnostics", VmValue::List(Arc::new(diagnostics))),
220 ]))
221}
222
223fn apply_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
224 let dict = first_dict(APPLY, args)?;
225 let rule = compile_rule(APPLY, &dict)?;
226 let dry_run = optional_bool(&dict, "dry_run", true);
227 let allow_unsafe = optional_bool(&dict, "allow_unsafe", false);
228 let files = load_files(APPLY, &dict)?;
229
230 let auto_applicable = rule.safety().is_auto_applicable();
231 let mut entries = Vec::new();
232 for file in &files {
233 let outcome = rule.apply(&file.source).map_err(|e| backend(APPLY, &e))?;
234 let applied = !dry_run && outcome.changed && (auto_applicable || allow_unsafe);
237 if applied {
238 std::fs::write(&file.path, &outcome.rewritten).map_err(|e| HostlibError::Backend {
239 builtin: APPLY,
240 message: format!("write `{}`: {e}", file.path.display()),
241 })?;
242 }
243 entries.push(dict_vm([
244 ("path", str_vm(file.path.display().to_string())),
245 ("changed", VmValue::Bool(outcome.changed)),
246 ("applied", VmValue::Bool(applied)),
247 ("idempotent", VmValue::Bool(outcome.idempotent)),
248 ("safety", str_vm(format!("{:?}", outcome.safety))),
249 ("before", str_vm(&file.source)),
252 ("preview", str_vm(outcome.rewritten)),
253 ]));
254 }
255 Ok(dict_vm([
256 ("result", str_vm("ok")),
257 ("dry_run", VmValue::Bool(dry_run)),
258 ("auto_applicable", VmValue::Bool(auto_applicable)),
259 ("files", VmValue::List(Arc::new(entries))),
260 ]))
261}
262
263fn compile_rule(
268 builtin: &'static str,
269 dict: &BTreeMap<String, VmValue>,
270) -> Result<CompiledRule, HostlibError> {
271 let toml = require_string(builtin, dict, "rule")?;
272 let rule = Rule::from_toml_str(&toml).map_err(|e| HostlibError::InvalidParameter {
273 builtin,
274 param: "rule",
275 message: format!("invalid rule TOML: {e}"),
276 })?;
277 CompiledRule::compile(&rule).map_err(|e| HostlibError::InvalidParameter {
278 builtin,
279 param: "rule",
280 message: e.to_string(),
281 })
282}
283
284fn load_files(
288 builtin: &'static str,
289 dict: &BTreeMap<String, VmValue>,
290) -> Result<Vec<SourceFile>, HostlibError> {
291 if let Some(source) = optional_string(dict, "source") {
292 let language_name = require_string(builtin, dict, "language")?;
293 let language =
294 Language::from_name(&language_name).ok_or_else(|| HostlibError::InvalidParameter {
295 builtin,
296 param: "language",
297 message: format!("unknown language `{language_name}`"),
298 })?;
299 let path = optional_string(dict, "path").unwrap_or_else(|| "<inline>".to_string());
300 return Ok(vec![SourceFile {
301 path: PathBuf::from(path),
302 language,
303 source,
304 }]);
305 }
306
307 let paths = optional_string_list(dict, "paths");
308 if paths.is_empty() {
309 return Err(HostlibError::MissingParameter {
310 builtin,
311 param: "paths",
312 });
313 }
314 let mut files = Vec::new();
315 for path in paths {
316 let contents = std::fs::read_to_string(&path).map_err(|e| HostlibError::Backend {
317 builtin,
318 message: format!("read `{path}`: {e}"),
319 })?;
320 if let Some(file) = SourceFile::detect(&path, contents) {
321 files.push(file);
322 }
323 }
324 Ok(files)
325}
326
327fn match_to_vm(path: &std::path::Path, m: &RuleMatch) -> VmValue {
328 let captures: BTreeMap<String, VmValue> = m
329 .bindings
330 .iter()
331 .map(|(name, b)| (name.clone(), str_vm(&b.text)))
332 .collect();
333 dict_vm([
334 ("path", str_vm(path.display().to_string())),
335 ("text", str_vm(&m.text)),
336 ("start_row", VmValue::Int(m.span.start_row as i64)),
337 ("start_col", VmValue::Int(m.span.start_col as i64)),
338 ("end_row", VmValue::Int(m.span.end_row as i64)),
339 ("end_col", VmValue::Int(m.span.end_col as i64)),
340 ("captures", VmValue::Dict(Arc::new(captures))),
341 ])
342}
343
344fn backend(builtin: &'static str, err: &harn_rules::RulesError) -> HostlibError {
345 HostlibError::Backend {
346 builtin,
347 message: err.to_string(),
348 }
349}
350
351fn host_err(err: HostlibError) -> VmError {
354 VmError::Runtime(err.to_string())
355}
356
357#[derive(Default)]
361struct ReportSpec {
362 message: Option<String>,
363 fix: Option<String>,
364 safety: Option<Safety>,
365 severity: Option<Severity>,
366}
367
368fn node_vm(m: &RuleMatch) -> VmValue {
371 let captures: BTreeMap<String, VmValue> = m
372 .bindings
373 .iter()
374 .map(|(name, b)| (name.clone(), str_vm(&b.text)))
375 .collect();
376 dict_vm([
377 ("text", str_vm(&m.text)),
378 ("captures", VmValue::Dict(Arc::new(captures))),
379 ("start_row", VmValue::Int(m.span.start_row as i64)),
380 ("start_col", VmValue::Int(m.span.start_col as i64)),
381 ("end_row", VmValue::Int(m.span.end_row as i64)),
382 ("end_col", VmValue::Int(m.span.end_col as i64)),
383 ])
384}
385
386fn ctx_vm(path: &Path, language: Language, source: &str, rule_id: &str) -> VmValue {
389 dict_vm([
390 ("path", str_vm(path.display().to_string())),
391 ("language", str_vm(language.name())),
392 ("source", str_vm(source)),
393 ("rule_id", str_vm(rule_id)),
394 ])
395}
396
397fn diagnostic_dict(
401 path: &Path,
402 rule_id: &str,
403 message: &str,
404 severity: Severity,
405 span: Span,
406 fix: Option<String>,
407 applicability: Applicability,
408) -> VmValue {
409 dict_vm([
410 ("path", str_vm(path.display().to_string())),
411 ("rule_id", str_vm(rule_id)),
412 ("message", str_vm(message)),
413 ("severity", str_vm(severity.as_str())),
414 ("start_row", VmValue::Int(span.start_row as i64)),
415 ("start_col", VmValue::Int(span.start_col as i64)),
416 ("end_row", VmValue::Int(span.end_row as i64)),
417 ("end_col", VmValue::Int(span.end_col as i64)),
418 ("applicability", str_vm(applicability.as_str())),
419 ("fix", fix.map(str_vm).unwrap_or(VmValue::Nil)),
420 ])
421}
422
423fn diagnostic_vm(path: &Path, d: &Diagnostic) -> VmValue {
424 diagnostic_dict(
425 path,
426 &d.rule_id,
427 &d.message,
428 d.severity,
429 d.span,
430 d.fix.clone(),
431 d.applicability,
432 )
433}
434
435fn report_to_diagnostic_vm(
438 path: &Path,
439 rule_id: &str,
440 span: Span,
441 report: ReportSpec,
442 default_severity: Severity,
443 default_safety: Safety,
444) -> VmValue {
445 let severity = report.severity.unwrap_or(default_severity);
446 let safety = report.safety.unwrap_or(default_safety);
447 diagnostic_dict(
448 path,
449 rule_id,
450 report.message.as_deref().unwrap_or(""),
451 severity,
452 span,
453 report.fix,
454 safety.applicability(),
455 )
456}
457
458fn reports_from_return(ret: VmValue) -> Vec<ReportSpec> {
462 match ret {
463 VmValue::Nil | VmValue::Bool(false) => Vec::new(),
464 VmValue::Bool(true) => vec![ReportSpec::default()],
465 VmValue::Dict(d) => vec![report_from_dict(&d)],
466 VmValue::List(items) => items.iter().filter_map(report_from_item).collect(),
467 _ => Vec::new(),
468 }
469}
470
471fn report_from_item(v: &VmValue) -> Option<ReportSpec> {
472 match v {
473 VmValue::Nil | VmValue::Bool(false) => None,
474 VmValue::Bool(true) => Some(ReportSpec::default()),
475 VmValue::Dict(d) => Some(report_from_dict(d)),
476 _ => None,
477 }
478}
479
480fn report_from_dict(d: &BTreeMap<String, VmValue>) -> ReportSpec {
481 ReportSpec {
482 message: optional_string(d, "message"),
483 fix: optional_string(d, "fix"),
484 safety: optional_string(d, "safety").and_then(|s| parse_safety(&s)),
485 severity: optional_string(d, "severity").and_then(|s| parse_severity(&s)),
486 }
487}
488
489fn parse_severity(s: &str) -> Option<Severity> {
490 match s {
491 "info" => Some(Severity::Info),
492 "warning" => Some(Severity::Warning),
493 "error" => Some(Severity::Error),
494 _ => None,
495 }
496}
497
498fn parse_safety(s: &str) -> Option<Safety> {
499 match s {
500 "format-only" => Some(Safety::FormatOnly),
501 "behavior-preserving" => Some(Safety::BehaviorPreserving),
502 "scope-local" => Some(Safety::ScopeLocal),
503 "surface-changing" => Some(Safety::SurfaceChanging),
504 "capability-changing" => Some(Safety::CapabilityChanging),
505 "needs-human" => Some(Safety::NeedsHuman),
506 _ => None,
507 }
508}
509
510fn json_to_vm(value: &serde_json::Value) -> VmValue {
511 match value {
512 serde_json::Value::Null => VmValue::Nil,
513 serde_json::Value::Bool(b) => VmValue::Bool(*b),
514 serde_json::Value::Number(n) => n
515 .as_i64()
516 .map(VmValue::Int)
517 .unwrap_or_else(|| VmValue::Float(n.as_f64().unwrap_or(0.0))),
518 serde_json::Value::String(s) => str_vm(s),
519 serde_json::Value::Array(items) => {
520 VmValue::List(Arc::new(items.iter().map(json_to_vm).collect()))
521 }
522 serde_json::Value::Object(map) => VmValue::Dict(Arc::new(
523 map.iter()
524 .map(|(k, v)| (k.clone(), json_to_vm(v)))
525 .collect(),
526 )),
527 }
528}
529
530fn first_dict(
535 builtin: &'static str,
536 args: &[VmValue],
537) -> Result<Arc<BTreeMap<String, VmValue>>, HostlibError> {
538 match args.first() {
539 Some(VmValue::Dict(dict)) => Ok(dict.clone()),
540 Some(VmValue::Nil) | None => Ok(Arc::new(BTreeMap::new())),
541 Some(_) => Err(HostlibError::InvalidParameter {
542 builtin,
543 param: "params",
544 message: "expected a dict argument".into(),
545 }),
546 }
547}
548
549fn require_string(
550 builtin: &'static str,
551 dict: &BTreeMap<String, VmValue>,
552 key: &'static str,
553) -> Result<String, HostlibError> {
554 match dict.get(key) {
555 Some(VmValue::String(s)) => Ok(s.to_string()),
556 _ => Err(HostlibError::MissingParameter {
557 builtin,
558 param: key,
559 }),
560 }
561}
562
563fn optional_string(dict: &BTreeMap<String, VmValue>, key: &str) -> Option<String> {
564 match dict.get(key) {
565 Some(VmValue::String(s)) => Some(s.to_string()),
566 _ => None,
567 }
568}
569
570fn optional_string_list(dict: &BTreeMap<String, VmValue>, key: &str) -> Vec<String> {
571 match dict.get(key) {
572 Some(VmValue::List(items)) => items
573 .iter()
574 .filter_map(|v| match v {
575 VmValue::String(s) => Some(s.to_string()),
576 _ => None,
577 })
578 .collect(),
579 _ => Vec::new(),
580 }
581}
582
583fn optional_bool(dict: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
584 match dict.get(key) {
585 Some(VmValue::Bool(b)) => *b,
586 _ => default,
587 }
588}
589
590fn str_vm(s: impl AsRef<str>) -> VmValue {
591 VmValue::String(Arc::from(s.as_ref()))
592}
593
594fn dict_vm<const N: usize>(entries: [(&str, VmValue); N]) -> VmValue {
595 let map: BTreeMap<String, VmValue> = entries
596 .into_iter()
597 .map(|(k, v)| (k.to_string(), v))
598 .collect();
599 VmValue::Dict(Arc::new(map))
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 fn dict(pairs: &[(&str, VmValue)]) -> VmValue {
607 let map: BTreeMap<String, VmValue> = pairs
608 .iter()
609 .map(|(k, v)| (k.to_string(), v.clone()))
610 .collect();
611 VmValue::Dict(Arc::new(map))
612 }
613
614 fn get<'a>(v: &'a VmValue, key: &str) -> &'a VmValue {
615 match v {
616 VmValue::Dict(d) => d.get(key).unwrap_or_else(|| panic!("missing {key}")),
617 _ => panic!("not a dict"),
618 }
619 }
620
621 fn int(v: &VmValue) -> i64 {
622 match v {
623 VmValue::Int(i) => *i,
624 other => panic!("not int: {other:?}"),
625 }
626 }
627
628 fn s(v: &VmValue) -> String {
629 match v {
630 VmValue::String(s) => s.to_string(),
631 other => panic!("not string: {other:?}"),
632 }
633 }
634
635 fn b(v: &VmValue) -> bool {
636 match v {
637 VmValue::Bool(b) => *b,
638 other => panic!("not bool: {other:?}"),
639 }
640 }
641
642 const SEARCH_RULE: &str = r#"
643 id = "find-calls"
644 language = "typescript"
645 [rule]
646 pattern = "$FN()"
647 "#;
648
649 #[test]
650 fn search_returns_matches_with_captures() {
651 let result = search_run(&[dict(&[
652 ("rule", str_vm(SEARCH_RULE)),
653 ("source", str_vm("foo();\nbar();\n")),
654 ("language", str_vm("typescript")),
655 ])])
656 .unwrap();
657 assert_eq!(int(get(&result, "match_count")), 2);
658 let matches = match get(&result, "matches") {
659 VmValue::List(l) => l.clone(),
660 _ => panic!(),
661 };
662 assert_eq!(s(get(get(&matches[0], "captures"), "FN")), "foo");
663 }
664
665 #[test]
666 fn report_returns_a_data_table() {
667 let result = report_run(&[dict(&[
668 ("rule", str_vm(SEARCH_RULE)),
669 ("source", str_vm("foo();\nbar();\n")),
670 ("language", str_vm("typescript")),
671 ("path", str_vm("a.ts")),
672 ])])
673 .unwrap();
674 assert_eq!(int(get(get(&result, "summary"), "total_rows")), 2);
675 assert_eq!(s(get(&result, "rule_id")), "find-calls");
676 }
677
678 #[test]
679 fn apply_dry_run_previews_without_writing() {
680 let rule = r#"
681 id = "rename"
682 language = "typescript"
683 safety = "behavior-preserving"
684 fix = "bar()"
685 [rule]
686 pattern = "foo()"
687 "#;
688 let result = apply_run(&[dict(&[
689 ("rule", str_vm(rule)),
690 ("source", str_vm("foo();\n")),
691 ("language", str_vm("typescript")),
692 ("dry_run", VmValue::Bool(true)),
693 ])])
694 .unwrap();
695 let files = match get(&result, "files") {
696 VmValue::List(l) => l.clone(),
697 _ => panic!(),
698 };
699 assert!(b(get(&files[0], "changed")));
700 assert!(!b(get(&files[0], "applied")));
701 assert_eq!(s(get(&files[0], "preview")), "bar();\n");
702 }
703
704 #[test]
705 fn diagnostics_returns_lint_findings() {
706 let lint = r#"
707 id = "calls"
708 language = "typescript"
709 message = "function call"
710 [rule]
711 pattern = "$FN()"
712 "#;
713 let result = diagnostics_run(&[dict(&[
714 ("rule", str_vm(lint)),
715 ("source", str_vm("foo();\nbar();\n")),
716 ("language", str_vm("typescript")),
717 ("path", str_vm("a.ts")),
718 ])])
719 .unwrap();
720 assert_eq!(int(get(&result, "diagnostic_count")), 2);
721 let diags = match get(&result, "diagnostics") {
722 VmValue::List(l) => l.clone(),
723 _ => panic!(),
724 };
725 assert_eq!(s(get(&diags[0], "message")), "function call");
726 assert_eq!(s(get(&diags[0], "severity")), "warning");
727 assert_eq!(s(get(&diags[0], "applicability")), "suggestion");
729 assert_eq!(int(get(&diags[1], "start_row")), 1);
730 assert!(matches!(get(&diags[0], "fix"), VmValue::Nil));
731 }
732
733 #[test]
734 fn report_helpers_round_trip_severity_and_safety() {
735 assert_eq!(parse_severity("error"), Some(Severity::Error));
737 assert_eq!(parse_severity("bogus"), None);
738 assert_eq!(parse_safety("format-only"), Some(Safety::FormatOnly));
739 assert_eq!(parse_safety("needs-human"), Some(Safety::NeedsHuman));
740 assert_eq!(parse_safety("nope"), None);
741 assert_eq!(reports_from_return(VmValue::Bool(true)).len(), 1);
743 assert_eq!(reports_from_return(VmValue::Nil).len(), 0);
744 assert_eq!(reports_from_return(VmValue::Bool(false)).len(), 0);
745 let list = VmValue::List(Arc::new(vec![
746 dict(&[("message", str_vm("a"))]),
747 VmValue::Nil,
748 dict(&[("message", str_vm("b"))]),
749 ]));
750 assert_eq!(reports_from_return(list).len(), 2);
751 }
752
753 #[test]
754 fn capability_does_not_register_the_async_visitor() {
755 let mut registry = BuiltinRegistry::new();
758 RulesCapability.register_builtins(&mut registry);
759 let names: Vec<_> = registry.iter().map(|b| b.name).collect();
760 assert!(!names.contains(&VISIT));
761 assert!(names.contains(&DIAGNOSTICS));
762 }
763
764 #[test]
765 fn missing_rule_is_an_error() {
766 let err = search_run(&[dict(&[
767 ("source", str_vm("x")),
768 ("language", str_vm("rust")),
769 ])]);
770 assert!(matches!(
771 err,
772 Err(HostlibError::MissingParameter { param: "rule", .. })
773 ));
774 }
775
776 #[test]
777 fn capability_registers_the_sync_builtins() {
778 let mut registry = BuiltinRegistry::new();
779 RulesCapability.register_builtins(&mut registry);
780 let names: Vec<_> = registry.iter().map(|b| b.name).collect();
781 assert_eq!(names, vec![SEARCH, REPORT, DIAGNOSTICS, APPLY]);
782 }
783}