1use std::fs;
20use std::io;
21use std::path::{Path, PathBuf};
22
23use bock_air::{AIRNode, NodeKind};
24use chrono::{DateTime, Utc};
25use serde::{Deserialize, Serialize};
26
27use crate::cache::compute_key;
28use crate::request::CandidateRule;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum Provenance {
36 Builtin,
38 Extracted,
40 Manual,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
48pub struct Rule {
49 pub id: String,
51 pub target_id: String,
53 pub node_kind: String,
55 pub pattern: String,
59 pub template: String,
61 pub provenance: Provenance,
63 pub pinned: bool,
65 pub confidence: f64,
67 pub priority: i32,
69 pub created: DateTime<Utc>,
71}
72
73impl Rule {
74 #[must_use]
82 pub fn from_candidate(candidate: &CandidateRule, node_kind: &str, confidence: f64) -> Self {
83 let id = compute_rule_id(&candidate.target_id, node_kind, &candidate.template);
84 Self {
85 id,
86 target_id: candidate.target_id.clone(),
87 node_kind: node_kind.into(),
88 pattern: candidate.pattern.clone(),
89 template: candidate.template.clone(),
90 provenance: Provenance::Extracted,
91 pinned: false,
92 confidence,
93 priority: candidate.priority,
94 created: Utc::now(),
95 }
96 }
97}
98
99#[must_use]
105pub fn compute_rule_id(target_id: &str, node_kind: &str, template: &str) -> String {
106 #[derive(Serialize)]
107 struct Keyed<'a> {
108 target: &'a str,
109 kind: &'a str,
110 template: &'a str,
111 }
112 let keyed = Keyed {
113 target: target_id,
114 kind: node_kind,
115 template,
116 };
117 compute_key(&keyed)
118 .unwrap_or_else(|_| format!("fallback-{target_id}-{node_kind}"))
119}
120
121#[derive(Debug, thiserror::Error)]
125pub enum RuleCacheError {
126 #[error("rule cache I/O error: {0}")]
128 Io(#[from] io::Error),
129 #[error("rule parse error in {path}: {source}")]
131 Parse {
132 path: PathBuf,
134 #[source]
136 source: serde_json::Error,
137 },
138 #[error("rule serialize error: {0}")]
140 Serialize(#[from] serde_json::Error),
141}
142
143#[derive(Debug, Clone)]
145pub struct RuleCache {
146 root: PathBuf,
147}
148
149impl RuleCache {
150 #[must_use]
155 pub fn new(project_root: &Path) -> Self {
156 Self {
157 root: project_root.join(".bock").join("rules"),
158 }
159 }
160
161 #[must_use]
163 pub fn with_root(root: PathBuf) -> Self {
164 Self { root }
165 }
166
167 #[must_use]
169 pub fn root(&self) -> &Path {
170 &self.root
171 }
172
173 #[must_use]
175 pub fn target_dir(&self, target_id: &str) -> PathBuf {
176 self.root.join(target_id)
177 }
178
179 pub fn insert(&self, rule: &Rule) -> Result<(), RuleCacheError> {
186 let dir = self.target_dir(&rule.target_id);
187 fs::create_dir_all(&dir)?;
188 let path = dir.join(format!("{}.json", rule.id));
189 let bytes = serde_json::to_vec_pretty(rule)?;
190 fs::write(&path, bytes)?;
191 Ok(())
192 }
193
194 pub fn load_for_target(&self, target_id: &str) -> Result<Vec<Rule>, RuleCacheError> {
199 let dir = self.target_dir(target_id);
200 if !dir.exists() {
201 return Ok(Vec::new());
202 }
203 let mut out = Vec::new();
204 for entry in fs::read_dir(&dir)? {
205 let entry = entry?;
206 let path = entry.path();
207 if path.extension().and_then(|e| e.to_str()) != Some("json") {
208 continue;
209 }
210 let bytes = fs::read(&path)?;
211 let rule: Rule =
212 serde_json::from_slice(&bytes).map_err(|source| RuleCacheError::Parse {
213 path: path.clone(),
214 source,
215 })?;
216 out.push(rule);
217 }
218 Ok(out)
219 }
220
221 pub fn lookup(
232 &self,
233 target_id: &str,
234 node: &AIRNode,
235 production_only_pinned: bool,
236 ) -> Result<Option<Rule>, RuleCacheError> {
237 let kind = node_kind_name(&node.kind);
238 let rules = self.load_for_target(target_id)?;
239 let best = rules
240 .into_iter()
241 .filter(|r| r.node_kind == kind)
242 .filter(|r| !production_only_pinned || r.pinned)
243 .max_by(|a, b| {
244 a.priority
245 .cmp(&b.priority)
246 .then(a.pinned.cmp(&b.pinned))
247 .then(a.created.cmp(&b.created))
248 });
249 Ok(best)
250 }
251}
252
253#[must_use]
258pub fn node_kind_name(kind: &NodeKind) -> &'static str {
259 match kind {
260 NodeKind::Module { .. } => "Module",
261 NodeKind::ImportDecl { .. } => "ImportDecl",
262 NodeKind::FnDecl { .. } => "FnDecl",
263 NodeKind::RecordDecl { .. } => "RecordDecl",
264 NodeKind::EnumDecl { .. } => "EnumDecl",
265 NodeKind::ClassDecl { .. } => "ClassDecl",
266 NodeKind::TraitDecl { .. } => "TraitDecl",
267 NodeKind::ImplBlock { .. } => "ImplBlock",
268 NodeKind::EffectDecl { .. } => "EffectDecl",
269 NodeKind::ConstDecl { .. } => "ConstDecl",
270 NodeKind::TypeAlias { .. } => "TypeAlias",
271 NodeKind::Param { .. } => "Param",
272 NodeKind::Block { .. } => "Block",
273 NodeKind::If { .. } => "If",
274 NodeKind::For { .. } => "For",
275 NodeKind::While { .. } => "While",
276 NodeKind::Loop { .. } => "Loop",
277 NodeKind::Match { .. } => "Match",
278 NodeKind::MatchArm { .. } => "MatchArm",
279 NodeKind::Guard { .. } => "Guard",
280 NodeKind::HandlingBlock { .. } => "HandlingBlock",
281 NodeKind::LetBinding { .. } => "LetBinding",
282 NodeKind::Return { .. } => "Return",
283 NodeKind::Break { .. } => "Break",
284 NodeKind::Assign { .. } => "Assign",
285 NodeKind::BinaryOp { .. } => "BinaryOp",
286 NodeKind::UnaryOp { .. } => "UnaryOp",
287 NodeKind::Call { .. } => "Call",
288 NodeKind::MethodCall { .. } => "MethodCall",
289 NodeKind::Lambda { .. } => "Lambda",
290 NodeKind::FieldAccess { .. } => "FieldAccess",
291 NodeKind::Index { .. } => "Index",
292 NodeKind::Pipe { .. } => "Pipe",
293 NodeKind::Compose { .. } => "Compose",
294 NodeKind::Await { .. } => "Await",
295 NodeKind::Propagate { .. } => "Propagate",
296 NodeKind::Move { .. } => "Move",
297 NodeKind::Borrow { .. } => "Borrow",
298 NodeKind::MutableBorrow { .. } => "MutableBorrow",
299 NodeKind::ListLiteral { .. } => "ListLiteral",
300 NodeKind::SetLiteral { .. } => "SetLiteral",
301 NodeKind::TupleLiteral { .. } => "TupleLiteral",
302 NodeKind::MapLiteral { .. } => "MapLiteral",
303 NodeKind::RecordConstruct { .. } => "RecordConstruct",
304 NodeKind::Range { .. } => "Range",
305 NodeKind::ResultConstruct { .. } => "ResultConstruct",
306 NodeKind::TypeNamed { .. } => "TypeNamed",
307 NodeKind::TypeTuple { .. } => "TypeTuple",
308 NodeKind::TypeFunction { .. } => "TypeFunction",
309 NodeKind::TypeOptional { .. } => "TypeOptional",
310 NodeKind::ModuleHandle { .. } => "ModuleHandle",
311 NodeKind::PropertyTest { .. } => "PropertyTest",
312 NodeKind::ConstructorPat { .. } => "ConstructorPat",
313 NodeKind::RecordPat { .. } => "RecordPat",
314 NodeKind::TuplePat { .. } => "TuplePat",
315 NodeKind::ListPat { .. } => "ListPat",
316 NodeKind::OrPat { .. } => "OrPat",
317 NodeKind::GuardPat { .. } => "GuardPat",
318 NodeKind::RangePat { .. } => "RangePat",
319 _ => "Other",
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use bock_air::{NodeIdGen, NodeKind};
327 use bock_errors::Span;
328
329 fn match_node() -> AIRNode {
330 let gen = NodeIdGen::new();
331 let scrutinee = AIRNode::new(
332 gen.next(),
333 Span::dummy(),
334 NodeKind::Block {
335 stmts: Vec::new(),
336 tail: None,
337 },
338 );
339 AIRNode::new(
340 gen.next(),
341 Span::dummy(),
342 NodeKind::Match {
343 scrutinee: Box::new(scrutinee),
344 arms: Vec::new(),
345 },
346 )
347 }
348
349 fn candidate() -> CandidateRule {
350 CandidateRule {
351 target_id: "js".into(),
352 pattern: "match on string scrutinee".into(),
353 template: "switch ({{ scrutinee }}) { {{ arms }} }".into(),
354 priority: 10,
355 }
356 }
357
358 #[test]
359 fn candidate_lifts_to_extracted_rule() {
360 let rule = Rule::from_candidate(&candidate(), "Match", 0.88);
361 assert_eq!(rule.provenance, Provenance::Extracted);
362 assert_eq!(rule.node_kind, "Match");
363 assert_eq!(rule.target_id, "js");
364 assert!(!rule.pinned);
365 assert!((rule.confidence - 0.88).abs() < f64::EPSILON);
366 }
367
368 #[test]
369 fn rule_id_is_stable_across_calls() {
370 let a = compute_rule_id("js", "Match", "switch x {}");
371 let b = compute_rule_id("js", "Match", "switch x {}");
372 assert_eq!(a, b);
373 let c = compute_rule_id("js", "Match", "switch y {}");
374 assert_ne!(a, c);
375 }
376
377 #[test]
378 fn insert_then_load() {
379 let dir = tempfile::tempdir().unwrap();
380 let cache = RuleCache::new(dir.path());
381 let rule = Rule::from_candidate(&candidate(), "Match", 0.9);
382 cache.insert(&rule).unwrap();
383
384 let loaded = cache.load_for_target("js").unwrap();
385 assert_eq!(loaded.len(), 1);
386 assert_eq!(loaded[0].id, rule.id);
387 assert_eq!(loaded[0].node_kind, "Match");
388 }
389
390 #[test]
391 fn lookup_matches_by_node_kind() {
392 let dir = tempfile::tempdir().unwrap();
393 let cache = RuleCache::new(dir.path());
394 let rule = Rule::from_candidate(&candidate(), "Match", 0.9);
395 cache.insert(&rule).unwrap();
396
397 let hit = cache.lookup("js", &match_node(), false).unwrap();
398 assert!(hit.is_some());
399 assert_eq!(hit.unwrap().node_kind, "Match");
400 }
401
402 #[test]
403 fn lookup_misses_on_different_kind() {
404 let dir = tempfile::tempdir().unwrap();
405 let cache = RuleCache::new(dir.path());
406 let rule = Rule::from_candidate(&candidate(), "Call", 0.9);
407 cache.insert(&rule).unwrap();
408
409 let hit = cache.lookup("js", &match_node(), false).unwrap();
410 assert!(hit.is_none());
411 }
412
413 #[test]
414 fn production_mode_ignores_unpinned_rules() {
415 let dir = tempfile::tempdir().unwrap();
416 let cache = RuleCache::new(dir.path());
417 let rule = Rule::from_candidate(&candidate(), "Match", 0.9);
418 cache.insert(&rule).unwrap();
419
420 assert!(cache.lookup("js", &match_node(), true).unwrap().is_none());
421
422 let mut pinned = rule.clone();
423 pinned.pinned = true;
424 pinned.id = format!("{}-pinned", rule.id);
425 cache.insert(&pinned).unwrap();
426
427 let hit = cache.lookup("js", &match_node(), true).unwrap().unwrap();
428 assert!(hit.pinned);
429 }
430
431 #[test]
432 fn lookup_misses_on_empty_directory() {
433 let dir = tempfile::tempdir().unwrap();
434 let cache = RuleCache::new(dir.path());
435 let hit = cache.lookup("js", &match_node(), false).unwrap();
436 assert!(hit.is_none());
437 }
438
439 #[test]
440 fn load_skips_non_json_files() {
441 let dir = tempfile::tempdir().unwrap();
442 let cache = RuleCache::new(dir.path());
443 fs::create_dir_all(cache.target_dir("js")).unwrap();
444 fs::write(cache.target_dir("js").join("junk.txt"), "not json").unwrap();
445 let rules = cache.load_for_target("js").unwrap();
446 assert!(rules.is_empty());
447 }
448
449 #[test]
450 fn priority_breaks_ties() {
451 let dir = tempfile::tempdir().unwrap();
452 let cache = RuleCache::new(dir.path());
453
454 let low = Rule {
455 id: "low".into(),
456 target_id: "js".into(),
457 node_kind: "Match".into(),
458 pattern: "low".into(),
459 template: "low()".into(),
460 provenance: Provenance::Extracted,
461 pinned: false,
462 confidence: 0.5,
463 priority: 1,
464 created: Utc::now(),
465 };
466 let high = Rule {
467 id: "high".into(),
468 priority: 99,
469 template: "high()".into(),
470 ..low.clone()
471 };
472 cache.insert(&low).unwrap();
473 cache.insert(&high).unwrap();
474
475 let hit = cache.lookup("js", &match_node(), false).unwrap().unwrap();
476 assert_eq!(hit.id, "high");
477 }
478
479 #[test]
480 fn node_kind_name_covers_common_variants() {
481 let gen = NodeIdGen::new();
482 let block = AIRNode::new(
483 gen.next(),
484 Span::dummy(),
485 NodeKind::Block {
486 stmts: Vec::new(),
487 tail: None,
488 },
489 );
490 assert_eq!(node_kind_name(&block.kind), "Block");
491 }
492}