egglog_bridge/
rule.rs

1//! APIs for building egglog rules.
2//!
3//! Egglog rules are ultimately just (sets of) `core-relations` rules
4//! parameterized by a range of timestamps used as constraints during seminaive
5//! evaluation.
6
7use std::{cmp::Ordering, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11    ColumnId, Constraint, CounterId, ExternalFunctionId, PlanStrategy, QueryBuilder,
12    RuleBuilder as CoreRuleBuilder, RuleSetBuilder, TableId, Value, WriteVal,
13};
14use crate::numeric_id::{DenseIdMap, NumericId, define_id};
15use anyhow::Context;
16use hashbrown::HashSet;
17use log::debug;
18use smallvec::SmallVec;
19use thiserror::Error;
20
21use crate::syntax::SourceSyntax;
22use crate::{CachedPlanInfo, NOT_SUBSUMED, RowVals, SUBSUMED, SchemaMath};
23use crate::{
24    ColumnTy, DefaultVal, EGraph, FunctionId, Result, RuleId, RuleInfo, Timestamp,
25    proof_spec::{ProofBuilder, RebuildVars},
26};
27
28define_id!(pub Variable, u32, "A variable in an egglog query");
29define_id!(pub AtomId, u32, "an atom in an egglog query");
30pub(crate) type DstVar = core_relations::QueryEntry;
31
32#[derive(Debug, Error)]
33enum RuleBuilderError {
34    #[error("type mismatch: expected {expected:?}, got {got:?}")]
35    TypeMismatch { expected: ColumnTy, got: ColumnTy },
36    #[error("arity mismatch: expected {expected:?}, got {got:?}")]
37    ArityMismatch { expected: usize, got: usize },
38}
39
40#[derive(Clone)]
41struct VarInfo {
42    ty: ColumnTy,
43    /// If there is a "term-level" variant of this variable bound elsewhere, it
44    /// is stored here. Otherwise, this points back to the variable itself.
45    term_var: Variable,
46}
47
48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
49pub enum QueryEntry {
50    Var {
51        id: Variable,
52        name: Option<Box<str>>,
53    },
54    Const {
55        val: Value,
56        // Constants can have a type plumbed through, particularly if they
57        // correspond to a base value constant in egglog.
58        ty: ColumnTy,
59    },
60}
61
62impl QueryEntry {
63    /// Get the variable associated with this entry, panicking if it isn't a
64    /// variable.
65    pub(crate) fn var(&self) -> Variable {
66        match self {
67            QueryEntry::Var { id, .. } => *id,
68            QueryEntry::Const { .. } => panic!("expected variable, found constant"),
69        }
70    }
71}
72
73impl From<Variable> for QueryEntry {
74    fn from(id: Variable) -> Self {
75        QueryEntry::Var { id, name: None }
76    }
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
80pub enum Function {
81    Table(FunctionId),
82    Prim(ExternalFunctionId),
83}
84
85impl From<FunctionId> for Function {
86    fn from(f: FunctionId) -> Self {
87        Function::Table(f)
88    }
89}
90
91impl From<ExternalFunctionId> for Function {
92    fn from(f: ExternalFunctionId) -> Self {
93        Function::Prim(f)
94    }
95}
96
97trait Brc:
98    Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + dyn_clone::DynClone + Send + Sync
99{
100}
101impl<T: Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + Clone + Send + Sync> Brc for T {}
102dyn_clone::clone_trait_object!(Brc);
103type BuildRuleCallback = Box<dyn Brc>;
104
105#[derive(Clone)]
106pub(crate) struct Query {
107    uf_table: TableId,
108    id_counter: CounterId,
109    ts_counter: CounterId,
110    tracing: bool,
111    rule_id: RuleId,
112    vars: DenseIdMap<Variable, VarInfo>,
113    /// The current proofs that are in scope.
114    atom_proofs: Vec<Variable>,
115    atoms: Vec<(TableId, Vec<QueryEntry>, SchemaMath)>,
116    /// An optional callback to wire up proof-related metadata before running the RHS of a rule.
117    build_reason: Option<BuildRuleCallback>,
118    /// The builders for queries in this module essentially wrap the lower-level
119    /// builders from the `core_relations` crate. A single egglog rule can turn
120    /// into N core-relations rules. The code is structured by constructing a
121    /// series of callbacks that will iteratively build up a low-level rule that
122    /// looks like the high-level rule, passing along an environment that keeps
123    /// track of the mappings between low and high-level variables.
124    add_rule: Vec<BuildRuleCallback>,
125    /// If set, execute a single rule (rather than O(atoms.len()) rules) during
126    /// seminaive, with the given atom as the focus.
127    sole_focus: Option<usize>,
128    seminaive: bool,
129    plan_strategy: PlanStrategy,
130}
131
132pub struct RuleBuilder<'a> {
133    egraph: &'a mut EGraph,
134    proof_builder: ProofBuilder,
135    query: Query,
136}
137
138impl EGraph {
139    /// Add a rewrite rule for this [`EGraph`] using a [`RuleBuilder`].
140    /// If you aren't sure, use `egraph.new_rule("", true)`.
141    pub fn new_rule(&mut self, desc: &str, seminaive: bool) -> RuleBuilder<'_> {
142        let uf_table = self.uf_table;
143        let id_counter = self.id_counter;
144        let ts_counter = self.timestamp_counter;
145        let tracing = self.tracing;
146        let rule_id = self.rules.reserve_slot();
147        RuleBuilder {
148            egraph: self,
149            proof_builder: ProofBuilder::new(desc, rule_id),
150            query: Query {
151                uf_table,
152                id_counter,
153                ts_counter,
154                tracing,
155                rule_id,
156                seminaive,
157                build_reason: None,
158                sole_focus: None,
159                atom_proofs: Default::default(),
160                vars: Default::default(),
161                atoms: Default::default(),
162                add_rule: Default::default(),
163                plan_strategy: Default::default(),
164            },
165        }
166    }
167
168    /// Remove a rewrite rule from this [`EGraph`].
169    pub fn free_rule(&mut self, id: RuleId) {
170        self.rules.take(id);
171    }
172}
173
174impl RuleBuilder<'_> {
175    fn add_callback(&mut self, cb: impl Brc + 'static) {
176        self.query.add_rule.push(Box::new(cb));
177    }
178
179    /// Access the underlying egraph within the builder.
180    pub fn egraph(&self) -> &EGraph {
181        self.egraph
182    }
183
184    pub(crate) fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
185        self.query.plan_strategy = strategy;
186    }
187
188    /// Get the canonical value of an id in the union-find. An internal-only
189    /// routine used to implement rebuilding.
190    ///
191    /// Note, calling this with a non-Id entry can cause errors at rule runtime
192    /// (The derived rules will not compile).
193    pub(crate) fn lookup_uf(&mut self, entry: QueryEntry) -> Result<Variable> {
194        let res = self.new_var(ColumnTy::Id);
195        let uf_table = self.query.uf_table;
196        self.assert_has_ty(&entry, ColumnTy::Id)
197            .context("lookup_uf: ")?;
198        self.add_callback(move |inner, rb| {
199            let entry = inner.convert(&entry);
200            let res_inner = rb.lookup_with_default(uf_table, &[entry], entry, ColumnId::new(1))?;
201            inner.mapping.insert(res, res_inner.into());
202            Ok(())
203        });
204        Ok(res)
205    }
206
207    /// A low-level routine used in rebuilding. Halts execution if `lhs` and
208    /// `rhs` are equal (pointwise).
209    ///
210    /// Note, calling this with invalid arguments (e.g. different lengths for
211    /// `lhs` and `rhs`) can cause errors at rule runtime.
212    pub(crate) fn check_for_update(
213        &mut self,
214        lhs: &[QueryEntry],
215        rhs: &[QueryEntry],
216    ) -> Result<()> {
217        let lhs = SmallVec::<[QueryEntry; 4]>::from_iter(lhs.iter().cloned());
218        let rhs = SmallVec::<[QueryEntry; 4]>::from_iter(rhs.iter().cloned());
219        if lhs.len() != rhs.len() {
220            return Err(RuleBuilderError::ArityMismatch {
221                expected: lhs.len(),
222                got: rhs.len(),
223            }
224            .into());
225        }
226        lhs.iter().zip(rhs.iter()).try_for_each(|(l, r)| {
227            self.assert_same_ty(l, r).with_context(|| {
228                format!("check_for_update: {lhs:?} and {rhs:?}, mismatch between {l:?} and {r:?}")
229            })
230        })?;
231
232        self.add_callback(move |inner, rb| {
233            let lhs = inner.convert_all(&lhs);
234            let rhs = inner.convert_all(&rhs);
235            rb.assert_any_ne(&lhs, &rhs).context("check_for_update")
236        });
237        Ok(())
238    }
239
240    fn assert_same_ty(
241        &self,
242        l: &QueryEntry,
243        r: &QueryEntry,
244    ) -> std::result::Result<(), RuleBuilderError> {
245        match (l, r) {
246            (QueryEntry::Var { id: v1, .. }, QueryEntry::Var { id: v2, .. }) => {
247                let ty1 = self.query.vars[*v1].ty;
248                let ty2 = self.query.vars[*v2].ty;
249                if ty1 != ty2 {
250                    return Err(RuleBuilderError::TypeMismatch {
251                        expected: ty1,
252                        got: ty2,
253                    });
254                }
255            }
256            // constants can be untyped
257            (QueryEntry::Const { .. }, QueryEntry::Const { .. })
258            | (QueryEntry::Var { .. }, QueryEntry::Const { .. })
259            | (QueryEntry::Const { .. }, QueryEntry::Var { .. }) => {}
260        }
261        Ok(())
262    }
263
264    fn assert_has_ty(
265        &self,
266        entry: &QueryEntry,
267        ty: ColumnTy,
268    ) -> std::result::Result<(), RuleBuilderError> {
269        if let QueryEntry::Var { id: v, .. } = entry {
270            let var_ty = self.query.vars[*v].ty;
271            if var_ty != ty {
272                return Err(RuleBuilderError::TypeMismatch {
273                    expected: var_ty,
274                    got: ty,
275                });
276            }
277        }
278        Ok(())
279    }
280
281    /// Register the given rule with the egraph.
282    pub fn build(self) -> RuleId {
283        assert!(
284            !self.egraph.tracing,
285            "proofs are enabled: use `build_with_syntax` instead"
286        );
287        self.build_internal(None)
288    }
289
290    pub fn build_with_syntax(self, syntax: SourceSyntax) -> RuleId {
291        self.build_internal(Some(syntax))
292    }
293
294    pub(crate) fn build_internal(mut self, syntax: Option<SourceSyntax>) -> RuleId {
295        if self.query.atoms.len() == 1 {
296            self.query.plan_strategy = PlanStrategy::MinCover;
297        }
298        if let Some(syntax) = &syntax {
299            if self.egraph.tracing {
300                let cb = self
301                    .proof_builder
302                    .create_reason(syntax.clone(), self.egraph);
303                self.query.build_reason = Some(Box::new(move |bndgs, rb| {
304                    let reason = cb(bndgs, rb)?;
305                    bndgs.lhs_reason = Some(reason.into());
306                    Ok(())
307                }));
308            }
309        }
310        let res = self.query.rule_id;
311        let info = RuleInfo {
312            last_run_at: Timestamp::new(0),
313            query: self.query,
314            cached_plan: None,
315            desc: self.proof_builder.rule_description,
316        };
317        debug!("created rule {res:?} / {}", info.desc);
318        self.egraph.rules.insert(res, info);
319        res
320    }
321
322    pub(crate) fn set_focus(&mut self, focus: usize) {
323        self.query.sole_focus = Some(focus);
324    }
325
326    /// Bind a new variable of the given type in the query.
327    pub fn new_var(&mut self, ty: ColumnTy) -> Variable {
328        let res = self.query.vars.next_id();
329        self.query.vars.push(VarInfo { ty, term_var: res })
330    }
331
332    /// Bind a new variable of the given type in the query.
333    ///
334    /// This method attaches the given name to the [`QueryEntry`], which can
335    /// make debugging easier.
336    pub fn new_var_named(&mut self, ty: ColumnTy, name: &str) -> QueryEntry {
337        let id = self.new_var(ty);
338        QueryEntry::Var {
339            id,
340            name: Some(name.into()),
341        }
342    }
343
344    /// A low-level way to add an atom to a query.
345    ///
346    /// The atom is added directly to `table`. If `func` is supplied, then metadata about the
347    /// function is used for schema validation and proof generation. If `subsume_entry` is
348    /// supplied and the supplied function is enabled for subsumption, then the given
349    /// [`QueryEntry`] is used to populated the subsumption column for the table. This allows
350    /// higher-level routines to constrain the subsumption column or use it for other purposes.
351    pub(crate) fn add_atom_with_timestamp_and_func(
352        &mut self,
353        table: TableId,
354        func: Option<FunctionId>,
355        subsume_entry: Option<QueryEntry>,
356        entries: &[QueryEntry],
357    ) -> AtomId {
358        let mut atom = entries.to_vec();
359        let schema_math = if let Some(func) = func {
360            let info = &self.egraph.funcs[func];
361            assert_eq!(info.schema.len(), entries.len());
362            SchemaMath {
363                tracing: self.egraph.tracing,
364                subsume: info.can_subsume,
365                func_cols: info.schema.len(),
366            }
367        } else {
368            SchemaMath {
369                tracing: self.egraph.tracing,
370                subsume: subsume_entry.is_some(),
371                func_cols: entries.len(),
372            }
373        };
374        schema_math.write_table_row(
375            &mut atom,
376            RowVals {
377                timestamp: self.new_var(ColumnTy::Id).into(),
378                proof: self
379                    .egraph
380                    .tracing
381                    .then(|| self.new_var(ColumnTy::Id).into()),
382                subsume: if schema_math.subsume {
383                    Some(subsume_entry.unwrap_or_else(|| self.new_var(ColumnTy::Id).into()))
384                } else {
385                    None
386                },
387                ret_val: None,
388            },
389        );
390        let res = AtomId::from_usize(self.query.atoms.len());
391        if self.egraph.tracing {
392            let proof_var = atom[schema_math.proof_id_col()].var();
393            self.proof_builder.term_vars.insert(res, proof_var.into());
394            self.query.atom_proofs.push(proof_var);
395            if let Some(QueryEntry::Var { id, .. }) = entries.last() {
396                if table != self.egraph.uf_table {
397                    // Don't overwrite "term_var" for uf_table; it stores
398                    // reasons inline / doesn't have terms.
399                    self.query.vars[*id].term_var = proof_var;
400                }
401            }
402        }
403        self.query.atoms.push((table, atom, schema_math));
404        res
405    }
406
407    pub fn call_external_func(
408        &mut self,
409        func: ExternalFunctionId,
410        args: &[QueryEntry],
411        ret_ty: ColumnTy,
412        panic_msg: impl FnOnce() -> String + 'static + Send,
413    ) -> Variable {
414        let args = args.to_vec();
415        let res = self.new_var(ret_ty);
416        // External functions that fail on the RHS of a rule should cause a panic.
417        let panic_fn = self.egraph.new_panic_lazy(panic_msg);
418        self.query.add_rule.push(Box::new(move |inner, rb| {
419            let args = inner.convert_all(&args);
420            let var = rb.call_external_with_fallback(func, &args, panic_fn, &[])?;
421            inner.mapping.insert(res, var.into());
422            Ok(())
423        }));
424        res
425    }
426
427    /// Add the given table atom to query. As elsewhere in the crate, the last
428    /// argument is the "return value" of the function. Can also optionally
429    /// check the subsumption bit.
430    pub fn query_table(
431        &mut self,
432        func: FunctionId,
433        entries: &[QueryEntry],
434        is_subsumed: Option<bool>,
435    ) -> Result<AtomId> {
436        let info = &self.egraph.funcs[func];
437        let schema = &info.schema;
438        if schema.len() != entries.len() {
439            return Err(anyhow::Error::from(RuleBuilderError::ArityMismatch {
440                expected: schema.len(),
441                got: entries.len(),
442            }))
443            .with_context(|| format!("query_table: mismatch between {entries:?} and {schema:?}"));
444        }
445        entries
446            .iter()
447            .zip(schema.iter())
448            .try_for_each(|(entry, ty)| {
449                self.assert_has_ty(entry, *ty)
450                    .with_context(|| format!("query_table: mismatch between {entry:?} and {ty:?}"))
451            })?;
452        Ok(self.add_atom_with_timestamp_and_func(
453            info.table,
454            Some(func),
455            is_subsumed.map(|b| QueryEntry::Const {
456                val: match b {
457                    true => SUBSUMED,
458                    false => NOT_SUBSUMED,
459                },
460                ty: ColumnTy::Id,
461            }),
462            entries,
463        ))
464    }
465
466    /// Add the given primitive atom to query. As elsewhere in the crate, the last
467    /// argument is the "return value" of the function.
468    pub fn query_prim(
469        &mut self,
470        func: ExternalFunctionId,
471        entries: &[QueryEntry],
472        // NB: not clear if we still need this now that proof checker is in a separate crate.
473        _ret_ty: ColumnTy,
474    ) -> Result<()> {
475        let entries = entries.to_vec();
476        self.query.add_rule.push(Box::new(move |inner, rb| {
477            let mut dst_vars = inner.convert_all(&entries);
478            let expected = dst_vars.pop().expect("must specify a return value");
479            let var = rb.call_external(func, &dst_vars)?;
480            match entries.last().unwrap() {
481                QueryEntry::Var { id, .. } if !inner.grounded.contains(id) => {
482                    inner.mapping.insert(*id, var.into());
483                    inner.grounded.insert(*id);
484                }
485                _ => rb.assert_eq(var.into(), expected),
486            }
487            Ok(())
488        }));
489        Ok(())
490    }
491
492    /// Subsume the given entry in `func`.
493    ///
494    /// `entries` should match the number of keys to the function.
495    pub fn subsume(&mut self, func: FunctionId, entries: &[QueryEntry]) {
496        // First, insert a subsumed value if the tuple is new.
497        let ret = self.lookup_with_subsumed(
498            func,
499            entries,
500            QueryEntry::Const {
501                val: SUBSUMED,
502                ty: ColumnTy::Id,
503            },
504            || "subsumed a nonextestent row!".to_string(),
505        );
506        let info = &self.egraph.funcs[func];
507        let schema_math = SchemaMath {
508            tracing: self.egraph.tracing,
509            subsume: info.can_subsume,
510            func_cols: info.schema.len(),
511        };
512        assert!(info.can_subsume);
513        assert_eq!(entries.len() + 1, info.schema.len());
514        let entries = entries.to_vec();
515        let table = info.table;
516        self.add_callback(move |inner, rb| {
517            // Then, add a tuple subsuming the entry, but only if the entry isn't already subsumed.
518            // Look up the current subsume value.
519            let mut dst_entries = inner.convert_all(&entries);
520            let cur_subsume_val = rb.lookup(
521                table,
522                &dst_entries,
523                ColumnId::from_usize(schema_math.subsume_col()),
524            )?;
525            let cur_proof_val = if schema_math.tracing {
526                Some(DstVar::from(rb.lookup(
527                    table,
528                    &dst_entries,
529                    ColumnId::from_usize(schema_math.proof_id_col()),
530                )?))
531            } else {
532                None
533            };
534            schema_math.write_table_row(
535                &mut dst_entries,
536                RowVals {
537                    timestamp: inner.next_ts(),
538                    proof: cur_proof_val,
539                    subsume: Some(SUBSUMED.into()),
540                    ret_val: Some(inner.convert(&ret.into())),
541                },
542            );
543            rb.insert_if_eq(
544                table,
545                cur_subsume_val.into(),
546                NOT_SUBSUMED.into(),
547                &dst_entries,
548            )?;
549            Ok(())
550        });
551    }
552
553    pub(crate) fn lookup_with_subsumed(
554        &mut self,
555        func: FunctionId,
556        entries: &[QueryEntry],
557        subsumed: QueryEntry,
558        panic_msg: impl FnOnce() -> String + Send + 'static,
559    ) -> Variable {
560        let entries = entries.to_vec();
561        let info = &self.egraph.funcs[func];
562        let res = self.query.vars.push(VarInfo {
563            ty: info.ret_ty(),
564            term_var: self.query.vars.next_id(),
565        });
566        let table = info.table;
567        let id_counter = self.query.id_counter;
568        let schema_math = SchemaMath {
569            tracing: self.egraph.tracing,
570            subsume: info.can_subsume,
571            func_cols: info.schema.len(),
572        };
573        let cb: BuildRuleCallback = match info.default_val {
574            DefaultVal::Const(_) | DefaultVal::FreshId => {
575                let (wv, wv_ref): (WriteVal, WriteVal) = match &info.default_val {
576                    DefaultVal::Const(c) => ((*c).into(), (*c).into()),
577                    DefaultVal::FreshId => (
578                        WriteVal::IncCounter(id_counter),
579                        // When we create a new term, we should
580                        // simply "reuse" the value we just minted
581                        // for the value.
582                        WriteVal::CurrentVal(schema_math.ret_val_col()),
583                    ),
584                    _ => unreachable!(),
585                };
586                let get_write_vals = move |inner: &mut Bindings| {
587                    let mut write_vals = SmallVec::<[WriteVal; 4]>::new();
588                    for i in schema_math.num_keys()..schema_math.table_columns() {
589                        if i == schema_math.ts_col() {
590                            write_vals.push(inner.next_ts().into());
591                        } else if i == schema_math.ret_val_col() {
592                            write_vals.push(wv);
593                        } else if schema_math.tracing && i == schema_math.proof_id_col() {
594                            write_vals.push(wv_ref);
595                        } else if schema_math.subsume && i == schema_math.subsume_col() {
596                            write_vals.push(inner.convert(&subsumed).into())
597                        } else {
598                            unreachable!()
599                        }
600                    }
601                    write_vals
602                };
603
604                if self.egraph.tracing {
605                    let term_var = self.new_var(ColumnTy::Id);
606                    self.query.vars[res].term_var = term_var;
607                    let ts_var = self.new_var(ColumnTy::Id);
608                    let mut insert_entries = entries.to_vec();
609                    insert_entries.push(res.into());
610                    let add_proof =
611                        self.proof_builder
612                            .new_row(func, insert_entries, term_var, self.egraph);
613                    Box::new(move |inner, rb| {
614                        let write_vals = get_write_vals(inner);
615                        let dst_vars = inner.convert_all(&entries);
616                        // NB: having one `lookup_or_insert` call
617                        // per projection is pretty inefficient
618                        // here, but merging these into a custom
619                        // instruction didn't move the needle on a
620                        // write-heavy benchmark when I tried it
621                        // early on. May be worth revisiting after
622                        // more low-hanging fruit has been
623                        // optimized.
624                        let var = rb.lookup_or_insert(
625                            table,
626                            &dst_vars,
627                            &write_vals,
628                            ColumnId::from_usize(schema_math.ret_val_col()),
629                        )?;
630                        let ts = rb.lookup_or_insert(
631                            table,
632                            &dst_vars,
633                            &write_vals,
634                            ColumnId::from_usize(schema_math.ts_col()),
635                        )?;
636                        let term = rb.lookup_or_insert(
637                            table,
638                            &dst_vars,
639                            &write_vals,
640                            ColumnId::from_usize(schema_math.proof_id_col()),
641                        )?;
642                        inner.mapping.insert(term_var, term.into());
643                        inner.mapping.insert(res, var.into());
644                        inner.mapping.insert(ts_var, ts.into());
645                        rb.assert_eq(var.into(), term.into());
646                        // The following bookeeping is only needed
647                        // if the value is new. That only happens if
648                        // the main id equals the term id.
649                        add_proof(inner, rb)?;
650                        Ok(())
651                    })
652                } else {
653                    Box::new(move |inner, rb| {
654                        let write_vals = get_write_vals(inner);
655                        let dst_vars = inner.convert_all(&entries);
656                        let var = rb.lookup_or_insert(
657                            table,
658                            &dst_vars,
659                            &write_vals,
660                            ColumnId::from_usize(schema_math.ret_val_col()),
661                        )?;
662                        inner.mapping.insert(res, var.into());
663                        Ok(())
664                    })
665                }
666            }
667            DefaultVal::Fail => {
668                let panic_func = self.egraph.new_panic_lazy(panic_msg);
669                if self.egraph.tracing {
670                    let term_var = self.new_var(ColumnTy::Id);
671                    Box::new(move |inner, rb| {
672                        let dst_vars = inner.convert_all(&entries);
673                        let var = rb.lookup_with_fallback(
674                            table,
675                            &dst_vars,
676                            ColumnId::from_usize(schema_math.ret_val_col()),
677                            panic_func,
678                            &[],
679                        )?;
680                        let term = rb.lookup(
681                            table,
682                            &dst_vars,
683                            ColumnId::from_usize(schema_math.proof_id_col()),
684                        )?;
685                        inner.mapping.insert(res, var.into());
686                        inner.mapping.insert(term_var, term.into());
687                        Ok(())
688                    })
689                } else {
690                    Box::new(move |inner, rb| {
691                        let dst_vars = inner.convert_all(&entries);
692                        let var = rb.lookup_with_fallback(
693                            table,
694                            &dst_vars,
695                            ColumnId::from_usize(schema_math.ret_val_col()),
696                            panic_func,
697                            &[],
698                        )?;
699                        inner.mapping.insert(res, var.into());
700                        Ok(())
701                    })
702                }
703            }
704        };
705        self.query.add_rule.push(cb);
706        res
707    }
708
709    /// Look up the value of a function in the database. If the value is not
710    /// present, the configured default for the function is used.
711    ///
712    /// For functions configured with [`DefaultVal::Fail`], failing lookups will use `panic_msg` in
713    /// the panic output.
714    pub fn lookup(
715        &mut self,
716        func: FunctionId,
717        entries: &[QueryEntry],
718        panic_msg: impl FnOnce() -> String + Send + 'static,
719    ) -> Variable {
720        self.lookup_with_subsumed(
721            func,
722            entries,
723            QueryEntry::Const {
724                val: NOT_SUBSUMED,
725                ty: ColumnTy::Id,
726            },
727            panic_msg,
728        )
729    }
730
731    /// Merge the two values in the union-find.
732    pub fn union(&mut self, mut l: QueryEntry, mut r: QueryEntry) {
733        let cb: BuildRuleCallback = if self.query.tracing {
734            // Union proofs should reflect term-level variables rather than the
735            // current leader of the e-class.
736            for entry in [&mut l, &mut r] {
737                if let QueryEntry::Var { id, .. } = entry {
738                    *id = self.query.vars[*id].term_var;
739                }
740            }
741            Box::new(move |inner, rb| {
742                let l = inner.convert(&l);
743                let r = inner.convert(&r);
744                let proof = inner.lhs_reason.expect("reason must be set");
745                rb.insert(inner.uf_table, &[l, r, inner.next_ts(), proof])
746                    .context("union")
747            })
748        } else {
749            Box::new(move |inner, rb| {
750                let l = inner.convert(&l);
751                let r = inner.convert(&r);
752                rb.insert(inner.uf_table, &[l, r, inner.next_ts()])
753                    .context("union")
754            })
755        };
756        self.query.add_rule.push(cb);
757    }
758
759    /// This method is equivalent to `remove(table, before); set(table, after)`
760    /// when tracing/proofs aren't enabled. When proofs are enabled, it
761    /// creates a proof term specialized for equality.
762    ///
763    /// This allows us to reconstruct proofs lazily from the UF, rather than
764    /// running the proof generation algorithm eagerly as we query the table.
765    /// Proof generation is a relatively expensive operation, and we'd prefer to
766    /// avoid doing it on every union-find lookup.
767    pub(crate) fn rebuild_row(
768        &mut self,
769        func: FunctionId,
770        before: &[QueryEntry],
771        after: &[QueryEntry],
772        // If subsumption is enabled for this function, we can optionally propagate it to the next
773        // row.
774        subsume_var: Option<Variable>,
775    ) {
776        assert_eq!(before.len(), after.len());
777        self.remove(func, &before[..before.len() - 1]);
778        if !self.egraph.tracing {
779            if let Some(subsume_var) = subsume_var {
780                self.set_with_subsume(
781                    func,
782                    after,
783                    QueryEntry::Var {
784                        id: subsume_var,
785                        name: None,
786                    },
787                );
788            } else {
789                self.set(func, after);
790            }
791            return;
792        }
793        let table = self.egraph.funcs[func].table;
794        let term_var = self.new_var(ColumnTy::Id);
795        let reason_var = self.new_var(ColumnTy::Id);
796        let before_id = before.last().unwrap().var();
797        let before_term = self.query.vars[before_id].term_var;
798        debug_assert_ne!(before_term, before_id);
799        let add_proof = self.proof_builder.rebuild_proof(
800            func,
801            after,
802            RebuildVars {
803                before_term,
804                new_term: term_var,
805                reason: reason_var,
806            },
807            self.egraph,
808        );
809        let after = SmallVec::<[_; 4]>::from_iter(after.iter().cloned());
810        let uf_table = self.query.uf_table;
811        let info = &self.egraph.funcs[func];
812        let schema_math = SchemaMath {
813            tracing: self.egraph.tracing,
814            subsume: info.can_subsume,
815            func_cols: info.schema.len(),
816        };
817        self.query.add_rule.push(Box::new(move |inner, rb| {
818            add_proof(inner, rb)?;
819            let mut dst_vars = inner.convert_all(&after);
820            schema_math.write_table_row(
821                &mut dst_vars,
822                RowVals {
823                    timestamp: inner.next_ts(),
824                    proof: Some(inner.mapping[term_var]),
825                    subsume: subsume_var.map(|v| inner.mapping[v]),
826                    ret_val: None, // already filled in,
827                },
828            );
829            // This congruence rule will also serve as a proof that the old and
830            // new terms are equal.
831            rb.insert(
832                uf_table,
833                &[
834                    inner.mapping[before_term],
835                    inner.mapping[term_var],
836                    inner.next_ts(),
837                    inner.mapping[reason_var],
838                ],
839            )
840            .context("rebuild_row_uf")?;
841            rb.insert(table, &dst_vars).context("rebuild_row_table")
842        }));
843    }
844
845    /// Set the value of a function in the database.
846    pub fn set(&mut self, func: FunctionId, entries: &[QueryEntry]) {
847        self.set_with_subsume(
848            func,
849            entries,
850            QueryEntry::Const {
851                val: NOT_SUBSUMED,
852                ty: ColumnTy::Id,
853            },
854        );
855    }
856
857    pub(crate) fn set_with_subsume(
858        &mut self,
859        func: FunctionId,
860        entries: &[QueryEntry],
861        subsume_entry: QueryEntry,
862    ) {
863        let info = &self.egraph.funcs[func];
864        let table = info.table;
865        let entries = entries.to_vec();
866        let schema_math = SchemaMath {
867            tracing: self.egraph.tracing,
868            subsume: info.can_subsume,
869            func_cols: info.schema.len(),
870        };
871        if self.egraph.tracing {
872            let res = self.lookup(func, &entries[0..entries.len() - 1], || {
873                "lookup failed during proof-enabled set; this is an internal proofs bug".to_string()
874            });
875            self.union(res.into(), entries.last().unwrap().clone());
876            if schema_math.subsume {
877                // Set the original row but with the passed-in subsumption value.
878                self.add_callback(move |inner, rb| {
879                    let mut dst_vars = inner.convert_all(&entries);
880                    let proof_var = rb.lookup(
881                        table,
882                        &dst_vars[0..schema_math.num_keys()],
883                        ColumnId::from_usize(schema_math.proof_id_col()),
884                    )?;
885                    schema_math.write_table_row(
886                        &mut dst_vars,
887                        RowVals {
888                            timestamp: inner.next_ts(),
889                            proof: Some(proof_var.into()),
890                            subsume: Some(inner.convert(&subsume_entry)),
891                            ret_val: Some(inner.convert(&res.into())),
892                        },
893                    );
894                    rb.insert(table, &dst_vars).context("set")
895                });
896            }
897        } else {
898            self.query.add_rule.push(Box::new(move |inner, rb| {
899                let mut dst_vars = inner.convert_all(&entries);
900                schema_math.write_table_row(
901                    &mut dst_vars,
902                    RowVals {
903                        timestamp: inner.next_ts(),
904                        proof: None, // tracing is off
905                        subsume: schema_math.subsume.then(|| inner.convert(&subsume_entry)),
906                        ret_val: None, // already filled in
907                    },
908                );
909                rb.insert(table, &dst_vars).context("set")
910            }));
911        };
912    }
913
914    /// Remove the value of a function from the database.
915    pub fn remove(&mut self, table: FunctionId, entries: &[QueryEntry]) {
916        let table = self.egraph.funcs[table].table;
917        let entries = entries.to_vec();
918        let cb: BuildRuleCallback = Box::new(move |inner, rb| {
919            let dst_vars = inner.convert_all(&entries);
920            rb.remove(table, &dst_vars).context("remove")
921        });
922        self.query.add_rule.push(cb);
923    }
924
925    /// Panic with a given message.
926    pub fn panic(&mut self, message: String) {
927        let panic = self.egraph.new_panic(message.clone());
928        let ret_ty = ColumnTy::Id;
929        let res = self.new_var(ret_ty);
930        self.query.add_rule.push(Box::new(move |inner, rb| {
931            let var = rb.call_external(panic, &[])?;
932            inner.mapping.insert(res, var.into());
933            Ok(())
934        }));
935    }
936}
937
938impl Query {
939    fn query_state<'a, 'outer>(
940        &self,
941        rsb: &'a mut RuleSetBuilder<'outer>,
942    ) -> (QueryBuilder<'outer, 'a>, Bindings) {
943        let mut qb = rsb.new_rule();
944        qb.set_plan_strategy(self.plan_strategy);
945        let mut inner = Bindings {
946            uf_table: self.uf_table,
947            next_ts: None,
948            lhs_reason: None,
949            mapping: Default::default(),
950            grounded: Default::default(),
951        };
952        for (var, _) in self.vars.iter() {
953            inner.mapping.insert(var, DstVar::Var(qb.new_var()));
954        }
955        (qb, inner)
956    }
957
958    fn run_rules_and_build(
959        &self,
960        qb: QueryBuilder,
961        mut inner: Bindings,
962        desc: &str,
963    ) -> Result<core_relations::RuleId> {
964        let mut rb = qb.build();
965        inner.next_ts = Some(rb.read_counter(self.ts_counter).into());
966        // Set up proof state if it's configured.
967        if let Some(build_reason) = &self.build_reason {
968            build_reason(&mut inner, &mut rb)?;
969        }
970        self.add_rule
971            .iter()
972            .try_for_each(|f| f(&mut inner, &mut rb))?;
973        Ok(rb.build_with_description(desc))
974    }
975
976    pub(crate) fn build_cached_plan(
977        &self,
978        db: &mut core_relations::Database,
979        desc: &str,
980    ) -> Result<CachedPlanInfo> {
981        let mut rsb = RuleSetBuilder::new(db);
982        let (mut qb, mut inner) = self.query_state(&mut rsb);
983        let mut atom_mapping = Vec::with_capacity(self.atoms.len());
984        for (table, entries, _schema_info) in &self.atoms {
985            atom_mapping.push(add_atom(&mut qb, *table, entries, &[], &mut inner)?);
986        }
987        let rule_id = self.run_rules_and_build(qb, inner, desc)?;
988        let rs = rsb.build();
989        let plan = Arc::new(rs.build_cached_plan(rule_id));
990        Ok(CachedPlanInfo { plan, atom_mapping })
991    }
992
993    /// Add rules to the [`RuleSetBuilder`] for the query specified by the [`CachedPlanInfo`].
994    ///
995    /// A [`CachedPlanInfo`] is a compiled RHS and partial LHS for an egglog rules. In order to
996    /// implement seminaive evaluation, we run several variants of this cached plan with different
997    /// constraints on the timestamps for different atoms. This rule handles building these
998    /// variants of the base plan and adding them to `rsb`.
999    pub(crate) fn add_rules_from_cached(
1000        &self,
1001        rsb: &mut RuleSetBuilder,
1002        mid_ts: Timestamp,
1003        cached_plan: &CachedPlanInfo,
1004    ) -> Result<()> {
1005        // For N atoms, we create N queries for seminaive evaluation. We can reuse the cached plan
1006        // directly.
1007        if !self.seminaive || (self.atoms.is_empty() && mid_ts == Timestamp::new(0)) {
1008            rsb.add_rule_from_cached_plan(&cached_plan.plan, &[]);
1009            return Ok(());
1010        }
1011        if let Some(focus_atom) = self.sole_focus {
1012            // There is a single "focus" atom that we will constrain to look at new values.
1013            let (_, _, schema_info) = &self.atoms[focus_atom];
1014            let ts_col = ColumnId::from_usize(schema_info.ts_col());
1015            rsb.add_rule_from_cached_plan(
1016                &cached_plan.plan,
1017                &[(
1018                    cached_plan.atom_mapping[focus_atom],
1019                    Constraint::GeConst {
1020                        col: ts_col,
1021                        val: mid_ts.to_value(),
1022                    },
1023                )],
1024            );
1025            return Ok(());
1026        }
1027        // Use the cached plan atoms.len() times with different constraints on each atom.
1028        let mut constraints: Vec<(core_relations::AtomId, Constraint)> =
1029            Vec::with_capacity(self.atoms.len());
1030        'outer: for focus_atom in 0..self.atoms.len() {
1031            for (i, (_, _, schema_info)) in self.atoms.iter().enumerate() {
1032                let ts_col = ColumnId::from_usize(schema_info.ts_col());
1033                match i.cmp(&focus_atom) {
1034                    Ordering::Less => {
1035                        if mid_ts == Timestamp::new(0) {
1036                            continue 'outer;
1037                        }
1038                        constraints.push((
1039                            cached_plan.atom_mapping[i],
1040                            Constraint::LtConst {
1041                                col: ts_col,
1042                                val: mid_ts.to_value(),
1043                            },
1044                        ));
1045                    }
1046                    Ordering::Equal => constraints.push((
1047                        cached_plan.atom_mapping[i],
1048                        Constraint::GeConst {
1049                            col: ts_col,
1050                            val: mid_ts.to_value(),
1051                        },
1052                    )),
1053                    Ordering::Greater => {}
1054                };
1055            }
1056            rsb.add_rule_from_cached_plan(&cached_plan.plan, &constraints);
1057            constraints.clear();
1058        }
1059        Ok(())
1060    }
1061}
1062
1063/// State that is used during query execution to translate variabes in egglog
1064/// rules into variables for core-relations rules.
1065pub(crate) struct Bindings {
1066    uf_table: TableId,
1067    next_ts: Option<DstVar>,
1068    /// If proofs are enabled, this variable contains the "reason id" for any union or insertion
1069    /// that happens on the RHS of a rule.
1070    pub(crate) lhs_reason: Option<DstVar>,
1071    pub(crate) mapping: DenseIdMap<Variable, DstVar>,
1072    grounded: HashSet<Variable>,
1073}
1074
1075impl Bindings {
1076    pub(crate) fn next_ts(&self) -> DstVar {
1077        self.next_ts
1078            .expect("ts_var should only be used in RHS of the rule")
1079    }
1080    pub(crate) fn convert(&self, entry: &QueryEntry) -> DstVar {
1081        match entry {
1082            QueryEntry::Var { id: v, .. } => self.mapping[*v],
1083            QueryEntry::Const { val, .. } => DstVar::Const(*val),
1084        }
1085    }
1086    pub(crate) fn convert_all(&self, entries: &[QueryEntry]) -> SmallVec<[DstVar; 4]> {
1087        entries.iter().map(|e| self.convert(e)).collect()
1088    }
1089}
1090
1091fn add_atom(
1092    qb: &mut QueryBuilder,
1093    table: TableId,
1094    entries: &[QueryEntry],
1095    constraints: &[Constraint],
1096    inner: &mut Bindings,
1097) -> Result<core_relations::AtomId> {
1098    for entry in entries {
1099        if let QueryEntry::Var { id, .. } = entry {
1100            inner.grounded.insert(*id);
1101        }
1102    }
1103    let vars = inner.convert_all(entries);
1104    Ok(qb.add_atom(table, &vars, constraints)?)
1105}