1use std::{iter, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11 ColumnId, CounterId, ExecutionState, ExternalFunctionId, MergeVal, RuleBuilder, TableId, Value,
12 WriteVal, make_external_func,
13};
14use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
15use crate::{EGraph, NOT_SUBSUMED, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath};
16use smallvec::SmallVec;
17
18use crate::{
19 ColumnTy, FunctionId, RuleId,
20 proof_spec::ProofBuilder,
21 rule::{AtomId, Bindings, Variable},
22};
23
24define_id!(pub SyntaxId, u32, "an offset into a Syntax DAG.");
25
26#[derive(Debug, Clone)]
27pub enum TopLevelLhsExpr {
28 Exists(SyntaxId),
30 Eq(SyntaxId, SyntaxId),
32}
33
34#[derive(Debug, Clone)]
37pub enum SourceExpr {
38 Const { ty: ColumnTy, val: Value },
40 Var {
42 id: Variable,
43 ty: ColumnTy,
44 name: String,
45 },
46 ExternalCall {
48 var: Variable,
51 ty: ColumnTy,
52 func: ExternalFunctionId,
53 args: Vec<SyntaxId>,
54 },
55 FunctionCall {
57 func: FunctionId,
59 atom: AtomId,
62 args: Vec<SyntaxId>,
64 },
65}
66
67#[derive(Debug, Clone, Default)]
70pub struct SourceSyntax {
71 pub(crate) backing: IdVec<SyntaxId, SourceExpr>,
72 pub(crate) vars: Vec<(Variable, ColumnTy)>,
73 pub(crate) roots: Vec<TopLevelLhsExpr>,
74}
75
76impl SourceSyntax {
77 pub fn add_expr(&mut self, expr: SourceExpr) -> SyntaxId {
82 match &expr {
83 SourceExpr::Const { .. } | SourceExpr::FunctionCall { .. } => {}
84 SourceExpr::Var { id, ty, .. } => self.vars.push((*id, *ty)),
85 SourceExpr::ExternalCall { var, ty, .. } => self.vars.push((*var, *ty)),
86 };
87 self.backing.push(expr)
88 }
89
90 pub fn add_toplevel_expr(&mut self, expr: TopLevelLhsExpr) {
92 self.roots.push(expr);
93 }
94
95 fn funcs(&self) -> impl Iterator<Item = FunctionId> + '_ {
96 self.backing.iter().filter_map(|(_, v)| {
97 if let SourceExpr::FunctionCall { func, .. } = v {
98 Some(*func)
99 } else {
100 None
101 }
102 })
103 }
104}
105
106#[derive(Debug)]
109pub(crate) struct RuleData {
110 pub(crate) rule_id: RuleId,
111 pub(crate) syntax: SourceSyntax,
112}
113
114impl RuleData {
115 pub(crate) fn n_vars(&self) -> usize {
116 self.syntax.vars.len()
117 }
118}
119
120impl ProofBuilder {
121 pub(crate) fn create_reason(
134 &mut self,
135 syntax: SourceSyntax,
136 egraph: &mut EGraph,
137 ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result<core_relations::Variable> + Clone + use<>
138 {
139 let mut metadata = DenseIdMap::default();
141 for func in syntax.funcs() {
142 metadata.insert(func, self.build_cong_metadata(func, egraph));
143 }
144
145 let reason_spec = Arc::new(ProofReason::Rule(RuleData {
146 rule_id: self.rule_id,
147 syntax: syntax.clone(),
148 }));
149 let reason_table = egraph.reason_table(&reason_spec);
150 let reason_spec_id = egraph.proof_specs.push(reason_spec);
151 let reason_counter = egraph.reason_counter;
152 let atom_mapping = self.term_vars.clone();
153 move |bndgs, rb| {
154 let mut state = TermReconstructionState {
156 syntax: &syntax,
157 syntax_mapping: Default::default(),
158 metadata: metadata.clone(),
159 atom_mapping: atom_mapping.clone(),
160 };
161 for toplevel_expr in &syntax.roots {
162 match toplevel_expr {
163 TopLevelLhsExpr::Exists(id) => {
164 state.justify_query(*id, bndgs, rb)?;
165 }
166 TopLevelLhsExpr::Eq(id1, id2) => {
167 state.justify_query(*id1, bndgs, rb)?;
168 state.justify_query(*id2, bndgs, rb)?;
169 }
170 }
171 }
172 let mut row = SmallVec::<[core_relations::QueryEntry; 8]>::new();
175 row.push(Value::new(reason_spec_id.rep()).into());
176 for (var, _) in &syntax.vars {
177 row.push(bndgs.mapping[*var]);
178 }
179 Ok(rb.lookup_or_insert(
180 reason_table,
181 &row,
182 &[WriteVal::IncCounter(reason_counter)],
183 ColumnId::from_usize(row.len()),
184 )?)
185 }
186 }
187
188 fn build_cong_metadata(&self, func: FunctionId, egraph: &mut EGraph) -> FunctionCongMetadata {
189 let func_info = &egraph.funcs[func];
190 let func_underlying = func_info.table;
191 let schema_math = SchemaMath {
192 subsume: func_info.can_subsume,
193 tracing: true,
194 func_cols: func_info.schema.len(),
195 };
196 let cong_args = CongArgs {
197 func_table: func,
198 func_underlying,
199 schema_math,
200 reason_table: egraph.reason_table(&ProofReason::CongRow),
201 term_table: egraph.term_table(func_underlying),
202 reason_counter: egraph.reason_counter,
203 term_counter: egraph.id_counter,
204 ts_counter: egraph.timestamp_counter,
205 reason_spec_id: egraph.cong_spec,
206 };
207 let build_term = egraph.register_external_func(make_external_func(move |es, vals| {
208 cong_term(&cong_args, es, vals)
209 }));
210 FunctionCongMetadata {
211 table: func_underlying,
212 build_term,
213 schema_math,
214 }
215 }
216}
217
218#[derive(Copy, Clone)]
220struct FunctionCongMetadata {
221 table: TableId,
222 build_term: ExternalFunctionId,
223 schema_math: SchemaMath,
224}
225
226struct TermReconstructionState<'a> {
227 syntax: &'a SourceSyntax,
231 syntax_mapping: DenseIdMap<SyntaxId, core_relations::QueryEntry>,
234 atom_mapping: DenseIdMap<AtomId, QueryEntry>,
237 metadata: DenseIdMap<FunctionId, FunctionCongMetadata>,
238}
239
240impl TermReconstructionState<'_> {
241 fn justify_query(
242 &mut self,
243 node: SyntaxId,
244 bndgs: &mut Bindings,
245 rb: &mut RuleBuilder,
246 ) -> Result<core_relations::QueryEntry> {
247 if let Some(entry) = self.syntax_mapping.get(node) {
248 return Ok(*entry);
249 }
250 let syntax = self.syntax;
251 let res = match &syntax.backing[node] {
252 SourceExpr::Const { val, .. } => return Ok(core_relations::QueryEntry::Const(*val)),
253 SourceExpr::Var { id, .. } => bndgs.mapping[*id],
254 SourceExpr::ExternalCall { var, args, .. } => {
255 for arg in args {
256 self.justify_query(*arg, bndgs, rb)?;
257 }
258 bndgs.mapping[*var]
259 }
260 SourceExpr::FunctionCall { func, atom, args } => {
261 let old_term = bndgs.convert(&self.atom_mapping[*atom]);
262 let mut buf: Vec<core_relations::QueryEntry> = vec![old_term];
263
264 for arg in args.iter().map(|s| self.justify_query(*s, bndgs, rb)) {
265 buf.push(arg?);
266 }
267 let FunctionCongMetadata {
268 table,
269 build_term,
270 schema_math,
271 } = &self.metadata[*func];
272 let term_col = ColumnId::from_usize(schema_math.proof_id_col());
273 rb.lookup_with_fallback(*table, &buf[1..], term_col, *build_term, &buf)?
274 .into()
275 }
276 };
277 self.syntax_mapping.insert(node, res);
278 Ok(res)
279 }
280}
281
282#[derive(Clone)]
285struct CongArgs {
286 func_table: FunctionId,
288 func_underlying: TableId,
290 schema_math: SchemaMath,
292 reason_table: TableId,
294 term_table: TableId,
296 reason_counter: CounterId,
298 term_counter: CounterId,
300 ts_counter: CounterId,
302 reason_spec_id: ReasonSpecId,
304}
305
306fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option<Value> {
307 let old_term = vals[0];
308 let new_term = &vals[1..];
309 let reason = es.predict_col(
310 args.reason_table,
311 &[Value::new(args.reason_spec_id.rep()), old_term],
312 iter::once(MergeVal::Counter(args.reason_counter)),
313 ColumnId::new(2),
314 );
315 let mut term_row = SmallVec::<[Value; 8]>::default();
316 term_row.push(Value::new(args.func_table.rep()));
317 term_row.extend_from_slice(new_term);
318 let term_val = es.predict_col(
319 args.term_table,
320 &term_row,
321 [
322 MergeVal::Counter(args.term_counter),
323 MergeVal::Constant(reason),
324 ]
325 .into_iter(),
326 ColumnId::from_usize(term_row.len()),
327 );
328
329 let ts = Value::from_usize(es.read_counter(args.ts_counter));
333 term_row.resize(args.schema_math.table_columns(), NOT_SUBSUMED);
334 term_row[args.schema_math.ret_val_col()] = term_val;
335 term_row[args.schema_math.proof_id_col()] = term_val;
336 term_row[args.schema_math.ts_col()] = ts;
337 es.stage_insert(args.func_underlying, &term_row);
338 Some(term_val)
339}