Skip to main content

egglog_bridge/
syntax.rs

1//! Egglog proofs reference the source syntax of the query, but the syntax in `egglog-bridge` is a
2//! lower-level, "desugared" representation of that syntax.
3//!
4//! This module defines the [`SourceSyntax`] and [`SourceExpr`] types, which allow callers to
5//! reflect the syntax of the original egglog query, along with how it maps to the desugared query
6//! language in this crate. The proofs machinery then reconstructs proofs according to this syntax.
7use std::{iter, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11    ColumnId, CounterId, ExecutionState, ExternalFunctionId, MergeVal, RuleBuilder, TableId, Value,
12    WriteVal, make_external_func,
13};
14use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
15use crate::{EGraph, NOT_SUBSUMED, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath};
16use smallvec::SmallVec;
17
18use crate::{
19    ColumnTy, FunctionId, RuleId,
20    proof_spec::ProofBuilder,
21    rule::{AtomId, Bindings, VariableId},
22};
23
24define_id!(pub SyntaxId, u32, "an offset into a Syntax DAG.");
25
26#[derive(Debug, Clone)]
27pub enum TopLevelLhsExpr {
28    /// Simply requires the presence of a term matching the given [`SourceExpr`].
29    Exists(SyntaxId),
30    /// Asserts the equality of two expressions matching the given [`SourceExpr`]s.
31    Eq(SyntaxId, SyntaxId),
32}
33
34/// Representative source syntax for _one line_ of an egglog query, namely, the left-hand-side of
35/// an egglog rule.
36#[derive(Debug, Clone)]
37pub enum SourceExpr {
38    /// A constant.
39    Const { ty: ColumnTy, val: Value },
40    /// A single variable.
41    Var {
42        id: VariableId,
43        ty: ColumnTy,
44        name: String,
45    },
46    /// A call to an external (aka primitive) function.
47    ExternalCall {
48        /// This external function call must be present in the destination query, and bound to this
49        /// variable
50        var: VariableId,
51        ty: ColumnTy,
52        func: ExternalFunctionId,
53        args: Vec<SyntaxId>,
54    },
55    /// A query of an egglog-level function (i.e. a table).
56    FunctionCall {
57        /// The egglog function being bound.
58        func: FunctionId,
59        /// The atom in the _destination_ query (i.e. at the egglog-bridge level) to which this
60        /// call corresponds.
61        atom: AtomId,
62        /// Arguments to the function.
63        args: Vec<SyntaxId>,
64    },
65}
66
67/// A data-structure representing an egglog query. Essentially, multiple [`SourceExpr`]s, one per
68/// line, along with a backing store accounting for subterms indexed by [`SyntaxId`].
69#[derive(Debug, Clone, Default)]
70pub struct SourceSyntax {
71    pub(crate) backing: IdVec<SyntaxId, SourceExpr>,
72    pub(crate) vars: Vec<(VariableId, ColumnTy)>,
73    pub(crate) roots: Vec<TopLevelLhsExpr>,
74}
75
76impl SourceSyntax {
77    /// Add `expr` to the known syntax of the [`SourceSyntax`].
78    ///
79    /// The returned [`SyntaxId`] can be used to construct another [`SourceExpr`] or a
80    /// [`TopLevelLhsExpr`].
81    pub fn add_expr(&mut self, expr: SourceExpr) -> SyntaxId {
82        match &expr {
83            SourceExpr::Const { .. } | SourceExpr::FunctionCall { .. } => {}
84            SourceExpr::Var { id, ty, .. } => self.vars.push((*id, *ty)),
85            SourceExpr::ExternalCall { var, ty, .. } => self.vars.push((*var, *ty)),
86        };
87        self.backing.push(expr)
88    }
89
90    /// Add `expr` to the toplevel representation of the syntax.
91    pub fn add_toplevel_expr(&mut self, expr: TopLevelLhsExpr) {
92        self.roots.push(expr);
93    }
94
95    fn funcs(&self) -> impl Iterator<Item = FunctionId> + '_ {
96        self.backing.iter().filter_map(|(_, v)| {
97            if let SourceExpr::FunctionCall { func, .. } = v {
98                Some(*func)
99            } else {
100                None
101            }
102        })
103    }
104}
105
106/// The data associated with a proof of a given term whose premises are given by a
107/// [`SourceSyntax`].
108#[derive(Debug)]
109pub(crate) struct RuleData {
110    pub(crate) rule_id: RuleId,
111    pub(crate) syntax: SourceSyntax,
112}
113
114impl RuleData {
115    pub(crate) fn n_vars(&self) -> usize {
116        self.syntax.vars.len()
117    }
118}
119
120impl ProofBuilder {
121    /// Given a [`SourceSyntax`] build a callback that returns a variable corresponding to the id
122    /// of the "reason" for a given rule. This callback does two things, both based on the context
123    /// of the syntax being passed in:
124    ///
125    /// 1. It reconstructs any terms specified by the syntax. This is done by applying congruence
126    ///    rules to the `AtomId`s mapped in the syntax.
127    ///
128    /// 2. It writes a reason holding the concrete substitution corersponding to the current match
129    ///    for this syntax.
130    ///
131    /// Like most of the rest of this crate, the return value is a callback that consumes state
132    /// associated with instantiating a rule in the `core-relations` sense.
133    pub(crate) fn create_reason(
134        &mut self,
135        syntax: SourceSyntax,
136        egraph: &mut EGraph,
137    ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result<core_relations::Variable> + Clone + use<>
138    {
139        // first, create all the relevant cong metadata
140        let mut metadata = DenseIdMap::default();
141        for func in syntax.funcs() {
142            metadata.insert(func, self.build_cong_metadata(func, egraph));
143        }
144
145        let reason_spec = Arc::new(ProofReason::Rule(RuleData {
146            rule_id: self.rule_id,
147            syntax: syntax.clone(),
148        }));
149        let reason_table = egraph.reason_table(&reason_spec);
150        let reason_spec_id = egraph.proof_specs.push(reason_spec);
151        let reason_counter = egraph.reason_counter;
152        let atom_mapping = self.term_vars.clone();
153        move |bndgs, rb| {
154            // Now, insert all needed reconstructed terms.
155            let mut state = TermReconstructionState {
156                syntax: &syntax,
157                syntax_mapping: Default::default(),
158                metadata: metadata.clone(),
159                atom_mapping: atom_mapping.clone(),
160            };
161            for toplevel_expr in &syntax.roots {
162                match toplevel_expr {
163                    TopLevelLhsExpr::Exists(id) => {
164                        state.justify_query(*id, bndgs, rb)?;
165                    }
166                    TopLevelLhsExpr::Eq(id1, id2) => {
167                        state.justify_query(*id1, bndgs, rb)?;
168                        state.justify_query(*id2, bndgs, rb)?;
169                    }
170                }
171            }
172            // Once those terms are all guaranteed to be in the e-graph, we only need to write down
173            // the base substitution of variables into a reason table.
174            let mut row = SmallVec::<[core_relations::QueryEntry; 8]>::new();
175            row.push(Value::new(reason_spec_id.rep()).into());
176            for (var, _) in &syntax.vars {
177                row.push(bndgs.mapping[*var]);
178            }
179            Ok(rb.lookup_or_insert(
180                reason_table,
181                &row,
182                &[WriteVal::IncCounter(reason_counter)],
183                ColumnId::from_usize(row.len()),
184            )?)
185        }
186    }
187
188    fn build_cong_metadata(&self, func: FunctionId, egraph: &mut EGraph) -> FunctionCongMetadata {
189        let func_info = &egraph.funcs[func];
190        let func_underlying = func_info.table;
191        let schema_math = SchemaMath {
192            subsume: func_info.can_subsume,
193            tracing: true,
194            func_cols: func_info.schema.len(),
195        };
196        let cong_args = CongArgs {
197            func_table: func,
198            func_underlying,
199            schema_math,
200            reason_table: egraph.reason_table(&ProofReason::CongRow),
201            term_table: egraph.term_table(func_underlying),
202            reason_counter: egraph.reason_counter,
203            term_counter: egraph.id_counter,
204            ts_counter: egraph.timestamp_counter,
205            reason_spec_id: egraph.cong_spec,
206        };
207        let build_term =
208            egraph.register_external_func(Box::new(make_external_func(move |es, vals| {
209                cong_term(&cong_args, es, vals)
210            })));
211        FunctionCongMetadata {
212            table: func_underlying,
213            build_term,
214            schema_math,
215        }
216    }
217}
218
219/// Metadata needed to reconstruct a term whose head corresponds to a particular function.
220#[derive(Copy, Clone)]
221struct FunctionCongMetadata {
222    table: TableId,
223    build_term: ExternalFunctionId,
224    schema_math: SchemaMath,
225}
226
227struct TermReconstructionState<'a> {
228    /// The syntax of the LHS of a rule that we are reconstructing.
229    ///
230    /// This is an immutable reference to make it easy to borrow across recursive calls.
231    syntax: &'a SourceSyntax,
232    /// A memo cache from syntax node to the [`core_relations::QueryEntry`] that it corresponds to
233    /// in the reconstructed term.
234    syntax_mapping: DenseIdMap<SyntaxId, core_relations::QueryEntry>,
235    /// The [`QueryEntry`] (in `egglog-bridge`, not `core-relations`) to which the given atom
236    /// corresponds.
237    atom_mapping: DenseIdMap<AtomId, QueryEntry>,
238    metadata: DenseIdMap<FunctionId, FunctionCongMetadata>,
239}
240
241impl TermReconstructionState<'_> {
242    fn justify_query(
243        &mut self,
244        node: SyntaxId,
245        bndgs: &mut Bindings,
246        rb: &mut RuleBuilder,
247    ) -> Result<core_relations::QueryEntry> {
248        if let Some(entry) = self.syntax_mapping.get(node) {
249            return Ok(*entry);
250        }
251        let syntax = self.syntax;
252        let res = match &syntax.backing[node] {
253            SourceExpr::Const { val, .. } => return Ok(core_relations::QueryEntry::Const(*val)),
254            SourceExpr::Var { id, .. } => bndgs.mapping[*id],
255            SourceExpr::ExternalCall { var, args, .. } => {
256                for arg in args {
257                    self.justify_query(*arg, bndgs, rb)?;
258                }
259                bndgs.mapping[*var]
260            }
261            SourceExpr::FunctionCall { func, atom, args } => {
262                let old_term = bndgs.convert(&self.atom_mapping[*atom]);
263                let mut buf: Vec<core_relations::QueryEntry> = vec![old_term];
264
265                for arg in args.iter().map(|s| self.justify_query(*s, bndgs, rb)) {
266                    buf.push(arg?);
267                }
268                let FunctionCongMetadata {
269                    table,
270                    build_term,
271                    schema_math,
272                } = &self.metadata[*func];
273                let term_col = ColumnId::from_usize(schema_math.proof_id_col());
274                rb.lookup_with_fallback(*table, &buf[1..], term_col, *build_term, &buf)?
275                    .into()
276            }
277        };
278        self.syntax_mapping.insert(node, res);
279        Ok(res)
280    }
281}
282
283/// Metadata from the EGraph that we copy into an [`core_relations::ExternalFunction`] closure that
284/// recreates terms justified by congruence.
285#[derive(Clone)]
286struct CongArgs {
287    /// The function that we are applying congruence to.
288    func_table: FunctionId,
289    /// The undcerlying `core_relations` table that this function corresponds to.
290    func_underlying: TableId,
291    /// Schema-related offset information needed for writing to the table.
292    schema_math: SchemaMath,
293    /// The table that will hold the reason justifying the new term, if we need to insert one.
294    reason_table: TableId,
295    /// The table that will hold the new term, if we need to insert one.
296    term_table: TableId,
297    /// The counter that will be incremented when we insert a new reason.
298    reason_counter: CounterId,
299    /// The counter that will be incremented when we insert a new term.
300    term_counter: CounterId,
301    /// The counter that will be used to read the current timestamp for the new row.
302    ts_counter: CounterId,
303    /// The specification (or schema) for the reason we are writing (congruence, in this case).
304    reason_spec_id: ReasonSpecId,
305}
306
307fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option<Value> {
308    let old_term = vals[0];
309    let new_term = &vals[1..];
310    let reason = es.predict_col(
311        args.reason_table,
312        &[Value::new(args.reason_spec_id.rep()), old_term],
313        iter::once(MergeVal::Counter(args.reason_counter)),
314        ColumnId::new(2),
315    );
316    let mut term_row = SmallVec::<[Value; 8]>::default();
317    term_row.push(Value::new(args.func_table.rep()));
318    term_row.extend_from_slice(new_term);
319    let term_val = es.predict_col(
320        args.term_table,
321        &term_row,
322        [
323            MergeVal::Counter(args.term_counter),
324            MergeVal::Constant(reason),
325        ]
326        .into_iter(),
327        ColumnId::from_usize(term_row.len()),
328    );
329
330    // We should be able to do a raw insert at this point. All conflicting inserts will have the
331    // same term value, and this function only gets called when a lookup fails.
332
333    let ts = Value::from_usize(es.read_counter(args.ts_counter));
334    term_row.resize(args.schema_math.table_columns(), NOT_SUBSUMED);
335    term_row[args.schema_math.ret_val_col()] = term_val;
336    term_row[args.schema_math.proof_id_col()] = term_val;
337    term_row[args.schema_math.ts_col()] = ts;
338    es.stage_insert(args.func_underlying, &term_row);
339    Some(term_val)
340}