1use anyhow::{Result, anyhow};
30use mangle_ir::physical::{Aggregate, CmpOp, Condition, Constant, DataSource, Expr, Op, Operand};
31use mangle_ir::{Ir, NameId};
32use std::collections::HashMap;
33
34pub use mangle_common::{Store, Value};
35
36#[derive(Default)]
39pub struct MemStore {
40 stable: HashMap<String, Vec<Vec<Value>>>,
42 delta: HashMap<String, Vec<Vec<Value>>>,
44 next_delta: HashMap<String, Vec<Vec<Value>>>,
46
47 stable_indexes: HashMap<(String, usize), HashMap<Value, Vec<usize>>>,
51 delta_indexes: HashMap<(String, usize), HashMap<Value, Vec<usize>>>,
52}
53
54impl MemStore {
55 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn create_relation(&mut self, relation: &str) {
61 self.stable.entry(relation.to_string()).or_default();
62 }
63
64 pub fn add_fact(&mut self, relation: &str, args: Vec<Value>) {
66 let table = self.stable.entry(relation.to_string()).or_default();
67 if !table.contains(&args) {
68 let row_idx = table.len();
69 table.push(args.clone());
70 for (col_idx, val) in args.into_iter().enumerate() {
72 self.stable_indexes
73 .entry((relation.to_string(), col_idx))
74 .or_default()
75 .entry(val)
76 .or_default()
77 .push(row_idx);
78 }
79 }
80 }
81
82 fn rebuild_indexes_for(&mut self, relation: &str) {
84 self.stable_indexes.retain(|(rel, _), _| rel != relation);
86 self.delta_indexes.retain(|(rel, _), _| rel != relation);
87
88 if let Some(table) = self.stable.get(relation) {
90 for (row_idx, tuple) in table.iter().enumerate() {
91 for (col_idx, val) in tuple.iter().enumerate() {
92 self.stable_indexes
93 .entry((relation.to_string(), col_idx))
94 .or_default()
95 .entry(val.clone())
96 .or_default()
97 .push(row_idx);
98 }
99 }
100 }
101
102 if let Some(table) = self.delta.get(relation) {
104 for (row_idx, tuple) in table.iter().enumerate() {
105 for (col_idx, val) in tuple.iter().enumerate() {
106 self.delta_indexes
107 .entry((relation.to_string(), col_idx))
108 .or_default()
109 .entry(val.clone())
110 .or_default()
111 .push(row_idx);
112 }
113 }
114 }
115 }
116
117 pub fn get_facts(&self, relation: &str) -> Vec<Vec<Value>> {
118 let mut all = self.stable.get(relation).cloned().unwrap_or_default();
119 if let Some(d) = self.delta.get(relation) {
120 all.extend(d.iter().cloned());
121 }
122 all
123 }
124}
125
126impl Store for MemStore {
127 fn scan(&self, relation: &str) -> Result<Box<dyn Iterator<Item = Vec<Value>> + '_>> {
128 let s = self.stable.get(relation).into_iter().flatten().cloned();
129 let d = self.delta.get(relation).into_iter().flatten().cloned();
130 Ok(Box::new(s.chain(d)))
131 }
132
133 fn scan_delta(&self, relation: &str) -> Result<Box<dyn Iterator<Item = Vec<Value>> + '_>> {
134 match self.delta.get(relation) {
135 Some(tuples) => Ok(Box::new(tuples.iter().cloned())),
136 None => Ok(Box::new(std::iter::empty())),
137 }
138 }
139
140 fn scan_next_delta(&self, relation: &str) -> Result<Box<dyn Iterator<Item = Vec<Value>> + '_>> {
141 match self.next_delta.get(relation) {
142 Some(tuples) => Ok(Box::new(tuples.iter().cloned())),
143 None => Ok(Box::new(std::iter::empty())),
144 }
145 }
146
147 fn scan_index(
148 &self,
149 relation: &str,
150 col_idx: usize,
151 key: &Value,
152 ) -> Result<Box<dyn Iterator<Item = Vec<Value>> + '_>> {
153 let mut results: Vec<Vec<Value>> = Vec::new();
154
155 if let Some(idx_map) = self.stable_indexes.get(&(relation.to_string(), col_idx))
156 && let Some(row_indices) = idx_map.get(key)
157 && let Some(table) = self.stable.get(relation)
158 {
159 for &i in row_indices {
160 results.push(table[i].clone());
161 }
162 }
163
164 if let Some(idx_map) = self.delta_indexes.get(&(relation.to_string(), col_idx))
165 && let Some(row_indices) = idx_map.get(key)
166 && let Some(table) = self.delta.get(relation)
167 {
168 for &i in row_indices {
169 results.push(table[i].clone());
170 }
171 }
172
173 Ok(Box::new(results.into_iter()))
174 }
175
176 fn scan_delta_index(
177 &self,
178 relation: &str,
179 col_idx: usize,
180 key: &Value,
181 ) -> Result<Box<dyn Iterator<Item = Vec<Value>> + '_>> {
182 let mut results: Vec<Vec<Value>> = Vec::new();
183
184 if let Some(idx_map) = self.delta_indexes.get(&(relation.to_string(), col_idx))
185 && let Some(row_indices) = idx_map.get(key)
186 && let Some(table) = self.delta.get(relation)
187 {
188 for &i in row_indices {
189 results.push(table[i].clone());
190 }
191 }
192
193 Ok(Box::new(results.into_iter()))
194 }
195
196 fn insert(&mut self, relation: &str, tuple: Vec<Value>) -> Result<bool> {
197 if self
199 .stable
200 .get(relation)
201 .is_some_and(|v| v.contains(&tuple))
202 || self.delta.get(relation).is_some_and(|v| v.contains(&tuple))
203 || self
204 .next_delta
205 .get(relation)
206 .is_some_and(|v| v.contains(&tuple))
207 {
208 return Ok(false);
209 }
210
211 self.next_delta
212 .entry(relation.to_string())
213 .or_default()
214 .push(tuple);
215 Ok(true)
216 }
217
218 fn merge_deltas(&mut self) {
219 for (rel_name, mut tuples) in self.delta.drain() {
221 let table = self.stable.entry(rel_name.clone()).or_default();
222 for tuple in tuples.drain(..) {
223 let row_idx = table.len();
224 for (col_idx, val) in tuple.iter().enumerate() {
226 self.stable_indexes
227 .entry((rel_name.clone(), col_idx))
228 .or_default()
229 .entry(val.clone())
230 .or_default()
231 .push(row_idx);
232 }
233 table.push(tuple);
234 }
235 }
236 self.delta_indexes.clear();
237
238 self.delta = std::mem::take(&mut self.next_delta);
240 for (rel_name, tuples) in &self.delta {
241 for (row_idx, tuple) in tuples.iter().enumerate() {
242 for (col_idx, val) in tuple.iter().enumerate() {
243 self.delta_indexes
244 .entry((rel_name.clone(), col_idx))
245 .or_default()
246 .entry(val.clone())
247 .or_default()
248 .push(row_idx);
249 }
250 }
251 }
252 }
253
254 fn create_relation(&mut self, relation: &str) {
255 self.stable.entry(relation.to_string()).or_default();
256 }
257
258 fn retract(&mut self, relation: &str, tuple: &[Value]) -> Result<bool> {
259 let removed = if let Some(table) = self.stable.get_mut(relation) {
260 if let Some(pos) = table.iter().position(|t| t.as_slice() == tuple) {
261 table.swap_remove(pos);
262 true
263 } else {
264 false
265 }
266 } else {
267 false
268 };
269
270 if let Some(table) = self.delta.get_mut(relation) {
272 if let Some(pos) = table.iter().position(|t| t.as_slice() == tuple) {
273 table.swap_remove(pos);
274 }
275 }
276 if let Some(table) = self.next_delta.get_mut(relation) {
277 if let Some(pos) = table.iter().position(|t| t.as_slice() == tuple) {
278 table.swap_remove(pos);
279 }
280 }
281
282 if removed {
283 self.rebuild_indexes_for(relation);
284 }
285 Ok(removed)
286 }
287
288 fn clear(&mut self, relation: &str) {
289 if let Some(table) = self.stable.get_mut(relation) {
290 table.clear();
291 }
292 if let Some(table) = self.delta.get_mut(relation) {
293 table.clear();
294 }
295 if let Some(table) = self.next_delta.get_mut(relation) {
296 table.clear();
297 }
298 self.stable_indexes.retain(|(rel, _), _| rel != relation);
300 self.delta_indexes.retain(|(rel, _), _| rel != relation);
301 }
302
303 fn relation_names(&self) -> Vec<String> {
304 self.stable.keys().cloned().collect()
305 }
306}
307
308#[derive(Debug, Clone)]
310pub struct ProvenanceEntry {
311 pub derived: (String, Vec<Value>),
313 pub premises: Vec<(String, Vec<Value>)>,
315}
316
317#[derive(Default)]
323pub struct ProvenanceRecorder {
324 pub entries: Vec<ProvenanceEntry>,
326 active_premises: Vec<(String, Vec<Value>)>,
328}
329
330pub struct Interpreter<'a> {
332 ir: &'a Ir,
333 store: Box<dyn Store + 'a>,
334 provenance: Option<ProvenanceRecorder>,
335}
336
337struct Env {
338 vars: HashMap<NameId, Value>,
339}
340
341impl Env {
342 fn new() -> Self {
343 Self {
344 vars: HashMap::new(),
345 }
346 }
347}
348
349impl<'a> Interpreter<'a> {
350 pub fn new(ir: &'a Ir, store: Box<dyn Store + 'a>) -> Self {
351 Self {
352 ir,
353 store,
354 provenance: None,
355 }
356 }
357
358 pub fn with_provenance(mut self) -> Self {
361 self.provenance = Some(ProvenanceRecorder::default());
362 self
363 }
364
365 pub fn store(&self) -> &dyn Store {
367 &*self.store
368 }
369
370 pub fn store_mut(&mut self) -> &mut dyn Store {
372 &mut *self.store
373 }
374
375 pub fn into_store(self) -> Box<dyn Store + 'a> {
377 self.store
378 }
379
380 pub fn into_provenance(self) -> Option<ProvenanceRecorder> {
382 self.provenance
383 }
384
385 pub fn into_parts(self) -> (Box<dyn Store + 'a>, Option<ProvenanceRecorder>) {
387 (self.store, self.provenance)
388 }
389
390 pub fn execute(&mut self, op: &Op) -> Result<usize> {
392 let mut env = Env::new();
393 self.exec_op(op, &mut env)
394 }
395
396 fn exec_op(&mut self, op: &Op, env: &mut Env) -> Result<usize> {
397 match op {
398 Op::Nop => Ok(0),
399 Op::Seq(ops) => {
400 let mut count = 0;
401 for o in ops {
402 count += self.exec_op(o, env)?;
403 }
404 Ok(count)
405 }
406 Op::Iterate { source, body } => {
407 let mut count = 0;
408 match source {
409 DataSource::Scan { relation, vars } => {
410 let rel_name = self.ir.resolve_name(*relation);
411 let iter = self.store.scan(rel_name)?;
412 let tuples: Vec<_> = iter.collect();
413
414 for tuple in tuples {
415 if tuple.len() != vars.len() {
416 continue;
417 }
418 for (i, var) in vars.iter().enumerate() {
419 env.vars.insert(*var, tuple[i].clone());
420 }
421 if let Some(ref mut prov) = self.provenance {
422 prov.active_premises
423 .push((rel_name.to_string(), tuple.clone()));
424 }
425 count += self.exec_op(body, env)?;
426 if self.provenance.is_some() {
427 self.provenance.as_mut().unwrap().active_premises.pop();
428 }
429 }
430 }
431 DataSource::ScanDelta { relation, vars } => {
432 let rel_name = self.ir.resolve_name(*relation);
433 let iter = self.store.scan_delta(rel_name)?;
434 let tuples: Vec<_> = iter.collect();
435
436 for tuple in tuples {
437 if tuple.len() != vars.len() {
438 continue;
439 }
440 for (i, var) in vars.iter().enumerate() {
441 env.vars.insert(*var, tuple[i].clone());
442 }
443 if let Some(ref mut prov) = self.provenance {
444 prov.active_premises
445 .push((rel_name.to_string(), tuple.clone()));
446 }
447 count += self.exec_op(body, env)?;
448 if self.provenance.is_some() {
449 self.provenance.as_mut().unwrap().active_premises.pop();
450 }
451 }
452 }
453 DataSource::IndexLookup {
454 relation,
455 col_idx,
456 key,
457 vars,
458 } => {
459 let rel_name = self.ir.resolve_name(*relation);
460 let key_val = self.eval_operand(key, env)?;
461
462 let iter = self.store.scan_index(rel_name, *col_idx, &key_val)?;
463 let tuples: Vec<_> = iter.collect();
464
465 for tuple in tuples {
466 if tuple.len() != vars.len() {
467 continue;
468 }
469 for (i, var) in vars.iter().enumerate() {
470 env.vars.insert(*var, tuple[i].clone());
471 }
472 if let Some(ref mut prov) = self.provenance {
473 prov.active_premises
474 .push((rel_name.to_string(), tuple.clone()));
475 }
476 count += self.exec_op(body, env)?;
477 if self.provenance.is_some() {
478 self.provenance.as_mut().unwrap().active_premises.pop();
479 }
480 }
481 }
482 }
483 Ok(count)
484 }
485 Op::Filter { cond, body } => {
486 if self.eval_cond(cond, env)? {
487 self.exec_op(body, env)
488 } else {
489 Ok(0)
490 }
491 }
492 Op::Insert { relation, args } => {
493 let rel_name = self.ir.resolve_name(*relation);
494 let mut tuple = Vec::new();
495 for arg in args {
496 tuple.push(self.eval_operand(arg, env)?);
497 }
498 let is_new = self.store.insert(rel_name, tuple.clone())?;
499 if is_new {
500 if let Some(ref mut prov) = self.provenance {
501 prov.entries.push(ProvenanceEntry {
502 derived: (rel_name.to_string(), tuple),
503 premises: prov.active_premises.clone(),
504 });
505 }
506 Ok(1)
507 } else {
508 Ok(0)
509 }
510 }
511 Op::Let { var, expr, body } => {
512 let val = self.eval_expr(expr, env)?;
513 env.vars.insert(*var, val);
514 self.exec_op(body, env)
515 }
516 Op::GroupBy {
517 source,
518 vars,
519 keys,
520 aggregates,
521 body,
522 } => {
523 let rel_name = self.ir.resolve_name(*source);
524
525 let iter = self.store.scan(rel_name)?;
528 let mut tuples: Vec<_> = iter.collect();
529
530 if let Ok(nd_iter) = self.store.scan_next_delta(rel_name) {
532 tuples.extend(nd_iter);
533 }
534
535 let mut groups: HashMap<Vec<Value>, Vec<Vec<Value>>> = HashMap::new();
536
537 for tuple in tuples {
538 if tuple.len() != vars.len() {
539 continue;
540 }
541 for (i, var) in vars.iter().enumerate() {
543 env.vars.insert(*var, tuple[i].clone());
544 }
545
546 let mut key = Vec::new();
547 for k in keys {
548 if let Some(val) = env.vars.get(k) {
549 key.push(val.clone());
550 } else {
551 key.push(Value::Null);
553 }
554 }
555 groups.entry(key).or_default().push(tuple);
556 }
557
558 let mut count = 0;
559 for (key, group_tuples) in groups {
560 for (i, k) in keys.iter().enumerate() {
562 env.vars.insert(*k, key[i].clone());
563 }
564
565 for agg in aggregates {
567 let val = self.eval_aggregate(agg, &group_tuples, vars, env)?;
568 env.vars.insert(agg.var, val);
569 }
570
571 count += self.exec_op(body, env)?;
572 }
573 Ok(count)
574 }
575 }
576 }
577
578 fn eval_aggregate(
579 &self,
580 agg: &Aggregate,
581 group: &[Vec<Value>],
582 vars: &[NameId],
583 env: &mut Env,
584 ) -> Result<Value> {
585 let fn_name = self.ir.resolve_name(agg.func);
586 match fn_name {
587 "fn:count" => Ok(Value::Number(group.len() as i64)),
588 "fn:sum" => {
589 let mut sum = 0;
590 let arg = agg
592 .args
593 .first()
594 .ok_or_else(|| anyhow!("fn:sum requires 1 argument"))?;
595
596 for tuple in group {
597 for (i, var) in vars.iter().enumerate() {
599 env.vars.insert(*var, tuple[i].clone());
600 }
601 let val = self.eval_operand(arg, env)?;
602 if let Value::Number(n) = val {
603 sum += n;
604 }
605 }
606 Ok(Value::Number(sum))
607 }
608 "fn:max" => {
609 let mut max_val = None;
610 let arg = agg
611 .args
612 .first()
613 .ok_or_else(|| anyhow!("fn:max requires 1 argument"))?;
614
615 for tuple in group {
616 for (i, var) in vars.iter().enumerate() {
617 env.vars.insert(*var, tuple[i].clone());
618 }
619 let val = self.eval_operand(arg, env)?;
620 match max_val {
621 None => max_val = Some(val),
622 Some(ref m) => {
623 if val > *m {
624 max_val = Some(val);
625 }
626 }
627 }
628 }
629 max_val.ok_or_else(|| anyhow!("fn:max on empty group"))
630 }
631 "fn:min" => {
632 let mut min_val = None;
633 let arg = agg
634 .args
635 .first()
636 .ok_or_else(|| anyhow!("fn:min requires 1 argument"))?;
637
638 for tuple in group {
639 for (i, var) in vars.iter().enumerate() {
640 env.vars.insert(*var, tuple[i].clone());
641 }
642 let val = self.eval_operand(arg, env)?;
643 match min_val {
644 None => min_val = Some(val),
645 Some(ref m) => {
646 if val < *m {
647 min_val = Some(val);
648 }
649 }
650 }
651 }
652 min_val.ok_or_else(|| anyhow!("fn:min on empty group"))
653 }
654 _ => Err(anyhow!("Unknown aggregation function: {fn_name}")),
655 }
656 }
657
658 fn eval_cond(&self, cond: &Condition, env: &Env) -> Result<bool> {
659 match cond {
660 Condition::Cmp { op, left, right } => {
661 let l = self.eval_operand(left, env)?;
662 let r = self.eval_operand(right, env)?;
663 match op {
664 CmpOp::Eq => Ok(l == r),
665 CmpOp::Neq => Ok(l != r),
666 CmpOp::Lt => Ok(l < r),
667 CmpOp::Le => Ok(l <= r),
668 CmpOp::Gt => Ok(l > r),
669 CmpOp::Ge => Ok(l >= r),
670 }
671 }
672 Condition::Negation { relation, args } => {
673 let rel_name = self.ir.resolve_name(*relation);
674 let iter = self.store.scan(rel_name)?;
675 for tuple in iter {
676 let mut mat = true;
677 if tuple.len() != args.len() {
678 continue;
679 }
680 for (i, arg) in args.iter().enumerate() {
681 let val = self.eval_operand(arg, env)?;
682 if tuple[i] != val {
683 mat = false;
684 break;
685 }
686 }
687 if mat {
688 return Ok(false); }
690 }
691 Ok(true) }
693 Condition::Call { .. } => {
694 Ok(true)
696 }
697 }
698 }
699
700 fn eval_expr(&self, expr: &Expr, env: &Env) -> Result<Value> {
701 match expr {
702 Expr::Value(op) => self.eval_operand(op, env),
703 Expr::Call { function, args } => {
704 let fn_name = self.ir.resolve_name(*function);
705 let mut vals = Vec::new();
706 for arg in args {
707 vals.push(self.eval_operand(arg, env)?);
708 }
709 match fn_name {
710 "fn:plus" => {
711 if let (Value::Number(a), Value::Number(b)) = (&vals[0], &vals[1]) {
712 Ok(Value::Number(a + b))
713 } else {
714 Err(anyhow!("Type mismatch for fn:plus"))
715 }
716 }
717 "fn:minus" => {
718 if let (Value::Number(a), Value::Number(b)) = (&vals[0], &vals[1]) {
719 Ok(Value::Number(a - b))
720 } else {
721 Err(anyhow!("Type mismatch for fn:minus"))
722 }
723 }
724 _ => Err(anyhow!("Unknown function: {fn_name}")),
725 }
726 }
727 }
728 }
729
730 fn eval_operand(&self, op: &Operand, env: &Env) -> Result<Value> {
731 match op {
732 Operand::Var(v) => env
733 .vars
734 .get(v)
735 .cloned()
736 .ok_or_else(|| anyhow!("Variable not found")),
737 Operand::Const(c) => match c {
738 Constant::Number(n) => Ok(Value::Number(*n)),
739 Constant::String(sid) => {
740 Ok(Value::String(self.ir.resolve_string(*sid).to_string()))
741 }
742 Constant::Name(nid) => Ok(Value::String(self.ir.resolve_name(*nid).to_string())),
743 },
744 }
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751
752 #[test]
753 fn test_retract_existing() {
754 let mut store = MemStore::new();
755 store.add_fact("r", vec![Value::Number(1), Value::Number(2)]);
756 store.add_fact("r", vec![Value::Number(3), Value::Number(4)]);
757
758 let removed = store
759 .retract("r", &[Value::Number(1), Value::Number(2)])
760 .unwrap();
761 assert!(removed);
762
763 let facts = store.get_facts("r");
764 assert_eq!(facts.len(), 1);
765 assert_eq!(facts[0], vec![Value::Number(3), Value::Number(4)]);
766 }
767
768 #[test]
769 fn test_retract_nonexistent() {
770 let mut store = MemStore::new();
771 store.add_fact("r", vec![Value::Number(1)]);
772
773 let removed = store.retract("r", &[Value::Number(99)]).unwrap();
774 assert!(!removed);
775
776 let facts = store.get_facts("r");
777 assert_eq!(facts.len(), 1);
778 assert_eq!(facts[0], vec![Value::Number(1)]);
779 }
780
781 #[test]
782 fn test_clear() {
783 let mut store = MemStore::new();
784 store.add_fact("r", vec![Value::Number(1)]);
785 store.add_fact("r", vec![Value::Number(2)]);
786 store.add_fact("s", vec![Value::Number(10)]);
787
788 store.clear("r");
789
790 let r_facts = store.get_facts("r");
791 assert!(r_facts.is_empty());
792
793 let s_facts = store.get_facts("s");
795 assert_eq!(s_facts.len(), 1);
796 }
797
798 #[test]
799 fn test_relation_names() {
800 let mut store = MemStore::new();
801 store.create_relation("alpha");
802 store.create_relation("beta");
803 store.add_fact("gamma", vec![Value::Number(1)]);
804
805 let mut names = store.relation_names();
806 names.sort();
807 assert_eq!(names, vec!["alpha", "beta", "gamma"]);
808 }
809
810 #[test]
811 fn test_provenance_recording() {
812 use mangle_ir::physical::{DataSource, Operand};
813
814 let mut ir = mangle_ir::Ir::new();
816 let base_name = ir.intern_name("base");
817 let derived_name = ir.intern_name("derived");
818 let var_x = ir.intern_name("X");
819
820 let op = Op::Iterate {
822 source: DataSource::Scan {
823 relation: base_name,
824 vars: vec![var_x],
825 },
826 body: Box::new(Op::Insert {
827 relation: derived_name,
828 args: vec![Operand::Var(var_x)],
829 }),
830 };
831
832 let mut store = Box::new(MemStore::new());
833 store.add_fact("base", vec![Value::Number(10)]);
834 store.add_fact("base", vec![Value::Number(20)]);
835 store.create_relation("derived");
836
837 let mut interpreter = Interpreter::new(&ir, store as Box<dyn Store>).with_provenance();
838
839 let count = interpreter.execute(&op).unwrap();
840 assert_eq!(count, 2);
841
842 let prov = interpreter.provenance.as_ref().unwrap();
844 assert_eq!(prov.entries.len(), 2);
845
846 for entry in &prov.entries {
848 assert_eq!(entry.derived.0, "derived");
849 assert_eq!(entry.premises.len(), 1);
850 assert_eq!(entry.premises[0].0, "base");
851 }
852
853 let mut derived_vals: Vec<i64> = prov
855 .entries
856 .iter()
857 .map(|e| match &e.derived.1[0] {
858 Value::Number(n) => *n,
859 _ => panic!("expected number"),
860 })
861 .collect();
862 derived_vals.sort();
863 assert_eq!(derived_vals, vec![10, 20]);
864 }
865}