Skip to main content

egglog_bridge/
lib.rs

1//! An implementation of egglog-style queries on top of core-relations.
2//!
3//! This module translates a well-typed egglog-esque query into the abstractions
4//! from the `core-relations` crate. The main higher-level functionality that it
5//! implements are seminaive evaluation, default values, and merge functions.
6//!
7//! This crate is essentially involved in desugaring: it elaborates the encoding
8//! of core egglog functionality, but it does not implement algorithms for
9//! joins, union-finds, etc.
10
11use std::{
12    cmp,
13    fmt::Debug,
14    hash::Hash,
15    iter, mem,
16    ops::{Index, IndexMut},
17    sync::{Arc, Mutex},
18};
19
20use crate::core_relations::{
21    BaseValue, BaseValueId, BaseValues, ColumnId, Constraint, ContainerValue, ContainerValues,
22    CounterId, Database, DisplacedTable, DisplacedTableWithProvenance, ExecutionState,
23    ExternalFunction, ExternalFunctionId, MergeVal, Offset, PlanStrategy, SortedWritesTable,
24    TableId, TaggedRowBuffer, Value, WrappedTable,
25};
26use crate::numeric_id::{DenseIdMap, DenseIdMapWithReuse, IdVec, NumericId, define_id};
27use egglog_core_relations as core_relations;
28use egglog_numeric_id as numeric_id;
29use egglog_reports::{IterationReport, ReportLevel, RuleSetReport};
30use hashbrown::HashMap;
31use indexmap::{IndexMap, IndexSet, map::Entry};
32use log::info;
33use once_cell::sync::Lazy;
34pub use proof_format::{EqProofId, ProofStore, TermProofId};
35use proof_spec::{ProofReason, ProofReconstructionState, ReasonSpecId};
36use smallvec::SmallVec;
37use web_time::{Duration, Instant};
38
39pub mod macros;
40pub mod proof_format;
41pub(crate) mod proof_spec;
42pub(crate) mod rule;
43pub mod syntax;
44#[cfg(test)]
45mod tests;
46
47pub use rule::{Function, QueryEntry, RuleBuilder};
48pub use syntax::{SourceExpr, SourceSyntax, TopLevelLhsExpr};
49use thiserror::Error;
50
51#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
52pub enum ColumnTy {
53    Id,
54    Base(BaseValueId),
55}
56
57define_id!(pub RuleId, u32, "An egglog-style rule");
58define_id!(pub FunctionId, u32, "An id representing an egglog function");
59define_id!(pub(crate) Timestamp, u32, "An abstract timestamp used to track execution of egglog rules");
60impl Timestamp {
61    fn to_value(self) -> Value {
62        Value::new(self.rep())
63    }
64}
65
66/// The state associated with an egglog program.
67#[derive(Clone)]
68pub struct EGraph {
69    db: Database,
70    uf_table: TableId,
71    id_counter: CounterId,
72    reason_counter: CounterId,
73    timestamp_counter: CounterId,
74    rules: DenseIdMapWithReuse<RuleId, RuleInfo>,
75    funcs: DenseIdMap<FunctionId, FunctionInfo>,
76    panic_message: SideChannel<String>,
77    /// This is a cache of all the different panic messages that we may use while executing rules
78    /// against the EGraph. Oftentimes, these messages are generated dynamically: keeping this map
79    /// around allows us to cache external function ids with repeat panic messages and they can
80    /// also serve as a debugging tool in the case that the number of panic messages grows without
81    /// bound.
82    panic_funcs: HashMap<String, ExternalFunctionId>,
83    proof_specs: IdVec<ReasonSpecId, Arc<ProofReason>>,
84    cong_spec: ReasonSpecId,
85    /// Side tables used to store proof information. We initialize these lazily
86    /// as a proof object with a given number of parameters is added.
87    reason_tables: IndexMap<usize /* arity */, TableId>,
88    term_tables: IndexMap<usize /* arity */, TableId>,
89    /// Union-find table that records the stable term id for each provisional term id.
90    ///
91    /// This table is only used when proofs are enabled. `core-relations` has a concept of a
92    /// "predicted value" to handle lookups for rows on tables that haven't been created yet. For
93    /// example, in the associativity rewrite:
94    ///
95    /// > (rewrite (add (add x y) z) (add x (add y z)))
96    ///
97    /// The RHS of the rule will check if `(add y z)` is present in the database, and if it isn't
98    /// it will create a new id and insert `(add y z) => id`. This `id` is cached so subsequent
99    /// uses of `(add y z)` will still see `id` within the same rule.
100    ///
101    /// This cache is _local_: so different threads can potentially see different values for `id`.
102    /// In other words, each thread will have its unique id for the new row being added. This is
103    /// fine as all provisional values for `id` will get unioned together and the database will be
104    /// consistent after congruence closure runs. This same logic, however, does not hold true for
105    /// proofs.
106    ///
107    /// When proofs mint a new `id` for a row it is used both as the canonical e-class id in the main
108    /// e-graph and also the *term* id for that specific row. Term ids are used to reconstruct
109    /// proofs and never change. e-class ids can change depending on the result of a UF `union`
110    /// operation. For example, consider these three terms:
111    ///
112    /// > x => term: id0, canon: id0
113    /// > (add x 0) => term: id1, canon: id1
114    /// > (add 0 x) => term: id2, canon: id2
115    ///
116    /// If we have rules like `(add x 0) => x` and `(add x y) => (add y x)` then our e-graph
117    /// may look like this:
118    ///
119    /// > x => term: id0, canon: id0
120    /// > (add x 0) => term: id1, canon: id0
121    /// > (add 0 x) => term: id2, canon: id0
122    ///
123    /// In other words, while canonical ids change over time and need not be unique to a term, term
124    /// ids are unique to the specific shape of the (head of the) term when it was first
125    /// instantiated. This is a problem for local caching because it means that we can have
126    /// multiple term values for the same row. How do we pick the real one?
127    ///
128    /// We have a separate union-find. Duplicate insertions into the term table cause `union`s on
129    /// this union-find, and future lookups of this table canonicalize ids with respect to this
130    /// union-find. In this way, it's a subset of the full union-find in the `uf_table` row, only
131    /// used to resolved temporary inconsistencies in cached term values.
132    term_consistency_table: TableId,
133    tracing: bool,
134    report_level: ReportLevel,
135}
136
137pub type Result<T> = std::result::Result<T, anyhow::Error>;
138
139impl Default for EGraph {
140    fn default() -> Self {
141        let mut db = Database::new();
142        let uf_table = db.add_table_named(
143            DisplacedTable::default(),
144            "$uf".into(),
145            iter::empty(),
146            iter::empty(),
147        );
148        EGraph::create_internal(db, uf_table, false)
149    }
150}
151
152/// Properties of a function added to an [`EGraph`].
153pub struct FunctionConfig {
154    /// The function's schema. The last column in the schema is the return type.
155    pub schema: Vec<ColumnTy>,
156    /// The behavior of the function when lookups are made on keys not currently present.
157    pub default: DefaultVal,
158    /// How to resolve FD conflicts for the function.
159    pub merge: MergeFn,
160    /// The function's name
161    pub name: String,
162    /// Whether or not subsumption is enabled for this function.
163    pub can_subsume: bool,
164}
165
166impl EGraph {
167    /// Create a new EGraph with tracing (aka 'proofs') enabled.
168    ///
169    /// Execution of queries against a tracing-enabled EGraph will be slower,
170    /// but will annotate the egraph with annotations that can explain how rows
171    /// came to appear.
172    pub fn with_tracing() -> EGraph {
173        let mut db = Database::new();
174        let uf_table = db.add_table_named(
175            DisplacedTableWithProvenance::default(),
176            "$uf".into(),
177            iter::empty(),
178            iter::empty(),
179        );
180        EGraph::create_internal(db, uf_table, true)
181    }
182
183    fn create_internal(mut db: Database, uf_table: TableId, tracing: bool) -> EGraph {
184        let id_counter = db.add_counter();
185        let trace_counter = db.add_counter();
186        let ts_counter = db.add_counter();
187        // Start the timestamp counter at 1.
188        db.inc_counter(ts_counter);
189        let mut proof_specs = IdVec::default();
190        let cong_spec = proof_specs.push(Arc::new(ProofReason::CongRow));
191        let term_consistency_table =
192            db.add_table(DisplacedTable::default(), iter::empty(), iter::empty());
193
194        Self {
195            db,
196            uf_table,
197            id_counter,
198            reason_counter: trace_counter,
199            timestamp_counter: ts_counter,
200            rules: Default::default(),
201            funcs: Default::default(),
202            panic_message: Default::default(),
203            panic_funcs: Default::default(),
204            proof_specs,
205            cong_spec,
206            reason_tables: Default::default(),
207            term_tables: Default::default(),
208            term_consistency_table,
209            report_level: Default::default(),
210            tracing,
211        }
212    }
213
214    fn next_ts(&self) -> Timestamp {
215        Timestamp::from_usize(self.db.read_counter(self.timestamp_counter))
216    }
217
218    fn inc_ts(&mut self) {
219        self.db.inc_counter(self.timestamp_counter);
220    }
221
222    /// Get a mutable reference to the underlying table of base values for this
223    /// `EGraph`.
224    pub fn base_values_mut(&mut self) -> &mut BaseValues {
225        self.db.base_values_mut()
226    }
227
228    /// Get a mutable reference to the underlying table of containers for this
229    /// `EGraph`.
230    pub fn container_values_mut(&mut self) -> &mut ContainerValues {
231        self.db.container_values_mut()
232    }
233
234    /// Get a reference to the underlying table of containers for this `EGraph`.
235    pub fn container_values(&self) -> &ContainerValues {
236        self.db.container_values()
237    }
238
239    /// Intern the given container value into the EGraph.
240    pub fn get_container_value<C: ContainerValue>(&mut self, val: C) -> Value {
241        self.register_container_ty::<C>();
242        self.db
243            .with_execution_state(|state| state.clone().container_values().register_val(val, state))
244    }
245
246    /// Register the given [`ContainerValue`] type with this EGraph.
247    ///
248    /// The given container will use the EGraph's union-find to manage rebuilding and the merging
249    /// of containers with a common id.
250    pub fn register_container_ty<C: ContainerValue>(&mut self) {
251        let uf_table = self.uf_table;
252        let ts_counter = self.timestamp_counter;
253        self.db.container_values_mut().register_type::<C>(
254            self.id_counter,
255            move |state, old, new| {
256                if old != new {
257                    let next_ts = Value::from_usize(state.read_counter(ts_counter));
258                    state.stage_insert(uf_table, &[old, new, next_ts]);
259                    std::cmp::min(old, new)
260                } else {
261                    old
262                }
263            },
264        );
265    }
266
267    /// Get a reference to the underlying table of base values for this `EGraph`.
268    pub fn base_values(&self) -> &BaseValues {
269        self.db.base_values()
270    }
271
272    /// Create a [`QueryEntry`] for a base value.
273    pub fn base_value_constant<T>(&self, x: T) -> QueryEntry
274    where
275        T: BaseValue,
276    {
277        QueryEntry::Const {
278            val: self.base_values().get(x),
279            ty: ColumnTy::Base(self.base_values().get_ty::<T>()),
280        }
281    }
282
283    pub fn register_external_func(
284        &mut self,
285        func: Box<dyn ExternalFunction + 'static>,
286    ) -> ExternalFunctionId {
287        self.db.add_external_function(func)
288    }
289
290    pub fn free_external_func(&mut self, func: ExternalFunctionId) {
291        self.db.free_external_function(func)
292    }
293
294    /// Generate a fresh id.
295    pub fn fresh_id(&mut self) -> Value {
296        Value::from_usize(self.db.inc_counter(self.id_counter))
297    }
298
299    /// Look up the canonical value for `val` in the union-find.
300    ///
301    /// If the value has never been inserted into the union-find, `val` is returned.
302    fn get_canon_in_uf(&self, val: Value) -> Value {
303        let table = self.db.get_table(self.uf_table);
304        let row = table.get_row(&[val]);
305        row.map(|row| row.vals[1]).unwrap_or(val)
306    }
307
308    /// Get the canonical representation for `val` based on type.
309    ///
310    /// For [`ColumnTy::Id`], it looks up the union find; otherwise,
311    /// it returns the value itself.
312    pub fn get_canon_repr(&self, val: Value, ty: ColumnTy) -> Value {
313        match ty {
314            ColumnTy::Id => self.get_canon_in_uf(val),
315            ColumnTy::Base(_) => val,
316        }
317    }
318
319    fn record_term_consistency(
320        state: &mut ExecutionState,
321        table: TableId,
322        ts_counter: CounterId,
323        from: Value,
324        to: Value,
325    ) {
326        if from == to {
327            return;
328        }
329        let ts = Value::from_usize(state.read_counter(ts_counter));
330        state.stage_insert(table, &[from, to, ts]);
331    }
332
333    fn canonicalize_term_id(&mut self, term_id: Value) -> Value {
334        let table = self.db.get_table(self.term_consistency_table);
335        table
336            .get_row(&[term_id])
337            .map(|row| row.vals[1])
338            .unwrap_or(term_id)
339    }
340
341    fn term_table(&mut self, table: TableId) -> TableId {
342        let info = self.db.get_table_info(table);
343        let spec = info.spec();
344        match self.term_tables.entry(spec.n_keys) {
345            Entry::Occupied(o) => *o.get(),
346            Entry::Vacant(v) => {
347                let term_index = spec.n_keys + 1;
348                let term_consistency_table = self.term_consistency_table;
349                let ts_counter = self.timestamp_counter;
350                let table = SortedWritesTable::new(
351                    spec.n_keys + 1,     // added entry for the tableid
352                    spec.n_keys + 1 + 2, // one value for the term id, one for the reason,
353                    None,
354                    vec![], // no rebuilding needed for term table
355                    Box::new(move |state, old, new, out| {
356                        // We want to pick the minimum term value.
357                        let l_term_id = old[term_index];
358                        let r_term_id = new[term_index];
359                        // NB: we should only need this merge function when we are executing
360                        // rules in parallel. We could consider a simpler merge function if
361                        // parallelism is disabled.
362                        if r_term_id < l_term_id {
363                            EGraph::record_term_consistency(
364                                state,
365                                term_consistency_table,
366                                ts_counter,
367                                l_term_id,
368                                r_term_id,
369                            );
370                            out.extend(new);
371                            true
372                        } else {
373                            false
374                        }
375                    }),
376                );
377                let table_id =
378                    self.db
379                        .add_table(table, iter::empty(), iter::once(term_consistency_table));
380                *v.insert(table_id)
381            }
382        }
383    }
384
385    fn reason_table(&mut self, spec: &ProofReason) -> TableId {
386        let arity = spec.arity();
387        match self.reason_tables.entry(arity) {
388            Entry::Occupied(o) => *o.get(),
389            Entry::Vacant(v) => {
390                let table = SortedWritesTable::new(
391                    arity,
392                    arity + 1, // one value for the reason id
393                    None,
394                    vec![], // no rebuilding needed for reason tables
395                    Box::new(|_, _, _, _| false),
396                );
397                let table_id = self.db.add_table(table, iter::empty(), iter::empty());
398                *v.insert(table_id)
399            }
400        }
401    }
402
403    /// Load the given values into the database.
404    ///
405    /// # Panics
406    /// This method panics if the values do not match the arity of the function.
407    ///
408    /// NB: this is not an efficient interface for bulk loading. We should add
409    /// one that allows us to pass through a series of RowBuffers before
410    /// incrementing the timestamp.
411    pub fn add_values(&mut self, values: impl IntoIterator<Item = (FunctionId, Vec<Value>)>) {
412        self.add_values_with_desc("", values)
413    }
414
415    /// A term-oriented means of adding data to the database: hand back a "term
416    /// id" for the given function and keys for the function. Proofs for this
417    /// term will include `desc`.
418    ///
419    /// # Panics
420    /// This method panics if the values do not match the arity of the function.
421    pub fn add_term(&mut self, func: FunctionId, inputs: &[Value], desc: &str) -> Value {
422        let info = &self.funcs[func];
423        let schema_math = SchemaMath {
424            tracing: self.tracing,
425            subsume: info.can_subsume,
426            func_cols: info.schema.len(),
427        };
428        let mut extended_row = Vec::new();
429        extended_row.extend_from_slice(inputs);
430        let term = self.tracing.then(|| {
431            let reason = self.get_fiat_reason(desc);
432            self.get_term(func, inputs, reason)
433        });
434        let res = term.unwrap_or_else(|| self.fresh_id());
435        schema_math.write_table_row(
436            &mut extended_row,
437            RowVals {
438                timestamp: self.next_ts().to_value(),
439                ret_val: Some(res),
440                proof: term,
441                subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
442            },
443        );
444        extended_row[schema_math.ret_val_col()] = res;
445        let table_id = self.funcs[func].table;
446        self.db.new_buffer(table_id).stage_insert(&extended_row);
447        self.flush_updates();
448        self.get_canon_in_uf(res)
449    }
450
451    /// Get an id corresponding to the given term, inserting the value into the
452    /// corresponding terms table if it isn't there.
453    ///
454    /// This method is really only relevant when tracing is enabled.
455    fn get_term(&mut self, func: FunctionId, key: &[Value], reason: Value) -> Value {
456        let table_id = self.funcs[func].table;
457        let term_table_id = self.term_table(table_id);
458        let table = self.db.get_table(term_table_id);
459        let mut term_key = Vec::with_capacity(key.len() + 1);
460        term_key.push(Value::new(func.rep()));
461        term_key.extend(key);
462        if let Some(row) = table.get_row(&term_key) {
463            row.vals[row.vals.len() - 2]
464        } else {
465            let result = Value::from_usize(self.db.inc_counter(self.id_counter));
466            term_key.push(result);
467            term_key.push(reason);
468            self.db.new_buffer(term_table_id).stage_insert(&term_key);
469            self.db.merge_table(term_table_id);
470            result
471        }
472    }
473
474    /// Lookup the id associated with a function `func` and the given arguments
475    /// (`key`).
476    pub fn lookup_id(&self, func: FunctionId, key: &[Value]) -> Option<Value> {
477        let info = &self.funcs[func];
478        let schema_math = SchemaMath {
479            tracing: self.tracing,
480            subsume: info.can_subsume,
481            func_cols: info.schema.len(),
482        };
483        let table_id = info.table;
484        let table = self.db.get_table(table_id);
485        let row = table.get_row(key)?;
486        Some(row.vals[schema_math.ret_val_col()])
487    }
488
489    fn get_fiat_reason(&mut self, desc: &str) -> Value {
490        let reason = Arc::new(ProofReason::Fiat { desc: desc.into() });
491        let reason_table = self.reason_table(&reason);
492        let reason_spec_id = self.proof_specs.push(reason);
493        let reason_id = Value::from_usize(self.db.inc_counter(self.reason_counter));
494        self.db
495            .new_buffer(reason_table)
496            .stage_insert(&[Value::new(reason_spec_id.rep()), reason_id]);
497        self.db.merge_table(reason_table);
498        reason_id
499    }
500
501    /// Load the given values into the database. If tracing is enabled, the
502    /// proof rows will be tagged with "desc" as their proof.
503    ///
504    /// # Panics
505    /// This method panics if the values do not match the arity of the function.
506    ///
507    /// NB: this is not an efficient interface for bulk loading. We should add
508    /// one that allows us to pass through a series of RowBuffers before
509    /// incrementing the timestamp.
510    pub fn add_values_with_desc(
511        &mut self,
512        desc: &str,
513        values: impl IntoIterator<Item = (FunctionId, Vec<Value>)>,
514    ) {
515        let mut extended_row = Vec::<Value>::new();
516        let reason_id = self.tracing.then(|| self.get_fiat_reason(desc));
517        let mut bufs = DenseIdMap::default();
518        for (func, row) in values.into_iter() {
519            let table_info = &self.funcs[func];
520            let schema_math = SchemaMath {
521                tracing: self.tracing,
522                subsume: table_info.can_subsume,
523                func_cols: table_info.schema.len(),
524            };
525            let table_id = table_info.table;
526            let term_id = reason_id.map(|reason| {
527                // Get the term id itself
528                let term_id = self.get_term(func, &row[0..schema_math.num_keys()], reason);
529                let buf = bufs.get_or_insert(self.uf_table, || self.db.new_buffer(self.uf_table));
530                // Then union it with the value being set for this term.
531                buf.stage_insert(&[
532                    *row.last().unwrap(),
533                    term_id,
534                    self.next_ts().to_value(),
535                    reason,
536                ]);
537                term_id
538            });
539            extended_row.extend_from_slice(&row);
540            schema_math.write_table_row(
541                &mut extended_row,
542                RowVals {
543                    timestamp: self.next_ts().to_value(),
544                    proof: term_id,
545                    subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
546                    ret_val: None, // already filled in.
547                },
548            );
549            let buf = bufs.get_or_insert(table_id, || self.db.new_buffer(table_id));
550            buf.stage_insert(&extended_row);
551            extended_row.clear();
552        }
553        // Flush the buffers.
554        mem::drop(bufs);
555        self.flush_updates();
556    }
557
558    pub fn approx_table_size(&self, table: FunctionId) -> usize {
559        self.db.estimate_size(self.funcs[table].table, None)
560    }
561
562    pub fn table_size(&self, table: FunctionId) -> usize {
563        self.db.get_table(self.funcs[table].table).len()
564    }
565
566    /// Generate a proof explaining why a given term is in the database.
567    ///
568    /// # Errors
569    /// This method will return an error if tracing is not enabled, or if the row is not in the database.
570    ///
571    /// # Panics
572    /// This method may panic if `key` does not match the arity of the function,
573    /// or is otherwise malformed.
574    pub fn explain_term(&mut self, id: Value, store: &mut ProofStore) -> Result<TermProofId> {
575        if !self.tracing {
576            return Err(ProofReconstructionError::TracingNotEnabled.into());
577        }
578        let mut state = ProofReconstructionState::new(store);
579        Ok(self.explain_term_inner(id, &mut state))
580    }
581
582    /// Generate a proof explaining why the term corresponding to `id1`
583    /// is equal to that corresponding to `id2`.
584    ///
585    /// # Errors
586    /// This method will return an error if tracing is not enabled, if the row
587    /// is not in the database, or if the terms themselves are not equal.
588    pub fn explain_terms_equal(
589        &mut self,
590        id1: Value,
591        id2: Value,
592        store: &mut ProofStore,
593    ) -> Result<EqProofId> {
594        if !self.tracing {
595            return Err(ProofReconstructionError::TracingNotEnabled.into());
596        }
597        let mut state = ProofReconstructionState::new(store);
598        if self.get_canon_in_uf(id1) != self.get_canon_in_uf(id2) {
599            // These terms aren't equal. Reconstruct the relevant terms so as to
600            // get a nicer error message on the way out.
601            let mut buf = Vec::<u8>::new();
602            let term_id_1 = self.reconstruct_term(id1, ColumnTy::Id, &mut state);
603            let term_id_2 = self.reconstruct_term(id2, ColumnTy::Id, &mut state);
604            store.termdag.print_term(term_id_1, &mut buf).unwrap();
605            let term1 = String::from_utf8(buf).unwrap();
606            let mut buf = Vec::<u8>::new();
607            store.termdag.print_term(term_id_2, &mut buf).unwrap();
608            let term2 = String::from_utf8(buf).unwrap();
609            return Err(
610                ProofReconstructionError::EqualityExplanationOfUnequalTerms { term1, term2 }.into(),
611            );
612        }
613        Ok(self.explain_terms_equal_inner(id1, id2, &mut state))
614    }
615
616    /// Read the contents of the given function.
617    ///
618    /// The callback `f` is called with each row and its subsumption status.
619    pub fn for_each(&self, table: FunctionId, mut f: impl FnMut(FunctionRow<'_>)) {
620        self.for_each_while(table, |row| {
621            f(row);
622            true
623        });
624    }
625
626    /// Iterate over the rows of a function table, calling `f` on each row. If `f` returns `false`
627    /// the function returns early and stops reading rows from the table.
628    pub fn for_each_while(&self, table: FunctionId, mut f: impl FnMut(FunctionRow<'_>) -> bool) {
629        let info = &self.funcs[table];
630        let table = self.funcs[table].table;
631        let schema_math = SchemaMath {
632            tracing: self.tracing,
633            subsume: info.can_subsume,
634            func_cols: info.schema.len(),
635        };
636        let imp = self.db.get_table(table);
637        let all = imp.all();
638        let mut cur = Offset::new(0);
639        let mut buf = TaggedRowBuffer::new(imp.spec().arity());
640        // This somewhat awkward iteration strategy is forced on us by the `scan_bounded` API. We
641        // should look into ways to avoid this cludge where the loop body effectively must be
642        // repeated at the end. The obvious and idiomatic ways to do this all require
643        // `dyn`-compatibility on `Table` or dynamic dispatch per row.
644        macro_rules! drain_buf {
645            ($buf:expr) => {
646                for (_, row) in $buf.non_stale() {
647                    let subsumed =
648                        schema_math.subsume && row[schema_math.subsume_col()] == SUBSUMED;
649                    if !f(FunctionRow {
650                        vals: &row[0..schema_math.func_cols],
651                        subsumed,
652                    }) {
653                        return;
654                    }
655                }
656                $buf.clear();
657            };
658        }
659        while let Some(next) = imp.scan_bounded(all.as_ref(), cur, 32, &mut buf) {
660            drain_buf!(buf);
661            cur = next;
662        }
663        drain_buf!(buf);
664    }
665
666    /// A basic method for dumping the state of the database to `log::info!`.
667    ///
668    /// For large tables, this is unlikely to give particularly useful output.
669    pub fn dump_debug_info(&self) {
670        info!("=== View Tables ===");
671        for (id, info) in self.funcs.iter() {
672            let table = self.db.get_table(info.table);
673            self.scan_table(table, |row| {
674                info!(
675                    "View Table {name} / {id:?} / {table:?}: {row:?}",
676                    name = info.name,
677                    table = info.table
678                )
679            });
680        }
681
682        info!("=== Term Tables ===");
683        for (_, table_id) in &self.term_tables {
684            let table = self.db.get_table(*table_id);
685            self.scan_table(table, |row| {
686                let name = &self.funcs[FunctionId::new(row[0].rep())].name;
687                let row = &row[1..];
688                info!("Term Table {table_id:?}: {name}, {row:?}")
689            });
690        }
691
692        info!("=== Reason Tables ===");
693        for (_, table_id) in &self.reason_tables {
694            let table = self.db.get_table(*table_id);
695            self.scan_table(table, |row| {
696                let spec = self.proof_specs[ReasonSpecId::new(row[0].rep())].as_ref();
697                let row = &row[1..];
698                info!("Reason Table {table_id:?}: {spec:?}, {row:?}")
699            });
700        }
701    }
702
703    /// A helper for scanning the entries in a table.
704    fn scan_table(&self, table: &WrappedTable, mut f: impl FnMut(&[Value])) {
705        const BATCH_SIZE: usize = 128;
706        let all = table.all();
707        let mut cur = Offset::new(0);
708        let mut out = TaggedRowBuffer::new(table.spec().arity());
709        while let Some(next) = table.scan_bounded(all.as_ref(), cur, BATCH_SIZE, &mut out) {
710            out.non_stale().for_each(|(_, row)| f(row));
711            out.clear();
712            cur = next;
713        }
714        out.non_stale().for_each(|(_, row)| f(row));
715    }
716
717    /// Register a function in this EGraph.
718    pub fn add_table(&mut self, config: FunctionConfig) -> FunctionId {
719        let FunctionConfig {
720            schema,
721            default,
722            merge,
723            name,
724            can_subsume,
725        } = config;
726        assert!(
727            !schema.is_empty(),
728            "must have at least one column in schema"
729        );
730        let to_rebuild: Vec<ColumnId> = schema
731            .iter()
732            .enumerate()
733            .filter(|(_, ty)| matches!(ty, ColumnTy::Id))
734            .map(|(i, _)| ColumnId::from_usize(i))
735            .collect();
736        let schema_math = SchemaMath {
737            tracing: self.tracing,
738            subsume: can_subsume,
739            func_cols: schema.len(),
740        };
741        let n_args = schema_math.num_keys();
742        let n_cols = schema_math.table_columns();
743        let next_func_id = self.funcs.next_id();
744        let mut read_deps = IndexSet::<TableId>::new();
745        let mut write_deps = IndexSet::<TableId>::new();
746        merge.fill_deps(self, &mut read_deps, &mut write_deps);
747        let merge_fn = merge.to_callback(schema_math, &name, self);
748        let table = SortedWritesTable::new(
749            n_args,
750            n_cols,
751            Some(ColumnId::from_usize(schema.len())),
752            to_rebuild,
753            merge_fn,
754        );
755        let name: Arc<str> = name.into();
756        let table_id = self.db.add_table_named(
757            table,
758            name.clone(),
759            read_deps.iter().copied(),
760            write_deps.iter().copied(),
761        );
762
763        let res = self.funcs.push(FunctionInfo {
764            table: table_id,
765            schema: schema.clone(),
766            incremental_rebuild_rules: Default::default(),
767            nonincremental_rebuild_rule: RuleId::new(!0),
768            default_val: default,
769            can_subsume,
770            name,
771        });
772        debug_assert_eq!(res, next_func_id);
773        let incremental_rebuild_rules = self.incremental_rebuild_rules(res, &schema);
774        let nonincremental_rebuild_rule = self.nonincremental_rebuild(res, &schema);
775        let info = &mut self.funcs[res];
776        info.incremental_rebuild_rules = incremental_rebuild_rules;
777        info.nonincremental_rebuild_rule = nonincremental_rebuild_rule;
778        res
779    }
780
781    /// Run the given rules, returning whether the database changed.
782    ///
783    /// If the given rules are malformed, this method can return an error.
784    pub fn run_rules(&mut self, rules: &[RuleId]) -> Result<IterationReport> {
785        let ts = self.next_ts();
786
787        let rule_set_report =
788            run_rules_impl(&mut self.db, &mut self.rules, rules, ts, self.report_level)?;
789        if let Some(message) = self.panic_message.lock().unwrap().take() {
790            return Err(PanicError(message).into());
791        }
792
793        let mut iteration_report = IterationReport {
794            rule_set_report,
795            rebuild_time: Duration::ZERO,
796        };
797        if !iteration_report.changed() {
798            return Ok(iteration_report);
799        }
800
801        let rebuild_timer = Instant::now();
802        self.rebuild()?;
803        iteration_report.rebuild_time = rebuild_timer.elapsed();
804
805        if let Some(message) = self.panic_message.lock().unwrap().take() {
806            return Err(PanicError(message).into());
807        }
808
809        Ok(iteration_report)
810    }
811
812    fn rebuild(&mut self) -> Result<()> {
813        fn do_parallel() -> bool {
814            #[cfg(test)]
815            {
816                use rand::Rng;
817                rand::rng().random_bool(0.5)
818            }
819            #[cfg(not(test))]
820            {
821                rayon::current_num_threads() > 1
822            }
823        }
824        if self.db.get_table(self.uf_table).rebuilder(&[]).is_some() {
825            // The UF implementation supports "native"  rebuilding.
826            let mut tables = Vec::with_capacity(self.funcs.next_id().index());
827            for (_, func) in self.funcs.iter() {
828                tables.push(func.table);
829            }
830            loop {
831                // Order matters here: we need to rebuild containers first and then rebuild the
832                // tables. Why?
833                //
834                // Say we have a sort that can map to and from a vector containing only itself:
835                // (sort X)
836                // (function to-vec (X) (Vec X) :no-merge)
837                // (constructor from-vec (Vec X) X)
838                // (constructor Num (i64) X)
839                // (constructor Add (X X) X)
840                //
841                // Along with rules:
842                // (rule ((= x (Num i))) ((set (to-vec x) (vec-of x))))
843                // (rule ((= x (Add i j))) ((set (to-vec x) (vec-of x))))
844                // (rule ((= x (from-vec v))) ((set (to-vec x) v))
845                // (rewrite (Add (Num i) (Num j)) (Num (+ i j)))
846                //
847                // These rules, while redundant, should be safe. However, if we rebuild tables
848                // before containers some schedules can cause us to violate the `:no-merge`
849                // directive, which asserts that all values written for a key are equal.
850                //
851                // Suppose we start off with x1=(Num 1), x2=(Num 3), and x3=(Add (Num 1) (Num 2)) as
852                // expressions, with `to-vec` and `from-vec` entries for all three expressions.
853                // We'll call (to-vec xi) vi for all i.
854                //
855                // Now suppose we run the `rewrite` above: now, x3 = x2. But v3 will only equal v2
856                // _after_ we rebuild the `Vec` container. That means that if we rebuild `to-vec`
857                // we will collapse the the rows for x3 and x2, but then fail to merge v3 and v2
858                // because they are not (yet) equal.
859                //
860                // Rebuilding containers first will find that v3 and v2 are equal, and the rest of
861                // the rules can proceed.
862                let container_rebuild = self.db.rebuild_containers(self.uf_table);
863                let table_rebuild =
864                    self.db
865                        .apply_rebuild(self.uf_table, &tables, self.next_ts().to_value());
866                self.inc_ts();
867                if !table_rebuild && !container_rebuild {
868                    break;
869                }
870            }
871            return Ok(());
872        }
873        if do_parallel() {
874            return self.rebuild_parallel();
875        }
876        let start = Instant::now();
877
878        // The database changed. Rebuild. New entries should land after the given rules.
879        let mut changed = true;
880        while changed {
881            changed = false;
882            // We need to iterate rebuilding to a fixed point. Future scans
883            // should look only at the latest updates.
884            self.inc_ts();
885            let ts = self.next_ts();
886            for (_, info) in self.funcs.iter_mut() {
887                let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
888                let table_size = self.db.estimate_size(info.table, None);
889                let uf_size = self.db.estimate_size(
890                    self.uf_table,
891                    Some(Constraint::GeConst {
892                        col: ColumnId::new(2),
893                        val: last_rebuilt_at.to_value(),
894                    }),
895                );
896                if incremental_rebuild(uf_size, table_size, false) {
897                    marker_incremental_rebuild(|| -> Result<()> {
898                        // Run each of the incremental rules serially.
899                        //
900                        // This is to avoid recanonicalizing the same row multiple
901                        // times.
902                        for rule in &info.incremental_rebuild_rules {
903                            changed |= run_rules_impl(
904                                &mut self.db,
905                                &mut self.rules,
906                                &[*rule],
907                                ts,
908                                ReportLevel::TimeOnly,
909                            )?
910                            .changed;
911                        }
912                        // Reset the rule we did not run. These two should be equivalent.
913                        self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
914                        Ok(())
915                    })?;
916                } else {
917                    marker_nonincremental_rebuild(|| -> Result<()> {
918                        changed |= run_rules_impl(
919                            &mut self.db,
920                            &mut self.rules,
921                            &[info.nonincremental_rebuild_rule],
922                            ts,
923                            ReportLevel::TimeOnly,
924                        )?
925                        .changed;
926                        for rule in &info.incremental_rebuild_rules {
927                            self.rules[*rule].last_run_at = ts;
928                        }
929                        Ok(())
930                    })?;
931                }
932            }
933        }
934        log::info!("rebuild took {:?}", start.elapsed());
935        Ok(())
936    }
937
938    /// A variant of `rebuild` that attempts to combine rebuild rules into
939    /// larger rulesets to increase parallelism. This kind of preprocessing can
940    /// slow processing down in a single-threaded setting, so it is only used
941    /// when the number of active threads is greater than 1.
942    fn rebuild_parallel(&mut self) -> Result<()> {
943        let start = Instant::now();
944        #[derive(Default)]
945        struct RebuildState {
946            nonincremental: Vec<FunctionId>,
947            incremental: DenseIdMap<usize, SmallVec<[FunctionId; 2]>>,
948        }
949
950        impl RebuildState {
951            fn clear(&mut self) {
952                self.nonincremental.clear();
953                self.incremental.iter_mut().for_each(|(_, v)| v.clear());
954            }
955        }
956
957        let mut changed = true;
958        let mut state = RebuildState::default();
959        let mut scratch = Vec::new();
960        while changed {
961            changed = false;
962            state.clear();
963            self.inc_ts();
964            // First, figure out which functions will be rebuilt nonincrementally,
965            // vs. incrementally. Group them together.
966            for (func, info) in self.funcs.iter_mut() {
967                let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
968                let table_size = self.db.estimate_size(info.table, None);
969                let uf_size = self.db.estimate_size(
970                    self.uf_table,
971                    Some(Constraint::GeConst {
972                        col: ColumnId::new(2),
973                        val: last_rebuilt_at.to_value(),
974                    }),
975                );
976                if incremental_rebuild(uf_size, table_size, true) {
977                    for (i, _) in info.incremental_rebuild_rules.iter().enumerate() {
978                        state.incremental.get_or_default(i).push(func);
979                    }
980                } else {
981                    state.nonincremental.push(func);
982                }
983            }
984            let ts = self.next_ts();
985            for func in state.nonincremental.iter().copied() {
986                scratch.push(self.funcs[func].nonincremental_rebuild_rule);
987                for rule in &self.funcs[func].incremental_rebuild_rules {
988                    self.rules[*rule].last_run_at = ts;
989                }
990            }
991            changed |= run_rules_impl(
992                &mut self.db,
993                &mut self.rules,
994                &scratch,
995                ts,
996                ReportLevel::TimeOnly,
997            )?
998            .changed;
999            scratch.clear();
1000            let ts = self.next_ts();
1001            for (i, funcs) in state.incremental.iter() {
1002                for func in funcs.iter().copied() {
1003                    let info = &mut self.funcs[func];
1004                    scratch.push(info.incremental_rebuild_rules[i]);
1005                    self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
1006                }
1007                changed |= run_rules_impl(
1008                    &mut self.db,
1009                    &mut self.rules,
1010                    &scratch,
1011                    ts,
1012                    ReportLevel::TimeOnly,
1013                )?
1014                .changed;
1015                scratch.clear();
1016            }
1017        }
1018        log::info!("rebuild took {:?}", start.elapsed());
1019        Ok(())
1020    }
1021
1022    fn incremental_rebuild_rules(&mut self, table: FunctionId, schema: &[ColumnTy]) -> Vec<RuleId> {
1023        schema
1024            .iter()
1025            .enumerate()
1026            .filter_map(|(i, ty)| match ty {
1027                ColumnTy::Id => {
1028                    Some(self.incremental_rebuild_rule(table, schema, ColumnId::from_usize(i)))
1029                }
1030                ColumnTy::Base(_) => None,
1031            })
1032            .collect()
1033    }
1034
1035    fn incremental_rebuild_rule(
1036        &mut self,
1037        table: FunctionId,
1038        schema: &[ColumnTy],
1039        col: ColumnId,
1040    ) -> RuleId {
1041        let subsume = self.funcs[table].can_subsume;
1042        let table_id = self.funcs[table].table;
1043        let uf_table = self.uf_table;
1044        // Two atoms, one binding a whole tuple, one binding a displaced column
1045        let mut rb = self.new_rule(&format!("incremental rebuild {table:?}, {col:?}"), true);
1046        rb.set_plan_strategy(PlanStrategy::MinCover);
1047        let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
1048        for ty in schema {
1049            vars.push(rb.new_var(*ty).into());
1050        }
1051        let canon_val: QueryEntry = rb.new_var(ColumnTy::Id).into();
1052        let subsume_var = subsume.then(|| rb.new_var(ColumnTy::Id));
1053        rb.add_atom_with_timestamp_and_func(
1054            table_id,
1055            Some(table),
1056            subsume_var.clone().map(QueryEntry::from),
1057            &vars,
1058        );
1059        rb.add_atom_with_timestamp_and_func(
1060            uf_table,
1061            None,
1062            None,
1063            &[vars[col.index()].clone(), canon_val.clone()],
1064        );
1065        rb.set_focus(1); // Set the uf atom as the sole focus.
1066
1067        // Now canonicalize the entire row.
1068        let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
1069        for (i, (var, ty)) in vars.iter().zip(schema.iter()).enumerate() {
1070            canon.push(if i == col.index() {
1071                canon_val.clone()
1072            } else if let ColumnTy::Id = ty {
1073                rb.lookup_uf(var.clone()).unwrap().into()
1074            } else {
1075                var.clone()
1076            })
1077        }
1078
1079        // Remove the old row and insert the new one.
1080        rb.rebuild_row(table, &vars, &canon, subsume_var);
1081        rb.build_internal(None)
1082    }
1083
1084    fn nonincremental_rebuild(&mut self, table: FunctionId, schema: &[ColumnTy]) -> RuleId {
1085        let can_subsume = self.funcs[table].can_subsume;
1086        let table_id = self.funcs[table].table;
1087        let mut rb = self.new_rule(&format!("nonincremental rebuild {table:?}"), false);
1088        rb.set_plan_strategy(PlanStrategy::MinCover);
1089        let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
1090        for ty in schema {
1091            vars.push(rb.new_var(*ty).into());
1092        }
1093        let subsume_var = can_subsume.then(|| rb.new_var(ColumnTy::Id));
1094        rb.add_atom_with_timestamp_and_func(
1095            table_id,
1096            Some(table),
1097            subsume_var.clone().map(QueryEntry::from),
1098            &vars,
1099        );
1100        let mut lhs = SmallVec::<[QueryEntry; 4]>::new();
1101        let mut rhs = SmallVec::<[QueryEntry; 4]>::new();
1102        let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
1103        for (var, ty) in vars.iter().zip(schema.iter()) {
1104            canon.push(if let ColumnTy::Id = ty {
1105                lhs.push(var.clone());
1106                let canon_var = QueryEntry::from(rb.lookup_uf(var.clone()).unwrap());
1107                rhs.push(canon_var.clone());
1108                canon_var
1109            } else {
1110                var.clone()
1111            })
1112        }
1113        rb.check_for_update(&lhs, &rhs).unwrap();
1114        rb.rebuild_row(table, &vars, &canon, subsume_var);
1115        rb.build_internal(None) // skip the syntax check
1116    }
1117
1118    /// Gives the user a handle to the underlying ExecutionState. Useful for staging updates
1119    /// to the database.
1120    ///
1121    /// The staged updates are not immediately reflected in the EGraph, so you may want to
1122    /// manually flush the updates using [`EGraph::flush_updates`].
1123    pub fn with_execution_state<R>(&self, f: impl FnOnce(&mut ExecutionState<'_>) -> R) -> R {
1124        self.db.with_execution_state(f)
1125    }
1126
1127    /// Flush the pending update buffers to the EGraph.
1128    /// Returns `true` if the database is updated.
1129    pub fn flush_updates(&mut self) -> bool {
1130        let updated = self.db.merge_all();
1131        self.inc_ts();
1132        self.rebuild().unwrap();
1133        updated
1134    }
1135
1136    pub fn set_report_level(&mut self, level: ReportLevel) {
1137        self.report_level = level;
1138    }
1139}
1140
1141#[derive(Clone)]
1142struct RuleInfo {
1143    last_run_at: Timestamp,
1144    query: rule::Query,
1145    cached_plan: Option<CachedPlanInfo>,
1146    desc: Arc<str>,
1147}
1148
1149#[derive(Clone)]
1150struct CachedPlanInfo {
1151    plan: Arc<core_relations::CachedPlan>,
1152    /// A mapping from index into a [`rule::Query`]'s atoms to the atoms in the underlying cached
1153    /// plan.
1154    atom_mapping: Vec<core_relations::AtomId>,
1155}
1156
1157#[derive(Clone)]
1158struct FunctionInfo {
1159    table: TableId,
1160    schema: Vec<ColumnTy>,
1161    incremental_rebuild_rules: Vec<RuleId>,
1162    nonincremental_rebuild_rule: RuleId,
1163    default_val: DefaultVal,
1164    can_subsume: bool,
1165    name: Arc<str>,
1166}
1167
1168impl FunctionInfo {
1169    fn ret_ty(&self) -> ColumnTy {
1170        self.schema.last().copied().unwrap()
1171    }
1172}
1173
1174/// How defaults are computed for the given function.
1175#[derive(Copy, Clone)]
1176pub enum DefaultVal {
1177    /// Generate a fresh UF id.
1178    FreshId,
1179    /// Cause an egglog-level panic if a lookup fails.
1180    Fail,
1181    /// Insert a constant of some kind.
1182    Const(Value),
1183}
1184
1185/// How to resolve FD conflicts for a table.
1186pub enum MergeFn {
1187    /// Panic if the old and new values don't match.
1188    AssertEq,
1189    /// Use congruence to resolve FD conflicts.
1190    UnionId,
1191    /// The output of a merge is determined by applying the given ExternalFunction to the result
1192    /// of the argument merge functions.
1193    Primitive(ExternalFunctionId, Vec<MergeFn>),
1194    /// The output of a merge is determined by looking up the value for the given function and the
1195    /// given arguments in the egraph.
1196    Function(FunctionId, Vec<MergeFn>),
1197    /// Always return the old value for the given function.
1198    Old,
1199    /// Always return the new value for the given function.
1200    New,
1201    /// Always overwrite the new value for the given function with a constant. This is more useful
1202    /// as a "base case" in a more complicated merge function (e.g. one that clamps a value between
1203    /// 1 and 100) than it is as a standalone merge function.
1204    Const(Value),
1205}
1206
1207impl MergeFn {
1208    fn fill_deps(
1209        &self,
1210        egraph: &EGraph,
1211        read_deps: &mut IndexSet<TableId>,
1212        write_deps: &mut IndexSet<TableId>,
1213    ) {
1214        use MergeFn::*;
1215        match self {
1216            Primitive(_, args) => {
1217                args.iter()
1218                    .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1219            }
1220            Function(func, args) => {
1221                read_deps.insert(egraph.funcs[*func].table);
1222                write_deps.insert(egraph.funcs[*func].table);
1223                args.iter()
1224                    .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1225            }
1226            UnionId if !egraph.tracing => {
1227                write_deps.insert(egraph.uf_table);
1228            }
1229            UnionId | AssertEq | Old | New | Const(..) => {}
1230        }
1231    }
1232
1233    fn to_callback(
1234        &self,
1235        schema_math: SchemaMath,
1236        function_name: &str,
1237        egraph: &mut EGraph,
1238    ) -> Box<core_relations::MergeFn> {
1239        assert!(
1240            !egraph.tracing || matches!(self, MergeFn::UnionId),
1241            "proofs aren't supported for non-union merge functions"
1242        );
1243
1244        let resolved = self.resolve(function_name, egraph);
1245
1246        Box::new(move |state, cur, new, out| {
1247            let timestamp = new[schema_math.ts_col()];
1248
1249            let mut changed = false;
1250
1251            let ret_val = {
1252                let cur = cur[schema_math.ret_val_col()];
1253                let new = new[schema_math.ret_val_col()];
1254                let out = resolved.run(state, cur, new, timestamp);
1255                changed |= cur != out;
1256                out
1257            };
1258
1259            let subsume = schema_math.subsume.then(|| {
1260                let cur = cur[schema_math.subsume_col()];
1261                let new = new[schema_math.subsume_col()];
1262                let out = combine_subsumed(cur, new);
1263                changed |= cur != out;
1264                out
1265            });
1266            let mut proof = None;
1267            if schema_math.tracing {
1268                let old_term = cur[schema_math.proof_id_col()];
1269                let new_term = new[schema_math.proof_id_col()];
1270                proof = Some(cmp::min(old_term, new_term));
1271                changed |= new_term < old_term;
1272            }
1273
1274            if changed {
1275                out.extend_from_slice(new);
1276                schema_math.write_table_row(
1277                    out,
1278                    RowVals {
1279                        timestamp,
1280                        proof,
1281                        subsume,
1282                        ret_val: Some(ret_val),
1283                    },
1284                );
1285            }
1286
1287            changed
1288        })
1289    }
1290
1291    fn resolve(&self, function_name: &str, egraph: &mut EGraph) -> ResolvedMergeFn {
1292        match self {
1293            MergeFn::Const(v) => ResolvedMergeFn::Const(*v),
1294            MergeFn::Old => ResolvedMergeFn::Old,
1295            MergeFn::New => ResolvedMergeFn::New,
1296            MergeFn::AssertEq => ResolvedMergeFn::AssertEq {
1297                panic: egraph.new_panic(format!(
1298                    "Illegal merge attempted for function {function_name}"
1299                )),
1300            },
1301            MergeFn::UnionId => ResolvedMergeFn::UnionId {
1302                uf_table: egraph.uf_table,
1303                tracing: egraph.tracing,
1304            },
1305            // NB: The primitive and function-based merge functions heap allocate a single callback
1306            // for each layer of nesting. This introduces a bit of overhead, particularly for cases
1307            // that look like `(f old new)` or `(f new old)`. We could special-case common cases in
1308            // this function if that overhead shows up.
1309            MergeFn::Primitive(prim, args) => ResolvedMergeFn::Primitive {
1310                prim: *prim,
1311                args: args
1312                    .iter()
1313                    .map(|arg| arg.resolve(function_name, egraph))
1314                    .collect::<Vec<_>>(),
1315                panic: egraph.new_panic(format!(
1316                    "Merge function for {function_name} primitive call failed"
1317                )),
1318            },
1319            MergeFn::Function(func, args) => {
1320                let func_info = &egraph.funcs[*func];
1321                assert_eq!(
1322                    func_info.schema.len(),
1323                    args.len() + 1,
1324                    "Merge function for {function_name} must match function arity for {}",
1325                    func_info.name
1326                );
1327                ResolvedMergeFn::Function {
1328                    func: TableAction::new(egraph, *func),
1329                    panic: egraph.new_panic(format!(
1330                        "Lookup on {} failed in the merge function for {function_name}",
1331                        func_info.name
1332                    )),
1333                    args: args
1334                        .iter()
1335                        .map(|arg| arg.resolve(function_name, egraph))
1336                        .collect::<Vec<_>>(),
1337                }
1338            }
1339        }
1340    }
1341}
1342
1343/// This enum is taking the place of a
1344/// `Box<dyn Fn(&mut ExecutionState, Value, Value, Value) -> Value + Send + Sync>`
1345/// to avoid extra boxes. It stores the data needed to run a `MergeFn` without
1346/// holding onto any references, so it can be `move`d inside the `core_relations::MergeFn`.
1347enum ResolvedMergeFn {
1348    Const(Value),
1349    Old,
1350    New,
1351    AssertEq {
1352        panic: ExternalFunctionId,
1353    },
1354    UnionId {
1355        uf_table: TableId,
1356        tracing: bool,
1357    },
1358    Primitive {
1359        prim: ExternalFunctionId,
1360        args: Vec<ResolvedMergeFn>,
1361        panic: ExternalFunctionId,
1362    },
1363    Function {
1364        func: TableAction,
1365        args: Vec<ResolvedMergeFn>,
1366        panic: ExternalFunctionId,
1367    },
1368}
1369
1370impl ResolvedMergeFn {
1371    fn run(&self, state: &mut ExecutionState, cur: Value, new: Value, ts: Value) -> Value {
1372        match self {
1373            ResolvedMergeFn::Const(v) => *v,
1374            ResolvedMergeFn::Old => cur,
1375            ResolvedMergeFn::New => new,
1376            ResolvedMergeFn::AssertEq { panic } => {
1377                if cur != new {
1378                    let res = state.call_external_func(*panic, &[]);
1379                    assert_eq!(res, None);
1380                }
1381                cur
1382            }
1383            ResolvedMergeFn::UnionId { uf_table, tracing } => {
1384                if cur != new && !tracing {
1385                    // When proofs are enabled, these are the same term. They are already
1386                    // equal and we can just do nothing.
1387                    state.stage_insert(*uf_table, &[cur, new, ts]);
1388                    // We pick the minimum when unioning. This matches the original egglog
1389                    // behavior. THIS MUST MATCH THE UNION-FIND IMPLEMENTATION!
1390                    std::cmp::min(cur, new)
1391                } else {
1392                    cur
1393                }
1394            }
1395            // NB: The primitive and function-based merge functions heap allocate a single callback
1396            // for each layer of nesting. This introduces a bit of overhead, particularly for cases
1397            // that look like `(f old new)` or `(f new old)`. We could special-case common cases in
1398            // this function if that overhead shows up.
1399            ResolvedMergeFn::Primitive { prim, args, panic } => {
1400                let args = args
1401                    .iter()
1402                    .map(|arg| arg.run(state, cur, new, ts))
1403                    .collect::<Vec<_>>();
1404
1405                match state.call_external_func(*prim, &args) {
1406                    Some(result) => result,
1407                    None => {
1408                        let res = state.call_external_func(*panic, &[]);
1409                        assert_eq!(res, None);
1410                        cur
1411                    }
1412                }
1413            }
1414            ResolvedMergeFn::Function { func, args, panic } => {
1415                // see github.com/egraphs-good/egglog/pull/287
1416                if cur == new {
1417                    return cur;
1418                }
1419
1420                let args = args
1421                    .iter()
1422                    .map(|arg| arg.run(state, cur, new, ts))
1423                    .collect::<Vec<_>>();
1424
1425                func.lookup(state, &args).unwrap_or_else(|| {
1426                    let res = state.call_external_func(*panic, &[]);
1427                    assert_eq!(res, None);
1428                    cur
1429                })
1430            }
1431        }
1432    }
1433}
1434
1435/// This is an intern-able struct that holds all the data needed
1436/// to do table operations with an [`ExecutionState`], assuming
1437/// that the [`FunctionId`] for the table is known ahead of time.
1438#[derive(Debug, PartialEq, Eq, Hash)]
1439pub struct TableAction {
1440    table: TableId,
1441    table_math: SchemaMath,
1442    default: Option<MergeVal>,
1443    timestamp: CounterId,
1444    scratch: Vec<Value>,
1445}
1446
1447impl Clone for TableAction {
1448    fn clone(&self) -> Self {
1449        Self {
1450            table: self.table,
1451            table_math: self.table_math,
1452            default: self.default,
1453            timestamp: self.timestamp,
1454            scratch: Vec::new(),
1455        }
1456    }
1457}
1458
1459impl TableAction {
1460    /// Create a new `TableAction` to be used later.
1461    /// This requires access to the `egglog_bridge::EGraph`.
1462    pub fn new(egraph: &EGraph, func: FunctionId) -> TableAction {
1463        assert!(!egraph.tracing, "proofs not supported yet");
1464
1465        let func_info = &egraph.funcs[func];
1466        TableAction {
1467            table: func_info.table,
1468            table_math: SchemaMath {
1469                func_cols: func_info.schema.len(),
1470                subsume: func_info.can_subsume,
1471                tracing: egraph.tracing,
1472            },
1473            default: match &func_info.default_val {
1474                DefaultVal::FreshId => Some(MergeVal::Counter(egraph.id_counter)),
1475                DefaultVal::Fail => None,
1476                DefaultVal::Const(val) => Some(MergeVal::Constant(*val)),
1477            },
1478            timestamp: egraph.timestamp_counter,
1479            scratch: Vec::new(),
1480        }
1481    }
1482
1483    /// A "table lookup" is not a read-only operation. It will insert a row when
1484    /// the [`DefaultVal`] for the table is not [`DefaultVal::Fail`] and
1485    /// the `key` is not already present in the table.
1486    pub fn lookup(&self, state: &mut ExecutionState, key: &[Value]) -> Option<Value> {
1487        match self.default {
1488            Some(default) => {
1489                let timestamp =
1490                    MergeVal::Constant(Value::from_usize(state.read_counter(self.timestamp)));
1491                let mut merge_vals = SmallVec::<[MergeVal; 3]>::new();
1492                SchemaMath {
1493                    func_cols: 1,
1494                    ..self.table_math
1495                }
1496                .write_table_row(
1497                    &mut merge_vals,
1498                    RowVals {
1499                        timestamp,
1500                        proof: None,
1501                        subsume: self
1502                            .table_math
1503                            .subsume
1504                            .then_some(MergeVal::Constant(NOT_SUBSUMED)),
1505                        ret_val: Some(default),
1506                    },
1507                );
1508                Some(
1509                    state.predict_val(self.table, key, merge_vals.iter().copied())
1510                        [self.table_math.ret_val_col()],
1511                )
1512            }
1513            None => state
1514                .get_table(self.table)
1515                .get_row(key)
1516                .map(|row| row.vals[self.table_math.ret_val_col()]),
1517        }
1518    }
1519
1520    /// Insert a row into this table.
1521    pub fn insert(&mut self, state: &mut ExecutionState, row: impl Iterator<Item = Value>) {
1522        let ts = Value::from_usize(state.read_counter(self.timestamp));
1523        self.scratch.clear();
1524        self.scratch.extend(row);
1525        self.table_math.write_table_row(
1526            &mut self.scratch,
1527            RowVals {
1528                timestamp: ts,
1529                proof: None,
1530                subsume: self.table_math.subsume.then_some(NOT_SUBSUMED),
1531                ret_val: None,
1532            },
1533        );
1534        state.stage_insert(self.table, &self.scratch);
1535    }
1536
1537    /// Delete a row from this table.
1538    pub fn remove(&self, state: &mut ExecutionState, key: &[Value]) {
1539        state.stage_remove(self.table, key);
1540    }
1541
1542    /// Subsume a row in this table.
1543    pub fn subsume(&mut self, state: &mut ExecutionState, key: impl Iterator<Item = Value>) {
1544        let ts = Value::from_usize(state.read_counter(self.timestamp));
1545        self.scratch.clear();
1546        self.scratch.extend(key);
1547
1548        let ret_val = self
1549            .lookup(state, &self.scratch)
1550            .expect("subsume lookup failed");
1551
1552        self.table_math.write_table_row(
1553            &mut self.scratch,
1554            RowVals {
1555                timestamp: ts,
1556                proof: None,
1557                subsume: Some(SUBSUMED),
1558                ret_val: Some(ret_val),
1559            },
1560        );
1561        state.stage_insert(self.table, &self.scratch);
1562    }
1563}
1564
1565/// A variant of `TableAction` for the union-find.
1566#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1567pub struct UnionAction {
1568    table: TableId,
1569    timestamp: CounterId,
1570}
1571
1572impl UnionAction {
1573    /// Create a new `UnionAction` to be used later.
1574    /// This requires access to the `egglog_bridge::EGraph`.
1575    pub fn new(egraph: &EGraph) -> UnionAction {
1576        assert!(!egraph.tracing, "proofs not supported yet");
1577        UnionAction {
1578            table: egraph.uf_table,
1579            timestamp: egraph.timestamp_counter,
1580        }
1581    }
1582
1583    /// Union two values.
1584    pub fn union(&self, state: &mut ExecutionState, x: Value, y: Value) {
1585        let ts = Value::from_usize(state.read_counter(self.timestamp));
1586        state.stage_insert(self.table, &[x, y, ts]);
1587    }
1588}
1589
1590fn run_rules_impl(
1591    db: &mut Database,
1592    rule_info: &mut DenseIdMapWithReuse<RuleId, RuleInfo>,
1593    rules: &[RuleId],
1594    next_ts: Timestamp,
1595    report_level: ReportLevel,
1596) -> Result<RuleSetReport> {
1597    for rule in rules {
1598        let info = &mut rule_info[*rule];
1599        if info.cached_plan.is_none() {
1600            info.cached_plan = Some(info.query.build_cached_plan(db, &info.desc)?);
1601        }
1602    }
1603    let mut rsb = db.new_rule_set();
1604    for rule in rules {
1605        let info = &mut rule_info[*rule];
1606        let cached_plan = info.cached_plan.as_ref().unwrap();
1607        info.query
1608            .add_rules_from_cached(&mut rsb, info.last_run_at, cached_plan)?;
1609        info.last_run_at = next_ts;
1610    }
1611    let ruleset = rsb.build();
1612    Ok(db.run_rule_set(&ruleset, report_level))
1613}
1614
1615// These markers are just used to make it easy to distinguish time spent in
1616// incremental vs. nonincremental rebuilds in time-based profiles.
1617
1618#[inline(never)]
1619fn marker_incremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1620    f()
1621}
1622
1623#[inline(never)]
1624fn marker_nonincremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1625    f()
1626}
1627
1628/// A useful type definition for external functions that need to pass data
1629/// to outside code, such as `Panic`.
1630pub type SideChannel<T> = Arc<Mutex<Option<T>>>;
1631
1632/// An external function used to grab a value out of the database matching a
1633/// particular query.
1634//
1635// TODO: once we have parallelism wired in, we'll want to replace this with a
1636// more efficient solution (e.g. one based on crossbeam or arcswap).
1637#[derive(Clone)]
1638struct GetFirstMatch(SideChannel<Vec<Value>>);
1639
1640impl ExternalFunction for GetFirstMatch {
1641    fn invoke(&self, _: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1642        let mut guard = self.0.lock().unwrap();
1643        if guard.is_some() {
1644            return None;
1645        }
1646        *guard = Some(args.to_vec());
1647        Some(Value::new(0))
1648    }
1649}
1650
1651/// This is a variant on [`Panic`] that avoids eager construction of the panic message.
1652///
1653/// The main thing this is used for is to avoid constructing the panic message ahead of time during
1654/// a call to [`RuleBuilder::call_external_func`]; these panic messages are often quite rare and
1655/// may never need to be constructed at all. Furthermore, a closure to produce the panic message in
1656/// most cases need only close over a few cheap-to-clone values.
1657///
1658/// The downside of this, and why we do not use it everywhere, is that there's no natural "key"
1659/// that we can use to cache duplicate panic messages. We would need a more complex API to support
1660/// both and fully replace our use of `Panic`.
1661struct LazyPanic<F>(Arc<Lazy<String, F>>, SideChannel<String>);
1662
1663impl<F: FnOnce() -> String + Send> ExternalFunction for LazyPanic<F> {
1664    fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1665        assert!(args.is_empty());
1666        state.trigger_early_stop();
1667        let mut guard = self.1.lock().unwrap();
1668        if guard.is_none() {
1669            *guard = Some(Lazy::force(&self.0).clone());
1670        }
1671        None
1672    }
1673}
1674
1675impl<F> Clone for LazyPanic<F> {
1676    fn clone(&self) -> Self {
1677        LazyPanic(self.0.clone(), self.1.clone())
1678    }
1679}
1680
1681/// An external function used to store a message when a panic occurs.
1682//
1683// TODO: once we have parallelism wired in, we'll want to replace this with a
1684// more efficient solution (e.g. one based on crossbeam or arcswap).
1685#[derive(Clone)]
1686struct Panic(String, SideChannel<String>);
1687
1688impl EGraph {
1689    /// Create a new `ExternalFunction` that panics with the given message.
1690    pub fn new_panic(&mut self, message: String) -> ExternalFunctionId {
1691        *self
1692            .panic_funcs
1693            .entry(message.to_string())
1694            .or_insert_with(|| {
1695                let panic = Panic(message, self.panic_message.clone());
1696                self.db.add_external_function(Box::new(panic))
1697            })
1698    }
1699
1700    pub fn new_panic_lazy(
1701        &mut self,
1702        message: impl FnOnce() -> String + Send + 'static,
1703    ) -> ExternalFunctionId {
1704        let lazy = Lazy::new(message);
1705        let panic = LazyPanic(Arc::new(lazy), self.panic_message.clone());
1706        self.db.add_external_function(Box::new(panic))
1707    }
1708}
1709
1710impl ExternalFunction for Panic {
1711    fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1712        // TODO (egglog feature): change this to support interpolating panic messages
1713        assert!(args.is_empty());
1714
1715        state.trigger_early_stop();
1716        let mut guard = self.1.lock().unwrap();
1717        if guard.is_none() {
1718            *guard = Some(self.0.clone());
1719        }
1720        None
1721    }
1722}
1723
1724#[derive(Error, Debug)]
1725enum ProofReconstructionError {
1726    #[error(
1727        "attempting to explain a row without tracing enabled. Try constructing with `EGraph::with_tracing`"
1728    )]
1729    TracingNotEnabled,
1730    #[error("attempting to construct a proof that {term1} = {term2}, but they are not equal")]
1731    EqualityExplanationOfUnequalTerms { term1: String, term2: String },
1732}
1733
1734/// Heuristic for deciding whether to do an incremental or nonincremental
1735/// rebuild for a given table.
1736fn incremental_rebuild(uf_size: usize, table_size: usize, parallel: bool) -> bool {
1737    if parallel {
1738        uf_size <= (table_size / 16)
1739    } else {
1740        uf_size <= (table_size / 8)
1741    }
1742}
1743
1744pub(crate) const SUBSUMED: Value = Value::new_const(1);
1745pub(crate) const NOT_SUBSUMED: Value = Value::new_const(0);
1746fn combine_subsumed(v1: Value, v2: Value) -> Value {
1747    std::cmp::max(v1, v2)
1748}
1749
1750/// A struct helping with some calculations of where some information is stored at the
1751/// core-relations Table level for a given function.
1752///
1753/// Functions can have multiple "output columns" in the underlying core-relations layer depending
1754/// on whether different features are enabled. Roughly, tables are laid out as:
1755///
1756/// > `[key0, ..., keyn, return value, timestamp, proof_id?, subsume?]`
1757///
1758/// Where there are `n+1` key columns and columns marked with a question mark are optional,
1759/// depending on the egraph and table-level configuration.
1760#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1761struct SchemaMath {
1762    /// Whether or not proofs are enabled.
1763    tracing: bool,
1764    /// Whether or not the table is enabled for subsumption.
1765    subsume: bool,
1766    /// The number of columns in the function (including the return value).
1767    func_cols: usize,
1768}
1769
1770/// A struct containing possible non-key portions of a table row. To be used with
1771/// [`SchemaMath::write_table_row`].
1772///
1773/// This is not to be confused with [`FunctionRow`], which is higher-level and for public uses.
1774struct RowVals<T> {
1775    /// The timestamp for the row.
1776    timestamp: T,
1777    /// The proof id (or term id) for the row. Only relevant if tracing is enabled.
1778    proof: Option<T>,
1779    /// The subsumption tag for the row. Only relevant if the table has subsumption enabled.
1780    subsume: Option<T>,
1781    /// The return value of the row. Return values are mandatory but callers may have already
1782    /// filled it in.
1783    ret_val: Option<T>,
1784}
1785
1786/// A struct representing the content of a row in a function table
1787#[derive(Clone, Debug)]
1788pub struct FunctionRow<'a> {
1789    pub vals: &'a [Value],
1790    pub subsumed: bool,
1791}
1792
1793impl SchemaMath {
1794    fn write_table_row<T: Clone>(
1795        &self,
1796        row: &mut impl HasResizeWith<T>,
1797        RowVals {
1798            timestamp,
1799            proof,
1800            subsume,
1801            ret_val,
1802        }: RowVals<T>,
1803    ) {
1804        row.resize_with(self.table_columns(), || timestamp.clone());
1805        row[self.ts_col()] = timestamp;
1806        if let Some(ret_val) = ret_val {
1807            row[self.ret_val_col()] = ret_val;
1808        }
1809        if let Some(proof_id) = proof {
1810            row[self.proof_id_col()] = proof_id;
1811        } else {
1812            assert!(
1813                !self.tracing,
1814                "proof_id must be provided if tracing is enabled"
1815            );
1816        }
1817        if let Some(subsume) = subsume {
1818            row[self.subsume_col()] = subsume;
1819        } else {
1820            assert!(
1821                !self.subsume,
1822                "subsume flag must be provided if subsumption is enabled"
1823            );
1824        }
1825    }
1826
1827    fn num_keys(&self) -> usize {
1828        self.func_cols - 1
1829    }
1830
1831    fn table_columns(&self) -> usize {
1832        self.func_cols + 1 /* timestamp */ + if self.tracing { 1 } else { 0 } + if self.subsume { 1 } else { 0 }
1833    }
1834
1835    #[track_caller]
1836    fn proof_id_col(&self) -> usize {
1837        assert!(self.tracing);
1838        self.func_cols + 1
1839    }
1840
1841    fn ret_val_col(&self) -> usize {
1842        self.func_cols - 1
1843    }
1844
1845    fn ts_col(&self) -> usize {
1846        self.func_cols
1847    }
1848
1849    #[track_caller]
1850    fn subsume_col(&self) -> usize {
1851        assert!(self.subsume);
1852        if self.tracing {
1853            self.func_cols + 2
1854        } else {
1855            self.func_cols + 1
1856        }
1857    }
1858}
1859
1860#[derive(Error, Debug)]
1861#[error("Panic: {0}")]
1862struct PanicError(String);
1863
1864/// Basic ad-hoc polymorphism around `resize_with` in order to get [`SchemaMath::write_table_row`]
1865/// to work with both `Vec` and `SmallVec`.
1866trait HasResizeWith<T>:
1867    AsMut<[T]> + AsRef<[T]> + Index<usize, Output = T> + IndexMut<usize, Output = T>
1868{
1869    fn resize_with<F>(&mut self, new_size: usize, f: F)
1870    where
1871        F: FnMut() -> T;
1872}
1873
1874impl<T> HasResizeWith<T> for Vec<T> {
1875    fn resize_with<F>(&mut self, new_size: usize, f: F)
1876    where
1877        F: FnMut() -> T,
1878    {
1879        self.resize_with(new_size, f);
1880    }
1881}
1882
1883impl<T, A: smallvec::Array<Item = T>> HasResizeWith<T> for SmallVec<A> {
1884    fn resize_with<F>(&mut self, new_size: usize, f: F)
1885    where
1886        F: FnMut() -> T,
1887    {
1888        self.resize_with(new_size, f);
1889    }
1890}