egglog_core_relations/uf/
mod.rs

1//! A table implementation backed by a union-find.
2
3use std::{
4    any::Any,
5    mem,
6    sync::{Arc, Weak},
7};
8
9use crate::numeric_id::{DenseIdMap, NumericId};
10use crossbeam_queue::SegQueue;
11use indexmap::IndexMap;
12use petgraph::{Direction, Graph, algo::dijkstra, graph::NodeIndex, visit::EdgeRef};
13
14use crate::{
15    TableChange, TaggedRowBuffer,
16    action::ExecutionState,
17    common::{HashMap, IndexSet, Value},
18    offsets::{OffsetRange, RowId, Subset, SubsetRef},
19    pool::with_pool_set,
20    row_buffer::RowBuffer,
21    table_spec::{
22        ColumnId, Constraint, Generation, MutationBuffer, Offset, Rebuilder, Row, Table, TableSpec,
23        TableVersion, WrappedTableRef,
24    },
25};
26
27#[cfg(test)]
28mod tests;
29
30type UnionFind = crate::union_find::UnionFind<Value>;
31
32/// A special table backed by a union-find used to efficiently implement
33/// egglog-style canonicaliztion.
34///
35/// To canonicalize columns, we need to efficiently discover values that have
36/// ceased to be canonical. To do that we keep a table of _displaced_ values:
37///
38/// This table has three columns:
39/// 1. (the only key): a value that is _no longer canonical_ in the equivalence relation.
40/// 2. The canonical value of the equivalence class.
41/// 3. The timestamp at which the key stopped being canonical.
42///
43/// We do not store the second value explicitly: instead, we compute it
44/// on-the-fly using a union-find data-structure.
45///
46/// This is related to the 'Leader' encoding in some versions of egglog:
47/// Displaced is a version of Leader that _only_ stores ids when they cease to
48/// be canonical. Rows are also "automatically updated" with the current leader,
49/// rather than requiring the DB to replay history or canonicalize redundant
50/// values in the table.
51///
52/// To union new ids `l`, and `r`, stage an update `Displaced(l, r, ts)` where
53/// `ts` is the current timestamp. Note that all tie-breaks and other encoding
54/// decisions are made internally, so there may not literally be a row added
55/// with this value.
56pub struct DisplacedTable {
57    uf: UnionFind,
58    displaced: Vec<(Value, Value)>,
59    changed: bool,
60    lookup_table: HashMap<Value, RowId>,
61    buffered_writes: Arc<SegQueue<RowBuffer>>,
62}
63
64struct Canonicalizer<'a> {
65    cols: Vec<ColumnId>,
66    table: &'a DisplacedTable,
67}
68
69impl Rebuilder for Canonicalizer<'_> {
70    fn hint_col(&self) -> Option<ColumnId> {
71        Some(ColumnId::new(0))
72    }
73    fn rebuild_val(&self, val: Value) -> Value {
74        self.table.uf.find_naive(val)
75    }
76    fn rebuild_buf(
77        &self,
78        buf: &RowBuffer,
79        start: RowId,
80        end: RowId,
81        out: &mut TaggedRowBuffer,
82        _exec_state: &mut ExecutionState,
83    ) {
84        if start >= end {
85            return;
86        }
87        assert!(end.index() <= buf.len());
88        let mut cur = start;
89        let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
90        // SAFETY: `cur` is always in-bounds, guaranteed by the above assertion.
91        // Special-case small columns: this gives us a modest speedup on rebuilding-heavy
92        // workloads.
93        match self.cols.as_slice() {
94            [c] => {
95                while cur < end {
96                    let row = unsafe { buf.get_row_unchecked(cur) };
97                    let to_canon = row[c.index()];
98                    let canon = self.table.uf.find_naive(to_canon);
99                    if canon != to_canon {
100                        scratch.extend_from_slice(row);
101                        scratch[c.index()] = canon;
102                        out.add_row(cur, &scratch);
103                        scratch.clear();
104                    }
105                    cur = cur.inc();
106                }
107            }
108            [c1, c2] => {
109                while cur < end {
110                    let row = unsafe { buf.get_row_unchecked(cur) };
111                    let v1 = row[c1.index()];
112                    let v2 = row[c2.index()];
113                    let ca1 = self.table.uf.find_naive(v1);
114                    let ca2 = self.table.uf.find_naive(v2);
115                    if ca1 != v1 || ca2 != v2 {
116                        scratch.extend_from_slice(row);
117                        scratch[c1.index()] = ca1;
118                        scratch[c2.index()] = ca2;
119                        out.add_row(cur, &scratch);
120                        scratch.clear();
121                    }
122                    cur = cur.inc();
123                }
124            }
125            [c1, c2, c3] => {
126                while cur < end {
127                    let row = unsafe { buf.get_row_unchecked(cur) };
128                    let v1 = row[c1.index()];
129                    let v2 = row[c2.index()];
130                    let v3 = row[c3.index()];
131                    let ca1 = self.table.uf.find_naive(v1);
132                    let ca2 = self.table.uf.find_naive(v2);
133                    let ca3 = self.table.uf.find_naive(v3);
134                    if ca1 != v1 || ca2 != v2 || ca3 != v3 {
135                        scratch.extend_from_slice(row);
136                        scratch[c1.index()] = ca1;
137                        scratch[c2.index()] = ca2;
138                        scratch[c3.index()] = ca3;
139                        out.add_row(cur, &scratch);
140                        scratch.clear();
141                    }
142                    cur = cur.inc();
143                }
144            }
145            cs => {
146                while cur < end {
147                    scratch.extend_from_slice(unsafe { buf.get_row_unchecked(cur) });
148                    let mut changed = false;
149                    for c in cs {
150                        let to_canon = scratch[c.index()];
151                        let canon = self.table.uf.find_naive(to_canon);
152                        scratch[c.index()] = canon;
153                        changed |= canon != to_canon;
154                    }
155                    if changed {
156                        out.add_row(cur, &scratch);
157                    }
158                    scratch.clear();
159                    cur = cur.inc();
160                }
161            }
162        }
163    }
164    fn rebuild_subset(
165        &self,
166        other: WrappedTableRef,
167        subset: SubsetRef,
168        out: &mut TaggedRowBuffer,
169        _exec_state: &mut ExecutionState,
170    ) {
171        let _next = other.scan_bounded(subset, Offset::new(0), usize::MAX, out);
172        debug_assert!(_next.is_none());
173        for i in 0..u32::try_from(out.len()).expect("row buffer sizes should fit in a u32") {
174            let i = RowId::new(i);
175            let (_id, row) = out.get_row_mut(i);
176            let mut changed = false;
177            for col in &self.cols {
178                let to_canon = row[col.index()];
179                let canon = self.table.uf.find_naive(to_canon);
180                changed |= canon != to_canon;
181                row[col.index()] = canon;
182            }
183            if !changed {
184                out.set_stale(i);
185            }
186        }
187    }
188    fn rebuild_slice(&self, vals: &mut [Value]) -> bool {
189        let mut changed = false;
190        for val in vals {
191            let canon = self.table.uf.find_naive(*val);
192            changed |= canon != *val;
193            *val = canon;
194        }
195        changed
196    }
197}
198
199impl Default for DisplacedTable {
200    fn default() -> Self {
201        Self {
202            uf: UnionFind::default(),
203            displaced: Vec::new(),
204            changed: false,
205            lookup_table: HashMap::default(),
206            buffered_writes: Arc::new(SegQueue::new()),
207        }
208    }
209}
210
211impl Clone for DisplacedTable {
212    fn clone(&self) -> Self {
213        DisplacedTable {
214            uf: self.uf.clone(),
215            displaced: self.displaced.clone(),
216            changed: self.changed,
217            lookup_table: self.lookup_table.clone(),
218            buffered_writes: Default::default(),
219        }
220    }
221}
222
223struct UfBuffer {
224    to_insert: RowBuffer,
225    buffered_writes: Weak<SegQueue<RowBuffer>>,
226}
227
228impl Drop for UfBuffer {
229    fn drop(&mut self) {
230        let Some(buffered_writes) = self.buffered_writes.upgrade() else {
231            return;
232        };
233        let arity = self.to_insert.arity();
234        buffered_writes.push(mem::replace(&mut self.to_insert, RowBuffer::new(arity)));
235    }
236}
237
238impl MutationBuffer for UfBuffer {
239    fn stage_insert(&mut self, row: &[Value]) {
240        self.to_insert.add_row(row);
241    }
242    fn stage_remove(&mut self, _: &[Value]) {
243        panic!("attempting to remove data from a DisplacedTable")
244    }
245    fn fresh_handle(&self) -> Box<dyn MutationBuffer> {
246        Box::new(UfBuffer {
247            to_insert: RowBuffer::new(self.to_insert.arity()),
248            buffered_writes: self.buffered_writes.clone(),
249        })
250    }
251}
252
253impl Table for DisplacedTable {
254    fn dyn_clone(&self) -> Box<dyn Table> {
255        Box::new(self.clone())
256    }
257    fn as_any(&self) -> &dyn Any {
258        self
259    }
260    fn spec(&self) -> TableSpec {
261        let mut uncacheable_columns = DenseIdMap::default();
262        // The second column of this table is determined dynamically by the union-find.
263        uncacheable_columns.insert(ColumnId::new(1), true);
264        TableSpec {
265            n_keys: 1,
266            n_vals: 2,
267            uncacheable_columns,
268            allows_delete: false,
269        }
270    }
271
272    fn rebuilder<'a>(&'a self, cols: &[ColumnId]) -> Option<Box<dyn Rebuilder + 'a>> {
273        Some(Box::new(Canonicalizer {
274            cols: cols.to_vec(),
275            table: self,
276        }))
277    }
278
279    fn clear(&mut self) {
280        self.uf.reset();
281        self.displaced.clear();
282    }
283
284    fn all(&self) -> Subset {
285        Subset::Dense(OffsetRange::new(
286            RowId::new(0),
287            RowId::from_usize(self.displaced.len()),
288        ))
289    }
290
291    fn len(&self) -> usize {
292        self.displaced.len()
293    }
294
295    fn version(&self) -> TableVersion {
296        TableVersion {
297            major: Generation::new(0),
298            minor: Offset::from_usize(self.displaced.len()),
299        }
300    }
301
302    fn updates_since(&self, offset: Offset) -> Subset {
303        Subset::Dense(OffsetRange::new(
304            RowId::from_usize(offset.index()),
305            RowId::from_usize(self.displaced.len()),
306        ))
307    }
308
309    fn scan_generic_bounded(
310        &self,
311        subset: SubsetRef,
312        start: Offset,
313        n: usize,
314        cs: &[Constraint],
315        mut f: impl FnMut(RowId, &[Value]),
316    ) -> Option<Offset>
317    where
318        Self: Sized,
319    {
320        if cs.is_empty() {
321            let start = start.index();
322            subset
323                .iter_bounded(start, start + n, |row| {
324                    f(row, self.expand(row).as_slice());
325                })
326                .map(Offset::from_usize)
327        } else {
328            let start = start.index();
329            subset
330                .iter_bounded(start, start + n, |row| {
331                    if cs.iter().all(|c| self.eval(c, row)) {
332                        f(row, self.expand(row).as_slice());
333                    }
334                })
335                .map(Offset::from_usize)
336        }
337    }
338
339    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
340        subset.retain(|row| self.eval(c, row));
341        subset
342    }
343
344    fn fast_subset(&self, constraint: &Constraint) -> Option<Subset> {
345        let ts = ColumnId::new(2);
346        match constraint {
347            Constraint::Eq { .. } => None,
348            Constraint::EqConst { col, val } => {
349                if *col == ColumnId::new(1) {
350                    return None;
351                }
352                if *col == ColumnId::new(0) {
353                    return Some(match self.lookup_table.get(val) {
354                        Some(row) => Subset::Dense(OffsetRange::new(
355                            *row,
356                            RowId::from_usize(row.index() + 1),
357                        )),
358                        None => Subset::empty(),
359                    });
360                }
361                match self.timestamp_bounds(*val) {
362                    Ok((start, end)) => Some(Subset::Dense(OffsetRange::new(start, end))),
363                    Err(_) => None,
364                }
365            }
366            Constraint::LtConst { col, val } => {
367                if *col != ts {
368                    return None;
369                }
370                match self.timestamp_bounds(*val) {
371                    Err(bound) | Ok((bound, _)) => {
372                        Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
373                    }
374                }
375            }
376            Constraint::GtConst { col, val } => {
377                if *col != ts {
378                    return None;
379                }
380
381                match self.timestamp_bounds(*val) {
382                    Err(bound) | Ok((_, bound)) => Some(Subset::Dense(OffsetRange::new(
383                        bound,
384                        RowId::from_usize(self.displaced.len()),
385                    ))),
386                }
387            }
388            Constraint::LeConst { col, val } => {
389                if *col != ts {
390                    return None;
391                }
392
393                match self.timestamp_bounds(*val) {
394                    Err(bound) | Ok((_, bound)) => {
395                        Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
396                    }
397                }
398            }
399            Constraint::GeConst { col, val } => {
400                if *col != ts {
401                    return None;
402                }
403
404                match self.timestamp_bounds(*val) {
405                    Err(bound) | Ok((bound, _)) => Some(Subset::Dense(OffsetRange::new(
406                        bound,
407                        RowId::from_usize(self.displaced.len()),
408                    ))),
409                }
410            }
411        }
412    }
413
414    fn get_row(&self, key: &[Value]) -> Option<Row> {
415        assert_eq!(key.len(), 1, "attempt to lookup a row with the wrong key");
416        let row_id = *self.lookup_table.get(&key[0])?;
417        let mut vals = with_pool_set(|ps| ps.get::<Vec<Value>>());
418        vals.extend_from_slice(self.expand(row_id).as_slice());
419        Some(Row { id: row_id, vals })
420    }
421
422    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
423        assert_eq!(key.len(), 1, "attempt to lookup a row with the wrong key");
424        if col == ColumnId::new(1) {
425            Some(self.uf.find_naive(key[0]))
426        } else {
427            let row_id = *self.lookup_table.get(&key[0])?;
428            Some(self.expand(row_id)[col.index()])
429        }
430    }
431
432    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
433        Box::new(UfBuffer {
434            to_insert: RowBuffer::new(3),
435            buffered_writes: Arc::downgrade(&self.buffered_writes),
436        })
437    }
438
439    fn merge(&mut self, _: &mut ExecutionState) -> TableChange {
440        while let Some(rowbuf) = self.buffered_writes.pop() {
441            for row in rowbuf.iter() {
442                self.changed |= self.insert_impl(row).is_some();
443            }
444        }
445        let changed = mem::take(&mut self.changed);
446        // UF table rows can be updated "in place", we count both added and removed as changed in
447        // this case.
448        TableChange {
449            added: changed,
450            removed: changed,
451        }
452    }
453}
454
455impl DisplacedTable {
456    pub fn underlying_uf(&self) -> &UnionFind {
457        &self.uf
458    }
459    fn expand(&self, row: RowId) -> [Value; 3] {
460        let (child, ts) = self.displaced[row.index()];
461        [child, self.uf.find_naive(child), ts]
462    }
463    fn timestamp_bounds(&self, val: Value) -> Result<(RowId, RowId), RowId> {
464        match self.displaced.binary_search_by_key(&val, |(_, ts)| *ts) {
465            Ok(mut off) => {
466                let mut next = off;
467                while off > 0 && self.displaced[off - 1].1 == val {
468                    off -= 1;
469                }
470                while next < self.displaced.len() && self.displaced[next].1 == val {
471                    next += 1;
472                }
473                Ok((RowId::from_usize(off), RowId::from_usize(next)))
474            }
475            Err(off) => Err(RowId::from_usize(off)),
476        }
477    }
478    fn eval(&self, constraint: &Constraint, row: RowId) -> bool {
479        let vals = self.expand(row);
480        eval_constraint(&vals, constraint)
481    }
482    fn insert_impl(&mut self, row: &[Value]) -> Option<(Value, Value)> {
483        assert_eq!(row.len(), 3, "attempt to insert a row with the wrong arity");
484        if self.uf.find(row[0]) == self.uf.find(row[1]) {
485            return None;
486        }
487        let (parent, child) = self.uf.union(row[0], row[1]);
488
489        // Compress paths somewhat, given that we perform naive finds everywhere else.
490        let _ = self.uf.find(parent);
491        let _ = self.uf.find(child);
492        let ts = row[2];
493        if let Some((_, highest)) = self.displaced.last() {
494            assert!(
495                *highest <= ts,
496                "must insert rows with increasing timestamps"
497            );
498        }
499        let next = RowId::from_usize(self.displaced.len());
500        self.displaced.push((child, ts));
501        self.lookup_table.insert(child, next);
502        Some((parent, child))
503    }
504}
505
506/// A variant of `DisplacedTable` that also stores "provenance" information that
507/// can be used to generate proofs of equality.
508///
509/// This table expects a fourth "proof" column, though the values it hands back
510/// _are not_ the proofs that come in and generally should not be used directly.
511/// To generate a proof that two values are equal, this table exports a separate
512/// `get_proof` method.
513#[derive(Clone, Default)]
514pub struct DisplacedTableWithProvenance {
515    base: DisplacedTable,
516    /// Added context for a given "displaced" row. We use this to store "proofs
517    /// that x = y".
518    ///
519    /// N.B. We currently only use the first proof that we find. The remaining
520    /// proofs are used for debugging. With some further refactoring we should
521    /// be able to remove this field entirely, as complete proof information is
522    /// now available through `proof_graph`.
523    context: HashMap<(Value, Value), IndexSet<Value>>,
524    proof_graph: Graph<Value, ProofEdge>,
525    node_map: HashMap<Value, NodeIndex>,
526    /// The value that was displaced, the value _immediately_ displacing it.
527    /// NB: this is different from the 'displaced' table in 'base', which holds
528    /// a timestamp.
529    displaced: Vec<(Value, Value)>,
530    buffered_writes: Arc<SegQueue<RowBuffer>>,
531}
532
533#[derive(Copy, Clone, Eq, PartialEq)]
534struct ProofEdge {
535    reason: ProofReason,
536    ts: Value,
537}
538
539#[derive(Clone, Debug, PartialEq, Eq)]
540pub struct ProofStep {
541    pub lhs: Value,
542    pub rhs: Value,
543    pub reason: ProofReason,
544}
545
546#[derive(Debug, PartialEq, Eq, Clone, Copy)]
547pub enum ProofReason {
548    Forward(Value),
549    Backward(Value),
550}
551
552impl DisplacedTableWithProvenance {
553    fn expand(&self, row: RowId) -> [Value; 4] {
554        let [v1, v2, v3] = self.base.expand(row);
555        let (child, parent) = self.displaced[row.index()];
556        debug_assert_eq!(child, v1);
557        let proof = *self.context[&(child, parent)].get_index(0).unwrap();
558        [v1, v2, v3, proof]
559    }
560
561    fn eval(&self, constraint: &Constraint, row: RowId) -> bool {
562        eval_constraint(&self.expand(row), constraint)
563    }
564
565    /// Return the timestamp when `l` and `r` became equal.
566    ///
567    /// This is used to filter possible paths in the proof graph. The algorithm
568    /// we use here is a variant of the classic algorithm in "Proof-Producing
569    /// Congruence Closure" by Nieuwenhuis and Oliveras for reconstructing a
570    /// proof.
571    fn timestamp_when_equal(&self, l: Value, r: Value) -> Option<u32> {
572        if l == r {
573            return Some(0);
574        }
575        let mut l_proofs = IndexMap::new();
576        let mut r_proofs = IndexMap::new();
577        if self.base.uf.find_naive(l) != self.base.uf.find_naive(r) {
578            // The two values aren't equal.
579            return None;
580        }
581        let canon = self.base.uf.find_naive(l);
582
583        // General case: collect individual equality proofs that point from `l`
584        // (sim. `r`) and move towards canon. We stop early and don't always go
585        // to `canon`. To see why consider the following sequences of unions.
586        // For simplicity, we'll assume that the "leader" (or new canonical id)
587        // is always the second argument to `union`.
588        // * left:  A: union(0,2), B: union(2,4), C: union(4,6)
589        // * right: D: union(1,3), E: union(3,5), F: union(5,4), C: union(4,6)
590        // Where `l` `r` are 0 and 1, and their canonical value is `6`.
591        // A simple approach here would be to simply glue the proofs that `l=6`
592        // and `r=6` together, something like:
593        //
594        //    [A;B;C;rev(C);rev(F);rev(E);rev(D)]
595        //
596        // The code below avoids the redundant common suffix (i.e. `C;rev(C)`)
597        // and just uses A,B,D,E, and F.
598        //
599        // In addition to allowing us to generate smaller proofs, this sort of
600        // algorithm also ensures that we are returning the first proof of `l =
601        // r` that we learned about, which is important for avoiding cycles when
602        // reconstructing a proof.
603
604        // General case: create a proof  that l = canon, then compose it with
605        // the proof that r = canon, reversed.
606        for (mut cur, steps) in [(l, &mut l_proofs), (r, &mut r_proofs)] {
607            while cur != canon {
608                // Find where cur became non-canonical.
609                let row = *self.base.lookup_table.get(&cur).unwrap();
610                let (_, ts) = self.base.displaced[row.index()];
611                let (child, parent) = self.displaced[row.index()];
612                debug_assert_eq!(child, cur);
613                steps.insert(parent, ts);
614                cur = parent;
615            }
616        }
617
618        let mut l_end = None;
619        let mut r_start = None;
620
621        if let Some(i) = r_proofs.get_index_of(&l) {
622            r_start = Some(i);
623        } else {
624            for (i, (next_id, _)) in l_proofs.iter().enumerate() {
625                if *next_id == r {
626                    l_end = Some(i);
627                    break;
628                }
629                if let Some(j) = r_proofs.get_index_of(next_id) {
630                    l_end = Some(i);
631                    r_start = Some(j);
632                    break;
633                }
634            }
635        }
636        match (l_end, r_start) {
637            (None, Some(start)) => r_proofs.as_slice()[..=start]
638                .iter()
639                .map(|(_, ts)| ts.rep())
640                .max(),
641            (Some(end), None) => l_proofs.as_slice()[..=end]
642                .iter()
643                .map(|(_, ts)| ts.rep())
644                .max(),
645            (Some(end), Some(start)) => l_proofs.as_slice()[..=end]
646                .iter()
647                .map(|(_, ts)| ts.rep())
648                .chain(r_proofs.as_slice()[..=start].iter().map(|(_, ts)| ts.rep()))
649                .max(),
650            (None, None) => {
651                panic!(
652                    "did not find common id, despite the values being equivalent {l:?} / {r:?}, l_proofs={l_proofs:?}, r_proofs={r_proofs:?}"
653                )
654            }
655        }
656    }
657
658    /// A simple proof generation algorithm that searches for the shortest path
659    /// in the proof graph between `l` and `r`.
660    ///
661    /// The path in the graph is restricted to the timestamps at or before `l`
662    /// and `r` first became equal. This is to avoid cycles during proof
663    /// reconstruction.
664    pub fn get_proof(&self, l: Value, r: Value) -> Option<Vec<ProofStep>> {
665        let ts = self.timestamp_when_equal(l, r)?;
666        let start = self.node_map[&l];
667        let goal = self.node_map[&r];
668        let costs = dijkstra(&self.proof_graph, self.node_map[&l], Some(goal), |edge| {
669            if edge.weight().ts.rep() > ts {
670                // avoid edges added after the two became equal.
671                f64::INFINITY
672            } else {
673                1.0f64
674            }
675        });
676        // Reconstruct the proof steps from the cost map returned from petgraph.
677        // Start at the end and then work backwards along the shortest path.
678        let mut path = Vec::new();
679        let mut cur = goal;
680        while cur != start {
681            let (_, step, next) = self
682                .proof_graph
683                .edges_directed(cur, Direction::Incoming)
684                .filter_map(|edge| {
685                    let source = edge.source();
686                    let cost = costs.get(&source)?;
687                    let step = ProofStep {
688                        lhs: *self.proof_graph.node_weight(source).unwrap(),
689                        rhs: *self.proof_graph.node_weight(edge.target()).unwrap(),
690                        reason: edge.weight().reason,
691                    };
692                    Some((cost, step, source))
693                })
694                .fold(None, |acc, cur| {
695                    // Manually implement 'min' because we are using f64 for costs.
696                    // We should probably switch these edge costs over to NotNan
697                    // or a custom type.
698                    let Some(acc) = acc else {
699                        return Some(cur);
700                    };
701                    Some(if acc.0 > cur.0 { cur } else { acc })
702                })
703                .unwrap();
704            path.push(step);
705            cur = next;
706        }
707        path.reverse();
708        Some(path)
709    }
710    fn get_or_create_node(&mut self, val: Value) -> NodeIndex {
711        *self
712            .node_map
713            .entry(val)
714            .or_insert_with(|| self.proof_graph.add_node(val))
715    }
716
717    fn insert_impl(&mut self, row: &[Value]) {
718        let [a, b, ts, reason] = row else {
719            panic!("attempt to insert a row with the wrong arity ({row:?})");
720        };
721        match self.base.insert_impl(&[*a, *b, *ts]) {
722            Some((parent, child)) => {
723                self.displaced.push((child, parent));
724                self.context
725                    .entry((child, parent))
726                    .or_default()
727                    .insert(*reason);
728                self.base.changed = true;
729
730                let a_node = self.get_or_create_node(*a);
731                let b_node = self.get_or_create_node(*b);
732                self.proof_graph.add_edge(
733                    a_node,
734                    b_node,
735                    ProofEdge {
736                        reason: ProofReason::Forward(*reason),
737                        ts: *ts,
738                    },
739                );
740                self.proof_graph.add_edge(
741                    b_node,
742                    a_node,
743                    ProofEdge {
744                        reason: ProofReason::Backward(*reason),
745                        ts: *ts,
746                    },
747                );
748            }
749            None => {
750                self.context.entry((*a, *b)).or_default().insert(*reason);
751                // We don't register a change, even if we learned a new proof.
752                // We may want to change this behavior in order to search for
753                // smaller proofs.
754            }
755        }
756    }
757}
758
759impl Table for DisplacedTableWithProvenance {
760    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
761        subset.retain(|row| self.eval(c, row));
762        subset
763    }
764    fn scan_generic_bounded(
765        &self,
766        subset: SubsetRef,
767        start: Offset,
768        n: usize,
769        cs: &[Constraint],
770        mut f: impl FnMut(RowId, &[Value]),
771    ) -> Option<Offset>
772    where
773        Self: Sized,
774    {
775        if cs.is_empty() {
776            let start = start.index();
777            subset
778                .iter_bounded(start, start + n, |row| {
779                    f(row, self.expand(row).as_slice());
780                })
781                .map(Offset::from_usize)
782        } else {
783            let start = start.index();
784            subset
785                .iter_bounded(start, start + n, |row| {
786                    if cs.iter().all(|c| self.eval(c, row)) {
787                        f(row, self.expand(row).as_slice());
788                    }
789                })
790                .map(Offset::from_usize)
791        }
792    }
793
794    fn spec(&self) -> TableSpec {
795        TableSpec {
796            n_vals: 3,
797            ..self.base.spec()
798        }
799    }
800
801    fn merge(&mut self, exec_state: &mut ExecutionState) -> TableChange {
802        while let Some(rowbuf) = self.buffered_writes.pop() {
803            for row in rowbuf.iter() {
804                self.insert_impl(row);
805            }
806        }
807
808        self.base.merge(exec_state)
809    }
810
811    fn get_row(&self, key: &[Value]) -> Option<Row> {
812        let mut inner = self.base.get_row(key)?;
813        let (child, parent) = self.displaced[inner.id.index()];
814        debug_assert_eq!(child, inner.vals[0]);
815        let proof = *self.context[&(child, parent)].get_index(0).unwrap();
816        inner.vals.push(proof);
817        Some(inner)
818    }
819
820    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
821        if col == ColumnId::new(3) {
822            let row = *self.base.lookup_table.get(&key[0])?;
823            Some(self.expand(row)[3])
824        } else {
825            self.base.get_row_column(key, col)
826        }
827    }
828
829    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
830        Box::new(UfBuffer {
831            to_insert: RowBuffer::new(4),
832            buffered_writes: Arc::downgrade(&self.buffered_writes),
833        })
834    }
835
836    // Many of these methods just delgate to `base`:
837
838    fn dyn_clone(&self) -> Box<dyn Table> {
839        Box::new(self.clone())
840    }
841    fn as_any(&self) -> &dyn Any {
842        self
843    }
844    fn clear(&mut self) {
845        self.base.clear()
846    }
847    fn all(&self) -> Subset {
848        self.base.all()
849    }
850    fn len(&self) -> usize {
851        self.base.len()
852    }
853    fn updates_since(&self, offset: Offset) -> Subset {
854        self.base.updates_since(offset)
855    }
856    fn version(&self) -> TableVersion {
857        self.base.version()
858    }
859    fn fast_subset(&self, c: &Constraint) -> Option<Subset> {
860        self.base.fast_subset(c)
861    }
862}
863
864fn eval_constraint<const N: usize>(vals: &[Value; N], constraint: &Constraint) -> bool {
865    match constraint {
866        Constraint::Eq { l_col, r_col } => vals[l_col.index()] == vals[r_col.index()],
867        Constraint::EqConst { col, val } => vals[col.index()] == *val,
868        Constraint::LtConst { col, val } => vals[col.index()] < *val,
869        Constraint::GtConst { col, val } => vals[col.index()] > *val,
870        Constraint::LeConst { col, val } => vals[col.index()] <= *val,
871        Constraint::GeConst { col, val } => vals[col.index()] >= *val,
872    }
873}