1use std::collections::BTreeMap;
48use std::path::{Path, PathBuf};
49use std::sync::Arc;
50
51use harn_hostlib::ast::Language;
52use harn_hostlib::tools::permissions::gated_handler;
53use harn_hostlib::{
54 BuiltinRegistry, HostlibCapability, HostlibError, HostlibRegistry, RegisteredBuiltin,
55};
56use harn_vm::{AsyncBuiltinCtx, Vm, VmError, VmValue};
57
58use harn_rules::{
59 data_table, Applicability, BindingMetadata, CompiledRule, Diagnostic, ResolvedBinding, Rule,
60 RuleMatch, Safety, Severity, SourceFile, Span,
61};
62
63const SEARCH: &str = "hostlib_rules_search";
64const REPORT: &str = "hostlib_rules_report";
65const DIAGNOSTICS: &str = "hostlib_rules_diagnostics";
66const VISIT: &str = "hostlib_rules_visit";
67const APPLY: &str = "hostlib_rules_apply";
68const FOLD: &str = "hostlib_rules_fold";
69const LINT_RUN: &str = "hostlib_lint_run";
70
71#[derive(Default)]
73pub struct RulesCapability;
74
75impl HostlibCapability for RulesCapability {
76 fn module_name(&self) -> &'static str {
77 "rules"
78 }
79
80 fn register_builtins(&self, registry: &mut BuiltinRegistry) {
81 registry.register(RegisteredBuiltin {
82 name: SEARCH,
83 module: "rules",
84 method: "search",
85 handler: Arc::new(search_run),
86 });
87 registry.register(RegisteredBuiltin {
88 name: REPORT,
89 module: "rules",
90 method: "report",
91 handler: Arc::new(report_run),
92 });
93 registry.register(RegisteredBuiltin {
94 name: DIAGNOSTICS,
95 module: "rules",
96 method: "diagnostics",
97 handler: Arc::new(diagnostics_run),
98 });
99 registry.register(RegisteredBuiltin {
101 name: APPLY,
102 module: "rules",
103 method: "apply",
104 handler: gated_handler(APPLY, apply_run),
105 });
106 registry.register(RegisteredBuiltin {
108 name: FOLD,
109 module: "rules",
110 method: "fold",
111 handler: gated_handler(FOLD, fold_run),
112 });
113 }
114}
115
116#[derive(Default)]
119pub struct LintCapability;
120
121impl HostlibCapability for LintCapability {
122 fn module_name(&self) -> &'static str {
123 "lint"
124 }
125
126 fn register_builtins(&self, registry: &mut BuiltinRegistry) {
127 registry.register(RegisteredBuiltin {
129 name: LINT_RUN,
130 module: "lint",
131 method: "run",
132 handler: Arc::new(lint_run),
133 });
134 }
135}
136
137pub fn install(vm: &mut Vm) {
140 HostlibRegistry::new()
141 .with(RulesCapability)
142 .with(LintCapability)
143 .register_into_vm(vm);
144 vm.register_async_builtin(VISIT, visit_run);
148}
149
150fn search_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
155 let dict = first_dict(SEARCH, args)?;
156 let rule = compile_rule(SEARCH, &dict)?;
157 let files = load_files(SEARCH, &dict)?;
158
159 let mut matches = Vec::new();
160 for file in &files {
161 for m in rule.run(&file.source).map_err(|e| backend(SEARCH, &e))? {
162 matches.push(match_to_vm(&file.path, &m));
163 }
164 }
165 Ok(dict_vm([
166 ("result", str_vm("ok")),
167 ("match_count", VmValue::Int(matches.len() as i64)),
168 ("matches", VmValue::List(Arc::new(matches))),
169 ]))
170}
171
172fn report_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
173 let dict = first_dict(REPORT, args)?;
174 let rule = compile_rule(REPORT, &dict)?;
175 let files = load_files(REPORT, &dict)?;
176 let table = data_table(&rule, &files).map_err(|e| backend(REPORT, &e))?;
177 Ok(json_to_vm(&table.to_json_value()))
178}
179
180fn diagnostics_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
181 let dict = first_dict(DIAGNOSTICS, args)?;
182 let rule = compile_rule(DIAGNOSTICS, &dict)?;
183 let files = load_files(DIAGNOSTICS, &dict)?;
184
185 let mut diagnostics = Vec::new();
186 for file in &files {
187 for d in rule
188 .diagnostics(&file.source)
189 .map_err(|e| backend(DIAGNOSTICS, &e))?
190 {
191 diagnostics.push(diagnostic_vm(&file.path, &d));
192 }
193 }
194 Ok(dict_vm([
195 ("result", str_vm("ok")),
196 ("diagnostic_count", VmValue::Int(diagnostics.len() as i64)),
197 ("diagnostics", VmValue::List(Arc::new(diagnostics))),
198 ]))
199}
200
201async fn visit_run(ctx: AsyncBuiltinCtx, args: Vec<VmValue>) -> Result<VmValue, VmError> {
206 let dict = first_dict(VISIT, &args).map_err(host_err)?;
207 let rule = compile_rule(VISIT, &dict).map_err(host_err)?;
208 let files = load_files(VISIT, &dict).map_err(host_err)?;
209 let visitor = match dict.get("on_match") {
210 Some(VmValue::Closure(c)) => c.clone(),
211 _ => {
212 return Err(VmError::Runtime(format!(
213 "{VISIT}: `on_match` must be a function `fn(node, ctx)`"
214 )))
215 }
216 };
217
218 let default_severity = rule.severity();
219 let default_safety = rule.safety();
220 let rule_id = rule.id().to_string();
221
222 let mut vm = ctx.child_vm();
223 let mut diagnostics = Vec::new();
224 for file in &files {
225 let matches = rule
226 .run(&file.source)
227 .map_err(|e| host_err(backend(VISIT, &e)))?;
228 let file_ctx = ctx_vm(&file.path, file.language, &file.source, &rule_id);
229 for m in &matches {
230 let node = node_vm(m);
231 let ret = vm
232 .call_closure_pub(&visitor, &[node, file_ctx.clone()])
233 .await?;
234 ctx.forward_output(&vm.take_output());
235 for report in reports_from_return(ret) {
236 diagnostics.push(report_to_diagnostic_vm(
237 &file.path,
238 &rule_id,
239 m.span,
240 report,
241 default_severity,
242 default_safety,
243 ));
244 }
245 }
246 }
247 Ok(dict_vm([
248 ("result", str_vm("ok")),
249 ("diagnostic_count", VmValue::Int(diagnostics.len() as i64)),
250 ("diagnostics", VmValue::List(Arc::new(diagnostics))),
251 ]))
252}
253
254fn apply_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
255 let dict = first_dict(APPLY, args)?;
256 let rule = compile_rule(APPLY, &dict)?;
257 let dry_run = optional_bool(&dict, "dry_run", true);
258 let allow_unsafe = optional_bool(&dict, "allow_unsafe", false);
259 let format = optional_bool(&dict, "format", true);
262 let files = load_files(APPLY, &dict)?;
263
264 let auto_applicable = rule.safety().is_auto_applicable();
265 let mut entries = Vec::new();
266 for file in &files {
267 let outcome = rule.apply(&file.source).map_err(|e| backend(APPLY, &e))?;
268 let formatted = format && outcome.changed && file.language == Language::Harn;
272 let rewritten = if formatted {
273 match harn_fmt::format_source(&outcome.rewritten) {
274 Ok(canonical) => canonical,
275 Err(_) => outcome.rewritten,
276 }
277 } else {
278 outcome.rewritten
279 };
280 let applied = !dry_run && outcome.changed && (auto_applicable || allow_unsafe);
283 if applied {
284 std::fs::write(&file.path, &rewritten).map_err(|e| HostlibError::Backend {
285 builtin: APPLY,
286 message: format!("write `{}`: {e}", file.path.display()),
287 })?;
288 }
289 entries.push(dict_vm([
290 ("path", str_vm(file.path.display().to_string())),
291 ("changed", VmValue::Bool(outcome.changed)),
292 ("applied", VmValue::Bool(applied)),
293 ("idempotent", VmValue::Bool(outcome.idempotent)),
294 ("formatted", VmValue::Bool(formatted)),
295 ("safety", str_vm(format!("{:?}", outcome.safety))),
296 ("before", str_vm(&file.source)),
299 ("preview", str_vm(rewritten)),
300 ]));
301 }
302 Ok(dict_vm([
303 ("result", str_vm("ok")),
304 ("dry_run", VmValue::Bool(dry_run)),
305 ("auto_applicable", VmValue::Bool(auto_applicable)),
306 ("files", VmValue::List(Arc::new(entries))),
307 ]))
308}
309
310fn fold_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
315 let dict = first_dict(FOLD, args)?;
316 let dry_run = optional_bool(&dict, "dry_run", true);
317 let files = load_files(FOLD, &dict)?;
318
319 let mut entries = Vec::new();
320 for file in &files {
321 let raw_folded =
322 harn_rules::fold::fold_destructure_defaults(&file.source, file.language.name())
323 .map_err(|e| backend(FOLD, &e))?;
324 let raw_changed = raw_folded != file.source;
325 let formatted = raw_changed && file.language == Language::Harn;
326 let folded = if formatted {
327 match harn_fmt::format_source(&raw_folded) {
328 Ok(canonical) => canonical,
329 Err(_) => raw_folded,
330 }
331 } else {
332 raw_folded
333 };
334 let changed = folded != file.source;
335 let idempotent = harn_rules::fold::fold_destructure_defaults(&folded, file.language.name())
336 .map(|again| again == folded)
337 .unwrap_or(false);
338 let applied = !dry_run && changed;
339 if applied {
340 std::fs::write(&file.path, &folded).map_err(|e| HostlibError::Backend {
341 builtin: FOLD,
342 message: format!("write `{}`: {e}", file.path.display()),
343 })?;
344 }
345 entries.push(dict_vm([
346 ("path", str_vm(file.path.display().to_string())),
347 ("changed", VmValue::Bool(changed)),
348 ("applied", VmValue::Bool(applied)),
349 ("idempotent", VmValue::Bool(idempotent)),
350 ("formatted", VmValue::Bool(formatted)),
351 ("safety", str_vm("BehaviorPreserving")),
352 ("before", str_vm(&file.source)),
353 ("preview", str_vm(folded)),
354 ]));
355 }
356 Ok(dict_vm([
357 ("result", str_vm("ok")),
358 ("dry_run", VmValue::Bool(dry_run)),
359 ("files", VmValue::List(Arc::new(entries))),
360 ]))
361}
362
363fn lint_run(args: &[VmValue]) -> Result<VmValue, HostlibError> {
368 let dict = first_dict(LINT_RUN, args)?;
369 let source = require_string(LINT_RUN, &dict, "source")?;
370 let disabled = optional_string_list(&dict, "disabled");
371 let severity_overrides = parse_severity_overrides(&dict);
372
373 let program = harn_parser::parse_source(&source).map_err(|e| HostlibError::Backend {
374 builtin: LINT_RUN,
375 message: format!("parse error: {e}"),
376 })?;
377 let options = harn_lint::LintOptions {
378 severity_overrides,
379 ..Default::default()
380 };
381 let diagnostics = harn_lint::lint_with_options(
382 &program,
383 &disabled,
384 Some(&source),
385 &std::collections::HashSet::new(),
386 &options,
387 );
388 let items: Vec<VmValue> = diagnostics.iter().map(lint_diagnostic_vm).collect();
389 Ok(dict_vm([
390 ("result", str_vm("ok")),
391 ("diagnostic_count", VmValue::Int(items.len() as i64)),
392 ("diagnostics", VmValue::List(Arc::new(items))),
393 ]))
394}
395
396fn parse_severity_overrides(
399 dict: &BTreeMap<String, VmValue>,
400) -> std::collections::HashMap<String, harn_lint::LintSeverity> {
401 let mut out = std::collections::HashMap::new();
402 if let Some(VmValue::Dict(map)) = dict.get("severity") {
403 for (rule, value) in map.iter() {
404 if let VmValue::String(s) = value {
405 let severity = match s.to_ascii_lowercase().as_str() {
406 "error" => Some(harn_lint::LintSeverity::Error),
407 "warning" | "warn" => Some(harn_lint::LintSeverity::Warning),
408 "info" => Some(harn_lint::LintSeverity::Info),
409 _ => None,
410 };
411 if let Some(severity) = severity {
412 out.insert(rule.clone(), severity);
413 }
414 }
415 }
416 }
417 out
418}
419
420fn lint_diagnostic_vm(diag: &harn_lint::LintDiagnostic) -> VmValue {
423 let severity = match diag.severity {
424 harn_lint::LintSeverity::Error => "error",
425 harn_lint::LintSeverity::Warning => "warning",
426 harn_lint::LintSeverity::Info => "info",
427 };
428 dict_vm([
429 ("code", str_vm(diag.code.as_str())),
430 ("rule", str_vm(diag.rule.as_ref())),
431 ("message", str_vm(&diag.message)),
432 ("severity", str_vm(severity)),
433 ("start_byte", VmValue::Int(diag.span.start as i64)),
434 ("end_byte", VmValue::Int(diag.span.end as i64)),
435 ("line", VmValue::Int(diag.span.line as i64)),
436 ("column", VmValue::Int(diag.span.column as i64)),
437 ])
438}
439
440fn compile_rule(
445 builtin: &'static str,
446 dict: &BTreeMap<String, VmValue>,
447) -> Result<CompiledRule, HostlibError> {
448 let toml = require_string(builtin, dict, "rule")?;
449 let rule = Rule::from_toml_str(&toml).map_err(|e| HostlibError::InvalidParameter {
450 builtin,
451 param: "rule",
452 message: format!("invalid rule TOML: {e}"),
453 })?;
454 CompiledRule::compile(&rule).map_err(|e| HostlibError::InvalidParameter {
455 builtin,
456 param: "rule",
457 message: e.to_string(),
458 })
459}
460
461fn load_files(
465 builtin: &'static str,
466 dict: &BTreeMap<String, VmValue>,
467) -> Result<Vec<SourceFile>, HostlibError> {
468 if let Some(source) = optional_string(dict, "source") {
469 let language_name = require_string(builtin, dict, "language")?;
470 let language =
471 Language::from_name(&language_name).ok_or_else(|| HostlibError::InvalidParameter {
472 builtin,
473 param: "language",
474 message: format!("unknown language `{language_name}`"),
475 })?;
476 let path = optional_string(dict, "path").unwrap_or_else(|| "<inline>".to_string());
477 return Ok(vec![SourceFile {
478 path: PathBuf::from(path),
479 language,
480 source,
481 }]);
482 }
483
484 let paths = optional_string_list(dict, "paths");
485 if paths.is_empty() {
486 return Err(HostlibError::MissingParameter {
487 builtin,
488 param: "paths",
489 });
490 }
491 let mut files = Vec::new();
492 for path in paths {
493 let contents = std::fs::read_to_string(&path).map_err(|e| HostlibError::Backend {
494 builtin,
495 message: format!("read `{path}`: {e}"),
496 })?;
497 if let Some(file) = SourceFile::detect(&path, contents) {
498 files.push(file);
499 }
500 }
501 Ok(files)
502}
503
504fn match_to_vm(path: &std::path::Path, m: &RuleMatch) -> VmValue {
505 let captures: BTreeMap<String, VmValue> = m
506 .bindings
507 .iter()
508 .map(|(name, b)| (name.clone(), str_vm(&b.text)))
509 .collect();
510 let capture_metadata = capture_metadata_vm(m);
511 dict_vm([
512 ("path", str_vm(path.display().to_string())),
513 ("text", str_vm(&m.text)),
514 ("start_row", VmValue::Int(m.span.start_row as i64)),
515 ("start_col", VmValue::Int(m.span.start_col as i64)),
516 ("end_row", VmValue::Int(m.span.end_row as i64)),
517 ("end_col", VmValue::Int(m.span.end_col as i64)),
518 ("captures", VmValue::Dict(Arc::new(captures))),
519 ("capture_metadata", capture_metadata),
520 ])
521}
522
523fn backend(builtin: &'static str, err: &harn_rules::RulesError) -> HostlibError {
524 HostlibError::Backend {
525 builtin,
526 message: err.to_string(),
527 }
528}
529
530fn host_err(err: HostlibError) -> VmError {
533 VmError::Runtime(err.to_string())
534}
535
536#[derive(Default)]
540struct ReportSpec {
541 message: Option<String>,
542 fix: Option<String>,
543 safety: Option<Safety>,
544 severity: Option<Severity>,
545}
546
547fn node_vm(m: &RuleMatch) -> VmValue {
550 let captures: BTreeMap<String, VmValue> = m
551 .bindings
552 .iter()
553 .map(|(name, b)| (name.clone(), str_vm(&b.text)))
554 .collect();
555 let capture_metadata = capture_metadata_vm(m);
556 dict_vm([
557 ("text", str_vm(&m.text)),
558 ("captures", VmValue::Dict(Arc::new(captures))),
559 ("capture_metadata", capture_metadata),
560 ("start_row", VmValue::Int(m.span.start_row as i64)),
561 ("start_col", VmValue::Int(m.span.start_col as i64)),
562 ("end_row", VmValue::Int(m.span.end_row as i64)),
563 ("end_col", VmValue::Int(m.span.end_col as i64)),
564 ])
565}
566
567fn capture_metadata_vm(m: &RuleMatch) -> VmValue {
568 let metadata: BTreeMap<String, VmValue> = m
569 .bindings
570 .iter()
571 .filter(|(_, binding)| !binding.metadata.is_empty())
572 .map(|(name, binding)| (name.clone(), binding_metadata_vm(&binding.metadata)))
573 .collect();
574 VmValue::Dict(Arc::new(metadata))
575}
576
577fn binding_metadata_vm(metadata: &BindingMetadata) -> VmValue {
578 let mut entries = BTreeMap::new();
579 if let Some(ty) = &metadata.ty {
580 entries.insert("type".into(), str_vm(ty));
581 }
582 if let Some(resolved) = &metadata.resolved {
583 entries.insert("resolved".into(), resolved_binding_vm(resolved));
584 }
585 VmValue::Dict(Arc::new(entries))
586}
587
588fn resolved_binding_vm(resolved: &ResolvedBinding) -> VmValue {
589 dict_vm([
590 ("id", str_vm(&resolved.id)),
591 ("name", str_vm(&resolved.name)),
592 ("kind", str_vm(&resolved.kind)),
593 ("start_row", VmValue::Int(resolved.span.start_row as i64)),
594 ("start_col", VmValue::Int(resolved.span.start_col as i64)),
595 ("end_row", VmValue::Int(resolved.span.end_row as i64)),
596 ("end_col", VmValue::Int(resolved.span.end_col as i64)),
597 ])
598}
599
600fn ctx_vm(path: &Path, language: Language, source: &str, rule_id: &str) -> VmValue {
603 dict_vm([
604 ("path", str_vm(path.display().to_string())),
605 ("language", str_vm(language.name())),
606 ("source", str_vm(source)),
607 ("rule_id", str_vm(rule_id)),
608 ])
609}
610
611fn diagnostic_dict(
615 path: &Path,
616 rule_id: &str,
617 message: &str,
618 severity: Severity,
619 span: Span,
620 fix: Option<String>,
621 applicability: Applicability,
622) -> VmValue {
623 dict_vm([
624 ("path", str_vm(path.display().to_string())),
625 ("rule_id", str_vm(rule_id)),
626 ("message", str_vm(message)),
627 ("severity", str_vm(severity.as_str())),
628 ("start_row", VmValue::Int(span.start_row as i64)),
629 ("start_col", VmValue::Int(span.start_col as i64)),
630 ("end_row", VmValue::Int(span.end_row as i64)),
631 ("end_col", VmValue::Int(span.end_col as i64)),
632 ("applicability", str_vm(applicability.as_str())),
633 ("fix", fix.map(str_vm).unwrap_or(VmValue::Nil)),
634 ])
635}
636
637fn diagnostic_vm(path: &Path, d: &Diagnostic) -> VmValue {
638 diagnostic_dict(
639 path,
640 &d.rule_id,
641 &d.message,
642 d.severity,
643 d.span,
644 d.fix.clone(),
645 d.applicability,
646 )
647}
648
649fn report_to_diagnostic_vm(
652 path: &Path,
653 rule_id: &str,
654 span: Span,
655 report: ReportSpec,
656 default_severity: Severity,
657 default_safety: Safety,
658) -> VmValue {
659 let severity = report.severity.unwrap_or(default_severity);
660 let safety = report.safety.unwrap_or(default_safety);
661 diagnostic_dict(
662 path,
663 rule_id,
664 report.message.as_deref().unwrap_or(""),
665 severity,
666 span,
667 report.fix,
668 safety.applicability(),
669 )
670}
671
672fn reports_from_return(ret: VmValue) -> Vec<ReportSpec> {
676 match ret {
677 VmValue::Nil | VmValue::Bool(false) => Vec::new(),
678 VmValue::Bool(true) => vec![ReportSpec::default()],
679 VmValue::Dict(d) => vec![report_from_dict(&d)],
680 VmValue::List(items) => items.iter().filter_map(report_from_item).collect(),
681 _ => Vec::new(),
682 }
683}
684
685fn report_from_item(v: &VmValue) -> Option<ReportSpec> {
686 match v {
687 VmValue::Nil | VmValue::Bool(false) => None,
688 VmValue::Bool(true) => Some(ReportSpec::default()),
689 VmValue::Dict(d) => Some(report_from_dict(d)),
690 _ => None,
691 }
692}
693
694fn report_from_dict(d: &BTreeMap<String, VmValue>) -> ReportSpec {
695 ReportSpec {
696 message: optional_string(d, "message"),
697 fix: optional_string(d, "fix"),
698 safety: optional_string(d, "safety").and_then(|s| parse_safety(&s)),
699 severity: optional_string(d, "severity").and_then(|s| parse_severity(&s)),
700 }
701}
702
703fn parse_severity(s: &str) -> Option<Severity> {
704 match s {
705 "info" => Some(Severity::Info),
706 "warning" => Some(Severity::Warning),
707 "error" => Some(Severity::Error),
708 _ => None,
709 }
710}
711
712fn parse_safety(s: &str) -> Option<Safety> {
713 match s {
714 "format-only" => Some(Safety::FormatOnly),
715 "behavior-preserving" => Some(Safety::BehaviorPreserving),
716 "scope-local" => Some(Safety::ScopeLocal),
717 "surface-changing" => Some(Safety::SurfaceChanging),
718 "capability-changing" => Some(Safety::CapabilityChanging),
719 "needs-human" => Some(Safety::NeedsHuman),
720 _ => None,
721 }
722}
723
724fn json_to_vm(value: &serde_json::Value) -> VmValue {
725 match value {
726 serde_json::Value::Null => VmValue::Nil,
727 serde_json::Value::Bool(b) => VmValue::Bool(*b),
728 serde_json::Value::Number(n) => n
729 .as_i64()
730 .map(VmValue::Int)
731 .unwrap_or_else(|| VmValue::Float(n.as_f64().unwrap_or(0.0))),
732 serde_json::Value::String(s) => str_vm(s),
733 serde_json::Value::Array(items) => {
734 VmValue::List(Arc::new(items.iter().map(json_to_vm).collect()))
735 }
736 serde_json::Value::Object(map) => VmValue::Dict(Arc::new(
737 map.iter()
738 .map(|(k, v)| (k.clone(), json_to_vm(v)))
739 .collect(),
740 )),
741 }
742}
743
744fn first_dict(
749 builtin: &'static str,
750 args: &[VmValue],
751) -> Result<Arc<BTreeMap<String, VmValue>>, HostlibError> {
752 match args.first() {
753 Some(VmValue::Dict(dict)) => Ok(dict.clone()),
754 Some(VmValue::Nil) | None => Ok(Arc::new(BTreeMap::new())),
755 Some(_) => Err(HostlibError::InvalidParameter {
756 builtin,
757 param: "params",
758 message: "expected a dict argument".into(),
759 }),
760 }
761}
762
763fn require_string(
764 builtin: &'static str,
765 dict: &BTreeMap<String, VmValue>,
766 key: &'static str,
767) -> Result<String, HostlibError> {
768 match dict.get(key) {
769 Some(VmValue::String(s)) => Ok(s.to_string()),
770 _ => Err(HostlibError::MissingParameter {
771 builtin,
772 param: key,
773 }),
774 }
775}
776
777fn optional_string(dict: &BTreeMap<String, VmValue>, key: &str) -> Option<String> {
778 match dict.get(key) {
779 Some(VmValue::String(s)) => Some(s.to_string()),
780 _ => None,
781 }
782}
783
784fn optional_string_list(dict: &BTreeMap<String, VmValue>, key: &str) -> Vec<String> {
785 match dict.get(key) {
786 Some(VmValue::List(items)) => items
787 .iter()
788 .filter_map(|v| match v {
789 VmValue::String(s) => Some(s.to_string()),
790 _ => None,
791 })
792 .collect(),
793 _ => Vec::new(),
794 }
795}
796
797fn optional_bool(dict: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
798 match dict.get(key) {
799 Some(VmValue::Bool(b)) => *b,
800 _ => default,
801 }
802}
803
804fn str_vm(s: impl AsRef<str>) -> VmValue {
805 VmValue::String(Arc::from(s.as_ref()))
806}
807
808fn dict_vm<const N: usize>(entries: [(&str, VmValue); N]) -> VmValue {
809 let map: BTreeMap<String, VmValue> = entries
810 .into_iter()
811 .map(|(k, v)| (k.to_string(), v))
812 .collect();
813 VmValue::Dict(Arc::new(map))
814}
815
816#[cfg(test)]
817mod tests {
818 use super::*;
819
820 fn dict(pairs: &[(&str, VmValue)]) -> VmValue {
821 let map: BTreeMap<String, VmValue> = pairs
822 .iter()
823 .map(|(k, v)| (k.to_string(), v.clone()))
824 .collect();
825 VmValue::Dict(Arc::new(map))
826 }
827
828 fn get<'a>(v: &'a VmValue, key: &str) -> &'a VmValue {
829 match v {
830 VmValue::Dict(d) => d.get(key).unwrap_or_else(|| panic!("missing {key}")),
831 _ => panic!("not a dict"),
832 }
833 }
834
835 fn int(v: &VmValue) -> i64 {
836 match v {
837 VmValue::Int(i) => *i,
838 other => panic!("not int: {other:?}"),
839 }
840 }
841
842 fn s(v: &VmValue) -> String {
843 match v {
844 VmValue::String(s) => s.to_string(),
845 other => panic!("not string: {other:?}"),
846 }
847 }
848
849 fn b(v: &VmValue) -> bool {
850 match v {
851 VmValue::Bool(b) => *b,
852 other => panic!("not bool: {other:?}"),
853 }
854 }
855
856 const SEARCH_RULE: &str = r#"
857 id = "find-calls"
858 language = "typescript"
859 [rule]
860 pattern = "$FN()"
861 "#;
862
863 #[test]
864 fn search_returns_matches_with_captures() {
865 let result = search_run(&[dict(&[
866 ("rule", str_vm(SEARCH_RULE)),
867 ("source", str_vm("foo();\nbar();\n")),
868 ("language", str_vm("typescript")),
869 ])])
870 .unwrap();
871 assert_eq!(int(get(&result, "match_count")), 2);
872 let matches = match get(&result, "matches") {
873 VmValue::List(l) => l.clone(),
874 _ => panic!(),
875 };
876 assert_eq!(s(get(get(&matches[0], "captures"), "FN")), "foo");
877 }
878
879 #[test]
880 fn search_returns_harn_capture_metadata() {
881 let rule = r#"
882 id = "int-logs"
883 language = "harn"
884 [rule]
885 pattern = "log($VALUE)"
886 "#;
887 let result = search_run(&[dict(&[
888 ("rule", str_vm(rule)),
889 (
890 "source",
891 str_vm("fn main() {\n let count: int = 1\n log(count)\n}\n"),
892 ),
893 ("language", str_vm("harn")),
894 ])])
895 .unwrap();
896 let matches = match get(&result, "matches") {
897 VmValue::List(l) => l.clone(),
898 _ => panic!(),
899 };
900 let metadata = get(get(&matches[0], "capture_metadata"), "VALUE");
901 assert_eq!(s(get(metadata, "type")), "int");
902 assert_eq!(s(get(get(metadata, "resolved"), "name")), "count");
903 assert_eq!(s(get(get(metadata, "resolved"), "kind")), "let");
904 }
905
906 #[test]
907 fn report_returns_a_data_table() {
908 let result = report_run(&[dict(&[
909 ("rule", str_vm(SEARCH_RULE)),
910 ("source", str_vm("foo();\nbar();\n")),
911 ("language", str_vm("typescript")),
912 ("path", str_vm("a.ts")),
913 ])])
914 .unwrap();
915 assert_eq!(int(get(get(&result, "summary"), "total_rows")), 2);
916 assert_eq!(s(get(&result, "rule_id")), "find-calls");
917 }
918
919 #[test]
920 fn apply_dry_run_previews_without_writing() {
921 let rule = r#"
922 id = "rename"
923 language = "typescript"
924 safety = "behavior-preserving"
925 fix = "bar()"
926 [rule]
927 pattern = "foo()"
928 "#;
929 let result = apply_run(&[dict(&[
930 ("rule", str_vm(rule)),
931 ("source", str_vm("foo();\n")),
932 ("language", str_vm("typescript")),
933 ("dry_run", VmValue::Bool(true)),
934 ])])
935 .unwrap();
936 let files = match get(&result, "files") {
937 VmValue::List(l) => l.clone(),
938 _ => panic!(),
939 };
940 assert!(b(get(&files[0], "changed")));
941 assert!(!b(get(&files[0], "applied")));
942 assert_eq!(s(get(&files[0], "preview")), "bar();\n");
943 }
944
945 const UGLY_HARN_CODEMOD: &str = r#"
946 id = "dd"
947 language = "harn"
948 safety = "scope-local"
949 fix = "let {$K=$D}=$X"
950 [rule]
951 pattern = "let $K = $X?.$K ?? $D"
952 "#;
953
954 #[test]
955 fn apply_formats_harn_output_by_default() {
956 let result = apply_run(&[dict(&[
959 ("rule", str_vm(UGLY_HARN_CODEMOD)),
960 (
961 "source",
962 str_vm("fn main() {\n let timeout = cfg?.timeout ?? 30\n}\n"),
963 ),
964 ("language", str_vm("harn")),
965 ("dry_run", VmValue::Bool(true)),
966 ])])
967 .unwrap();
968 let files = match get(&result, "files") {
969 VmValue::List(l) => l.clone(),
970 _ => panic!(),
971 };
972 assert!(b(get(&files[0], "changed")));
973 assert!(b(get(&files[0], "formatted")));
974 let preview = s(get(&files[0], "preview"));
975 assert!(preview.contains("= 30"), "preview not formatted: {preview}");
976 }
977
978 #[test]
979 fn apply_format_false_leaves_raw_output() {
980 let result = apply_run(&[dict(&[
981 ("rule", str_vm(UGLY_HARN_CODEMOD)),
982 (
983 "source",
984 str_vm("fn main() {\n let timeout = cfg?.timeout ?? 30\n}\n"),
985 ),
986 ("language", str_vm("harn")),
987 ("dry_run", VmValue::Bool(true)),
988 ("format", VmValue::Bool(false)),
989 ])])
990 .unwrap();
991 let files = match get(&result, "files") {
992 VmValue::List(l) => l.clone(),
993 _ => panic!(),
994 };
995 assert!(!b(get(&files[0], "formatted")));
996 let preview = s(get(&files[0], "preview"));
997 assert!(preview.contains("{timeout=30}"), "expected raw: {preview}");
998 }
999
1000 #[test]
1001 fn diagnostics_returns_lint_findings() {
1002 let lint = r#"
1003 id = "calls"
1004 language = "typescript"
1005 message = "function call"
1006 [rule]
1007 pattern = "$FN()"
1008 "#;
1009 let result = diagnostics_run(&[dict(&[
1010 ("rule", str_vm(lint)),
1011 ("source", str_vm("foo();\nbar();\n")),
1012 ("language", str_vm("typescript")),
1013 ("path", str_vm("a.ts")),
1014 ])])
1015 .unwrap();
1016 assert_eq!(int(get(&result, "diagnostic_count")), 2);
1017 let diags = match get(&result, "diagnostics") {
1018 VmValue::List(l) => l.clone(),
1019 _ => panic!(),
1020 };
1021 assert_eq!(s(get(&diags[0], "message")), "function call");
1022 assert_eq!(s(get(&diags[0], "severity")), "warning");
1023 assert_eq!(s(get(&diags[0], "applicability")), "suggestion");
1025 assert_eq!(int(get(&diags[1], "start_row")), 1);
1026 assert!(matches!(get(&diags[0], "fix"), VmValue::Nil));
1027 }
1028
1029 #[test]
1030 fn report_helpers_round_trip_severity_and_safety() {
1031 assert_eq!(parse_severity("error"), Some(Severity::Error));
1033 assert_eq!(parse_severity("bogus"), None);
1034 assert_eq!(parse_safety("format-only"), Some(Safety::FormatOnly));
1035 assert_eq!(parse_safety("needs-human"), Some(Safety::NeedsHuman));
1036 assert_eq!(parse_safety("nope"), None);
1037 assert_eq!(reports_from_return(VmValue::Bool(true)).len(), 1);
1039 assert_eq!(reports_from_return(VmValue::Nil).len(), 0);
1040 assert_eq!(reports_from_return(VmValue::Bool(false)).len(), 0);
1041 let list = VmValue::List(Arc::new(vec![
1042 dict(&[("message", str_vm("a"))]),
1043 VmValue::Nil,
1044 dict(&[("message", str_vm("b"))]),
1045 ]));
1046 assert_eq!(reports_from_return(list).len(), 2);
1047 }
1048
1049 #[test]
1050 fn capability_does_not_register_the_async_visitor() {
1051 let mut registry = BuiltinRegistry::new();
1054 RulesCapability.register_builtins(&mut registry);
1055 let names: Vec<_> = registry.iter().map(|b| b.name).collect();
1056 assert!(!names.contains(&VISIT));
1057 assert!(names.contains(&DIAGNOSTICS));
1058 }
1059
1060 #[test]
1061 fn missing_rule_is_an_error() {
1062 let err = search_run(&[dict(&[
1063 ("source", str_vm("x")),
1064 ("language", str_vm("rust")),
1065 ])]);
1066 assert!(matches!(
1067 err,
1068 Err(HostlibError::MissingParameter { param: "rule", .. })
1069 ));
1070 }
1071
1072 #[test]
1073 fn capability_registers_the_sync_builtins() {
1074 let mut registry = BuiltinRegistry::new();
1075 RulesCapability.register_builtins(&mut registry);
1076 let names: Vec<_> = registry.iter().map(|b| b.name).collect();
1077 assert_eq!(names, vec![SEARCH, REPORT, DIAGNOSTICS, APPLY, FOLD]);
1078 }
1079
1080 #[test]
1081 fn lint_capability_registers_run() {
1082 let mut registry = BuiltinRegistry::new();
1083 LintCapability.register_builtins(&mut registry);
1084 let names: Vec<_> = registry.iter().map(|b| b.name).collect();
1085 assert_eq!(names, vec![LINT_RUN]);
1086 }
1087
1088 #[test]
1089 fn lint_run_returns_the_linter_findings() {
1090 let result =
1091 lint_run(&[dict(&[("source", str_vm("fn f() {\n let x = (1)\n}\n"))])]).unwrap();
1092 assert_eq!(s(get(&result, "result")), "ok");
1093 let diags = match get(&result, "diagnostics") {
1094 VmValue::List(l) => l.clone(),
1095 _ => panic!(),
1096 };
1097 assert!(
1098 diags
1099 .iter()
1100 .any(|d| s(get(d, "rule")) == "unnecessary-parentheses"),
1101 "expected unnecessary-parentheses, got {diags:?}"
1102 );
1103 }
1104
1105 #[test]
1106 fn lint_run_applies_a_severity_override() {
1107 let result = lint_run(&[dict(&[
1108 ("source", str_vm("fn f() {\n let x = (1)\n}\n")),
1109 (
1110 "severity",
1111 dict(&[("unnecessary-parentheses", str_vm("error"))]),
1112 ),
1113 ])])
1114 .unwrap();
1115 let diags = match get(&result, "diagnostics") {
1116 VmValue::List(l) => l.clone(),
1117 _ => panic!(),
1118 };
1119 let d = diags
1120 .iter()
1121 .find(|d| s(get(d, "rule")) == "unnecessary-parentheses")
1122 .expect("rule present");
1123 assert_eq!(s(get(d, "severity")), "error");
1124 }
1125}