1use std::{
12 cmp,
13 fmt::Debug,
14 hash::Hash,
15 iter, mem,
16 ops::{Index, IndexMut},
17 sync::{Arc, Mutex},
18};
19
20use crate::core_relations::{
21 BaseValue, BaseValueId, BaseValues, ColumnId, Constraint, ContainerValue, ContainerValues,
22 CounterId, Database, DisplacedTable, DisplacedTableWithProvenance, ExecutionState,
23 ExternalFunction, ExternalFunctionId, MergeVal, Offset, PlanStrategy, SortedWritesTable,
24 TableId, TaggedRowBuffer, Value, WrappedTable,
25};
26use crate::numeric_id::{DenseIdMap, DenseIdMapWithReuse, IdVec, NumericId, define_id};
27use egglog_core_relations as core_relations;
28use egglog_numeric_id as numeric_id;
29use egglog_reports::{IterationReport, ReportLevel, RuleSetReport};
30use hashbrown::HashMap;
31use indexmap::{IndexMap, IndexSet, map::Entry};
32use log::info;
33use once_cell::sync::Lazy;
34pub use proof_format::{EqProofId, ProofStore, TermProofId};
35use proof_spec::{ProofReason, ProofReconstructionState, ReasonSpecId};
36use smallvec::SmallVec;
37use web_time::{Duration, Instant};
38
39pub mod macros;
40pub mod proof_format;
41pub(crate) mod proof_spec;
42pub(crate) mod rule;
43pub mod syntax;
44#[cfg(test)]
45mod tests;
46
47pub use rule::{Function, QueryEntry, RuleBuilder};
48pub use syntax::{SourceExpr, SourceSyntax, TopLevelLhsExpr};
49use thiserror::Error;
50
51#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
52pub enum ColumnTy {
53 Id,
54 Base(BaseValueId),
55}
56
57define_id!(pub RuleId, u32, "An egglog-style rule");
58define_id!(pub FunctionId, u32, "An id representing an egglog function");
59define_id!(pub(crate) Timestamp, u32, "An abstract timestamp used to track execution of egglog rules");
60impl Timestamp {
61 fn to_value(self) -> Value {
62 Value::new(self.rep())
63 }
64}
65
66#[derive(Clone)]
68pub struct EGraph {
69 db: Database,
70 uf_table: TableId,
71 id_counter: CounterId,
72 reason_counter: CounterId,
73 timestamp_counter: CounterId,
74 rules: DenseIdMapWithReuse<RuleId, RuleInfo>,
75 funcs: DenseIdMap<FunctionId, FunctionInfo>,
76 panic_message: SideChannel<String>,
77 panic_funcs: HashMap<String, ExternalFunctionId>,
83 proof_specs: IdVec<ReasonSpecId, Arc<ProofReason>>,
84 cong_spec: ReasonSpecId,
85 reason_tables: IndexMap<usize , TableId>,
88 term_tables: IndexMap<usize , TableId>,
89 term_consistency_table: TableId,
133 tracing: bool,
134 report_level: ReportLevel,
135}
136
137pub type Result<T> = std::result::Result<T, anyhow::Error>;
138
139impl Default for EGraph {
140 fn default() -> Self {
141 let mut db = Database::new();
142 let uf_table = db.add_table_named(
143 DisplacedTable::default(),
144 "$uf".into(),
145 iter::empty(),
146 iter::empty(),
147 );
148 EGraph::create_internal(db, uf_table, false)
149 }
150}
151
152pub struct FunctionConfig {
154 pub schema: Vec<ColumnTy>,
156 pub default: DefaultVal,
158 pub merge: MergeFn,
160 pub name: String,
162 pub can_subsume: bool,
164}
165
166impl EGraph {
167 pub fn with_tracing() -> EGraph {
173 let mut db = Database::new();
174 let uf_table = db.add_table_named(
175 DisplacedTableWithProvenance::default(),
176 "$uf".into(),
177 iter::empty(),
178 iter::empty(),
179 );
180 EGraph::create_internal(db, uf_table, true)
181 }
182
183 fn create_internal(mut db: Database, uf_table: TableId, tracing: bool) -> EGraph {
184 let id_counter = db.add_counter();
185 let trace_counter = db.add_counter();
186 let ts_counter = db.add_counter();
187 db.inc_counter(ts_counter);
189 let mut proof_specs = IdVec::default();
190 let cong_spec = proof_specs.push(Arc::new(ProofReason::CongRow));
191 let term_consistency_table =
192 db.add_table(DisplacedTable::default(), iter::empty(), iter::empty());
193
194 Self {
195 db,
196 uf_table,
197 id_counter,
198 reason_counter: trace_counter,
199 timestamp_counter: ts_counter,
200 rules: Default::default(),
201 funcs: Default::default(),
202 panic_message: Default::default(),
203 panic_funcs: Default::default(),
204 proof_specs,
205 cong_spec,
206 reason_tables: Default::default(),
207 term_tables: Default::default(),
208 term_consistency_table,
209 report_level: Default::default(),
210 tracing,
211 }
212 }
213
214 fn next_ts(&self) -> Timestamp {
215 Timestamp::from_usize(self.db.read_counter(self.timestamp_counter))
216 }
217
218 fn inc_ts(&mut self) {
219 self.db.inc_counter(self.timestamp_counter);
220 }
221
222 pub fn base_values_mut(&mut self) -> &mut BaseValues {
225 self.db.base_values_mut()
226 }
227
228 pub fn container_values_mut(&mut self) -> &mut ContainerValues {
231 self.db.container_values_mut()
232 }
233
234 pub fn container_values(&self) -> &ContainerValues {
236 self.db.container_values()
237 }
238
239 pub fn get_container_value<C: ContainerValue>(&mut self, val: C) -> Value {
241 self.register_container_ty::<C>();
242 self.db
243 .with_execution_state(|state| state.clone().container_values().register_val(val, state))
244 }
245
246 pub fn register_container_ty<C: ContainerValue>(&mut self) {
251 let uf_table = self.uf_table;
252 let ts_counter = self.timestamp_counter;
253 self.db.container_values_mut().register_type::<C>(
254 self.id_counter,
255 move |state, old, new| {
256 if old != new {
257 let next_ts = Value::from_usize(state.read_counter(ts_counter));
258 state.stage_insert(uf_table, &[old, new, next_ts]);
259 std::cmp::min(old, new)
260 } else {
261 old
262 }
263 },
264 );
265 }
266
267 pub fn base_values(&self) -> &BaseValues {
269 self.db.base_values()
270 }
271
272 pub fn base_value_constant<T>(&self, x: T) -> QueryEntry
274 where
275 T: BaseValue,
276 {
277 QueryEntry::Const {
278 val: self.base_values().get(x),
279 ty: ColumnTy::Base(self.base_values().get_ty::<T>()),
280 }
281 }
282
283 pub fn register_external_func(
284 &mut self,
285 func: Box<dyn ExternalFunction + 'static>,
286 ) -> ExternalFunctionId {
287 self.db.add_external_function(func)
288 }
289
290 pub fn free_external_func(&mut self, func: ExternalFunctionId) {
291 self.db.free_external_function(func)
292 }
293
294 pub fn fresh_id(&mut self) -> Value {
296 Value::from_usize(self.db.inc_counter(self.id_counter))
297 }
298
299 fn get_canon_in_uf(&self, val: Value) -> Value {
303 let table = self.db.get_table(self.uf_table);
304 let row = table.get_row(&[val]);
305 row.map(|row| row.vals[1]).unwrap_or(val)
306 }
307
308 pub fn get_canon_repr(&self, val: Value, ty: ColumnTy) -> Value {
313 match ty {
314 ColumnTy::Id => self.get_canon_in_uf(val),
315 ColumnTy::Base(_) => val,
316 }
317 }
318
319 fn record_term_consistency(
320 state: &mut ExecutionState,
321 table: TableId,
322 ts_counter: CounterId,
323 from: Value,
324 to: Value,
325 ) {
326 if from == to {
327 return;
328 }
329 let ts = Value::from_usize(state.read_counter(ts_counter));
330 state.stage_insert(table, &[from, to, ts]);
331 }
332
333 fn canonicalize_term_id(&mut self, term_id: Value) -> Value {
334 let table = self.db.get_table(self.term_consistency_table);
335 table
336 .get_row(&[term_id])
337 .map(|row| row.vals[1])
338 .unwrap_or(term_id)
339 }
340
341 fn term_table(&mut self, table: TableId) -> TableId {
342 let info = self.db.get_table_info(table);
343 let spec = info.spec();
344 match self.term_tables.entry(spec.n_keys) {
345 Entry::Occupied(o) => *o.get(),
346 Entry::Vacant(v) => {
347 let term_index = spec.n_keys + 1;
348 let term_consistency_table = self.term_consistency_table;
349 let ts_counter = self.timestamp_counter;
350 let table = SortedWritesTable::new(
351 spec.n_keys + 1, spec.n_keys + 1 + 2, None,
354 vec![], Box::new(move |state, old, new, out| {
356 let l_term_id = old[term_index];
358 let r_term_id = new[term_index];
359 if r_term_id < l_term_id {
363 EGraph::record_term_consistency(
364 state,
365 term_consistency_table,
366 ts_counter,
367 l_term_id,
368 r_term_id,
369 );
370 out.extend(new);
371 true
372 } else {
373 false
374 }
375 }),
376 );
377 let table_id =
378 self.db
379 .add_table(table, iter::empty(), iter::once(term_consistency_table));
380 *v.insert(table_id)
381 }
382 }
383 }
384
385 fn reason_table(&mut self, spec: &ProofReason) -> TableId {
386 let arity = spec.arity();
387 match self.reason_tables.entry(arity) {
388 Entry::Occupied(o) => *o.get(),
389 Entry::Vacant(v) => {
390 let table = SortedWritesTable::new(
391 arity,
392 arity + 1, None,
394 vec![], Box::new(|_, _, _, _| false),
396 );
397 let table_id = self.db.add_table(table, iter::empty(), iter::empty());
398 *v.insert(table_id)
399 }
400 }
401 }
402
403 pub fn add_values(&mut self, values: impl IntoIterator<Item = (FunctionId, Vec<Value>)>) {
412 self.add_values_with_desc("", values)
413 }
414
415 pub fn add_term(&mut self, func: FunctionId, inputs: &[Value], desc: &str) -> Value {
422 let info = &self.funcs[func];
423 let schema_math = SchemaMath {
424 tracing: self.tracing,
425 subsume: info.can_subsume,
426 func_cols: info.schema.len(),
427 };
428 let mut extended_row = Vec::new();
429 extended_row.extend_from_slice(inputs);
430 let term = self.tracing.then(|| {
431 let reason = self.get_fiat_reason(desc);
432 self.get_term(func, inputs, reason)
433 });
434 let res = term.unwrap_or_else(|| self.fresh_id());
435 schema_math.write_table_row(
436 &mut extended_row,
437 RowVals {
438 timestamp: self.next_ts().to_value(),
439 ret_val: Some(res),
440 proof: term,
441 subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
442 },
443 );
444 extended_row[schema_math.ret_val_col()] = res;
445 let table_id = self.funcs[func].table;
446 self.db.new_buffer(table_id).stage_insert(&extended_row);
447 self.flush_updates();
448 self.get_canon_in_uf(res)
449 }
450
451 fn get_term(&mut self, func: FunctionId, key: &[Value], reason: Value) -> Value {
456 let table_id = self.funcs[func].table;
457 let term_table_id = self.term_table(table_id);
458 let table = self.db.get_table(term_table_id);
459 let mut term_key = Vec::with_capacity(key.len() + 1);
460 term_key.push(Value::new(func.rep()));
461 term_key.extend(key);
462 if let Some(row) = table.get_row(&term_key) {
463 row.vals[row.vals.len() - 2]
464 } else {
465 let result = Value::from_usize(self.db.inc_counter(self.id_counter));
466 term_key.push(result);
467 term_key.push(reason);
468 self.db.new_buffer(term_table_id).stage_insert(&term_key);
469 self.db.merge_table(term_table_id);
470 result
471 }
472 }
473
474 pub fn lookup_id(&self, func: FunctionId, key: &[Value]) -> Option<Value> {
477 let info = &self.funcs[func];
478 let schema_math = SchemaMath {
479 tracing: self.tracing,
480 subsume: info.can_subsume,
481 func_cols: info.schema.len(),
482 };
483 let table_id = info.table;
484 let table = self.db.get_table(table_id);
485 let row = table.get_row(key)?;
486 Some(row.vals[schema_math.ret_val_col()])
487 }
488
489 fn get_fiat_reason(&mut self, desc: &str) -> Value {
490 let reason = Arc::new(ProofReason::Fiat { desc: desc.into() });
491 let reason_table = self.reason_table(&reason);
492 let reason_spec_id = self.proof_specs.push(reason);
493 let reason_id = Value::from_usize(self.db.inc_counter(self.reason_counter));
494 self.db
495 .new_buffer(reason_table)
496 .stage_insert(&[Value::new(reason_spec_id.rep()), reason_id]);
497 self.db.merge_table(reason_table);
498 reason_id
499 }
500
501 pub fn add_values_with_desc(
511 &mut self,
512 desc: &str,
513 values: impl IntoIterator<Item = (FunctionId, Vec<Value>)>,
514 ) {
515 let mut extended_row = Vec::<Value>::new();
516 let reason_id = self.tracing.then(|| self.get_fiat_reason(desc));
517 let mut bufs = DenseIdMap::default();
518 for (func, row) in values.into_iter() {
519 let table_info = &self.funcs[func];
520 let schema_math = SchemaMath {
521 tracing: self.tracing,
522 subsume: table_info.can_subsume,
523 func_cols: table_info.schema.len(),
524 };
525 let table_id = table_info.table;
526 let term_id = reason_id.map(|reason| {
527 let term_id = self.get_term(func, &row[0..schema_math.num_keys()], reason);
529 let buf = bufs.get_or_insert(self.uf_table, || self.db.new_buffer(self.uf_table));
530 buf.stage_insert(&[
532 *row.last().unwrap(),
533 term_id,
534 self.next_ts().to_value(),
535 reason,
536 ]);
537 term_id
538 });
539 extended_row.extend_from_slice(&row);
540 schema_math.write_table_row(
541 &mut extended_row,
542 RowVals {
543 timestamp: self.next_ts().to_value(),
544 proof: term_id,
545 subsume: schema_math.subsume.then_some(NOT_SUBSUMED),
546 ret_val: None, },
548 );
549 let buf = bufs.get_or_insert(table_id, || self.db.new_buffer(table_id));
550 buf.stage_insert(&extended_row);
551 extended_row.clear();
552 }
553 mem::drop(bufs);
555 self.flush_updates();
556 }
557
558 pub fn approx_table_size(&self, table: FunctionId) -> usize {
559 self.db.estimate_size(self.funcs[table].table, None)
560 }
561
562 pub fn table_size(&self, table: FunctionId) -> usize {
563 self.db.get_table(self.funcs[table].table).len()
564 }
565
566 pub fn explain_term(&mut self, id: Value, store: &mut ProofStore) -> Result<TermProofId> {
575 if !self.tracing {
576 return Err(ProofReconstructionError::TracingNotEnabled.into());
577 }
578 let mut state = ProofReconstructionState::new(store);
579 Ok(self.explain_term_inner(id, &mut state))
580 }
581
582 pub fn explain_terms_equal(
589 &mut self,
590 id1: Value,
591 id2: Value,
592 store: &mut ProofStore,
593 ) -> Result<EqProofId> {
594 if !self.tracing {
595 return Err(ProofReconstructionError::TracingNotEnabled.into());
596 }
597 let mut state = ProofReconstructionState::new(store);
598 if self.get_canon_in_uf(id1) != self.get_canon_in_uf(id2) {
599 let mut buf = Vec::<u8>::new();
602 let term_id_1 = self.reconstruct_term(id1, ColumnTy::Id, &mut state);
603 let term_id_2 = self.reconstruct_term(id2, ColumnTy::Id, &mut state);
604 store.termdag.print_term(term_id_1, &mut buf).unwrap();
605 let term1 = String::from_utf8(buf).unwrap();
606 let mut buf = Vec::<u8>::new();
607 store.termdag.print_term(term_id_2, &mut buf).unwrap();
608 let term2 = String::from_utf8(buf).unwrap();
609 return Err(
610 ProofReconstructionError::EqualityExplanationOfUnequalTerms { term1, term2 }.into(),
611 );
612 }
613 Ok(self.explain_terms_equal_inner(id1, id2, &mut state))
614 }
615
616 pub fn for_each(&self, table: FunctionId, mut f: impl FnMut(FunctionRow<'_>)) {
620 self.for_each_while(table, |row| {
621 f(row);
622 true
623 });
624 }
625
626 pub fn for_each_while(&self, table: FunctionId, mut f: impl FnMut(FunctionRow<'_>) -> bool) {
629 let info = &self.funcs[table];
630 let table = self.funcs[table].table;
631 let schema_math = SchemaMath {
632 tracing: self.tracing,
633 subsume: info.can_subsume,
634 func_cols: info.schema.len(),
635 };
636 let imp = self.db.get_table(table);
637 let all = imp.all();
638 let mut cur = Offset::new(0);
639 let mut buf = TaggedRowBuffer::new(imp.spec().arity());
640 macro_rules! drain_buf {
645 ($buf:expr) => {
646 for (_, row) in $buf.non_stale() {
647 let subsumed =
648 schema_math.subsume && row[schema_math.subsume_col()] == SUBSUMED;
649 if !f(FunctionRow {
650 vals: &row[0..schema_math.func_cols],
651 subsumed,
652 }) {
653 return;
654 }
655 }
656 $buf.clear();
657 };
658 }
659 while let Some(next) = imp.scan_bounded(all.as_ref(), cur, 32, &mut buf) {
660 drain_buf!(buf);
661 cur = next;
662 }
663 drain_buf!(buf);
664 }
665
666 pub fn dump_debug_info(&self) {
670 info!("=== View Tables ===");
671 for (id, info) in self.funcs.iter() {
672 let table = self.db.get_table(info.table);
673 self.scan_table(table, |row| {
674 info!(
675 "View Table {name} / {id:?} / {table:?}: {row:?}",
676 name = info.name,
677 table = info.table
678 )
679 });
680 }
681
682 info!("=== Term Tables ===");
683 for (_, table_id) in &self.term_tables {
684 let table = self.db.get_table(*table_id);
685 self.scan_table(table, |row| {
686 let name = &self.funcs[FunctionId::new(row[0].rep())].name;
687 let row = &row[1..];
688 info!("Term Table {table_id:?}: {name}, {row:?}")
689 });
690 }
691
692 info!("=== Reason Tables ===");
693 for (_, table_id) in &self.reason_tables {
694 let table = self.db.get_table(*table_id);
695 self.scan_table(table, |row| {
696 let spec = self.proof_specs[ReasonSpecId::new(row[0].rep())].as_ref();
697 let row = &row[1..];
698 info!("Reason Table {table_id:?}: {spec:?}, {row:?}")
699 });
700 }
701 }
702
703 fn scan_table(&self, table: &WrappedTable, mut f: impl FnMut(&[Value])) {
705 const BATCH_SIZE: usize = 128;
706 let all = table.all();
707 let mut cur = Offset::new(0);
708 let mut out = TaggedRowBuffer::new(table.spec().arity());
709 while let Some(next) = table.scan_bounded(all.as_ref(), cur, BATCH_SIZE, &mut out) {
710 out.non_stale().for_each(|(_, row)| f(row));
711 out.clear();
712 cur = next;
713 }
714 out.non_stale().for_each(|(_, row)| f(row));
715 }
716
717 pub fn add_table(&mut self, config: FunctionConfig) -> FunctionId {
719 let FunctionConfig {
720 schema,
721 default,
722 merge,
723 name,
724 can_subsume,
725 } = config;
726 assert!(
727 !schema.is_empty(),
728 "must have at least one column in schema"
729 );
730 let to_rebuild: Vec<ColumnId> = schema
731 .iter()
732 .enumerate()
733 .filter(|(_, ty)| matches!(ty, ColumnTy::Id))
734 .map(|(i, _)| ColumnId::from_usize(i))
735 .collect();
736 let schema_math = SchemaMath {
737 tracing: self.tracing,
738 subsume: can_subsume,
739 func_cols: schema.len(),
740 };
741 let n_args = schema_math.num_keys();
742 let n_cols = schema_math.table_columns();
743 let next_func_id = self.funcs.next_id();
744 let mut read_deps = IndexSet::<TableId>::new();
745 let mut write_deps = IndexSet::<TableId>::new();
746 merge.fill_deps(self, &mut read_deps, &mut write_deps);
747 let merge_fn = merge.to_callback(schema_math, &name, self);
748 let table = SortedWritesTable::new(
749 n_args,
750 n_cols,
751 Some(ColumnId::from_usize(schema.len())),
752 to_rebuild,
753 merge_fn,
754 );
755 let name: Arc<str> = name.into();
756 let table_id = self.db.add_table_named(
757 table,
758 name.clone(),
759 read_deps.iter().copied(),
760 write_deps.iter().copied(),
761 );
762
763 let res = self.funcs.push(FunctionInfo {
764 table: table_id,
765 schema: schema.clone(),
766 incremental_rebuild_rules: Default::default(),
767 nonincremental_rebuild_rule: RuleId::new(!0),
768 default_val: default,
769 can_subsume,
770 name,
771 });
772 debug_assert_eq!(res, next_func_id);
773 let incremental_rebuild_rules = self.incremental_rebuild_rules(res, &schema);
774 let nonincremental_rebuild_rule = self.nonincremental_rebuild(res, &schema);
775 let info = &mut self.funcs[res];
776 info.incremental_rebuild_rules = incremental_rebuild_rules;
777 info.nonincremental_rebuild_rule = nonincremental_rebuild_rule;
778 res
779 }
780
781 pub fn run_rules(&mut self, rules: &[RuleId]) -> Result<IterationReport> {
785 let ts = self.next_ts();
786
787 let rule_set_report =
788 run_rules_impl(&mut self.db, &mut self.rules, rules, ts, self.report_level)?;
789 if let Some(message) = self.panic_message.lock().unwrap().take() {
790 return Err(PanicError(message).into());
791 }
792
793 let mut iteration_report = IterationReport {
794 rule_set_report,
795 rebuild_time: Duration::ZERO,
796 };
797 if !iteration_report.changed() {
798 return Ok(iteration_report);
799 }
800
801 let rebuild_timer = Instant::now();
802 self.rebuild()?;
803 iteration_report.rebuild_time = rebuild_timer.elapsed();
804
805 if let Some(message) = self.panic_message.lock().unwrap().take() {
806 return Err(PanicError(message).into());
807 }
808
809 Ok(iteration_report)
810 }
811
812 fn rebuild(&mut self) -> Result<()> {
813 fn do_parallel() -> bool {
814 #[cfg(test)]
815 {
816 use rand::Rng;
817 rand::rng().random_bool(0.5)
818 }
819 #[cfg(not(test))]
820 {
821 rayon::current_num_threads() > 1
822 }
823 }
824 if self.db.get_table(self.uf_table).rebuilder(&[]).is_some() {
825 let mut tables = Vec::with_capacity(self.funcs.next_id().index());
827 for (_, func) in self.funcs.iter() {
828 tables.push(func.table);
829 }
830 loop {
831 let container_rebuild = self.db.rebuild_containers(self.uf_table);
863 let table_rebuild =
864 self.db
865 .apply_rebuild(self.uf_table, &tables, self.next_ts().to_value());
866 self.inc_ts();
867 if !table_rebuild && !container_rebuild {
868 break;
869 }
870 }
871 return Ok(());
872 }
873 if do_parallel() {
874 return self.rebuild_parallel();
875 }
876 let start = Instant::now();
877
878 let mut changed = true;
880 while changed {
881 changed = false;
882 self.inc_ts();
885 let ts = self.next_ts();
886 for (_, info) in self.funcs.iter_mut() {
887 let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
888 let table_size = self.db.estimate_size(info.table, None);
889 let uf_size = self.db.estimate_size(
890 self.uf_table,
891 Some(Constraint::GeConst {
892 col: ColumnId::new(2),
893 val: last_rebuilt_at.to_value(),
894 }),
895 );
896 if incremental_rebuild(uf_size, table_size, false) {
897 marker_incremental_rebuild(|| -> Result<()> {
898 for rule in &info.incremental_rebuild_rules {
903 changed |= run_rules_impl(
904 &mut self.db,
905 &mut self.rules,
906 &[*rule],
907 ts,
908 ReportLevel::TimeOnly,
909 )?
910 .changed;
911 }
912 self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
914 Ok(())
915 })?;
916 } else {
917 marker_nonincremental_rebuild(|| -> Result<()> {
918 changed |= run_rules_impl(
919 &mut self.db,
920 &mut self.rules,
921 &[info.nonincremental_rebuild_rule],
922 ts,
923 ReportLevel::TimeOnly,
924 )?
925 .changed;
926 for rule in &info.incremental_rebuild_rules {
927 self.rules[*rule].last_run_at = ts;
928 }
929 Ok(())
930 })?;
931 }
932 }
933 }
934 log::info!("rebuild took {:?}", start.elapsed());
935 Ok(())
936 }
937
938 fn rebuild_parallel(&mut self) -> Result<()> {
943 let start = Instant::now();
944 #[derive(Default)]
945 struct RebuildState {
946 nonincremental: Vec<FunctionId>,
947 incremental: DenseIdMap<usize, SmallVec<[FunctionId; 2]>>,
948 }
949
950 impl RebuildState {
951 fn clear(&mut self) {
952 self.nonincremental.clear();
953 self.incremental.iter_mut().for_each(|(_, v)| v.clear());
954 }
955 }
956
957 let mut changed = true;
958 let mut state = RebuildState::default();
959 let mut scratch = Vec::new();
960 while changed {
961 changed = false;
962 state.clear();
963 self.inc_ts();
964 for (func, info) in self.funcs.iter_mut() {
967 let last_rebuilt_at = self.rules[info.nonincremental_rebuild_rule].last_run_at;
968 let table_size = self.db.estimate_size(info.table, None);
969 let uf_size = self.db.estimate_size(
970 self.uf_table,
971 Some(Constraint::GeConst {
972 col: ColumnId::new(2),
973 val: last_rebuilt_at.to_value(),
974 }),
975 );
976 if incremental_rebuild(uf_size, table_size, true) {
977 for (i, _) in info.incremental_rebuild_rules.iter().enumerate() {
978 state.incremental.get_or_default(i).push(func);
979 }
980 } else {
981 state.nonincremental.push(func);
982 }
983 }
984 let ts = self.next_ts();
985 for func in state.nonincremental.iter().copied() {
986 scratch.push(self.funcs[func].nonincremental_rebuild_rule);
987 for rule in &self.funcs[func].incremental_rebuild_rules {
988 self.rules[*rule].last_run_at = ts;
989 }
990 }
991 changed |= run_rules_impl(
992 &mut self.db,
993 &mut self.rules,
994 &scratch,
995 ts,
996 ReportLevel::TimeOnly,
997 )?
998 .changed;
999 scratch.clear();
1000 let ts = self.next_ts();
1001 for (i, funcs) in state.incremental.iter() {
1002 for func in funcs.iter().copied() {
1003 let info = &mut self.funcs[func];
1004 scratch.push(info.incremental_rebuild_rules[i]);
1005 self.rules[info.nonincremental_rebuild_rule].last_run_at = ts;
1006 }
1007 changed |= run_rules_impl(
1008 &mut self.db,
1009 &mut self.rules,
1010 &scratch,
1011 ts,
1012 ReportLevel::TimeOnly,
1013 )?
1014 .changed;
1015 scratch.clear();
1016 }
1017 }
1018 log::info!("rebuild took {:?}", start.elapsed());
1019 Ok(())
1020 }
1021
1022 fn incremental_rebuild_rules(&mut self, table: FunctionId, schema: &[ColumnTy]) -> Vec<RuleId> {
1023 schema
1024 .iter()
1025 .enumerate()
1026 .filter_map(|(i, ty)| match ty {
1027 ColumnTy::Id => {
1028 Some(self.incremental_rebuild_rule(table, schema, ColumnId::from_usize(i)))
1029 }
1030 ColumnTy::Base(_) => None,
1031 })
1032 .collect()
1033 }
1034
1035 fn incremental_rebuild_rule(
1036 &mut self,
1037 table: FunctionId,
1038 schema: &[ColumnTy],
1039 col: ColumnId,
1040 ) -> RuleId {
1041 let subsume = self.funcs[table].can_subsume;
1042 let table_id = self.funcs[table].table;
1043 let uf_table = self.uf_table;
1044 let mut rb = self.new_rule(&format!("incremental rebuild {table:?}, {col:?}"), true);
1046 rb.set_plan_strategy(PlanStrategy::MinCover);
1047 let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
1048 for ty in schema {
1049 vars.push(rb.new_var(*ty).into());
1050 }
1051 let canon_val: QueryEntry = rb.new_var(ColumnTy::Id).into();
1052 let subsume_var = subsume.then(|| rb.new_var(ColumnTy::Id));
1053 rb.add_atom_with_timestamp_and_func(
1054 table_id,
1055 Some(table),
1056 subsume_var.clone().map(QueryEntry::from),
1057 &vars,
1058 );
1059 rb.add_atom_with_timestamp_and_func(
1060 uf_table,
1061 None,
1062 None,
1063 &[vars[col.index()].clone(), canon_val.clone()],
1064 );
1065 rb.set_focus(1); let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
1069 for (i, (var, ty)) in vars.iter().zip(schema.iter()).enumerate() {
1070 canon.push(if i == col.index() {
1071 canon_val.clone()
1072 } else if let ColumnTy::Id = ty {
1073 rb.lookup_uf(var.clone()).unwrap().into()
1074 } else {
1075 var.clone()
1076 })
1077 }
1078
1079 rb.rebuild_row(table, &vars, &canon, subsume_var);
1081 rb.build_internal(None)
1082 }
1083
1084 fn nonincremental_rebuild(&mut self, table: FunctionId, schema: &[ColumnTy]) -> RuleId {
1085 let can_subsume = self.funcs[table].can_subsume;
1086 let table_id = self.funcs[table].table;
1087 let mut rb = self.new_rule(&format!("nonincremental rebuild {table:?}"), false);
1088 rb.set_plan_strategy(PlanStrategy::MinCover);
1089 let mut vars = Vec::<QueryEntry>::with_capacity(schema.len());
1090 for ty in schema {
1091 vars.push(rb.new_var(*ty).into());
1092 }
1093 let subsume_var = can_subsume.then(|| rb.new_var(ColumnTy::Id));
1094 rb.add_atom_with_timestamp_and_func(
1095 table_id,
1096 Some(table),
1097 subsume_var.clone().map(QueryEntry::from),
1098 &vars,
1099 );
1100 let mut lhs = SmallVec::<[QueryEntry; 4]>::new();
1101 let mut rhs = SmallVec::<[QueryEntry; 4]>::new();
1102 let mut canon = Vec::<QueryEntry>::with_capacity(schema.len());
1103 for (var, ty) in vars.iter().zip(schema.iter()) {
1104 canon.push(if let ColumnTy::Id = ty {
1105 lhs.push(var.clone());
1106 let canon_var = QueryEntry::from(rb.lookup_uf(var.clone()).unwrap());
1107 rhs.push(canon_var.clone());
1108 canon_var
1109 } else {
1110 var.clone()
1111 })
1112 }
1113 rb.check_for_update(&lhs, &rhs).unwrap();
1114 rb.rebuild_row(table, &vars, &canon, subsume_var);
1115 rb.build_internal(None) }
1117
1118 pub fn with_execution_state<R>(&self, f: impl FnOnce(&mut ExecutionState<'_>) -> R) -> R {
1124 self.db.with_execution_state(f)
1125 }
1126
1127 pub fn flush_updates(&mut self) -> bool {
1130 let updated = self.db.merge_all();
1131 self.inc_ts();
1132 self.rebuild().unwrap();
1133 updated
1134 }
1135
1136 pub fn set_report_level(&mut self, level: ReportLevel) {
1137 self.report_level = level;
1138 }
1139}
1140
1141#[derive(Clone)]
1142struct RuleInfo {
1143 last_run_at: Timestamp,
1144 query: rule::Query,
1145 cached_plan: Option<CachedPlanInfo>,
1146 desc: Arc<str>,
1147}
1148
1149#[derive(Clone)]
1150struct CachedPlanInfo {
1151 plan: Arc<core_relations::CachedPlan>,
1152 atom_mapping: Vec<core_relations::AtomId>,
1155}
1156
1157#[derive(Clone)]
1158struct FunctionInfo {
1159 table: TableId,
1160 schema: Vec<ColumnTy>,
1161 incremental_rebuild_rules: Vec<RuleId>,
1162 nonincremental_rebuild_rule: RuleId,
1163 default_val: DefaultVal,
1164 can_subsume: bool,
1165 name: Arc<str>,
1166}
1167
1168impl FunctionInfo {
1169 fn ret_ty(&self) -> ColumnTy {
1170 self.schema.last().copied().unwrap()
1171 }
1172}
1173
1174#[derive(Copy, Clone)]
1176pub enum DefaultVal {
1177 FreshId,
1179 Fail,
1181 Const(Value),
1183}
1184
1185pub enum MergeFn {
1187 AssertEq,
1189 UnionId,
1191 Primitive(ExternalFunctionId, Vec<MergeFn>),
1194 Function(FunctionId, Vec<MergeFn>),
1197 Old,
1199 New,
1201 Const(Value),
1205}
1206
1207impl MergeFn {
1208 fn fill_deps(
1209 &self,
1210 egraph: &EGraph,
1211 read_deps: &mut IndexSet<TableId>,
1212 write_deps: &mut IndexSet<TableId>,
1213 ) {
1214 use MergeFn::*;
1215 match self {
1216 Primitive(_, args) => {
1217 args.iter()
1218 .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1219 }
1220 Function(func, args) => {
1221 read_deps.insert(egraph.funcs[*func].table);
1222 write_deps.insert(egraph.funcs[*func].table);
1223 args.iter()
1224 .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps));
1225 }
1226 UnionId if !egraph.tracing => {
1227 write_deps.insert(egraph.uf_table);
1228 }
1229 UnionId | AssertEq | Old | New | Const(..) => {}
1230 }
1231 }
1232
1233 fn to_callback(
1234 &self,
1235 schema_math: SchemaMath,
1236 function_name: &str,
1237 egraph: &mut EGraph,
1238 ) -> Box<core_relations::MergeFn> {
1239 assert!(
1240 !egraph.tracing || matches!(self, MergeFn::UnionId),
1241 "proofs aren't supported for non-union merge functions"
1242 );
1243
1244 let resolved = self.resolve(function_name, egraph);
1245
1246 Box::new(move |state, cur, new, out| {
1247 let timestamp = new[schema_math.ts_col()];
1248
1249 let mut changed = false;
1250
1251 let ret_val = {
1252 let cur = cur[schema_math.ret_val_col()];
1253 let new = new[schema_math.ret_val_col()];
1254 let out = resolved.run(state, cur, new, timestamp);
1255 changed |= cur != out;
1256 out
1257 };
1258
1259 let subsume = schema_math.subsume.then(|| {
1260 let cur = cur[schema_math.subsume_col()];
1261 let new = new[schema_math.subsume_col()];
1262 let out = combine_subsumed(cur, new);
1263 changed |= cur != out;
1264 out
1265 });
1266 let mut proof = None;
1267 if schema_math.tracing {
1268 let old_term = cur[schema_math.proof_id_col()];
1269 let new_term = new[schema_math.proof_id_col()];
1270 proof = Some(cmp::min(old_term, new_term));
1271 changed |= new_term < old_term;
1272 }
1273
1274 if changed {
1275 out.extend_from_slice(new);
1276 schema_math.write_table_row(
1277 out,
1278 RowVals {
1279 timestamp,
1280 proof,
1281 subsume,
1282 ret_val: Some(ret_val),
1283 },
1284 );
1285 }
1286
1287 changed
1288 })
1289 }
1290
1291 fn resolve(&self, function_name: &str, egraph: &mut EGraph) -> ResolvedMergeFn {
1292 match self {
1293 MergeFn::Const(v) => ResolvedMergeFn::Const(*v),
1294 MergeFn::Old => ResolvedMergeFn::Old,
1295 MergeFn::New => ResolvedMergeFn::New,
1296 MergeFn::AssertEq => ResolvedMergeFn::AssertEq {
1297 panic: egraph.new_panic(format!(
1298 "Illegal merge attempted for function {function_name}"
1299 )),
1300 },
1301 MergeFn::UnionId => ResolvedMergeFn::UnionId {
1302 uf_table: egraph.uf_table,
1303 tracing: egraph.tracing,
1304 },
1305 MergeFn::Primitive(prim, args) => ResolvedMergeFn::Primitive {
1310 prim: *prim,
1311 args: args
1312 .iter()
1313 .map(|arg| arg.resolve(function_name, egraph))
1314 .collect::<Vec<_>>(),
1315 panic: egraph.new_panic(format!(
1316 "Merge function for {function_name} primitive call failed"
1317 )),
1318 },
1319 MergeFn::Function(func, args) => {
1320 let func_info = &egraph.funcs[*func];
1321 assert_eq!(
1322 func_info.schema.len(),
1323 args.len() + 1,
1324 "Merge function for {function_name} must match function arity for {}",
1325 func_info.name
1326 );
1327 ResolvedMergeFn::Function {
1328 func: TableAction::new(egraph, *func),
1329 panic: egraph.new_panic(format!(
1330 "Lookup on {} failed in the merge function for {function_name}",
1331 func_info.name
1332 )),
1333 args: args
1334 .iter()
1335 .map(|arg| arg.resolve(function_name, egraph))
1336 .collect::<Vec<_>>(),
1337 }
1338 }
1339 }
1340 }
1341}
1342
1343enum ResolvedMergeFn {
1348 Const(Value),
1349 Old,
1350 New,
1351 AssertEq {
1352 panic: ExternalFunctionId,
1353 },
1354 UnionId {
1355 uf_table: TableId,
1356 tracing: bool,
1357 },
1358 Primitive {
1359 prim: ExternalFunctionId,
1360 args: Vec<ResolvedMergeFn>,
1361 panic: ExternalFunctionId,
1362 },
1363 Function {
1364 func: TableAction,
1365 args: Vec<ResolvedMergeFn>,
1366 panic: ExternalFunctionId,
1367 },
1368}
1369
1370impl ResolvedMergeFn {
1371 fn run(&self, state: &mut ExecutionState, cur: Value, new: Value, ts: Value) -> Value {
1372 match self {
1373 ResolvedMergeFn::Const(v) => *v,
1374 ResolvedMergeFn::Old => cur,
1375 ResolvedMergeFn::New => new,
1376 ResolvedMergeFn::AssertEq { panic } => {
1377 if cur != new {
1378 let res = state.call_external_func(*panic, &[]);
1379 assert_eq!(res, None);
1380 }
1381 cur
1382 }
1383 ResolvedMergeFn::UnionId { uf_table, tracing } => {
1384 if cur != new && !tracing {
1385 state.stage_insert(*uf_table, &[cur, new, ts]);
1388 std::cmp::min(cur, new)
1391 } else {
1392 cur
1393 }
1394 }
1395 ResolvedMergeFn::Primitive { prim, args, panic } => {
1400 let args = args
1401 .iter()
1402 .map(|arg| arg.run(state, cur, new, ts))
1403 .collect::<Vec<_>>();
1404
1405 match state.call_external_func(*prim, &args) {
1406 Some(result) => result,
1407 None => {
1408 let res = state.call_external_func(*panic, &[]);
1409 assert_eq!(res, None);
1410 cur
1411 }
1412 }
1413 }
1414 ResolvedMergeFn::Function { func, args, panic } => {
1415 if cur == new {
1417 return cur;
1418 }
1419
1420 let args = args
1421 .iter()
1422 .map(|arg| arg.run(state, cur, new, ts))
1423 .collect::<Vec<_>>();
1424
1425 func.lookup(state, &args).unwrap_or_else(|| {
1426 let res = state.call_external_func(*panic, &[]);
1427 assert_eq!(res, None);
1428 cur
1429 })
1430 }
1431 }
1432 }
1433}
1434
1435#[derive(Debug, PartialEq, Eq, Hash)]
1439pub struct TableAction {
1440 table: TableId,
1441 table_math: SchemaMath,
1442 default: Option<MergeVal>,
1443 timestamp: CounterId,
1444 scratch: Vec<Value>,
1445}
1446
1447impl Clone for TableAction {
1448 fn clone(&self) -> Self {
1449 Self {
1450 table: self.table,
1451 table_math: self.table_math,
1452 default: self.default,
1453 timestamp: self.timestamp,
1454 scratch: Vec::new(),
1455 }
1456 }
1457}
1458
1459impl TableAction {
1460 pub fn new(egraph: &EGraph, func: FunctionId) -> TableAction {
1463 assert!(!egraph.tracing, "proofs not supported yet");
1464
1465 let func_info = &egraph.funcs[func];
1466 TableAction {
1467 table: func_info.table,
1468 table_math: SchemaMath {
1469 func_cols: func_info.schema.len(),
1470 subsume: func_info.can_subsume,
1471 tracing: egraph.tracing,
1472 },
1473 default: match &func_info.default_val {
1474 DefaultVal::FreshId => Some(MergeVal::Counter(egraph.id_counter)),
1475 DefaultVal::Fail => None,
1476 DefaultVal::Const(val) => Some(MergeVal::Constant(*val)),
1477 },
1478 timestamp: egraph.timestamp_counter,
1479 scratch: Vec::new(),
1480 }
1481 }
1482
1483 pub fn lookup(&self, state: &mut ExecutionState, key: &[Value]) -> Option<Value> {
1487 match self.default {
1488 Some(default) => {
1489 let timestamp =
1490 MergeVal::Constant(Value::from_usize(state.read_counter(self.timestamp)));
1491 let mut merge_vals = SmallVec::<[MergeVal; 3]>::new();
1492 SchemaMath {
1493 func_cols: 1,
1494 ..self.table_math
1495 }
1496 .write_table_row(
1497 &mut merge_vals,
1498 RowVals {
1499 timestamp,
1500 proof: None,
1501 subsume: self
1502 .table_math
1503 .subsume
1504 .then_some(MergeVal::Constant(NOT_SUBSUMED)),
1505 ret_val: Some(default),
1506 },
1507 );
1508 Some(
1509 state.predict_val(self.table, key, merge_vals.iter().copied())
1510 [self.table_math.ret_val_col()],
1511 )
1512 }
1513 None => state
1514 .get_table(self.table)
1515 .get_row(key)
1516 .map(|row| row.vals[self.table_math.ret_val_col()]),
1517 }
1518 }
1519
1520 pub fn insert(&mut self, state: &mut ExecutionState, row: impl Iterator<Item = Value>) {
1522 let ts = Value::from_usize(state.read_counter(self.timestamp));
1523 self.scratch.clear();
1524 self.scratch.extend(row);
1525 self.table_math.write_table_row(
1526 &mut self.scratch,
1527 RowVals {
1528 timestamp: ts,
1529 proof: None,
1530 subsume: self.table_math.subsume.then_some(NOT_SUBSUMED),
1531 ret_val: None,
1532 },
1533 );
1534 state.stage_insert(self.table, &self.scratch);
1535 }
1536
1537 pub fn remove(&self, state: &mut ExecutionState, key: &[Value]) {
1539 state.stage_remove(self.table, key);
1540 }
1541
1542 pub fn subsume(&mut self, state: &mut ExecutionState, key: impl Iterator<Item = Value>) {
1544 let ts = Value::from_usize(state.read_counter(self.timestamp));
1545 self.scratch.clear();
1546 self.scratch.extend(key);
1547
1548 let ret_val = self
1549 .lookup(state, &self.scratch)
1550 .expect("subsume lookup failed");
1551
1552 self.table_math.write_table_row(
1553 &mut self.scratch,
1554 RowVals {
1555 timestamp: ts,
1556 proof: None,
1557 subsume: Some(SUBSUMED),
1558 ret_val: Some(ret_val),
1559 },
1560 );
1561 state.stage_insert(self.table, &self.scratch);
1562 }
1563}
1564
1565#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1567pub struct UnionAction {
1568 table: TableId,
1569 timestamp: CounterId,
1570}
1571
1572impl UnionAction {
1573 pub fn new(egraph: &EGraph) -> UnionAction {
1576 assert!(!egraph.tracing, "proofs not supported yet");
1577 UnionAction {
1578 table: egraph.uf_table,
1579 timestamp: egraph.timestamp_counter,
1580 }
1581 }
1582
1583 pub fn union(&self, state: &mut ExecutionState, x: Value, y: Value) {
1585 let ts = Value::from_usize(state.read_counter(self.timestamp));
1586 state.stage_insert(self.table, &[x, y, ts]);
1587 }
1588}
1589
1590fn run_rules_impl(
1591 db: &mut Database,
1592 rule_info: &mut DenseIdMapWithReuse<RuleId, RuleInfo>,
1593 rules: &[RuleId],
1594 next_ts: Timestamp,
1595 report_level: ReportLevel,
1596) -> Result<RuleSetReport> {
1597 for rule in rules {
1598 let info = &mut rule_info[*rule];
1599 if info.cached_plan.is_none() {
1600 info.cached_plan = Some(info.query.build_cached_plan(db, &info.desc)?);
1601 }
1602 }
1603 let mut rsb = db.new_rule_set();
1604 for rule in rules {
1605 let info = &mut rule_info[*rule];
1606 let cached_plan = info.cached_plan.as_ref().unwrap();
1607 info.query
1608 .add_rules_from_cached(&mut rsb, info.last_run_at, cached_plan)?;
1609 info.last_run_at = next_ts;
1610 }
1611 let ruleset = rsb.build();
1612 Ok(db.run_rule_set(&ruleset, report_level))
1613}
1614
1615#[inline(never)]
1619fn marker_incremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1620 f()
1621}
1622
1623#[inline(never)]
1624fn marker_nonincremental_rebuild<R>(f: impl FnOnce() -> R) -> R {
1625 f()
1626}
1627
1628pub type SideChannel<T> = Arc<Mutex<Option<T>>>;
1631
1632#[derive(Clone)]
1638struct GetFirstMatch(SideChannel<Vec<Value>>);
1639
1640impl ExternalFunction for GetFirstMatch {
1641 fn invoke(&self, _: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1642 let mut guard = self.0.lock().unwrap();
1643 if guard.is_some() {
1644 return None;
1645 }
1646 *guard = Some(args.to_vec());
1647 Some(Value::new(0))
1648 }
1649}
1650
1651struct LazyPanic<F>(Arc<Lazy<String, F>>, SideChannel<String>);
1662
1663impl<F: FnOnce() -> String + Send> ExternalFunction for LazyPanic<F> {
1664 fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1665 assert!(args.is_empty());
1666 state.trigger_early_stop();
1667 let mut guard = self.1.lock().unwrap();
1668 if guard.is_none() {
1669 *guard = Some(Lazy::force(&self.0).clone());
1670 }
1671 None
1672 }
1673}
1674
1675impl<F> Clone for LazyPanic<F> {
1676 fn clone(&self) -> Self {
1677 LazyPanic(self.0.clone(), self.1.clone())
1678 }
1679}
1680
1681#[derive(Clone)]
1686struct Panic(String, SideChannel<String>);
1687
1688impl EGraph {
1689 pub fn new_panic(&mut self, message: String) -> ExternalFunctionId {
1691 *self
1692 .panic_funcs
1693 .entry(message.to_string())
1694 .or_insert_with(|| {
1695 let panic = Panic(message, self.panic_message.clone());
1696 self.db.add_external_function(Box::new(panic))
1697 })
1698 }
1699
1700 pub fn new_panic_lazy(
1701 &mut self,
1702 message: impl FnOnce() -> String + Send + 'static,
1703 ) -> ExternalFunctionId {
1704 let lazy = Lazy::new(message);
1705 let panic = LazyPanic(Arc::new(lazy), self.panic_message.clone());
1706 self.db.add_external_function(Box::new(panic))
1707 }
1708}
1709
1710impl ExternalFunction for Panic {
1711 fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
1712 assert!(args.is_empty());
1714
1715 state.trigger_early_stop();
1716 let mut guard = self.1.lock().unwrap();
1717 if guard.is_none() {
1718 *guard = Some(self.0.clone());
1719 }
1720 None
1721 }
1722}
1723
1724#[derive(Error, Debug)]
1725enum ProofReconstructionError {
1726 #[error(
1727 "attempting to explain a row without tracing enabled. Try constructing with `EGraph::with_tracing`"
1728 )]
1729 TracingNotEnabled,
1730 #[error("attempting to construct a proof that {term1} = {term2}, but they are not equal")]
1731 EqualityExplanationOfUnequalTerms { term1: String, term2: String },
1732}
1733
1734fn incremental_rebuild(uf_size: usize, table_size: usize, parallel: bool) -> bool {
1737 if parallel {
1738 uf_size <= (table_size / 16)
1739 } else {
1740 uf_size <= (table_size / 8)
1741 }
1742}
1743
1744pub(crate) const SUBSUMED: Value = Value::new_const(1);
1745pub(crate) const NOT_SUBSUMED: Value = Value::new_const(0);
1746fn combine_subsumed(v1: Value, v2: Value) -> Value {
1747 std::cmp::max(v1, v2)
1748}
1749
1750#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1761struct SchemaMath {
1762 tracing: bool,
1764 subsume: bool,
1766 func_cols: usize,
1768}
1769
1770struct RowVals<T> {
1775 timestamp: T,
1777 proof: Option<T>,
1779 subsume: Option<T>,
1781 ret_val: Option<T>,
1784}
1785
1786#[derive(Clone, Debug)]
1788pub struct FunctionRow<'a> {
1789 pub vals: &'a [Value],
1790 pub subsumed: bool,
1791}
1792
1793impl SchemaMath {
1794 fn write_table_row<T: Clone>(
1795 &self,
1796 row: &mut impl HasResizeWith<T>,
1797 RowVals {
1798 timestamp,
1799 proof,
1800 subsume,
1801 ret_val,
1802 }: RowVals<T>,
1803 ) {
1804 row.resize_with(self.table_columns(), || timestamp.clone());
1805 row[self.ts_col()] = timestamp;
1806 if let Some(ret_val) = ret_val {
1807 row[self.ret_val_col()] = ret_val;
1808 }
1809 if let Some(proof_id) = proof {
1810 row[self.proof_id_col()] = proof_id;
1811 } else {
1812 assert!(
1813 !self.tracing,
1814 "proof_id must be provided if tracing is enabled"
1815 );
1816 }
1817 if let Some(subsume) = subsume {
1818 row[self.subsume_col()] = subsume;
1819 } else {
1820 assert!(
1821 !self.subsume,
1822 "subsume flag must be provided if subsumption is enabled"
1823 );
1824 }
1825 }
1826
1827 fn num_keys(&self) -> usize {
1828 self.func_cols - 1
1829 }
1830
1831 fn table_columns(&self) -> usize {
1832 self.func_cols + 1 + if self.tracing { 1 } else { 0 } + if self.subsume { 1 } else { 0 }
1833 }
1834
1835 #[track_caller]
1836 fn proof_id_col(&self) -> usize {
1837 assert!(self.tracing);
1838 self.func_cols + 1
1839 }
1840
1841 fn ret_val_col(&self) -> usize {
1842 self.func_cols - 1
1843 }
1844
1845 fn ts_col(&self) -> usize {
1846 self.func_cols
1847 }
1848
1849 #[track_caller]
1850 fn subsume_col(&self) -> usize {
1851 assert!(self.subsume);
1852 if self.tracing {
1853 self.func_cols + 2
1854 } else {
1855 self.func_cols + 1
1856 }
1857 }
1858}
1859
1860#[derive(Error, Debug)]
1861#[error("Panic: {0}")]
1862struct PanicError(String);
1863
1864trait HasResizeWith<T>:
1867 AsMut<[T]> + AsRef<[T]> + Index<usize, Output = T> + IndexMut<usize, Output = T>
1868{
1869 fn resize_with<F>(&mut self, new_size: usize, f: F)
1870 where
1871 F: FnMut() -> T;
1872}
1873
1874impl<T> HasResizeWith<T> for Vec<T> {
1875 fn resize_with<F>(&mut self, new_size: usize, f: F)
1876 where
1877 F: FnMut() -> T,
1878 {
1879 self.resize_with(new_size, f);
1880 }
1881}
1882
1883impl<T, A: smallvec::Array<Item = T>> HasResizeWith<T> for SmallVec<A> {
1884 fn resize_with<F>(&mut self, new_size: usize, f: F)
1885 where
1886 F: FnMut() -> T,
1887 {
1888 self.resize_with(new_size, f);
1889 }
1890}