egglog_bridge/
syntax.rs

1//! Egglog proofs reference the source syntax of the query, but the syntax in `egglog-bridge` is a
2//! lower-level, "desugared" representation of that syntax.
3//!
4//! This module defines the [`SourceSyntax`] and [`SourceExpr`] types, which allow callers to
5//! reflect the syntax of the original egglog query, along with how it maps to the desugared query
6//! language in this crate. The proofs machinery then reconstructs proofs according to this syntax.
7use std::{iter, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11    ColumnId, CounterId, ExecutionState, ExternalFunctionId, MergeVal, RuleBuilder, TableId, Value,
12    WriteVal, make_external_func,
13};
14use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
15use crate::{EGraph, NOT_SUBSUMED, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath};
16use smallvec::SmallVec;
17
18use crate::{
19    ColumnTy, FunctionId, RuleId,
20    proof_spec::ProofBuilder,
21    rule::{AtomId, Bindings, Variable},
22};
23
24define_id!(pub SyntaxId, u32, "an offset into a Syntax DAG.");
25
26#[derive(Debug, Clone)]
27pub enum TopLevelLhsExpr {
28    /// Simply requires the presence of a term matching the given [`SourceExpr`].
29    Exists(SyntaxId),
30    /// Asserts the equality of two expressions matching the given [`SourceExpr`]s.
31    Eq(SyntaxId, SyntaxId),
32}
33
34/// Representative source syntax for _one line_ of an egglog query, namely, the left-hand-side of
35/// an egglog rule.
36#[derive(Debug, Clone)]
37pub enum SourceExpr {
38    /// A constant.
39    Const { ty: ColumnTy, val: Value },
40    /// A single variable.
41    Var {
42        id: Variable,
43        ty: ColumnTy,
44        name: String,
45    },
46    /// A call to an external (aka primitive) function.
47    ExternalCall {
48        /// This external function call must be present in the destination query, and bound to this
49        /// variable
50        var: Variable,
51        ty: ColumnTy,
52        func: ExternalFunctionId,
53        args: Vec<SyntaxId>,
54    },
55    /// A query of an egglog-level function (i.e. a table).
56    FunctionCall {
57        /// The egglog function being bound.
58        func: FunctionId,
59        /// The atom in the _destination_ query (i.e. at the egglog-bridge level) to which this
60        /// call corresponds.
61        atom: AtomId,
62        /// Arguments to the function.
63        args: Vec<SyntaxId>,
64    },
65}
66
67/// A data-structure representing an egglog query. Essentially, multiple [`SourceExpr`]s, one per
68/// line, along with a backing store accounting for subterms indexed by [`SyntaxId].
69#[derive(Debug, Clone, Default)]
70pub struct SourceSyntax {
71    pub(crate) backing: IdVec<SyntaxId, SourceExpr>,
72    pub(crate) vars: Vec<(Variable, ColumnTy)>,
73    pub(crate) roots: Vec<TopLevelLhsExpr>,
74}
75
76impl SourceSyntax {
77    /// Add `expr` to the known syntax of the [`SourceSyntax`].
78    ///
79    /// The returned [`SyntaxId`] can be used to construct another [`SourceExpr`] or a
80    /// [`TopLevelLhsExpr`].
81    pub fn add_expr(&mut self, expr: SourceExpr) -> SyntaxId {
82        match &expr {
83            SourceExpr::Const { .. } | SourceExpr::FunctionCall { .. } => {}
84            SourceExpr::Var { id, ty, .. } => self.vars.push((*id, *ty)),
85            SourceExpr::ExternalCall { var, ty, .. } => self.vars.push((*var, *ty)),
86        };
87        self.backing.push(expr)
88    }
89
90    /// Add `expr` to the toplevel representation of the syntax.
91    pub fn add_toplevel_expr(&mut self, expr: TopLevelLhsExpr) {
92        self.roots.push(expr);
93    }
94
95    fn funcs(&self) -> impl Iterator<Item = FunctionId> + '_ {
96        self.backing.iter().filter_map(|(_, v)| {
97            if let SourceExpr::FunctionCall { func, .. } = v {
98                Some(*func)
99            } else {
100                None
101            }
102        })
103    }
104}
105
106/// The data associated with a proof of a given term whose premises are given by a
107/// [`SourceSyntax`].
108#[derive(Debug)]
109pub(crate) struct RuleData {
110    pub(crate) rule_id: RuleId,
111    pub(crate) syntax: SourceSyntax,
112}
113
114impl RuleData {
115    pub(crate) fn n_vars(&self) -> usize {
116        self.syntax.vars.len()
117    }
118}
119
120impl ProofBuilder {
121    /// Given a [`SourceSyntax`] build a callback that returns a variable corresponding to the id
122    /// of the "reason" for a given rule. This callback does two things, both based on the context
123    /// of the syntax being passed in:
124    ///
125    /// 1. It reconstructs any terms specified by the syntax. This is done by applying congruence
126    ///    rules to the `AtomId`s mapped in the syntax.
127    ///
128    /// 2. It writes a reason holding the concrete substitution corersponding to the current match
129    ///    for this syntax.
130    ///
131    /// Like most of the rest of this crate, the return value is a callback that consumes state
132    /// associated with instantiating a rule in the `core-relations` sense.
133    pub(crate) fn create_reason(
134        &mut self,
135        syntax: SourceSyntax,
136        egraph: &mut EGraph,
137    ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result<core_relations::Variable> + Clone + use<>
138    {
139        // first, create all the relevant cong metadata
140        let mut metadata = DenseIdMap::default();
141        for func in syntax.funcs() {
142            metadata.insert(func, self.build_cong_metadata(func, egraph));
143        }
144
145        let reason_spec = Arc::new(ProofReason::Rule(RuleData {
146            rule_id: self.rule_id,
147            syntax: syntax.clone(),
148        }));
149        let reason_table = egraph.reason_table(&reason_spec);
150        let reason_spec_id = egraph.proof_specs.push(reason_spec);
151        let reason_counter = egraph.reason_counter;
152        let atom_mapping = self.term_vars.clone();
153        move |bndgs, rb| {
154            // Now, insert all needed reconstructed terms.
155            let mut state = TermReconstructionState {
156                syntax: &syntax,
157                syntax_mapping: Default::default(),
158                metadata: metadata.clone(),
159                atom_mapping: atom_mapping.clone(),
160            };
161            for toplevel_expr in &syntax.roots {
162                match toplevel_expr {
163                    TopLevelLhsExpr::Exists(id) => {
164                        state.justify_query(*id, bndgs, rb)?;
165                    }
166                    TopLevelLhsExpr::Eq(id1, id2) => {
167                        state.justify_query(*id1, bndgs, rb)?;
168                        state.justify_query(*id2, bndgs, rb)?;
169                    }
170                }
171            }
172            // Once those terms are all guaranteed to be in the e-graph, we only need to write down
173            // the base substitution of variables into a reason table.
174            let mut row = SmallVec::<[core_relations::QueryEntry; 8]>::new();
175            row.push(Value::new(reason_spec_id.rep()).into());
176            for (var, _) in &syntax.vars {
177                row.push(bndgs.mapping[*var]);
178            }
179            Ok(rb.lookup_or_insert(
180                reason_table,
181                &row,
182                &[WriteVal::IncCounter(reason_counter)],
183                ColumnId::from_usize(row.len()),
184            )?)
185        }
186    }
187
188    fn build_cong_metadata(&self, func: FunctionId, egraph: &mut EGraph) -> FunctionCongMetadata {
189        let func_info = &egraph.funcs[func];
190        let func_underlying = func_info.table;
191        let schema_math = SchemaMath {
192            subsume: func_info.can_subsume,
193            tracing: true,
194            func_cols: func_info.schema.len(),
195        };
196        let cong_args = CongArgs {
197            func_table: func,
198            func_underlying,
199            schema_math,
200            reason_table: egraph.reason_table(&ProofReason::CongRow),
201            term_table: egraph.term_table(func_underlying),
202            reason_counter: egraph.reason_counter,
203            term_counter: egraph.id_counter,
204            ts_counter: egraph.timestamp_counter,
205            reason_spec_id: egraph.cong_spec,
206        };
207        let build_term = egraph.register_external_func(make_external_func(move |es, vals| {
208            cong_term(&cong_args, es, vals)
209        }));
210        FunctionCongMetadata {
211            table: func_underlying,
212            build_term,
213            schema_math,
214        }
215    }
216}
217
218/// Metadata needed to reconstruct a term whose head corresponds to a particular function.
219#[derive(Copy, Clone)]
220struct FunctionCongMetadata {
221    table: TableId,
222    build_term: ExternalFunctionId,
223    schema_math: SchemaMath,
224}
225
226struct TermReconstructionState<'a> {
227    /// The syntax of the LHS of a rule that we are reconstructing.
228    ///
229    /// This is an immutable reference to make it easy to borrow across recursive calls.
230    syntax: &'a SourceSyntax,
231    /// A memo cache from syntax node to the [`core_relations::QueryEntry`] that it corresponds to
232    /// in the reconstructed term.
233    syntax_mapping: DenseIdMap<SyntaxId, core_relations::QueryEntry>,
234    /// The [`QueryEntry`] (in `egglog-bridge`, not `core-relations`) to which the given atom
235    /// corresponds.
236    atom_mapping: DenseIdMap<AtomId, QueryEntry>,
237    metadata: DenseIdMap<FunctionId, FunctionCongMetadata>,
238}
239
240impl TermReconstructionState<'_> {
241    fn justify_query(
242        &mut self,
243        node: SyntaxId,
244        bndgs: &mut Bindings,
245        rb: &mut RuleBuilder,
246    ) -> Result<core_relations::QueryEntry> {
247        if let Some(entry) = self.syntax_mapping.get(node) {
248            return Ok(*entry);
249        }
250        let syntax = self.syntax;
251        let res = match &syntax.backing[node] {
252            SourceExpr::Const { val, .. } => return Ok(core_relations::QueryEntry::Const(*val)),
253            SourceExpr::Var { id, .. } => bndgs.mapping[*id],
254            SourceExpr::ExternalCall { var, args, .. } => {
255                for arg in args {
256                    self.justify_query(*arg, bndgs, rb)?;
257                }
258                bndgs.mapping[*var]
259            }
260            SourceExpr::FunctionCall { func, atom, args } => {
261                let old_term = bndgs.convert(&self.atom_mapping[*atom]);
262                let mut buf: Vec<core_relations::QueryEntry> = vec![old_term];
263
264                for arg in args.iter().map(|s| self.justify_query(*s, bndgs, rb)) {
265                    buf.push(arg?);
266                }
267                let FunctionCongMetadata {
268                    table,
269                    build_term,
270                    schema_math,
271                } = &self.metadata[*func];
272                let term_col = ColumnId::from_usize(schema_math.proof_id_col());
273                rb.lookup_with_fallback(*table, &buf[1..], term_col, *build_term, &buf)?
274                    .into()
275            }
276        };
277        self.syntax_mapping.insert(node, res);
278        Ok(res)
279    }
280}
281
282/// Metadata from the EGraph that we copy into an [`core_relations::ExternalFunction`] closure that
283/// recreates terms justified by congruence.
284#[derive(Clone)]
285struct CongArgs {
286    /// The function that we are applying congruence to.
287    func_table: FunctionId,
288    /// The undcerlying `core_relations` table that this function corresponds to.
289    func_underlying: TableId,
290    /// Schema-related offset information needed for writing to the table.
291    schema_math: SchemaMath,
292    /// The table that will hold the reason justifying the new term, if we need to insert one.
293    reason_table: TableId,
294    /// The table that will hold the new term, if we need to insert one.
295    term_table: TableId,
296    /// The counter that will be incremented when we insert a new reason.
297    reason_counter: CounterId,
298    /// The counter that will be incremented when we insert a new term.
299    term_counter: CounterId,
300    /// The counter that will be used to read the current timestamp for the new row.
301    ts_counter: CounterId,
302    /// The specification (or schema) for the reason we are writing (congruence, in this case).
303    reason_spec_id: ReasonSpecId,
304}
305
306fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option<Value> {
307    let old_term = vals[0];
308    let new_term = &vals[1..];
309    let reason = es.predict_col(
310        args.reason_table,
311        &[Value::new(args.reason_spec_id.rep()), old_term],
312        iter::once(MergeVal::Counter(args.reason_counter)),
313        ColumnId::new(2),
314    );
315    let mut term_row = SmallVec::<[Value; 8]>::default();
316    term_row.push(Value::new(args.func_table.rep()));
317    term_row.extend_from_slice(new_term);
318    let term_val = es.predict_col(
319        args.term_table,
320        &term_row,
321        [
322            MergeVal::Counter(args.term_counter),
323            MergeVal::Constant(reason),
324        ]
325        .into_iter(),
326        ColumnId::from_usize(term_row.len()),
327    );
328
329    // We should be able to do a raw insert at this point. All conflicting inserts will have the
330    // same term value, and this function only gets called when a lookup fails.
331
332    let ts = Value::from_usize(es.read_counter(args.ts_counter));
333    term_row.resize(args.schema_math.table_columns(), NOT_SUBSUMED);
334    term_row[args.schema_math.ret_val_col()] = term_val;
335    term_row[args.schema_math.proof_id_col()] = term_val;
336    term_row[args.schema_math.ts_col()] = ts;
337    es.stage_insert(args.func_underlying, &term_row);
338    Some(term_val)
339}