1use std::collections::{HashMap, HashSet};
19
20use mangle_common::Value;
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct FactKey {
25 pub relation: String,
26 pub tuple: Vec<Value>,
27}
28
29#[derive(Debug, Clone)]
31pub struct Derivation {
32 pub rule_id: usize,
33 pub premises: Vec<FactKey>,
34}
35
36pub struct ProvenanceIndex {
38 derivations: HashMap<FactKey, Vec<Derivation>>,
40 dependents: HashMap<FactKey, HashSet<FactKey>>,
42}
43
44impl ProvenanceIndex {
45 pub fn new() -> Self {
46 Self {
47 derivations: HashMap::new(),
48 dependents: HashMap::new(),
49 }
50 }
51
52 pub fn record(&mut self, derived: FactKey, rule_id: usize, premises: Vec<FactKey>) {
55 for premise in &premises {
57 self.dependents
58 .entry(premise.clone())
59 .or_default()
60 .insert(derived.clone());
61 }
62
63 self.derivations
65 .entry(derived)
66 .or_default()
67 .push(Derivation { rule_id, premises });
68 }
69
70 pub fn get_derivations(&self, fact: &FactKey) -> Option<&[Derivation]> {
72 self.derivations.get(fact).map(|v| v.as_slice())
73 }
74
75 pub fn get_dependents(&self, fact: &FactKey) -> Option<&HashSet<FactKey>> {
77 self.dependents.get(fact)
78 }
79
80 pub fn delete_phase(&mut self, retracted: &FactKey) -> HashSet<FactKey> {
83 let mut to_delete = HashSet::new();
84 let mut worklist = vec![retracted.clone()];
85
86 while let Some(fact) = worklist.pop() {
87 let dependents = self.dependents.get(&fact).cloned().unwrap_or_default();
89
90 for dependent in dependents {
91 if to_delete.contains(&dependent) {
92 continue;
93 }
94
95 if let Some(derivations) = self.derivations.get_mut(&dependent) {
97 derivations.retain(|d| !d.premises.contains(&fact));
98
99 if derivations.is_empty() {
100 to_delete.insert(dependent.clone());
102 worklist.push(dependent);
103 }
104 }
105 }
106 }
107
108 for fact in &to_delete {
110 self.derivations.remove(fact);
111 self.dependents.remove(fact);
112 }
113 for deps in self.dependents.values_mut() {
115 deps.retain(|d| !to_delete.contains(d));
116 }
117
118 to_delete
119 }
120}
121
122impl Default for ProvenanceIndex {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 fn fact(rel: &str, tuple: Vec<Value>) -> FactKey {
133 FactKey {
134 relation: rel.to_string(),
135 tuple,
136 }
137 }
138
139 #[test]
140 fn test_provenance_basic() {
141 let mut idx = ProvenanceIndex::new();
142
143 let edge12 = fact("edge", vec![Value::Number(1), Value::Number(2)]);
144 let reach12 = fact("reachable", vec![Value::Number(1), Value::Number(2)]);
145
146 idx.record(reach12.clone(), 0, vec![edge12.clone()]);
147
148 assert_eq!(idx.get_derivations(&reach12).unwrap().len(), 1);
149 assert!(idx.get_dependents(&edge12).unwrap().contains(&reach12));
150 }
151
152 #[test]
153 fn test_dred_simple() {
154 let mut idx = ProvenanceIndex::new();
155
156 let edge12 = fact("edge", vec![Value::Number(1), Value::Number(2)]);
157 let edge23 = fact("edge", vec![Value::Number(2), Value::Number(3)]);
158 let reach12 = fact("reachable", vec![Value::Number(1), Value::Number(2)]);
159 let reach23 = fact("reachable", vec![Value::Number(2), Value::Number(3)]);
160 let reach13 = fact("reachable", vec![Value::Number(1), Value::Number(3)]);
161
162 idx.record(reach12.clone(), 0, vec![edge12.clone()]);
164 idx.record(reach23.clone(), 0, vec![edge23.clone()]);
166 idx.record(reach13.clone(), 1, vec![reach12.clone(), edge23.clone()]);
168
169 let deleted = idx.delete_phase(&edge12);
171
172 assert!(deleted.contains(&reach12));
173 assert!(deleted.contains(&reach13));
174 assert!(!deleted.contains(&reach23)); }
176
177 #[test]
178 fn test_dred_multiple_derivations() {
179 let mut idx = ProvenanceIndex::new();
180
181 let a = fact("a", vec![Value::Number(1)]);
182 let b = fact("b", vec![Value::Number(1)]);
183 let c = fact("c", vec![Value::Number(1)]);
184
185 idx.record(c.clone(), 0, vec![a.clone()]);
187 idx.record(c.clone(), 1, vec![b.clone()]);
188
189 let deleted = idx.delete_phase(&a);
191
192 assert!(deleted.is_empty()); assert_eq!(idx.get_derivations(&c).unwrap().len(), 1);
194 }
195}