1use std::{cmp::Ordering, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11 ColumnId, Constraint, CounterId, ExternalFunctionId, PlanStrategy, QueryBuilder,
12 RuleBuilder as CoreRuleBuilder, RuleSetBuilder, TableId, Value, WriteVal,
13};
14use crate::numeric_id::{DenseIdMap, NumericId, define_id};
15use anyhow::Context;
16use hashbrown::HashSet;
17use log::debug;
18use smallvec::SmallVec;
19use thiserror::Error;
20
21use crate::syntax::SourceSyntax;
22use crate::{CachedPlanInfo, NOT_SUBSUMED, RowVals, SUBSUMED, SchemaMath};
23use crate::{
24 ColumnTy, DefaultVal, EGraph, FunctionId, Result, RuleId, RuleInfo, Timestamp,
25 proof_spec::{ProofBuilder, RebuildVars},
26};
27
28define_id!(pub VariableId, u32, "A variable in an egglog query");
29define_id!(pub AtomId, u32, "an atom in an egglog query");
30pub(crate) type DstVar = core_relations::QueryEntry;
31
32impl VariableId {
33 fn to_var(self) -> Variable {
34 Variable {
35 id: self,
36 name: None,
37 }
38 }
39}
40
41#[derive(Clone, Debug, PartialEq, Eq, Hash)]
42pub struct Variable {
43 pub id: VariableId,
44 pub name: Option<Box<str>>,
45}
46
47#[derive(Debug, Error)]
48enum RuleBuilderError {
49 #[error("type mismatch: expected {expected:?}, got {got:?}")]
50 TypeMismatch { expected: ColumnTy, got: ColumnTy },
51 #[error("arity mismatch: expected {expected:?}, got {got:?}")]
52 ArityMismatch { expected: usize, got: usize },
53}
54
55#[derive(Clone)]
56struct VarInfo {
57 ty: ColumnTy,
58 name: Option<Box<str>>,
59 term_var: Variable,
62}
63
64#[derive(Clone, Debug, PartialEq, Eq, Hash)]
65pub enum QueryEntry {
66 Var(Variable),
67 Const {
68 val: Value,
69 ty: ColumnTy,
72 },
73}
74
75impl QueryEntry {
76 pub(crate) fn var(&self) -> Variable {
79 match self {
80 QueryEntry::Var(v) => v.clone(),
81 QueryEntry::Const { .. } => panic!("expected variable, found constant"),
82 }
83 }
84}
85
86impl From<Variable> for QueryEntry {
87 fn from(var: Variable) -> Self {
88 QueryEntry::Var(var)
89 }
90}
91
92#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
93pub enum Function {
94 Table(FunctionId),
95 Prim(ExternalFunctionId),
96}
97
98impl From<FunctionId> for Function {
99 fn from(f: FunctionId) -> Self {
100 Function::Table(f)
101 }
102}
103
104impl From<ExternalFunctionId> for Function {
105 fn from(f: ExternalFunctionId) -> Self {
106 Function::Prim(f)
107 }
108}
109
110trait Brc:
111 Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + dyn_clone::DynClone + Send + Sync
112{
113}
114impl<T: Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + Clone + Send + Sync> Brc for T {}
115dyn_clone::clone_trait_object!(Brc);
116type BuildRuleCallback = Box<dyn Brc>;
117
118#[derive(Clone)]
119pub(crate) struct Query {
120 uf_table: TableId,
121 id_counter: CounterId,
122 ts_counter: CounterId,
123 tracing: bool,
124 rule_id: RuleId,
125 vars: DenseIdMap<VariableId, VarInfo>,
126 atom_proofs: Vec<Variable>,
128 atoms: Vec<(TableId, Vec<QueryEntry>, SchemaMath)>,
129 build_reason: Option<BuildRuleCallback>,
131 add_rule: Vec<BuildRuleCallback>,
138 sole_focus: Option<usize>,
141 seminaive: bool,
142 plan_strategy: PlanStrategy,
143}
144
145pub struct RuleBuilder<'a> {
146 egraph: &'a mut EGraph,
147 proof_builder: ProofBuilder,
148 query: Query,
149}
150
151impl EGraph {
152 pub fn new_rule(&mut self, desc: &str, seminaive: bool) -> RuleBuilder<'_> {
155 let uf_table = self.uf_table;
156 let id_counter = self.id_counter;
157 let ts_counter = self.timestamp_counter;
158 let tracing = self.tracing;
159 let rule_id = self.rules.reserve_slot();
160 RuleBuilder {
161 egraph: self,
162 proof_builder: ProofBuilder::new(desc, rule_id),
163 query: Query {
164 uf_table,
165 id_counter,
166 ts_counter,
167 tracing,
168 rule_id,
169 seminaive,
170 build_reason: None,
171 sole_focus: None,
172 atom_proofs: Default::default(),
173 vars: Default::default(),
174 atoms: Default::default(),
175 add_rule: Default::default(),
176 plan_strategy: Default::default(),
177 },
178 }
179 }
180
181 pub fn free_rule(&mut self, id: RuleId) {
183 self.rules.take(id);
184 }
185}
186
187impl RuleBuilder<'_> {
188 fn add_callback(&mut self, cb: impl Brc + 'static) {
189 self.query.add_rule.push(Box::new(cb));
190 }
191
192 pub fn egraph(&self) -> &EGraph {
194 self.egraph
195 }
196
197 pub(crate) fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
198 self.query.plan_strategy = strategy;
199 }
200
201 pub(crate) fn lookup_uf(&mut self, entry: QueryEntry) -> Result<Variable> {
207 let res = self.new_var(ColumnTy::Id);
208 let uf_table = self.query.uf_table;
209 self.assert_has_ty(&entry, ColumnTy::Id)
210 .context("lookup_uf: ")?;
211 self.add_callback(move |inner, rb| {
212 let entry = inner.convert(&entry);
213 let res_inner = rb.lookup_with_default(uf_table, &[entry], entry, ColumnId::new(1))?;
214 inner.mapping.insert(res.id, res_inner.into());
215 Ok(())
216 });
217 Ok(res)
218 }
219
220 pub(crate) fn check_for_update(
226 &mut self,
227 lhs: &[QueryEntry],
228 rhs: &[QueryEntry],
229 ) -> Result<()> {
230 let lhs = SmallVec::<[QueryEntry; 4]>::from_iter(lhs.iter().cloned());
231 let rhs = SmallVec::<[QueryEntry; 4]>::from_iter(rhs.iter().cloned());
232 if lhs.len() != rhs.len() {
233 return Err(RuleBuilderError::ArityMismatch {
234 expected: lhs.len(),
235 got: rhs.len(),
236 }
237 .into());
238 }
239 lhs.iter().zip(rhs.iter()).try_for_each(|(l, r)| {
240 self.assert_same_ty(l, r).with_context(|| {
241 format!("check_for_update: {lhs:?} and {rhs:?}, mismatch between {l:?} and {r:?}")
242 })
243 })?;
244
245 self.add_callback(move |inner, rb| {
246 let lhs = inner.convert_all(&lhs);
247 let rhs = inner.convert_all(&rhs);
248 rb.assert_any_ne(&lhs, &rhs).context("check_for_update")
249 });
250 Ok(())
251 }
252
253 fn assert_same_ty(
254 &self,
255 l: &QueryEntry,
256 r: &QueryEntry,
257 ) -> std::result::Result<(), RuleBuilderError> {
258 match (l, r) {
259 (
260 QueryEntry::Var(Variable { id: v1, .. }),
261 QueryEntry::Var(Variable { id: v2, .. }),
262 ) => {
263 let ty1 = self.query.vars[*v1].ty;
264 let ty2 = self.query.vars[*v2].ty;
265 if ty1 != ty2 {
266 return Err(RuleBuilderError::TypeMismatch {
267 expected: ty1,
268 got: ty2,
269 });
270 }
271 }
272 (QueryEntry::Const { .. }, QueryEntry::Const { .. })
274 | (QueryEntry::Var { .. }, QueryEntry::Const { .. })
275 | (QueryEntry::Const { .. }, QueryEntry::Var { .. }) => {}
276 }
277 Ok(())
278 }
279
280 fn assert_has_ty(
281 &self,
282 entry: &QueryEntry,
283 ty: ColumnTy,
284 ) -> std::result::Result<(), RuleBuilderError> {
285 if let QueryEntry::Var(Variable { id: v, .. }) = entry {
286 let var_ty = self.query.vars[*v].ty;
287 if var_ty != ty {
288 return Err(RuleBuilderError::TypeMismatch {
289 expected: var_ty,
290 got: ty,
291 });
292 }
293 }
294 Ok(())
295 }
296
297 pub fn build(self) -> RuleId {
299 assert!(
300 !self.egraph.tracing,
301 "proofs are enabled: use `build_with_syntax` instead"
302 );
303 self.build_internal(None)
304 }
305
306 pub fn build_with_syntax(self, syntax: SourceSyntax) -> RuleId {
307 self.build_internal(Some(syntax))
308 }
309
310 pub(crate) fn build_internal(mut self, syntax: Option<SourceSyntax>) -> RuleId {
311 if self.query.atoms.len() == 1 {
312 self.query.plan_strategy = PlanStrategy::MinCover;
313 }
314 if let Some(syntax) = &syntax {
315 if self.egraph.tracing {
316 let cb = self
317 .proof_builder
318 .create_reason(syntax.clone(), self.egraph);
319 self.query.build_reason = Some(Box::new(move |bndgs, rb| {
320 let reason = cb(bndgs, rb)?;
321 bndgs.lhs_reason = Some(reason.into());
322 Ok(())
323 }));
324 }
325 }
326 let res = self.query.rule_id;
327 let info = RuleInfo {
328 last_run_at: Timestamp::new(0),
329 query: self.query,
330 cached_plan: None,
331 desc: self.proof_builder.rule_description,
332 };
333 debug!("created rule {res:?} / {}", info.desc);
334 self.egraph.rules.insert(res, info);
335 res
336 }
337
338 pub(crate) fn set_focus(&mut self, focus: usize) {
339 self.query.sole_focus = Some(focus);
340 }
341
342 pub fn new_var(&mut self, ty: ColumnTy) -> Variable {
344 let res = self.query.vars.next_id();
345 let var = Variable {
346 id: res,
347 name: None,
348 };
349 self.query.vars.push(VarInfo {
350 ty,
351 name: None,
352 term_var: var.clone(),
353 });
354 var
355 }
356
357 pub fn new_var_named(&mut self, ty: ColumnTy, name: &str) -> QueryEntry {
362 let id = self.query.vars.next_id();
363 let var = Variable {
364 id,
365 name: Some(name.into()),
366 };
367 self.query.vars.push(VarInfo {
368 ty,
369 name: Some(name.into()),
370 term_var: var.clone(),
371 });
372 QueryEntry::Var(var)
373 }
374
375 pub(crate) fn add_atom_with_timestamp_and_func(
383 &mut self,
384 table: TableId,
385 func: Option<FunctionId>,
386 subsume_entry: Option<QueryEntry>,
387 entries: &[QueryEntry],
388 ) -> AtomId {
389 let mut atom = entries.to_vec();
390 let schema_math = if let Some(func) = func {
391 let info = &self.egraph.funcs[func];
392 assert_eq!(info.schema.len(), entries.len());
393 SchemaMath {
394 tracing: self.egraph.tracing,
395 subsume: info.can_subsume,
396 func_cols: info.schema.len(),
397 }
398 } else {
399 SchemaMath {
400 tracing: self.egraph.tracing,
401 subsume: subsume_entry.is_some(),
402 func_cols: entries.len(),
403 }
404 };
405 schema_math.write_table_row(
406 &mut atom,
407 RowVals {
408 timestamp: self.new_var(ColumnTy::Id).into(),
409 proof: self
410 .egraph
411 .tracing
412 .then(|| self.new_var(ColumnTy::Id).into()),
413 subsume: if schema_math.subsume {
414 Some(subsume_entry.unwrap_or_else(|| self.new_var(ColumnTy::Id).into()))
415 } else {
416 None
417 },
418 ret_val: None,
419 },
420 );
421 let res = AtomId::from_usize(self.query.atoms.len());
422 if self.egraph.tracing {
423 let proof_var = atom[schema_math.proof_id_col()].var();
424 self.proof_builder
425 .term_vars
426 .insert(res, proof_var.clone().into());
427 if let Some(QueryEntry::Var(Variable { id, .. })) = entries.last() {
428 if table != self.egraph.uf_table {
429 self.query.vars[*id].term_var = proof_var.clone();
432 }
433 }
434 self.query.atom_proofs.push(proof_var);
435 }
436 self.query.atoms.push((table, atom, schema_math));
437 res
438 }
439
440 pub fn call_external_func(
441 &mut self,
442 func: ExternalFunctionId,
443 args: &[QueryEntry],
444 ret_ty: ColumnTy,
445 panic_msg: impl FnOnce() -> String + 'static + Send,
446 ) -> Variable {
447 let args = args.to_vec();
448 let res = self.new_var(ret_ty);
449 let panic_fn = self.egraph.new_panic_lazy(panic_msg);
451 self.query.add_rule.push(Box::new(move |inner, rb| {
452 let args = inner.convert_all(&args);
453 let var = rb.call_external_with_fallback(func, &args, panic_fn, &[])?;
454 inner.mapping.insert(res.id, var.into());
455 Ok(())
456 }));
457 res
458 }
459
460 pub fn query_table(
464 &mut self,
465 func: FunctionId,
466 entries: &[QueryEntry],
467 is_subsumed: Option<bool>,
468 ) -> Result<AtomId> {
469 let info = &self.egraph.funcs[func];
470 let schema = &info.schema;
471 if schema.len() != entries.len() {
472 return Err(anyhow::Error::from(RuleBuilderError::ArityMismatch {
473 expected: schema.len(),
474 got: entries.len(),
475 }))
476 .with_context(|| format!("query_table: mismatch between {entries:?} and {schema:?}"));
477 }
478 entries
479 .iter()
480 .zip(schema.iter())
481 .try_for_each(|(entry, ty)| {
482 self.assert_has_ty(entry, *ty)
483 .with_context(|| format!("query_table: mismatch between {entry:?} and {ty:?}"))
484 })?;
485 Ok(self.add_atom_with_timestamp_and_func(
486 info.table,
487 Some(func),
488 is_subsumed.map(|b| QueryEntry::Const {
489 val: match b {
490 true => SUBSUMED,
491 false => NOT_SUBSUMED,
492 },
493 ty: ColumnTy::Id,
494 }),
495 entries,
496 ))
497 }
498
499 pub fn query_prim(
502 &mut self,
503 func: ExternalFunctionId,
504 entries: &[QueryEntry],
505 _ret_ty: ColumnTy,
507 ) -> Result<()> {
508 let entries = entries.to_vec();
509 self.query.add_rule.push(Box::new(move |inner, rb| {
510 let mut dst_vars = inner.convert_all(&entries);
511 let expected = dst_vars.pop().expect("must specify a return value");
512 let var = rb.call_external(func, &dst_vars)?;
513 match entries.last().unwrap() {
514 QueryEntry::Var(Variable { id, .. }) if !inner.grounded.contains(id) => {
515 inner.mapping.insert(*id, var.into());
516 inner.grounded.insert(*id);
517 }
518 _ => rb.assert_eq(var.into(), expected),
519 }
520 Ok(())
521 }));
522 Ok(())
523 }
524
525 pub fn subsume(&mut self, func: FunctionId, entries: &[QueryEntry]) {
529 let ret = self.lookup_with_subsumed(
531 func,
532 entries,
533 QueryEntry::Const {
534 val: SUBSUMED,
535 ty: ColumnTy::Id,
536 },
537 || "subsumed a nonextestent row!".to_string(),
538 );
539 let info = &self.egraph.funcs[func];
540 let schema_math = SchemaMath {
541 tracing: self.egraph.tracing,
542 subsume: info.can_subsume,
543 func_cols: info.schema.len(),
544 };
545 assert!(info.can_subsume);
546 assert_eq!(entries.len() + 1, info.schema.len());
547 let entries = entries.to_vec();
548 let table = info.table;
549
550 let ret: QueryEntry = ret.into();
551 self.add_callback(move |inner, rb| {
552 let mut dst_entries = inner.convert_all(&entries);
555 let cur_subsume_val = rb.lookup(
556 table,
557 &dst_entries,
558 ColumnId::from_usize(schema_math.subsume_col()),
559 )?;
560 let cur_proof_val = if schema_math.tracing {
561 Some(DstVar::from(rb.lookup(
562 table,
563 &dst_entries,
564 ColumnId::from_usize(schema_math.proof_id_col()),
565 )?))
566 } else {
567 None
568 };
569 schema_math.write_table_row(
570 &mut dst_entries,
571 RowVals {
572 timestamp: inner.next_ts(),
573 proof: cur_proof_val,
574 subsume: Some(SUBSUMED.into()),
575 ret_val: Some(inner.convert(&ret)),
576 },
577 );
578 rb.insert_if_eq(
579 table,
580 cur_subsume_val.into(),
581 NOT_SUBSUMED.into(),
582 &dst_entries,
583 )?;
584 Ok(())
585 });
586 }
587
588 pub(crate) fn lookup_with_subsumed(
589 &mut self,
590 func: FunctionId,
591 entries: &[QueryEntry],
592 subsumed: QueryEntry,
593 panic_msg: impl FnOnce() -> String + Send + 'static,
594 ) -> Variable {
595 let entries = entries.to_vec();
596 let info = &self.egraph.funcs[func];
597 let res = self
598 .query
599 .vars
600 .push(VarInfo {
601 ty: info.ret_ty(),
602 name: None,
603 term_var: self.query.vars.next_id().to_var(),
604 })
605 .to_var();
606 let table = info.table;
607 let id_counter = self.query.id_counter;
608 let schema_math = SchemaMath {
609 tracing: self.egraph.tracing,
610 subsume: info.can_subsume,
611 func_cols: info.schema.len(),
612 };
613 let cb: BuildRuleCallback = match info.default_val {
614 DefaultVal::Const(_) | DefaultVal::FreshId => {
615 let (wv, wv_ref): (WriteVal, WriteVal) = match &info.default_val {
616 DefaultVal::Const(c) => ((*c).into(), (*c).into()),
617 DefaultVal::FreshId => (
618 WriteVal::IncCounter(id_counter),
619 WriteVal::CurrentVal(schema_math.ret_val_col()),
623 ),
624 _ => unreachable!(),
625 };
626 let get_write_vals = move |inner: &mut Bindings| {
627 let mut write_vals = SmallVec::<[WriteVal; 4]>::new();
628 for i in schema_math.num_keys()..schema_math.table_columns() {
629 if i == schema_math.ts_col() {
630 write_vals.push(inner.next_ts().into());
631 } else if i == schema_math.ret_val_col() {
632 write_vals.push(wv);
633 } else if schema_math.tracing && i == schema_math.proof_id_col() {
634 write_vals.push(wv_ref);
635 } else if schema_math.subsume && i == schema_math.subsume_col() {
636 write_vals.push(inner.convert(&subsumed).into())
637 } else {
638 unreachable!()
639 }
640 }
641 write_vals
642 };
643
644 if self.egraph.tracing {
645 let term_var = self.new_var(ColumnTy::Id);
646 self.query.vars[res.id].term_var = term_var.clone();
647 let ts_var = self.new_var(ColumnTy::Id);
648 let mut insert_entries = entries.to_vec();
649 insert_entries.push(res.clone().into());
650 let add_proof =
651 self.proof_builder
652 .new_row(func, insert_entries, term_var.id, self.egraph);
653 Box::new(move |inner, rb| {
654 let write_vals = get_write_vals(inner);
655 let dst_vars = inner.convert_all(&entries);
656 let var = rb.lookup_or_insert(
665 table,
666 &dst_vars,
667 &write_vals,
668 ColumnId::from_usize(schema_math.ret_val_col()),
669 )?;
670 let ts = rb.lookup_or_insert(
671 table,
672 &dst_vars,
673 &write_vals,
674 ColumnId::from_usize(schema_math.ts_col()),
675 )?;
676 let term = rb.lookup_or_insert(
677 table,
678 &dst_vars,
679 &write_vals,
680 ColumnId::from_usize(schema_math.proof_id_col()),
681 )?;
682 inner.mapping.insert(term_var.id, term.into());
683 inner.mapping.insert(res.id, var.into());
684 inner.mapping.insert(ts_var.id, ts.into());
685 rb.assert_eq(var.into(), term.into());
686 add_proof(inner, rb)?;
690 Ok(())
691 })
692 } else {
693 Box::new(move |inner, rb| {
694 let write_vals = get_write_vals(inner);
695 let dst_vars = inner.convert_all(&entries);
696 let var = rb.lookup_or_insert(
697 table,
698 &dst_vars,
699 &write_vals,
700 ColumnId::from_usize(schema_math.ret_val_col()),
701 )?;
702 inner.mapping.insert(res.id, var.into());
703 Ok(())
704 })
705 }
706 }
707 DefaultVal::Fail => {
708 let panic_func = self.egraph.new_panic_lazy(panic_msg);
709 if self.egraph.tracing {
710 let term_var = self.new_var(ColumnTy::Id);
711 Box::new(move |inner, rb| {
712 let dst_vars = inner.convert_all(&entries);
713 let var = rb.lookup_with_fallback(
714 table,
715 &dst_vars,
716 ColumnId::from_usize(schema_math.ret_val_col()),
717 panic_func,
718 &[],
719 )?;
720 let term = rb.lookup(
721 table,
722 &dst_vars,
723 ColumnId::from_usize(schema_math.proof_id_col()),
724 )?;
725 inner.mapping.insert(res.id, var.into());
726 inner.mapping.insert(term_var.id, term.into());
727 Ok(())
728 })
729 } else {
730 Box::new(move |inner, rb| {
731 let dst_vars = inner.convert_all(&entries);
732 let var = rb.lookup_with_fallback(
733 table,
734 &dst_vars,
735 ColumnId::from_usize(schema_math.ret_val_col()),
736 panic_func,
737 &[],
738 )?;
739 inner.mapping.insert(res.id, var.into());
740 Ok(())
741 })
742 }
743 }
744 };
745 self.query.add_rule.push(cb);
746 res
747 }
748
749 pub fn lookup(
755 &mut self,
756 func: FunctionId,
757 entries: &[QueryEntry],
758 panic_msg: impl FnOnce() -> String + Send + 'static,
759 ) -> Variable {
760 self.lookup_with_subsumed(
761 func,
762 entries,
763 QueryEntry::Const {
764 val: NOT_SUBSUMED,
765 ty: ColumnTy::Id,
766 },
767 panic_msg,
768 )
769 }
770
771 pub fn union(&mut self, mut l: QueryEntry, mut r: QueryEntry) {
773 let cb: BuildRuleCallback = if self.query.tracing {
774 for entry in [&mut l, &mut r] {
777 if let QueryEntry::Var(v) = entry {
778 *v = self.query.vars[v.id].term_var.clone();
779 }
780 }
781 Box::new(move |inner, rb| {
782 let l = inner.convert(&l);
783 let r = inner.convert(&r);
784 let proof = inner.lhs_reason.expect("reason must be set");
785 rb.insert(inner.uf_table, &[l, r, inner.next_ts(), proof])
786 .context("union")
787 })
788 } else {
789 Box::new(move |inner, rb| {
790 let l = inner.convert(&l);
791 let r = inner.convert(&r);
792 rb.insert(inner.uf_table, &[l, r, inner.next_ts()])
793 .context("union")
794 })
795 };
796 self.query.add_rule.push(cb);
797 }
798
799 pub(crate) fn rebuild_row(
808 &mut self,
809 func: FunctionId,
810 before: &[QueryEntry],
811 after: &[QueryEntry],
812 subsume_var: Option<Variable>,
815 ) {
816 assert_eq!(before.len(), after.len());
817 self.remove(func, &before[..before.len() - 1]);
818 if !self.egraph.tracing {
819 if let Some(subsume_var) = subsume_var {
820 self.set_with_subsume(func, after, QueryEntry::Var(subsume_var));
821 } else {
822 self.set(func, after);
823 }
824 return;
825 }
826 let table = self.egraph.funcs[func].table;
827 let term_var = self.new_var(ColumnTy::Id);
828 let reason_var = self.new_var(ColumnTy::Id);
829 let before_id = before.last().unwrap().var();
830 let before_term = &self.query.vars[before_id.id].term_var;
831
832 let before_term_id = before_term.id;
833 let term_var_id = term_var.id;
834 let reason_var_id = reason_var.id;
835
836 debug_assert_ne!(before_term, &before_id);
837 let add_proof = self.proof_builder.rebuild_proof(
838 func,
839 after,
840 RebuildVars {
841 before_term: before_term.clone(),
842 new_term: term_var,
843 reason: reason_var,
844 },
845 self.egraph,
846 );
847 let after = SmallVec::<[_; 4]>::from_iter(after.iter().cloned());
848 let uf_table = self.query.uf_table;
849 let info = &self.egraph.funcs[func];
850 let schema_math = SchemaMath {
851 tracing: self.egraph.tracing,
852 subsume: info.can_subsume,
853 func_cols: info.schema.len(),
854 };
855
856 self.query.add_rule.push(Box::new(move |inner, rb| {
857 add_proof(inner, rb)?;
858 let mut dst_vars = inner.convert_all(&after);
859 schema_math.write_table_row(
860 &mut dst_vars,
861 RowVals {
862 timestamp: inner.next_ts(),
863 proof: Some(inner.mapping[term_var_id]),
864 subsume: subsume_var.clone().map(|v| inner.mapping[v.id]),
865 ret_val: None, },
867 );
868 rb.insert(
871 uf_table,
872 &[
873 inner.mapping[before_term_id],
874 inner.mapping[term_var_id],
875 inner.next_ts(),
876 inner.mapping[reason_var_id],
877 ],
878 )
879 .context("rebuild_row_uf")?;
880 rb.insert(table, &dst_vars).context("rebuild_row_table")
881 }));
882 }
883
884 pub fn set(&mut self, func: FunctionId, entries: &[QueryEntry]) {
886 self.set_with_subsume(
887 func,
888 entries,
889 QueryEntry::Const {
890 val: NOT_SUBSUMED,
891 ty: ColumnTy::Id,
892 },
893 );
894 }
895
896 pub(crate) fn set_with_subsume(
897 &mut self,
898 func: FunctionId,
899 entries: &[QueryEntry],
900 subsume_entry: QueryEntry,
901 ) {
902 let info = &self.egraph.funcs[func];
903 let table = info.table;
904 let entries = entries.to_vec();
905 let schema_math = SchemaMath {
906 tracing: self.egraph.tracing,
907 subsume: info.can_subsume,
908 func_cols: info.schema.len(),
909 };
910 if self.egraph.tracing {
911 let res = self.lookup(func, &entries[0..entries.len() - 1], || {
912 "lookup failed during proof-enabled set; this is an internal proofs bug".to_string()
913 });
914 let res_entry = res.clone().into();
915 self.union(res.into(), entries.last().unwrap().clone());
916 if schema_math.subsume {
917 self.add_callback(move |inner, rb| {
919 let mut dst_vars = inner.convert_all(&entries);
920 let proof_var = rb.lookup(
921 table,
922 &dst_vars[0..schema_math.num_keys()],
923 ColumnId::from_usize(schema_math.proof_id_col()),
924 )?;
925 schema_math.write_table_row(
926 &mut dst_vars,
927 RowVals {
928 timestamp: inner.next_ts(),
929 proof: Some(proof_var.into()),
930 subsume: Some(inner.convert(&subsume_entry)),
931 ret_val: Some(inner.convert(&res_entry)),
932 },
933 );
934 rb.insert(table, &dst_vars).context("set")
935 });
936 }
937 } else {
938 self.query.add_rule.push(Box::new(move |inner, rb| {
939 let mut dst_vars = inner.convert_all(&entries);
940 schema_math.write_table_row(
941 &mut dst_vars,
942 RowVals {
943 timestamp: inner.next_ts(),
944 proof: None, subsume: schema_math.subsume.then(|| inner.convert(&subsume_entry)),
946 ret_val: None, },
948 );
949 rb.insert(table, &dst_vars).context("set")
950 }));
951 };
952 }
953
954 pub fn remove(&mut self, table: FunctionId, entries: &[QueryEntry]) {
956 let table = self.egraph.funcs[table].table;
957 let entries = entries.to_vec();
958 let cb: BuildRuleCallback = Box::new(move |inner, rb| {
959 let dst_vars = inner.convert_all(&entries);
960 rb.remove(table, &dst_vars).context("remove")
961 });
962 self.query.add_rule.push(cb);
963 }
964
965 pub fn panic(&mut self, message: String) {
967 let panic = self.egraph.new_panic(message.clone());
968 let ret_ty = ColumnTy::Id;
969 let res = self.new_var(ret_ty);
970 self.query.add_rule.push(Box::new(move |inner, rb| {
971 let var = rb.call_external(panic, &[])?;
972 inner.mapping.insert(res.id, var.into());
973 Ok(())
974 }));
975 }
976}
977
978impl Query {
979 fn query_state<'a, 'outer>(
980 &self,
981 rsb: &'a mut RuleSetBuilder<'outer>,
982 ) -> (QueryBuilder<'outer, 'a>, Bindings) {
983 let mut qb = rsb.new_rule();
984 qb.set_plan_strategy(self.plan_strategy);
985 let mut inner = Bindings {
986 uf_table: self.uf_table,
987 next_ts: None,
988 lhs_reason: None,
989 mapping: Default::default(),
990 grounded: Default::default(),
991 };
992 for (var, info) in self.vars.iter() {
993 let new_var = match info.name.as_ref() {
994 Some(name) => qb.new_var_named(name),
995 None => qb.new_var(),
996 };
997 inner.mapping.insert(var, DstVar::Var(new_var));
998 }
999 (qb, inner)
1000 }
1001
1002 fn run_rules_and_build(
1003 &self,
1004 qb: QueryBuilder,
1005 mut inner: Bindings,
1006 desc: &str,
1007 ) -> Result<core_relations::RuleId> {
1008 let mut rb = qb.build();
1009 inner.next_ts = Some(rb.read_counter(self.ts_counter).into());
1010 if let Some(build_reason) = &self.build_reason {
1012 build_reason(&mut inner, &mut rb)?;
1013 }
1014 self.add_rule
1015 .iter()
1016 .try_for_each(|f| f(&mut inner, &mut rb))?;
1017 Ok(rb.build_with_description(desc))
1018 }
1019
1020 pub(crate) fn build_cached_plan(
1021 &self,
1022 db: &mut core_relations::Database,
1023 desc: &str,
1024 ) -> Result<CachedPlanInfo> {
1025 let mut rsb = RuleSetBuilder::new(db);
1026 let (mut qb, mut inner) = self.query_state(&mut rsb);
1027 let mut atom_mapping = Vec::with_capacity(self.atoms.len());
1028 for (table, entries, _schema_info) in &self.atoms {
1029 atom_mapping.push(add_atom(&mut qb, *table, entries, &[], &mut inner)?);
1030 }
1031 let rule_id = self.run_rules_and_build(qb, inner, desc)?;
1032 let rs = rsb.build();
1033 let plan = Arc::new(rs.build_cached_plan(rule_id));
1034 Ok(CachedPlanInfo { plan, atom_mapping })
1035 }
1036
1037 pub(crate) fn add_rules_from_cached(
1044 &self,
1045 rsb: &mut RuleSetBuilder,
1046 mid_ts: Timestamp,
1047 cached_plan: &CachedPlanInfo,
1048 ) -> Result<()> {
1049 if !self.seminaive || (self.atoms.is_empty() && mid_ts == Timestamp::new(0)) {
1052 rsb.add_rule_from_cached_plan(&cached_plan.plan, &[]);
1053 return Ok(());
1054 }
1055 if let Some(focus_atom) = self.sole_focus {
1056 let (_, _, schema_info) = &self.atoms[focus_atom];
1058 let ts_col = ColumnId::from_usize(schema_info.ts_col());
1059 rsb.add_rule_from_cached_plan(
1060 &cached_plan.plan,
1061 &[(
1062 cached_plan.atom_mapping[focus_atom],
1063 Constraint::GeConst {
1064 col: ts_col,
1065 val: mid_ts.to_value(),
1066 },
1067 )],
1068 );
1069 return Ok(());
1070 }
1071 let mut constraints: Vec<(core_relations::AtomId, Constraint)> =
1073 Vec::with_capacity(self.atoms.len());
1074 'outer: for focus_atom in 0..self.atoms.len() {
1075 for (i, (_, _, schema_info)) in self.atoms.iter().enumerate() {
1076 let ts_col = ColumnId::from_usize(schema_info.ts_col());
1077 match i.cmp(&focus_atom) {
1078 Ordering::Less => {
1079 if mid_ts == Timestamp::new(0) {
1080 continue 'outer;
1081 }
1082 constraints.push((
1083 cached_plan.atom_mapping[i],
1084 Constraint::LtConst {
1085 col: ts_col,
1086 val: mid_ts.to_value(),
1087 },
1088 ));
1089 }
1090 Ordering::Equal => constraints.push((
1091 cached_plan.atom_mapping[i],
1092 Constraint::GeConst {
1093 col: ts_col,
1094 val: mid_ts.to_value(),
1095 },
1096 )),
1097 Ordering::Greater => {}
1098 };
1099 }
1100 rsb.add_rule_from_cached_plan(&cached_plan.plan, &constraints);
1101 constraints.clear();
1102 }
1103 Ok(())
1104 }
1105}
1106
1107pub(crate) struct Bindings {
1110 uf_table: TableId,
1111 next_ts: Option<DstVar>,
1112 pub(crate) lhs_reason: Option<DstVar>,
1115 pub(crate) mapping: DenseIdMap<VariableId, DstVar>,
1116 grounded: HashSet<VariableId>,
1117}
1118
1119impl Bindings {
1120 pub(crate) fn next_ts(&self) -> DstVar {
1121 self.next_ts
1122 .expect("ts_var should only be used in RHS of the rule")
1123 }
1124 pub(crate) fn convert(&self, entry: &QueryEntry) -> DstVar {
1125 match entry {
1126 QueryEntry::Var(Variable { id: v, .. }) => self.mapping[*v],
1127 QueryEntry::Const { val, .. } => DstVar::Const(*val),
1128 }
1129 }
1130 pub(crate) fn convert_all(&self, entries: &[QueryEntry]) -> SmallVec<[DstVar; 4]> {
1131 entries.iter().map(|e| self.convert(e)).collect()
1132 }
1133}
1134
1135fn add_atom(
1136 qb: &mut QueryBuilder,
1137 table: TableId,
1138 entries: &[QueryEntry],
1139 constraints: &[Constraint],
1140 inner: &mut Bindings,
1141) -> Result<core_relations::AtomId> {
1142 for entry in entries {
1143 if let QueryEntry::Var(Variable { id, .. }) = entry {
1144 inner.grounded.insert(*id);
1145 }
1146 }
1147 let vars = inner.convert_all(entries);
1148 Ok(qb.add_atom(table, &vars, constraints)?)
1149}