1use std::{iter, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11 ColumnId, CounterId, ExecutionState, ExternalFunctionId, MergeVal, RuleBuilder, TableId, Value,
12 WriteVal, make_external_func,
13};
14use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
15use crate::{EGraph, NOT_SUBSUMED, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath};
16use smallvec::SmallVec;
17
18use crate::{
19 ColumnTy, FunctionId, RuleId,
20 proof_spec::ProofBuilder,
21 rule::{AtomId, Bindings, VariableId},
22};
23
24define_id!(pub SyntaxId, u32, "an offset into a Syntax DAG.");
25
26#[derive(Debug, Clone)]
27pub enum TopLevelLhsExpr {
28 Exists(SyntaxId),
30 Eq(SyntaxId, SyntaxId),
32}
33
34#[derive(Debug, Clone)]
37pub enum SourceExpr {
38 Const { ty: ColumnTy, val: Value },
40 Var {
42 id: VariableId,
43 ty: ColumnTy,
44 name: String,
45 },
46 ExternalCall {
48 var: VariableId,
51 ty: ColumnTy,
52 func: ExternalFunctionId,
53 args: Vec<SyntaxId>,
54 },
55 FunctionCall {
57 func: FunctionId,
59 atom: AtomId,
62 args: Vec<SyntaxId>,
64 },
65}
66
67#[derive(Debug, Clone, Default)]
70pub struct SourceSyntax {
71 pub(crate) backing: IdVec<SyntaxId, SourceExpr>,
72 pub(crate) vars: Vec<(VariableId, ColumnTy)>,
73 pub(crate) roots: Vec<TopLevelLhsExpr>,
74}
75
76impl SourceSyntax {
77 pub fn add_expr(&mut self, expr: SourceExpr) -> SyntaxId {
82 match &expr {
83 SourceExpr::Const { .. } | SourceExpr::FunctionCall { .. } => {}
84 SourceExpr::Var { id, ty, .. } => self.vars.push((*id, *ty)),
85 SourceExpr::ExternalCall { var, ty, .. } => self.vars.push((*var, *ty)),
86 };
87 self.backing.push(expr)
88 }
89
90 pub fn add_toplevel_expr(&mut self, expr: TopLevelLhsExpr) {
92 self.roots.push(expr);
93 }
94
95 fn funcs(&self) -> impl Iterator<Item = FunctionId> + '_ {
96 self.backing.iter().filter_map(|(_, v)| {
97 if let SourceExpr::FunctionCall { func, .. } = v {
98 Some(*func)
99 } else {
100 None
101 }
102 })
103 }
104}
105
106#[derive(Debug)]
109pub(crate) struct RuleData {
110 pub(crate) rule_id: RuleId,
111 pub(crate) syntax: SourceSyntax,
112}
113
114impl RuleData {
115 pub(crate) fn n_vars(&self) -> usize {
116 self.syntax.vars.len()
117 }
118}
119
120impl ProofBuilder {
121 pub(crate) fn create_reason(
134 &mut self,
135 syntax: SourceSyntax,
136 egraph: &mut EGraph,
137 ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result<core_relations::Variable> + Clone + use<>
138 {
139 let mut metadata = DenseIdMap::default();
141 for func in syntax.funcs() {
142 metadata.insert(func, self.build_cong_metadata(func, egraph));
143 }
144
145 let reason_spec = Arc::new(ProofReason::Rule(RuleData {
146 rule_id: self.rule_id,
147 syntax: syntax.clone(),
148 }));
149 let reason_table = egraph.reason_table(&reason_spec);
150 let reason_spec_id = egraph.proof_specs.push(reason_spec);
151 let reason_counter = egraph.reason_counter;
152 let atom_mapping = self.term_vars.clone();
153 move |bndgs, rb| {
154 let mut state = TermReconstructionState {
156 syntax: &syntax,
157 syntax_mapping: Default::default(),
158 metadata: metadata.clone(),
159 atom_mapping: atom_mapping.clone(),
160 };
161 for toplevel_expr in &syntax.roots {
162 match toplevel_expr {
163 TopLevelLhsExpr::Exists(id) => {
164 state.justify_query(*id, bndgs, rb)?;
165 }
166 TopLevelLhsExpr::Eq(id1, id2) => {
167 state.justify_query(*id1, bndgs, rb)?;
168 state.justify_query(*id2, bndgs, rb)?;
169 }
170 }
171 }
172 let mut row = SmallVec::<[core_relations::QueryEntry; 8]>::new();
175 row.push(Value::new(reason_spec_id.rep()).into());
176 for (var, _) in &syntax.vars {
177 row.push(bndgs.mapping[*var]);
178 }
179 Ok(rb.lookup_or_insert(
180 reason_table,
181 &row,
182 &[WriteVal::IncCounter(reason_counter)],
183 ColumnId::from_usize(row.len()),
184 )?)
185 }
186 }
187
188 fn build_cong_metadata(&self, func: FunctionId, egraph: &mut EGraph) -> FunctionCongMetadata {
189 let func_info = &egraph.funcs[func];
190 let func_underlying = func_info.table;
191 let schema_math = SchemaMath {
192 subsume: func_info.can_subsume,
193 tracing: true,
194 func_cols: func_info.schema.len(),
195 };
196 let cong_args = CongArgs {
197 func_table: func,
198 func_underlying,
199 schema_math,
200 reason_table: egraph.reason_table(&ProofReason::CongRow),
201 term_table: egraph.term_table(func_underlying),
202 reason_counter: egraph.reason_counter,
203 term_counter: egraph.id_counter,
204 ts_counter: egraph.timestamp_counter,
205 reason_spec_id: egraph.cong_spec,
206 };
207 let build_term =
208 egraph.register_external_func(Box::new(make_external_func(move |es, vals| {
209 cong_term(&cong_args, es, vals)
210 })));
211 FunctionCongMetadata {
212 table: func_underlying,
213 build_term,
214 schema_math,
215 }
216 }
217}
218
219#[derive(Copy, Clone)]
221struct FunctionCongMetadata {
222 table: TableId,
223 build_term: ExternalFunctionId,
224 schema_math: SchemaMath,
225}
226
227struct TermReconstructionState<'a> {
228 syntax: &'a SourceSyntax,
232 syntax_mapping: DenseIdMap<SyntaxId, core_relations::QueryEntry>,
235 atom_mapping: DenseIdMap<AtomId, QueryEntry>,
238 metadata: DenseIdMap<FunctionId, FunctionCongMetadata>,
239}
240
241impl TermReconstructionState<'_> {
242 fn justify_query(
243 &mut self,
244 node: SyntaxId,
245 bndgs: &mut Bindings,
246 rb: &mut RuleBuilder,
247 ) -> Result<core_relations::QueryEntry> {
248 if let Some(entry) = self.syntax_mapping.get(node) {
249 return Ok(*entry);
250 }
251 let syntax = self.syntax;
252 let res = match &syntax.backing[node] {
253 SourceExpr::Const { val, .. } => return Ok(core_relations::QueryEntry::Const(*val)),
254 SourceExpr::Var { id, .. } => bndgs.mapping[*id],
255 SourceExpr::ExternalCall { var, args, .. } => {
256 for arg in args {
257 self.justify_query(*arg, bndgs, rb)?;
258 }
259 bndgs.mapping[*var]
260 }
261 SourceExpr::FunctionCall { func, atom, args } => {
262 let old_term = bndgs.convert(&self.atom_mapping[*atom]);
263 let mut buf: Vec<core_relations::QueryEntry> = vec![old_term];
264
265 for arg in args.iter().map(|s| self.justify_query(*s, bndgs, rb)) {
266 buf.push(arg?);
267 }
268 let FunctionCongMetadata {
269 table,
270 build_term,
271 schema_math,
272 } = &self.metadata[*func];
273 let term_col = ColumnId::from_usize(schema_math.proof_id_col());
274 rb.lookup_with_fallback(*table, &buf[1..], term_col, *build_term, &buf)?
275 .into()
276 }
277 };
278 self.syntax_mapping.insert(node, res);
279 Ok(res)
280 }
281}
282
283#[derive(Clone)]
286struct CongArgs {
287 func_table: FunctionId,
289 func_underlying: TableId,
291 schema_math: SchemaMath,
293 reason_table: TableId,
295 term_table: TableId,
297 reason_counter: CounterId,
299 term_counter: CounterId,
301 ts_counter: CounterId,
303 reason_spec_id: ReasonSpecId,
305}
306
307fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option<Value> {
308 let old_term = vals[0];
309 let new_term = &vals[1..];
310 let reason = es.predict_col(
311 args.reason_table,
312 &[Value::new(args.reason_spec_id.rep()), old_term],
313 iter::once(MergeVal::Counter(args.reason_counter)),
314 ColumnId::new(2),
315 );
316 let mut term_row = SmallVec::<[Value; 8]>::default();
317 term_row.push(Value::new(args.func_table.rep()));
318 term_row.extend_from_slice(new_term);
319 let term_val = es.predict_col(
320 args.term_table,
321 &term_row,
322 [
323 MergeVal::Counter(args.term_counter),
324 MergeVal::Constant(reason),
325 ]
326 .into_iter(),
327 ColumnId::from_usize(term_row.len()),
328 );
329
330 let ts = Value::from_usize(es.read_counter(args.ts_counter));
334 term_row.resize(args.schema_math.table_columns(), NOT_SUBSUMED);
335 term_row[args.schema_math.ret_val_col()] = term_val;
336 term_row[args.schema_math.proof_id_col()] = term_val;
337 term_row[args.schema_math.ts_col()] = ts;
338 es.stage_insert(args.func_underlying, &term_row);
339 Some(term_val)
340}