Skip to main content

egglog_bridge/
rule.rs

1//! APIs for building egglog rules.
2//!
3//! Egglog rules are ultimately just (sets of) `core-relations` rules
4//! parameterized by a range of timestamps used as constraints during seminaive
5//! evaluation.
6
7use std::{cmp::Ordering, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11    ColumnId, Constraint, CounterId, ExternalFunctionId, PlanStrategy, QueryBuilder,
12    RuleBuilder as CoreRuleBuilder, RuleSetBuilder, TableId, Value, WriteVal,
13};
14use crate::numeric_id::{DenseIdMap, NumericId, define_id};
15use anyhow::Context;
16use hashbrown::HashSet;
17use log::debug;
18use smallvec::SmallVec;
19use thiserror::Error;
20
21use crate::syntax::SourceSyntax;
22use crate::{CachedPlanInfo, NOT_SUBSUMED, RowVals, SUBSUMED, SchemaMath};
23use crate::{
24    ColumnTy, DefaultVal, EGraph, FunctionId, Result, RuleId, RuleInfo, Timestamp,
25    proof_spec::{ProofBuilder, RebuildVars},
26};
27
28define_id!(pub VariableId, u32, "A variable in an egglog query");
29define_id!(pub AtomId, u32, "an atom in an egglog query");
30pub(crate) type DstVar = core_relations::QueryEntry;
31
32impl VariableId {
33    fn to_var(self) -> Variable {
34        Variable {
35            id: self,
36            name: None,
37        }
38    }
39}
40
41#[derive(Clone, Debug, PartialEq, Eq, Hash)]
42pub struct Variable {
43    pub id: VariableId,
44    pub name: Option<Box<str>>,
45}
46
47#[derive(Debug, Error)]
48enum RuleBuilderError {
49    #[error("type mismatch: expected {expected:?}, got {got:?}")]
50    TypeMismatch { expected: ColumnTy, got: ColumnTy },
51    #[error("arity mismatch: expected {expected:?}, got {got:?}")]
52    ArityMismatch { expected: usize, got: usize },
53}
54
55#[derive(Clone)]
56struct VarInfo {
57    ty: ColumnTy,
58    name: Option<Box<str>>,
59    /// If there is a "term-level" variant of this variable bound elsewhere, it
60    /// is stored here. Otherwise, this points back to the variable itself.
61    term_var: Variable,
62}
63
64#[derive(Clone, Debug, PartialEq, Eq, Hash)]
65pub enum QueryEntry {
66    Var(Variable),
67    Const {
68        val: Value,
69        // Constants can have a type plumbed through, particularly if they
70        // correspond to a base value constant in egglog.
71        ty: ColumnTy,
72    },
73}
74
75impl QueryEntry {
76    /// Get the variable associated with this entry, panicking if it isn't a
77    /// variable.
78    pub(crate) fn var(&self) -> Variable {
79        match self {
80            QueryEntry::Var(v) => v.clone(),
81            QueryEntry::Const { .. } => panic!("expected variable, found constant"),
82        }
83    }
84}
85
86impl From<Variable> for QueryEntry {
87    fn from(var: Variable) -> Self {
88        QueryEntry::Var(var)
89    }
90}
91
92#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
93pub enum Function {
94    Table(FunctionId),
95    Prim(ExternalFunctionId),
96}
97
98impl From<FunctionId> for Function {
99    fn from(f: FunctionId) -> Self {
100        Function::Table(f)
101    }
102}
103
104impl From<ExternalFunctionId> for Function {
105    fn from(f: ExternalFunctionId) -> Self {
106        Function::Prim(f)
107    }
108}
109
110trait Brc:
111    Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + dyn_clone::DynClone + Send + Sync
112{
113}
114impl<T: Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + Clone + Send + Sync> Brc for T {}
115dyn_clone::clone_trait_object!(Brc);
116type BuildRuleCallback = Box<dyn Brc>;
117
118#[derive(Clone)]
119pub(crate) struct Query {
120    uf_table: TableId,
121    id_counter: CounterId,
122    ts_counter: CounterId,
123    tracing: bool,
124    rule_id: RuleId,
125    vars: DenseIdMap<VariableId, VarInfo>,
126    /// The current proofs that are in scope.
127    atom_proofs: Vec<Variable>,
128    atoms: Vec<(TableId, Vec<QueryEntry>, SchemaMath)>,
129    /// An optional callback to wire up proof-related metadata before running the RHS of a rule.
130    build_reason: Option<BuildRuleCallback>,
131    /// The builders for queries in this module essentially wrap the lower-level
132    /// builders from the `core_relations` crate. A single egglog rule can turn
133    /// into N core-relations rules. The code is structured by constructing a
134    /// series of callbacks that will iteratively build up a low-level rule that
135    /// looks like the high-level rule, passing along an environment that keeps
136    /// track of the mappings between low and high-level variables.
137    add_rule: Vec<BuildRuleCallback>,
138    /// If set, execute a single rule (rather than O(atoms.len()) rules) during
139    /// seminaive, with the given atom as the focus.
140    sole_focus: Option<usize>,
141    seminaive: bool,
142    plan_strategy: PlanStrategy,
143}
144
145pub struct RuleBuilder<'a> {
146    egraph: &'a mut EGraph,
147    proof_builder: ProofBuilder,
148    query: Query,
149}
150
151impl EGraph {
152    /// Add a rewrite rule for this [`EGraph`] using a [`RuleBuilder`].
153    /// If you aren't sure, use `egraph.new_rule("", true)`.
154    pub fn new_rule(&mut self, desc: &str, seminaive: bool) -> RuleBuilder<'_> {
155        let uf_table = self.uf_table;
156        let id_counter = self.id_counter;
157        let ts_counter = self.timestamp_counter;
158        let tracing = self.tracing;
159        let rule_id = self.rules.reserve_slot();
160        RuleBuilder {
161            egraph: self,
162            proof_builder: ProofBuilder::new(desc, rule_id),
163            query: Query {
164                uf_table,
165                id_counter,
166                ts_counter,
167                tracing,
168                rule_id,
169                seminaive,
170                build_reason: None,
171                sole_focus: None,
172                atom_proofs: Default::default(),
173                vars: Default::default(),
174                atoms: Default::default(),
175                add_rule: Default::default(),
176                plan_strategy: Default::default(),
177            },
178        }
179    }
180
181    /// Remove a rewrite rule from this [`EGraph`].
182    pub fn free_rule(&mut self, id: RuleId) {
183        self.rules.take(id);
184    }
185}
186
187impl RuleBuilder<'_> {
188    fn add_callback(&mut self, cb: impl Brc + 'static) {
189        self.query.add_rule.push(Box::new(cb));
190    }
191
192    /// Access the underlying egraph within the builder.
193    pub fn egraph(&self) -> &EGraph {
194        self.egraph
195    }
196
197    pub(crate) fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
198        self.query.plan_strategy = strategy;
199    }
200
201    /// Get the canonical value of an id in the union-find. An internal-only
202    /// routine used to implement rebuilding.
203    ///
204    /// Note, calling this with a non-Id entry can cause errors at rule runtime
205    /// (The derived rules will not compile).
206    pub(crate) fn lookup_uf(&mut self, entry: QueryEntry) -> Result<Variable> {
207        let res = self.new_var(ColumnTy::Id);
208        let uf_table = self.query.uf_table;
209        self.assert_has_ty(&entry, ColumnTy::Id)
210            .context("lookup_uf: ")?;
211        self.add_callback(move |inner, rb| {
212            let entry = inner.convert(&entry);
213            let res_inner = rb.lookup_with_default(uf_table, &[entry], entry, ColumnId::new(1))?;
214            inner.mapping.insert(res.id, res_inner.into());
215            Ok(())
216        });
217        Ok(res)
218    }
219
220    /// A low-level routine used in rebuilding. Halts execution if `lhs` and
221    /// `rhs` are equal (pointwise).
222    ///
223    /// Note, calling this with invalid arguments (e.g. different lengths for
224    /// `lhs` and `rhs`) can cause errors at rule runtime.
225    pub(crate) fn check_for_update(
226        &mut self,
227        lhs: &[QueryEntry],
228        rhs: &[QueryEntry],
229    ) -> Result<()> {
230        let lhs = SmallVec::<[QueryEntry; 4]>::from_iter(lhs.iter().cloned());
231        let rhs = SmallVec::<[QueryEntry; 4]>::from_iter(rhs.iter().cloned());
232        if lhs.len() != rhs.len() {
233            return Err(RuleBuilderError::ArityMismatch {
234                expected: lhs.len(),
235                got: rhs.len(),
236            }
237            .into());
238        }
239        lhs.iter().zip(rhs.iter()).try_for_each(|(l, r)| {
240            self.assert_same_ty(l, r).with_context(|| {
241                format!("check_for_update: {lhs:?} and {rhs:?}, mismatch between {l:?} and {r:?}")
242            })
243        })?;
244
245        self.add_callback(move |inner, rb| {
246            let lhs = inner.convert_all(&lhs);
247            let rhs = inner.convert_all(&rhs);
248            rb.assert_any_ne(&lhs, &rhs).context("check_for_update")
249        });
250        Ok(())
251    }
252
253    fn assert_same_ty(
254        &self,
255        l: &QueryEntry,
256        r: &QueryEntry,
257    ) -> std::result::Result<(), RuleBuilderError> {
258        match (l, r) {
259            (
260                QueryEntry::Var(Variable { id: v1, .. }),
261                QueryEntry::Var(Variable { id: v2, .. }),
262            ) => {
263                let ty1 = self.query.vars[*v1].ty;
264                let ty2 = self.query.vars[*v2].ty;
265                if ty1 != ty2 {
266                    return Err(RuleBuilderError::TypeMismatch {
267                        expected: ty1,
268                        got: ty2,
269                    });
270                }
271            }
272            // constants can be untyped
273            (QueryEntry::Const { .. }, QueryEntry::Const { .. })
274            | (QueryEntry::Var { .. }, QueryEntry::Const { .. })
275            | (QueryEntry::Const { .. }, QueryEntry::Var { .. }) => {}
276        }
277        Ok(())
278    }
279
280    fn assert_has_ty(
281        &self,
282        entry: &QueryEntry,
283        ty: ColumnTy,
284    ) -> std::result::Result<(), RuleBuilderError> {
285        if let QueryEntry::Var(Variable { id: v, .. }) = entry {
286            let var_ty = self.query.vars[*v].ty;
287            if var_ty != ty {
288                return Err(RuleBuilderError::TypeMismatch {
289                    expected: var_ty,
290                    got: ty,
291                });
292            }
293        }
294        Ok(())
295    }
296
297    /// Register the given rule with the egraph.
298    pub fn build(self) -> RuleId {
299        assert!(
300            !self.egraph.tracing,
301            "proofs are enabled: use `build_with_syntax` instead"
302        );
303        self.build_internal(None)
304    }
305
306    pub fn build_with_syntax(self, syntax: SourceSyntax) -> RuleId {
307        self.build_internal(Some(syntax))
308    }
309
310    pub(crate) fn build_internal(mut self, syntax: Option<SourceSyntax>) -> RuleId {
311        if self.query.atoms.len() == 1 {
312            self.query.plan_strategy = PlanStrategy::MinCover;
313        }
314        if let Some(syntax) = &syntax {
315            if self.egraph.tracing {
316                let cb = self
317                    .proof_builder
318                    .create_reason(syntax.clone(), self.egraph);
319                self.query.build_reason = Some(Box::new(move |bndgs, rb| {
320                    let reason = cb(bndgs, rb)?;
321                    bndgs.lhs_reason = Some(reason.into());
322                    Ok(())
323                }));
324            }
325        }
326        let res = self.query.rule_id;
327        let info = RuleInfo {
328            last_run_at: Timestamp::new(0),
329            query: self.query,
330            cached_plan: None,
331            desc: self.proof_builder.rule_description,
332        };
333        debug!("created rule {res:?} / {}", info.desc);
334        self.egraph.rules.insert(res, info);
335        res
336    }
337
338    pub(crate) fn set_focus(&mut self, focus: usize) {
339        self.query.sole_focus = Some(focus);
340    }
341
342    /// Bind a new variable of the given type in the query.
343    pub fn new_var(&mut self, ty: ColumnTy) -> Variable {
344        let res = self.query.vars.next_id();
345        let var = Variable {
346            id: res,
347            name: None,
348        };
349        self.query.vars.push(VarInfo {
350            ty,
351            name: None,
352            term_var: var.clone(),
353        });
354        var
355    }
356
357    /// Bind a new variable of the given type in the query.
358    ///
359    /// This method attaches the given name to the [`QueryEntry`], which can
360    /// make debugging easier.
361    pub fn new_var_named(&mut self, ty: ColumnTy, name: &str) -> QueryEntry {
362        let id = self.query.vars.next_id();
363        let var = Variable {
364            id,
365            name: Some(name.into()),
366        };
367        self.query.vars.push(VarInfo {
368            ty,
369            name: Some(name.into()),
370            term_var: var.clone(),
371        });
372        QueryEntry::Var(var)
373    }
374
375    /// A low-level way to add an atom to a query.
376    ///
377    /// The atom is added directly to `table`. If `func` is supplied, then metadata about the
378    /// function is used for schema validation and proof generation. If `subsume_entry` is
379    /// supplied and the supplied function is enabled for subsumption, then the given
380    /// [`QueryEntry`] is used to populated the subsumption column for the table. This allows
381    /// higher-level routines to constrain the subsumption column or use it for other purposes.
382    pub(crate) fn add_atom_with_timestamp_and_func(
383        &mut self,
384        table: TableId,
385        func: Option<FunctionId>,
386        subsume_entry: Option<QueryEntry>,
387        entries: &[QueryEntry],
388    ) -> AtomId {
389        let mut atom = entries.to_vec();
390        let schema_math = if let Some(func) = func {
391            let info = &self.egraph.funcs[func];
392            assert_eq!(info.schema.len(), entries.len());
393            SchemaMath {
394                tracing: self.egraph.tracing,
395                subsume: info.can_subsume,
396                func_cols: info.schema.len(),
397            }
398        } else {
399            SchemaMath {
400                tracing: self.egraph.tracing,
401                subsume: subsume_entry.is_some(),
402                func_cols: entries.len(),
403            }
404        };
405        schema_math.write_table_row(
406            &mut atom,
407            RowVals {
408                timestamp: self.new_var(ColumnTy::Id).into(),
409                proof: self
410                    .egraph
411                    .tracing
412                    .then(|| self.new_var(ColumnTy::Id).into()),
413                subsume: if schema_math.subsume {
414                    Some(subsume_entry.unwrap_or_else(|| self.new_var(ColumnTy::Id).into()))
415                } else {
416                    None
417                },
418                ret_val: None,
419            },
420        );
421        let res = AtomId::from_usize(self.query.atoms.len());
422        if self.egraph.tracing {
423            let proof_var = atom[schema_math.proof_id_col()].var();
424            self.proof_builder
425                .term_vars
426                .insert(res, proof_var.clone().into());
427            if let Some(QueryEntry::Var(Variable { id, .. })) = entries.last() {
428                if table != self.egraph.uf_table {
429                    // Don't overwrite "term_var" for uf_table; it stores
430                    // reasons inline / doesn't have terms.
431                    self.query.vars[*id].term_var = proof_var.clone();
432                }
433            }
434            self.query.atom_proofs.push(proof_var);
435        }
436        self.query.atoms.push((table, atom, schema_math));
437        res
438    }
439
440    pub fn call_external_func(
441        &mut self,
442        func: ExternalFunctionId,
443        args: &[QueryEntry],
444        ret_ty: ColumnTy,
445        panic_msg: impl FnOnce() -> String + 'static + Send,
446    ) -> Variable {
447        let args = args.to_vec();
448        let res = self.new_var(ret_ty);
449        // External functions that fail on the RHS of a rule should cause a panic.
450        let panic_fn = self.egraph.new_panic_lazy(panic_msg);
451        self.query.add_rule.push(Box::new(move |inner, rb| {
452            let args = inner.convert_all(&args);
453            let var = rb.call_external_with_fallback(func, &args, panic_fn, &[])?;
454            inner.mapping.insert(res.id, var.into());
455            Ok(())
456        }));
457        res
458    }
459
460    /// Add the given table atom to query. As elsewhere in the crate, the last
461    /// argument is the "return value" of the function. Can also optionally
462    /// check the subsumption bit.
463    pub fn query_table(
464        &mut self,
465        func: FunctionId,
466        entries: &[QueryEntry],
467        is_subsumed: Option<bool>,
468    ) -> Result<AtomId> {
469        let info = &self.egraph.funcs[func];
470        let schema = &info.schema;
471        if schema.len() != entries.len() {
472            return Err(anyhow::Error::from(RuleBuilderError::ArityMismatch {
473                expected: schema.len(),
474                got: entries.len(),
475            }))
476            .with_context(|| format!("query_table: mismatch between {entries:?} and {schema:?}"));
477        }
478        entries
479            .iter()
480            .zip(schema.iter())
481            .try_for_each(|(entry, ty)| {
482                self.assert_has_ty(entry, *ty)
483                    .with_context(|| format!("query_table: mismatch between {entry:?} and {ty:?}"))
484            })?;
485        Ok(self.add_atom_with_timestamp_and_func(
486            info.table,
487            Some(func),
488            is_subsumed.map(|b| QueryEntry::Const {
489                val: match b {
490                    true => SUBSUMED,
491                    false => NOT_SUBSUMED,
492                },
493                ty: ColumnTy::Id,
494            }),
495            entries,
496        ))
497    }
498
499    /// Add the given primitive atom to query. As elsewhere in the crate, the last
500    /// argument is the "return value" of the function.
501    pub fn query_prim(
502        &mut self,
503        func: ExternalFunctionId,
504        entries: &[QueryEntry],
505        // NB: not clear if we still need this now that proof checker is in a separate crate.
506        _ret_ty: ColumnTy,
507    ) -> Result<()> {
508        let entries = entries.to_vec();
509        self.query.add_rule.push(Box::new(move |inner, rb| {
510            let mut dst_vars = inner.convert_all(&entries);
511            let expected = dst_vars.pop().expect("must specify a return value");
512            let var = rb.call_external(func, &dst_vars)?;
513            match entries.last().unwrap() {
514                QueryEntry::Var(Variable { id, .. }) if !inner.grounded.contains(id) => {
515                    inner.mapping.insert(*id, var.into());
516                    inner.grounded.insert(*id);
517                }
518                _ => rb.assert_eq(var.into(), expected),
519            }
520            Ok(())
521        }));
522        Ok(())
523    }
524
525    /// Subsume the given entry in `func`.
526    ///
527    /// `entries` should match the number of keys to the function.
528    pub fn subsume(&mut self, func: FunctionId, entries: &[QueryEntry]) {
529        // First, insert a subsumed value if the tuple is new.
530        let ret = self.lookup_with_subsumed(
531            func,
532            entries,
533            QueryEntry::Const {
534                val: SUBSUMED,
535                ty: ColumnTy::Id,
536            },
537            || "subsumed a nonextestent row!".to_string(),
538        );
539        let info = &self.egraph.funcs[func];
540        let schema_math = SchemaMath {
541            tracing: self.egraph.tracing,
542            subsume: info.can_subsume,
543            func_cols: info.schema.len(),
544        };
545        assert!(info.can_subsume);
546        assert_eq!(entries.len() + 1, info.schema.len());
547        let entries = entries.to_vec();
548        let table = info.table;
549
550        let ret: QueryEntry = ret.into();
551        self.add_callback(move |inner, rb| {
552            // Then, add a tuple subsuming the entry, but only if the entry isn't already subsumed.
553            // Look up the current subsume value.
554            let mut dst_entries = inner.convert_all(&entries);
555            let cur_subsume_val = rb.lookup(
556                table,
557                &dst_entries,
558                ColumnId::from_usize(schema_math.subsume_col()),
559            )?;
560            let cur_proof_val = if schema_math.tracing {
561                Some(DstVar::from(rb.lookup(
562                    table,
563                    &dst_entries,
564                    ColumnId::from_usize(schema_math.proof_id_col()),
565                )?))
566            } else {
567                None
568            };
569            schema_math.write_table_row(
570                &mut dst_entries,
571                RowVals {
572                    timestamp: inner.next_ts(),
573                    proof: cur_proof_val,
574                    subsume: Some(SUBSUMED.into()),
575                    ret_val: Some(inner.convert(&ret)),
576                },
577            );
578            rb.insert_if_eq(
579                table,
580                cur_subsume_val.into(),
581                NOT_SUBSUMED.into(),
582                &dst_entries,
583            )?;
584            Ok(())
585        });
586    }
587
588    pub(crate) fn lookup_with_subsumed(
589        &mut self,
590        func: FunctionId,
591        entries: &[QueryEntry],
592        subsumed: QueryEntry,
593        panic_msg: impl FnOnce() -> String + Send + 'static,
594    ) -> Variable {
595        let entries = entries.to_vec();
596        let info = &self.egraph.funcs[func];
597        let res = self
598            .query
599            .vars
600            .push(VarInfo {
601                ty: info.ret_ty(),
602                name: None,
603                term_var: self.query.vars.next_id().to_var(),
604            })
605            .to_var();
606        let table = info.table;
607        let id_counter = self.query.id_counter;
608        let schema_math = SchemaMath {
609            tracing: self.egraph.tracing,
610            subsume: info.can_subsume,
611            func_cols: info.schema.len(),
612        };
613        let cb: BuildRuleCallback = match info.default_val {
614            DefaultVal::Const(_) | DefaultVal::FreshId => {
615                let (wv, wv_ref): (WriteVal, WriteVal) = match &info.default_val {
616                    DefaultVal::Const(c) => ((*c).into(), (*c).into()),
617                    DefaultVal::FreshId => (
618                        WriteVal::IncCounter(id_counter),
619                        // When we create a new term, we should
620                        // simply "reuse" the value we just minted
621                        // for the value.
622                        WriteVal::CurrentVal(schema_math.ret_val_col()),
623                    ),
624                    _ => unreachable!(),
625                };
626                let get_write_vals = move |inner: &mut Bindings| {
627                    let mut write_vals = SmallVec::<[WriteVal; 4]>::new();
628                    for i in schema_math.num_keys()..schema_math.table_columns() {
629                        if i == schema_math.ts_col() {
630                            write_vals.push(inner.next_ts().into());
631                        } else if i == schema_math.ret_val_col() {
632                            write_vals.push(wv);
633                        } else if schema_math.tracing && i == schema_math.proof_id_col() {
634                            write_vals.push(wv_ref);
635                        } else if schema_math.subsume && i == schema_math.subsume_col() {
636                            write_vals.push(inner.convert(&subsumed).into())
637                        } else {
638                            unreachable!()
639                        }
640                    }
641                    write_vals
642                };
643
644                if self.egraph.tracing {
645                    let term_var = self.new_var(ColumnTy::Id);
646                    self.query.vars[res.id].term_var = term_var.clone();
647                    let ts_var = self.new_var(ColumnTy::Id);
648                    let mut insert_entries = entries.to_vec();
649                    insert_entries.push(res.clone().into());
650                    let add_proof =
651                        self.proof_builder
652                            .new_row(func, insert_entries, term_var.id, self.egraph);
653                    Box::new(move |inner, rb| {
654                        let write_vals = get_write_vals(inner);
655                        let dst_vars = inner.convert_all(&entries);
656                        // NB: having one `lookup_or_insert` call
657                        // per projection is pretty inefficient
658                        // here, but merging these into a custom
659                        // instruction didn't move the needle on a
660                        // write-heavy benchmark when I tried it
661                        // early on. May be worth revisiting after
662                        // more low-hanging fruit has been
663                        // optimized.
664                        let var = rb.lookup_or_insert(
665                            table,
666                            &dst_vars,
667                            &write_vals,
668                            ColumnId::from_usize(schema_math.ret_val_col()),
669                        )?;
670                        let ts = rb.lookup_or_insert(
671                            table,
672                            &dst_vars,
673                            &write_vals,
674                            ColumnId::from_usize(schema_math.ts_col()),
675                        )?;
676                        let term = rb.lookup_or_insert(
677                            table,
678                            &dst_vars,
679                            &write_vals,
680                            ColumnId::from_usize(schema_math.proof_id_col()),
681                        )?;
682                        inner.mapping.insert(term_var.id, term.into());
683                        inner.mapping.insert(res.id, var.into());
684                        inner.mapping.insert(ts_var.id, ts.into());
685                        rb.assert_eq(var.into(), term.into());
686                        // The following bookeeping is only needed
687                        // if the value is new. That only happens if
688                        // the main id equals the term id.
689                        add_proof(inner, rb)?;
690                        Ok(())
691                    })
692                } else {
693                    Box::new(move |inner, rb| {
694                        let write_vals = get_write_vals(inner);
695                        let dst_vars = inner.convert_all(&entries);
696                        let var = rb.lookup_or_insert(
697                            table,
698                            &dst_vars,
699                            &write_vals,
700                            ColumnId::from_usize(schema_math.ret_val_col()),
701                        )?;
702                        inner.mapping.insert(res.id, var.into());
703                        Ok(())
704                    })
705                }
706            }
707            DefaultVal::Fail => {
708                let panic_func = self.egraph.new_panic_lazy(panic_msg);
709                if self.egraph.tracing {
710                    let term_var = self.new_var(ColumnTy::Id);
711                    Box::new(move |inner, rb| {
712                        let dst_vars = inner.convert_all(&entries);
713                        let var = rb.lookup_with_fallback(
714                            table,
715                            &dst_vars,
716                            ColumnId::from_usize(schema_math.ret_val_col()),
717                            panic_func,
718                            &[],
719                        )?;
720                        let term = rb.lookup(
721                            table,
722                            &dst_vars,
723                            ColumnId::from_usize(schema_math.proof_id_col()),
724                        )?;
725                        inner.mapping.insert(res.id, var.into());
726                        inner.mapping.insert(term_var.id, term.into());
727                        Ok(())
728                    })
729                } else {
730                    Box::new(move |inner, rb| {
731                        let dst_vars = inner.convert_all(&entries);
732                        let var = rb.lookup_with_fallback(
733                            table,
734                            &dst_vars,
735                            ColumnId::from_usize(schema_math.ret_val_col()),
736                            panic_func,
737                            &[],
738                        )?;
739                        inner.mapping.insert(res.id, var.into());
740                        Ok(())
741                    })
742                }
743            }
744        };
745        self.query.add_rule.push(cb);
746        res
747    }
748
749    /// Look up the value of a function in the database. If the value is not
750    /// present, the configured default for the function is used.
751    ///
752    /// For functions configured with [`DefaultVal::Fail`], failing lookups will use `panic_msg` in
753    /// the panic output.
754    pub fn lookup(
755        &mut self,
756        func: FunctionId,
757        entries: &[QueryEntry],
758        panic_msg: impl FnOnce() -> String + Send + 'static,
759    ) -> Variable {
760        self.lookup_with_subsumed(
761            func,
762            entries,
763            QueryEntry::Const {
764                val: NOT_SUBSUMED,
765                ty: ColumnTy::Id,
766            },
767            panic_msg,
768        )
769    }
770
771    /// Merge the two values in the union-find.
772    pub fn union(&mut self, mut l: QueryEntry, mut r: QueryEntry) {
773        let cb: BuildRuleCallback = if self.query.tracing {
774            // Union proofs should reflect term-level variables rather than the
775            // current leader of the e-class.
776            for entry in [&mut l, &mut r] {
777                if let QueryEntry::Var(v) = entry {
778                    *v = self.query.vars[v.id].term_var.clone();
779                }
780            }
781            Box::new(move |inner, rb| {
782                let l = inner.convert(&l);
783                let r = inner.convert(&r);
784                let proof = inner.lhs_reason.expect("reason must be set");
785                rb.insert(inner.uf_table, &[l, r, inner.next_ts(), proof])
786                    .context("union")
787            })
788        } else {
789            Box::new(move |inner, rb| {
790                let l = inner.convert(&l);
791                let r = inner.convert(&r);
792                rb.insert(inner.uf_table, &[l, r, inner.next_ts()])
793                    .context("union")
794            })
795        };
796        self.query.add_rule.push(cb);
797    }
798
799    /// This method is equivalent to `remove(table, before); set(table, after)`
800    /// when tracing/proofs aren't enabled. When proofs are enabled, it
801    /// creates a proof term specialized for equality.
802    ///
803    /// This allows us to reconstruct proofs lazily from the UF, rather than
804    /// running the proof generation algorithm eagerly as we query the table.
805    /// Proof generation is a relatively expensive operation, and we'd prefer to
806    /// avoid doing it on every union-find lookup.
807    pub(crate) fn rebuild_row(
808        &mut self,
809        func: FunctionId,
810        before: &[QueryEntry],
811        after: &[QueryEntry],
812        // If subsumption is enabled for this function, we can optionally propagate it to the next
813        // row.
814        subsume_var: Option<Variable>,
815    ) {
816        assert_eq!(before.len(), after.len());
817        self.remove(func, &before[..before.len() - 1]);
818        if !self.egraph.tracing {
819            if let Some(subsume_var) = subsume_var {
820                self.set_with_subsume(func, after, QueryEntry::Var(subsume_var));
821            } else {
822                self.set(func, after);
823            }
824            return;
825        }
826        let table = self.egraph.funcs[func].table;
827        let term_var = self.new_var(ColumnTy::Id);
828        let reason_var = self.new_var(ColumnTy::Id);
829        let before_id = before.last().unwrap().var();
830        let before_term = &self.query.vars[before_id.id].term_var;
831
832        let before_term_id = before_term.id;
833        let term_var_id = term_var.id;
834        let reason_var_id = reason_var.id;
835
836        debug_assert_ne!(before_term, &before_id);
837        let add_proof = self.proof_builder.rebuild_proof(
838            func,
839            after,
840            RebuildVars {
841                before_term: before_term.clone(),
842                new_term: term_var,
843                reason: reason_var,
844            },
845            self.egraph,
846        );
847        let after = SmallVec::<[_; 4]>::from_iter(after.iter().cloned());
848        let uf_table = self.query.uf_table;
849        let info = &self.egraph.funcs[func];
850        let schema_math = SchemaMath {
851            tracing: self.egraph.tracing,
852            subsume: info.can_subsume,
853            func_cols: info.schema.len(),
854        };
855
856        self.query.add_rule.push(Box::new(move |inner, rb| {
857            add_proof(inner, rb)?;
858            let mut dst_vars = inner.convert_all(&after);
859            schema_math.write_table_row(
860                &mut dst_vars,
861                RowVals {
862                    timestamp: inner.next_ts(),
863                    proof: Some(inner.mapping[term_var_id]),
864                    subsume: subsume_var.clone().map(|v| inner.mapping[v.id]),
865                    ret_val: None, // already filled in,
866                },
867            );
868            // This congruence rule will also serve as a proof that the old and
869            // new terms are equal.
870            rb.insert(
871                uf_table,
872                &[
873                    inner.mapping[before_term_id],
874                    inner.mapping[term_var_id],
875                    inner.next_ts(),
876                    inner.mapping[reason_var_id],
877                ],
878            )
879            .context("rebuild_row_uf")?;
880            rb.insert(table, &dst_vars).context("rebuild_row_table")
881        }));
882    }
883
884    /// Set the value of a function in the database.
885    pub fn set(&mut self, func: FunctionId, entries: &[QueryEntry]) {
886        self.set_with_subsume(
887            func,
888            entries,
889            QueryEntry::Const {
890                val: NOT_SUBSUMED,
891                ty: ColumnTy::Id,
892            },
893        );
894    }
895
896    pub(crate) fn set_with_subsume(
897        &mut self,
898        func: FunctionId,
899        entries: &[QueryEntry],
900        subsume_entry: QueryEntry,
901    ) {
902        let info = &self.egraph.funcs[func];
903        let table = info.table;
904        let entries = entries.to_vec();
905        let schema_math = SchemaMath {
906            tracing: self.egraph.tracing,
907            subsume: info.can_subsume,
908            func_cols: info.schema.len(),
909        };
910        if self.egraph.tracing {
911            let res = self.lookup(func, &entries[0..entries.len() - 1], || {
912                "lookup failed during proof-enabled set; this is an internal proofs bug".to_string()
913            });
914            let res_entry = res.clone().into();
915            self.union(res.into(), entries.last().unwrap().clone());
916            if schema_math.subsume {
917                // Set the original row but with the passed-in subsumption value.
918                self.add_callback(move |inner, rb| {
919                    let mut dst_vars = inner.convert_all(&entries);
920                    let proof_var = rb.lookup(
921                        table,
922                        &dst_vars[0..schema_math.num_keys()],
923                        ColumnId::from_usize(schema_math.proof_id_col()),
924                    )?;
925                    schema_math.write_table_row(
926                        &mut dst_vars,
927                        RowVals {
928                            timestamp: inner.next_ts(),
929                            proof: Some(proof_var.into()),
930                            subsume: Some(inner.convert(&subsume_entry)),
931                            ret_val: Some(inner.convert(&res_entry)),
932                        },
933                    );
934                    rb.insert(table, &dst_vars).context("set")
935                });
936            }
937        } else {
938            self.query.add_rule.push(Box::new(move |inner, rb| {
939                let mut dst_vars = inner.convert_all(&entries);
940                schema_math.write_table_row(
941                    &mut dst_vars,
942                    RowVals {
943                        timestamp: inner.next_ts(),
944                        proof: None, // tracing is off
945                        subsume: schema_math.subsume.then(|| inner.convert(&subsume_entry)),
946                        ret_val: None, // already filled in
947                    },
948                );
949                rb.insert(table, &dst_vars).context("set")
950            }));
951        };
952    }
953
954    /// Remove the value of a function from the database.
955    pub fn remove(&mut self, table: FunctionId, entries: &[QueryEntry]) {
956        let table = self.egraph.funcs[table].table;
957        let entries = entries.to_vec();
958        let cb: BuildRuleCallback = Box::new(move |inner, rb| {
959            let dst_vars = inner.convert_all(&entries);
960            rb.remove(table, &dst_vars).context("remove")
961        });
962        self.query.add_rule.push(cb);
963    }
964
965    /// Panic with a given message.
966    pub fn panic(&mut self, message: String) {
967        let panic = self.egraph.new_panic(message.clone());
968        let ret_ty = ColumnTy::Id;
969        let res = self.new_var(ret_ty);
970        self.query.add_rule.push(Box::new(move |inner, rb| {
971            let var = rb.call_external(panic, &[])?;
972            inner.mapping.insert(res.id, var.into());
973            Ok(())
974        }));
975    }
976}
977
978impl Query {
979    fn query_state<'a, 'outer>(
980        &self,
981        rsb: &'a mut RuleSetBuilder<'outer>,
982    ) -> (QueryBuilder<'outer, 'a>, Bindings) {
983        let mut qb = rsb.new_rule();
984        qb.set_plan_strategy(self.plan_strategy);
985        let mut inner = Bindings {
986            uf_table: self.uf_table,
987            next_ts: None,
988            lhs_reason: None,
989            mapping: Default::default(),
990            grounded: Default::default(),
991        };
992        for (var, info) in self.vars.iter() {
993            let new_var = match info.name.as_ref() {
994                Some(name) => qb.new_var_named(name),
995                None => qb.new_var(),
996            };
997            inner.mapping.insert(var, DstVar::Var(new_var));
998        }
999        (qb, inner)
1000    }
1001
1002    fn run_rules_and_build(
1003        &self,
1004        qb: QueryBuilder,
1005        mut inner: Bindings,
1006        desc: &str,
1007    ) -> Result<core_relations::RuleId> {
1008        let mut rb = qb.build();
1009        inner.next_ts = Some(rb.read_counter(self.ts_counter).into());
1010        // Set up proof state if it's configured.
1011        if let Some(build_reason) = &self.build_reason {
1012            build_reason(&mut inner, &mut rb)?;
1013        }
1014        self.add_rule
1015            .iter()
1016            .try_for_each(|f| f(&mut inner, &mut rb))?;
1017        Ok(rb.build_with_description(desc))
1018    }
1019
1020    pub(crate) fn build_cached_plan(
1021        &self,
1022        db: &mut core_relations::Database,
1023        desc: &str,
1024    ) -> Result<CachedPlanInfo> {
1025        let mut rsb = RuleSetBuilder::new(db);
1026        let (mut qb, mut inner) = self.query_state(&mut rsb);
1027        let mut atom_mapping = Vec::with_capacity(self.atoms.len());
1028        for (table, entries, _schema_info) in &self.atoms {
1029            atom_mapping.push(add_atom(&mut qb, *table, entries, &[], &mut inner)?);
1030        }
1031        let rule_id = self.run_rules_and_build(qb, inner, desc)?;
1032        let rs = rsb.build();
1033        let plan = Arc::new(rs.build_cached_plan(rule_id));
1034        Ok(CachedPlanInfo { plan, atom_mapping })
1035    }
1036
1037    /// Add rules to the [`RuleSetBuilder`] for the query specified by the [`CachedPlanInfo`].
1038    ///
1039    /// A [`CachedPlanInfo`] is a compiled RHS and partial LHS for an egglog rules. In order to
1040    /// implement seminaive evaluation, we run several variants of this cached plan with different
1041    /// constraints on the timestamps for different atoms. This rule handles building these
1042    /// variants of the base plan and adding them to `rsb`.
1043    pub(crate) fn add_rules_from_cached(
1044        &self,
1045        rsb: &mut RuleSetBuilder,
1046        mid_ts: Timestamp,
1047        cached_plan: &CachedPlanInfo,
1048    ) -> Result<()> {
1049        // For N atoms, we create N queries for seminaive evaluation. We can reuse the cached plan
1050        // directly.
1051        if !self.seminaive || (self.atoms.is_empty() && mid_ts == Timestamp::new(0)) {
1052            rsb.add_rule_from_cached_plan(&cached_plan.plan, &[]);
1053            return Ok(());
1054        }
1055        if let Some(focus_atom) = self.sole_focus {
1056            // There is a single "focus" atom that we will constrain to look at new values.
1057            let (_, _, schema_info) = &self.atoms[focus_atom];
1058            let ts_col = ColumnId::from_usize(schema_info.ts_col());
1059            rsb.add_rule_from_cached_plan(
1060                &cached_plan.plan,
1061                &[(
1062                    cached_plan.atom_mapping[focus_atom],
1063                    Constraint::GeConst {
1064                        col: ts_col,
1065                        val: mid_ts.to_value(),
1066                    },
1067                )],
1068            );
1069            return Ok(());
1070        }
1071        // Use the cached plan atoms.len() times with different constraints on each atom.
1072        let mut constraints: Vec<(core_relations::AtomId, Constraint)> =
1073            Vec::with_capacity(self.atoms.len());
1074        'outer: for focus_atom in 0..self.atoms.len() {
1075            for (i, (_, _, schema_info)) in self.atoms.iter().enumerate() {
1076                let ts_col = ColumnId::from_usize(schema_info.ts_col());
1077                match i.cmp(&focus_atom) {
1078                    Ordering::Less => {
1079                        if mid_ts == Timestamp::new(0) {
1080                            continue 'outer;
1081                        }
1082                        constraints.push((
1083                            cached_plan.atom_mapping[i],
1084                            Constraint::LtConst {
1085                                col: ts_col,
1086                                val: mid_ts.to_value(),
1087                            },
1088                        ));
1089                    }
1090                    Ordering::Equal => constraints.push((
1091                        cached_plan.atom_mapping[i],
1092                        Constraint::GeConst {
1093                            col: ts_col,
1094                            val: mid_ts.to_value(),
1095                        },
1096                    )),
1097                    Ordering::Greater => {}
1098                };
1099            }
1100            rsb.add_rule_from_cached_plan(&cached_plan.plan, &constraints);
1101            constraints.clear();
1102        }
1103        Ok(())
1104    }
1105}
1106
1107/// State that is used during query execution to translate variabes in egglog
1108/// rules into variables for core-relations rules.
1109pub(crate) struct Bindings {
1110    uf_table: TableId,
1111    next_ts: Option<DstVar>,
1112    /// If proofs are enabled, this variable contains the "reason id" for any union or insertion
1113    /// that happens on the RHS of a rule.
1114    pub(crate) lhs_reason: Option<DstVar>,
1115    pub(crate) mapping: DenseIdMap<VariableId, DstVar>,
1116    grounded: HashSet<VariableId>,
1117}
1118
1119impl Bindings {
1120    pub(crate) fn next_ts(&self) -> DstVar {
1121        self.next_ts
1122            .expect("ts_var should only be used in RHS of the rule")
1123    }
1124    pub(crate) fn convert(&self, entry: &QueryEntry) -> DstVar {
1125        match entry {
1126            QueryEntry::Var(Variable { id: v, .. }) => self.mapping[*v],
1127            QueryEntry::Const { val, .. } => DstVar::Const(*val),
1128        }
1129    }
1130    pub(crate) fn convert_all(&self, entries: &[QueryEntry]) -> SmallVec<[DstVar; 4]> {
1131        entries.iter().map(|e| self.convert(e)).collect()
1132    }
1133}
1134
1135fn add_atom(
1136    qb: &mut QueryBuilder,
1137    table: TableId,
1138    entries: &[QueryEntry],
1139    constraints: &[Constraint],
1140    inner: &mut Bindings,
1141) -> Result<core_relations::AtomId> {
1142    for entry in entries {
1143        if let QueryEntry::Var(Variable { id, .. }) = entry {
1144            inner.grounded.insert(*id);
1145        }
1146    }
1147    let vars = inner.convert_all(entries);
1148    Ok(qb.add_atom(table, &vars, constraints)?)
1149}