1use std::{iter, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11    ColumnId, CounterId, ExecutionState, ExternalFunctionId, MergeVal, RuleBuilder, TableId, Value,
12    WriteVal, make_external_func,
13};
14use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
15use crate::{EGraph, NOT_SUBSUMED, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath};
16use smallvec::SmallVec;
17
18use crate::{
19    ColumnTy, FunctionId, RuleId,
20    proof_spec::ProofBuilder,
21    rule::{AtomId, Bindings, Variable},
22};
23
24define_id!(pub SyntaxId, u32, "an offset into a Syntax DAG.");
25
26#[derive(Debug, Clone)]
27pub enum TopLevelLhsExpr {
28    Exists(SyntaxId),
30    Eq(SyntaxId, SyntaxId),
32}
33
34#[derive(Debug, Clone)]
37pub enum SourceExpr {
38    Const { ty: ColumnTy, val: Value },
40    Var {
42        id: Variable,
43        ty: ColumnTy,
44        name: String,
45    },
46    ExternalCall {
48        var: Variable,
51        ty: ColumnTy,
52        func: ExternalFunctionId,
53        args: Vec<SyntaxId>,
54    },
55    FunctionCall {
57        func: FunctionId,
59        atom: AtomId,
62        args: Vec<SyntaxId>,
64    },
65}
66
67#[derive(Debug, Clone, Default)]
70pub struct SourceSyntax {
71    pub(crate) backing: IdVec<SyntaxId, SourceExpr>,
72    pub(crate) vars: Vec<(Variable, ColumnTy)>,
73    pub(crate) roots: Vec<TopLevelLhsExpr>,
74}
75
76impl SourceSyntax {
77    pub fn add_expr(&mut self, expr: SourceExpr) -> SyntaxId {
82        match &expr {
83            SourceExpr::Const { .. } | SourceExpr::FunctionCall { .. } => {}
84            SourceExpr::Var { id, ty, .. } => self.vars.push((*id, *ty)),
85            SourceExpr::ExternalCall { var, ty, .. } => self.vars.push((*var, *ty)),
86        };
87        self.backing.push(expr)
88    }
89
90    pub fn add_toplevel_expr(&mut self, expr: TopLevelLhsExpr) {
92        self.roots.push(expr);
93    }
94
95    fn funcs(&self) -> impl Iterator<Item = FunctionId> + '_ {
96        self.backing.iter().filter_map(|(_, v)| {
97            if let SourceExpr::FunctionCall { func, .. } = v {
98                Some(*func)
99            } else {
100                None
101            }
102        })
103    }
104}
105
106#[derive(Debug)]
109pub(crate) struct RuleData {
110    pub(crate) rule_id: RuleId,
111    pub(crate) syntax: SourceSyntax,
112}
113
114impl RuleData {
115    pub(crate) fn n_vars(&self) -> usize {
116        self.syntax.vars.len()
117    }
118}
119
120impl ProofBuilder {
121    pub(crate) fn create_reason(
134        &mut self,
135        syntax: SourceSyntax,
136        egraph: &mut EGraph,
137    ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result<core_relations::Variable> + Clone + use<>
138    {
139        let mut metadata = DenseIdMap::default();
141        for func in syntax.funcs() {
142            metadata.insert(func, self.build_cong_metadata(func, egraph));
143        }
144
145        let reason_spec = Arc::new(ProofReason::Rule(RuleData {
146            rule_id: self.rule_id,
147            syntax: syntax.clone(),
148        }));
149        let reason_table = egraph.reason_table(&reason_spec);
150        let reason_spec_id = egraph.proof_specs.push(reason_spec);
151        let reason_counter = egraph.reason_counter;
152        let atom_mapping = self.term_vars.clone();
153        move |bndgs, rb| {
154            let mut state = TermReconstructionState {
156                syntax: &syntax,
157                syntax_mapping: Default::default(),
158                metadata: metadata.clone(),
159                atom_mapping: atom_mapping.clone(),
160            };
161            for toplevel_expr in &syntax.roots {
162                match toplevel_expr {
163                    TopLevelLhsExpr::Exists(id) => {
164                        state.justify_query(*id, bndgs, rb)?;
165                    }
166                    TopLevelLhsExpr::Eq(id1, id2) => {
167                        state.justify_query(*id1, bndgs, rb)?;
168                        state.justify_query(*id2, bndgs, rb)?;
169                    }
170                }
171            }
172            let mut row = SmallVec::<[core_relations::QueryEntry; 8]>::new();
175            row.push(Value::new(reason_spec_id.rep()).into());
176            for (var, _) in &syntax.vars {
177                row.push(bndgs.mapping[*var]);
178            }
179            Ok(rb.lookup_or_insert(
180                reason_table,
181                &row,
182                &[WriteVal::IncCounter(reason_counter)],
183                ColumnId::from_usize(row.len()),
184            )?)
185        }
186    }
187
188    fn build_cong_metadata(&self, func: FunctionId, egraph: &mut EGraph) -> FunctionCongMetadata {
189        let func_info = &egraph.funcs[func];
190        let func_underlying = func_info.table;
191        let schema_math = SchemaMath {
192            subsume: func_info.can_subsume,
193            tracing: true,
194            func_cols: func_info.schema.len(),
195        };
196        let cong_args = CongArgs {
197            func_table: func,
198            func_underlying,
199            schema_math,
200            reason_table: egraph.reason_table(&ProofReason::CongRow),
201            term_table: egraph.term_table(func_underlying),
202            reason_counter: egraph.reason_counter,
203            term_counter: egraph.id_counter,
204            ts_counter: egraph.timestamp_counter,
205            reason_spec_id: egraph.cong_spec,
206        };
207        let build_term = egraph.register_external_func(make_external_func(move |es, vals| {
208            cong_term(&cong_args, es, vals)
209        }));
210        FunctionCongMetadata {
211            table: func_underlying,
212            build_term,
213            schema_math,
214        }
215    }
216}
217
218#[derive(Copy, Clone)]
220struct FunctionCongMetadata {
221    table: TableId,
222    build_term: ExternalFunctionId,
223    schema_math: SchemaMath,
224}
225
226struct TermReconstructionState<'a> {
227    syntax: &'a SourceSyntax,
231    syntax_mapping: DenseIdMap<SyntaxId, core_relations::QueryEntry>,
234    atom_mapping: DenseIdMap<AtomId, QueryEntry>,
237    metadata: DenseIdMap<FunctionId, FunctionCongMetadata>,
238}
239
240impl TermReconstructionState<'_> {
241    fn justify_query(
242        &mut self,
243        node: SyntaxId,
244        bndgs: &mut Bindings,
245        rb: &mut RuleBuilder,
246    ) -> Result<core_relations::QueryEntry> {
247        if let Some(entry) = self.syntax_mapping.get(node) {
248            return Ok(*entry);
249        }
250        let syntax = self.syntax;
251        let res = match &syntax.backing[node] {
252            SourceExpr::Const { val, .. } => return Ok(core_relations::QueryEntry::Const(*val)),
253            SourceExpr::Var { id, .. } => bndgs.mapping[*id],
254            SourceExpr::ExternalCall { var, args, .. } => {
255                for arg in args {
256                    self.justify_query(*arg, bndgs, rb)?;
257                }
258                bndgs.mapping[*var]
259            }
260            SourceExpr::FunctionCall { func, atom, args } => {
261                let old_term = bndgs.convert(&self.atom_mapping[*atom]);
262                let mut buf: Vec<core_relations::QueryEntry> = vec![old_term];
263
264                for arg in args.iter().map(|s| self.justify_query(*s, bndgs, rb)) {
265                    buf.push(arg?);
266                }
267                let FunctionCongMetadata {
268                    table,
269                    build_term,
270                    schema_math,
271                } = &self.metadata[*func];
272                let term_col = ColumnId::from_usize(schema_math.proof_id_col());
273                rb.lookup_with_fallback(*table, &buf[1..], term_col, *build_term, &buf)?
274                    .into()
275            }
276        };
277        self.syntax_mapping.insert(node, res);
278        Ok(res)
279    }
280}
281
282#[derive(Clone)]
285struct CongArgs {
286    func_table: FunctionId,
288    func_underlying: TableId,
290    schema_math: SchemaMath,
292    reason_table: TableId,
294    term_table: TableId,
296    reason_counter: CounterId,
298    term_counter: CounterId,
300    ts_counter: CounterId,
302    reason_spec_id: ReasonSpecId,
304}
305
306fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option<Value> {
307    let old_term = vals[0];
308    let new_term = &vals[1..];
309    let reason = es.predict_col(
310        args.reason_table,
311        &[Value::new(args.reason_spec_id.rep()), old_term],
312        iter::once(MergeVal::Counter(args.reason_counter)),
313        ColumnId::new(2),
314    );
315    let mut term_row = SmallVec::<[Value; 8]>::default();
316    term_row.push(Value::new(args.func_table.rep()));
317    term_row.extend_from_slice(new_term);
318    let term_val = es.predict_col(
319        args.term_table,
320        &term_row,
321        [
322            MergeVal::Counter(args.term_counter),
323            MergeVal::Constant(reason),
324        ]
325        .into_iter(),
326        ColumnId::from_usize(term_row.len()),
327    );
328
329    let ts = Value::from_usize(es.read_counter(args.ts_counter));
333    term_row.resize(args.schema_math.table_columns(), NOT_SUBSUMED);
334    term_row[args.schema_math.ret_val_col()] = term_val;
335    term_row[args.schema_math.proof_id_col()] = term_val;
336    term_row[args.schema_math.ts_col()] = ts;
337    es.stage_insert(args.func_underlying, &term_row);
338    Some(term_val)
339}