1use std::{cmp::Ordering, sync::Arc};
8
9use crate::core_relations;
10use crate::core_relations::{
11 ColumnId, Constraint, CounterId, ExternalFunctionId, PlanStrategy, QueryBuilder,
12 RuleBuilder as CoreRuleBuilder, RuleSetBuilder, TableId, Value, WriteVal,
13};
14use crate::numeric_id::{DenseIdMap, NumericId, define_id};
15use anyhow::Context;
16use hashbrown::HashSet;
17use log::debug;
18use smallvec::SmallVec;
19use thiserror::Error;
20
21use crate::syntax::SourceSyntax;
22use crate::{CachedPlanInfo, NOT_SUBSUMED, RowVals, SUBSUMED, SchemaMath};
23use crate::{
24 ColumnTy, DefaultVal, EGraph, FunctionId, Result, RuleId, RuleInfo, Timestamp,
25 proof_spec::{ProofBuilder, RebuildVars},
26};
27
28define_id!(pub Variable, u32, "A variable in an egglog query");
29define_id!(pub AtomId, u32, "an atom in an egglog query");
30pub(crate) type DstVar = core_relations::QueryEntry;
31
32#[derive(Debug, Error)]
33enum RuleBuilderError {
34 #[error("type mismatch: expected {expected:?}, got {got:?}")]
35 TypeMismatch { expected: ColumnTy, got: ColumnTy },
36 #[error("arity mismatch: expected {expected:?}, got {got:?}")]
37 ArityMismatch { expected: usize, got: usize },
38}
39
40#[derive(Clone)]
41struct VarInfo {
42 ty: ColumnTy,
43 term_var: Variable,
46}
47
48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
49pub enum QueryEntry {
50 Var {
51 id: Variable,
52 name: Option<Box<str>>,
53 },
54 Const {
55 val: Value,
56 ty: ColumnTy,
59 },
60}
61
62impl QueryEntry {
63 pub(crate) fn var(&self) -> Variable {
66 match self {
67 QueryEntry::Var { id, .. } => *id,
68 QueryEntry::Const { .. } => panic!("expected variable, found constant"),
69 }
70 }
71}
72
73impl From<Variable> for QueryEntry {
74 fn from(id: Variable) -> Self {
75 QueryEntry::Var { id, name: None }
76 }
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
80pub enum Function {
81 Table(FunctionId),
82 Prim(ExternalFunctionId),
83}
84
85impl From<FunctionId> for Function {
86 fn from(f: FunctionId) -> Self {
87 Function::Table(f)
88 }
89}
90
91impl From<ExternalFunctionId> for Function {
92 fn from(f: ExternalFunctionId) -> Self {
93 Function::Prim(f)
94 }
95}
96
97trait Brc:
98 Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + dyn_clone::DynClone + Send + Sync
99{
100}
101impl<T: Fn(&mut Bindings, &mut CoreRuleBuilder) -> Result<()> + Clone + Send + Sync> Brc for T {}
102dyn_clone::clone_trait_object!(Brc);
103type BuildRuleCallback = Box<dyn Brc>;
104
105#[derive(Clone)]
106pub(crate) struct Query {
107 uf_table: TableId,
108 id_counter: CounterId,
109 ts_counter: CounterId,
110 tracing: bool,
111 rule_id: RuleId,
112 vars: DenseIdMap<Variable, VarInfo>,
113 atom_proofs: Vec<Variable>,
115 atoms: Vec<(TableId, Vec<QueryEntry>, SchemaMath)>,
116 build_reason: Option<BuildRuleCallback>,
118 add_rule: Vec<BuildRuleCallback>,
125 sole_focus: Option<usize>,
128 seminaive: bool,
129 plan_strategy: PlanStrategy,
130}
131
132pub struct RuleBuilder<'a> {
133 egraph: &'a mut EGraph,
134 proof_builder: ProofBuilder,
135 query: Query,
136}
137
138impl EGraph {
139 pub fn new_rule(&mut self, desc: &str, seminaive: bool) -> RuleBuilder<'_> {
142 let uf_table = self.uf_table;
143 let id_counter = self.id_counter;
144 let ts_counter = self.timestamp_counter;
145 let tracing = self.tracing;
146 let rule_id = self.rules.reserve_slot();
147 RuleBuilder {
148 egraph: self,
149 proof_builder: ProofBuilder::new(desc, rule_id),
150 query: Query {
151 uf_table,
152 id_counter,
153 ts_counter,
154 tracing,
155 rule_id,
156 seminaive,
157 build_reason: None,
158 sole_focus: None,
159 atom_proofs: Default::default(),
160 vars: Default::default(),
161 atoms: Default::default(),
162 add_rule: Default::default(),
163 plan_strategy: Default::default(),
164 },
165 }
166 }
167
168 pub fn free_rule(&mut self, id: RuleId) {
170 self.rules.take(id);
171 }
172}
173
174impl RuleBuilder<'_> {
175 fn add_callback(&mut self, cb: impl Brc + 'static) {
176 self.query.add_rule.push(Box::new(cb));
177 }
178
179 pub fn egraph(&self) -> &EGraph {
181 self.egraph
182 }
183
184 pub(crate) fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
185 self.query.plan_strategy = strategy;
186 }
187
188 pub(crate) fn lookup_uf(&mut self, entry: QueryEntry) -> Result<Variable> {
194 let res = self.new_var(ColumnTy::Id);
195 let uf_table = self.query.uf_table;
196 self.assert_has_ty(&entry, ColumnTy::Id)
197 .context("lookup_uf: ")?;
198 self.add_callback(move |inner, rb| {
199 let entry = inner.convert(&entry);
200 let res_inner = rb.lookup_with_default(uf_table, &[entry], entry, ColumnId::new(1))?;
201 inner.mapping.insert(res, res_inner.into());
202 Ok(())
203 });
204 Ok(res)
205 }
206
207 pub(crate) fn check_for_update(
213 &mut self,
214 lhs: &[QueryEntry],
215 rhs: &[QueryEntry],
216 ) -> Result<()> {
217 let lhs = SmallVec::<[QueryEntry; 4]>::from_iter(lhs.iter().cloned());
218 let rhs = SmallVec::<[QueryEntry; 4]>::from_iter(rhs.iter().cloned());
219 if lhs.len() != rhs.len() {
220 return Err(RuleBuilderError::ArityMismatch {
221 expected: lhs.len(),
222 got: rhs.len(),
223 }
224 .into());
225 }
226 lhs.iter().zip(rhs.iter()).try_for_each(|(l, r)| {
227 self.assert_same_ty(l, r).with_context(|| {
228 format!("check_for_update: {lhs:?} and {rhs:?}, mismatch between {l:?} and {r:?}")
229 })
230 })?;
231
232 self.add_callback(move |inner, rb| {
233 let lhs = inner.convert_all(&lhs);
234 let rhs = inner.convert_all(&rhs);
235 rb.assert_any_ne(&lhs, &rhs).context("check_for_update")
236 });
237 Ok(())
238 }
239
240 fn assert_same_ty(
241 &self,
242 l: &QueryEntry,
243 r: &QueryEntry,
244 ) -> std::result::Result<(), RuleBuilderError> {
245 match (l, r) {
246 (QueryEntry::Var { id: v1, .. }, QueryEntry::Var { id: v2, .. }) => {
247 let ty1 = self.query.vars[*v1].ty;
248 let ty2 = self.query.vars[*v2].ty;
249 if ty1 != ty2 {
250 return Err(RuleBuilderError::TypeMismatch {
251 expected: ty1,
252 got: ty2,
253 });
254 }
255 }
256 (QueryEntry::Const { .. }, QueryEntry::Const { .. })
258 | (QueryEntry::Var { .. }, QueryEntry::Const { .. })
259 | (QueryEntry::Const { .. }, QueryEntry::Var { .. }) => {}
260 }
261 Ok(())
262 }
263
264 fn assert_has_ty(
265 &self,
266 entry: &QueryEntry,
267 ty: ColumnTy,
268 ) -> std::result::Result<(), RuleBuilderError> {
269 if let QueryEntry::Var { id: v, .. } = entry {
270 let var_ty = self.query.vars[*v].ty;
271 if var_ty != ty {
272 return Err(RuleBuilderError::TypeMismatch {
273 expected: var_ty,
274 got: ty,
275 });
276 }
277 }
278 Ok(())
279 }
280
281 pub fn build(self) -> RuleId {
283 assert!(
284 !self.egraph.tracing,
285 "proofs are enabled: use `build_with_syntax` instead"
286 );
287 self.build_internal(None)
288 }
289
290 pub fn build_with_syntax(self, syntax: SourceSyntax) -> RuleId {
291 self.build_internal(Some(syntax))
292 }
293
294 pub(crate) fn build_internal(mut self, syntax: Option<SourceSyntax>) -> RuleId {
295 if self.query.atoms.len() == 1 {
296 self.query.plan_strategy = PlanStrategy::MinCover;
297 }
298 if let Some(syntax) = &syntax {
299 if self.egraph.tracing {
300 let cb = self
301 .proof_builder
302 .create_reason(syntax.clone(), self.egraph);
303 self.query.build_reason = Some(Box::new(move |bndgs, rb| {
304 let reason = cb(bndgs, rb)?;
305 bndgs.lhs_reason = Some(reason.into());
306 Ok(())
307 }));
308 }
309 }
310 let res = self.query.rule_id;
311 let info = RuleInfo {
312 last_run_at: Timestamp::new(0),
313 query: self.query,
314 cached_plan: None,
315 desc: self.proof_builder.rule_description,
316 };
317 debug!("created rule {res:?} / {}", info.desc);
318 self.egraph.rules.insert(res, info);
319 res
320 }
321
322 pub(crate) fn set_focus(&mut self, focus: usize) {
323 self.query.sole_focus = Some(focus);
324 }
325
326 pub fn new_var(&mut self, ty: ColumnTy) -> Variable {
328 let res = self.query.vars.next_id();
329 self.query.vars.push(VarInfo { ty, term_var: res })
330 }
331
332 pub fn new_var_named(&mut self, ty: ColumnTy, name: &str) -> QueryEntry {
337 let id = self.new_var(ty);
338 QueryEntry::Var {
339 id,
340 name: Some(name.into()),
341 }
342 }
343
344 pub(crate) fn add_atom_with_timestamp_and_func(
352 &mut self,
353 table: TableId,
354 func: Option<FunctionId>,
355 subsume_entry: Option<QueryEntry>,
356 entries: &[QueryEntry],
357 ) -> AtomId {
358 let mut atom = entries.to_vec();
359 let schema_math = if let Some(func) = func {
360 let info = &self.egraph.funcs[func];
361 assert_eq!(info.schema.len(), entries.len());
362 SchemaMath {
363 tracing: self.egraph.tracing,
364 subsume: info.can_subsume,
365 func_cols: info.schema.len(),
366 }
367 } else {
368 SchemaMath {
369 tracing: self.egraph.tracing,
370 subsume: subsume_entry.is_some(),
371 func_cols: entries.len(),
372 }
373 };
374 schema_math.write_table_row(
375 &mut atom,
376 RowVals {
377 timestamp: self.new_var(ColumnTy::Id).into(),
378 proof: self
379 .egraph
380 .tracing
381 .then(|| self.new_var(ColumnTy::Id).into()),
382 subsume: if schema_math.subsume {
383 Some(subsume_entry.unwrap_or_else(|| self.new_var(ColumnTy::Id).into()))
384 } else {
385 None
386 },
387 ret_val: None,
388 },
389 );
390 let res = AtomId::from_usize(self.query.atoms.len());
391 if self.egraph.tracing {
392 let proof_var = atom[schema_math.proof_id_col()].var();
393 self.proof_builder.term_vars.insert(res, proof_var.into());
394 self.query.atom_proofs.push(proof_var);
395 if let Some(QueryEntry::Var { id, .. }) = entries.last() {
396 if table != self.egraph.uf_table {
397 self.query.vars[*id].term_var = proof_var;
400 }
401 }
402 }
403 self.query.atoms.push((table, atom, schema_math));
404 res
405 }
406
407 pub fn call_external_func(
408 &mut self,
409 func: ExternalFunctionId,
410 args: &[QueryEntry],
411 ret_ty: ColumnTy,
412 panic_msg: impl FnOnce() -> String + 'static + Send,
413 ) -> Variable {
414 let args = args.to_vec();
415 let res = self.new_var(ret_ty);
416 let panic_fn = self.egraph.new_panic_lazy(panic_msg);
418 self.query.add_rule.push(Box::new(move |inner, rb| {
419 let args = inner.convert_all(&args);
420 let var = rb.call_external_with_fallback(func, &args, panic_fn, &[])?;
421 inner.mapping.insert(res, var.into());
422 Ok(())
423 }));
424 res
425 }
426
427 pub fn query_table(
431 &mut self,
432 func: FunctionId,
433 entries: &[QueryEntry],
434 is_subsumed: Option<bool>,
435 ) -> Result<AtomId> {
436 let info = &self.egraph.funcs[func];
437 let schema = &info.schema;
438 if schema.len() != entries.len() {
439 return Err(anyhow::Error::from(RuleBuilderError::ArityMismatch {
440 expected: schema.len(),
441 got: entries.len(),
442 }))
443 .with_context(|| format!("query_table: mismatch between {entries:?} and {schema:?}"));
444 }
445 entries
446 .iter()
447 .zip(schema.iter())
448 .try_for_each(|(entry, ty)| {
449 self.assert_has_ty(entry, *ty)
450 .with_context(|| format!("query_table: mismatch between {entry:?} and {ty:?}"))
451 })?;
452 Ok(self.add_atom_with_timestamp_and_func(
453 info.table,
454 Some(func),
455 is_subsumed.map(|b| QueryEntry::Const {
456 val: match b {
457 true => SUBSUMED,
458 false => NOT_SUBSUMED,
459 },
460 ty: ColumnTy::Id,
461 }),
462 entries,
463 ))
464 }
465
466 pub fn query_prim(
469 &mut self,
470 func: ExternalFunctionId,
471 entries: &[QueryEntry],
472 _ret_ty: ColumnTy,
474 ) -> Result<()> {
475 let entries = entries.to_vec();
476 self.query.add_rule.push(Box::new(move |inner, rb| {
477 let mut dst_vars = inner.convert_all(&entries);
478 let expected = dst_vars.pop().expect("must specify a return value");
479 let var = rb.call_external(func, &dst_vars)?;
480 match entries.last().unwrap() {
481 QueryEntry::Var { id, .. } if !inner.grounded.contains(id) => {
482 inner.mapping.insert(*id, var.into());
483 inner.grounded.insert(*id);
484 }
485 _ => rb.assert_eq(var.into(), expected),
486 }
487 Ok(())
488 }));
489 Ok(())
490 }
491
492 pub fn subsume(&mut self, func: FunctionId, entries: &[QueryEntry]) {
496 let ret = self.lookup_with_subsumed(
498 func,
499 entries,
500 QueryEntry::Const {
501 val: SUBSUMED,
502 ty: ColumnTy::Id,
503 },
504 || "subsumed a nonextestent row!".to_string(),
505 );
506 let info = &self.egraph.funcs[func];
507 let schema_math = SchemaMath {
508 tracing: self.egraph.tracing,
509 subsume: info.can_subsume,
510 func_cols: info.schema.len(),
511 };
512 assert!(info.can_subsume);
513 assert_eq!(entries.len() + 1, info.schema.len());
514 let entries = entries.to_vec();
515 let table = info.table;
516 self.add_callback(move |inner, rb| {
517 let mut dst_entries = inner.convert_all(&entries);
520 let cur_subsume_val = rb.lookup(
521 table,
522 &dst_entries,
523 ColumnId::from_usize(schema_math.subsume_col()),
524 )?;
525 let cur_proof_val = if schema_math.tracing {
526 Some(DstVar::from(rb.lookup(
527 table,
528 &dst_entries,
529 ColumnId::from_usize(schema_math.proof_id_col()),
530 )?))
531 } else {
532 None
533 };
534 schema_math.write_table_row(
535 &mut dst_entries,
536 RowVals {
537 timestamp: inner.next_ts(),
538 proof: cur_proof_val,
539 subsume: Some(SUBSUMED.into()),
540 ret_val: Some(inner.convert(&ret.into())),
541 },
542 );
543 rb.insert_if_eq(
544 table,
545 cur_subsume_val.into(),
546 NOT_SUBSUMED.into(),
547 &dst_entries,
548 )?;
549 Ok(())
550 });
551 }
552
553 pub(crate) fn lookup_with_subsumed(
554 &mut self,
555 func: FunctionId,
556 entries: &[QueryEntry],
557 subsumed: QueryEntry,
558 panic_msg: impl FnOnce() -> String + Send + 'static,
559 ) -> Variable {
560 let entries = entries.to_vec();
561 let info = &self.egraph.funcs[func];
562 let res = self.query.vars.push(VarInfo {
563 ty: info.ret_ty(),
564 term_var: self.query.vars.next_id(),
565 });
566 let table = info.table;
567 let id_counter = self.query.id_counter;
568 let schema_math = SchemaMath {
569 tracing: self.egraph.tracing,
570 subsume: info.can_subsume,
571 func_cols: info.schema.len(),
572 };
573 let cb: BuildRuleCallback = match info.default_val {
574 DefaultVal::Const(_) | DefaultVal::FreshId => {
575 let (wv, wv_ref): (WriteVal, WriteVal) = match &info.default_val {
576 DefaultVal::Const(c) => ((*c).into(), (*c).into()),
577 DefaultVal::FreshId => (
578 WriteVal::IncCounter(id_counter),
579 WriteVal::CurrentVal(schema_math.ret_val_col()),
583 ),
584 _ => unreachable!(),
585 };
586 let get_write_vals = move |inner: &mut Bindings| {
587 let mut write_vals = SmallVec::<[WriteVal; 4]>::new();
588 for i in schema_math.num_keys()..schema_math.table_columns() {
589 if i == schema_math.ts_col() {
590 write_vals.push(inner.next_ts().into());
591 } else if i == schema_math.ret_val_col() {
592 write_vals.push(wv);
593 } else if schema_math.tracing && i == schema_math.proof_id_col() {
594 write_vals.push(wv_ref);
595 } else if schema_math.subsume && i == schema_math.subsume_col() {
596 write_vals.push(inner.convert(&subsumed).into())
597 } else {
598 unreachable!()
599 }
600 }
601 write_vals
602 };
603
604 if self.egraph.tracing {
605 let term_var = self.new_var(ColumnTy::Id);
606 self.query.vars[res].term_var = term_var;
607 let ts_var = self.new_var(ColumnTy::Id);
608 let mut insert_entries = entries.to_vec();
609 insert_entries.push(res.into());
610 let add_proof =
611 self.proof_builder
612 .new_row(func, insert_entries, term_var, self.egraph);
613 Box::new(move |inner, rb| {
614 let write_vals = get_write_vals(inner);
615 let dst_vars = inner.convert_all(&entries);
616 let var = rb.lookup_or_insert(
625 table,
626 &dst_vars,
627 &write_vals,
628 ColumnId::from_usize(schema_math.ret_val_col()),
629 )?;
630 let ts = rb.lookup_or_insert(
631 table,
632 &dst_vars,
633 &write_vals,
634 ColumnId::from_usize(schema_math.ts_col()),
635 )?;
636 let term = rb.lookup_or_insert(
637 table,
638 &dst_vars,
639 &write_vals,
640 ColumnId::from_usize(schema_math.proof_id_col()),
641 )?;
642 inner.mapping.insert(term_var, term.into());
643 inner.mapping.insert(res, var.into());
644 inner.mapping.insert(ts_var, ts.into());
645 rb.assert_eq(var.into(), term.into());
646 add_proof(inner, rb)?;
650 Ok(())
651 })
652 } else {
653 Box::new(move |inner, rb| {
654 let write_vals = get_write_vals(inner);
655 let dst_vars = inner.convert_all(&entries);
656 let var = rb.lookup_or_insert(
657 table,
658 &dst_vars,
659 &write_vals,
660 ColumnId::from_usize(schema_math.ret_val_col()),
661 )?;
662 inner.mapping.insert(res, var.into());
663 Ok(())
664 })
665 }
666 }
667 DefaultVal::Fail => {
668 let panic_func = self.egraph.new_panic_lazy(panic_msg);
669 if self.egraph.tracing {
670 let term_var = self.new_var(ColumnTy::Id);
671 Box::new(move |inner, rb| {
672 let dst_vars = inner.convert_all(&entries);
673 let var = rb.lookup_with_fallback(
674 table,
675 &dst_vars,
676 ColumnId::from_usize(schema_math.ret_val_col()),
677 panic_func,
678 &[],
679 )?;
680 let term = rb.lookup(
681 table,
682 &dst_vars,
683 ColumnId::from_usize(schema_math.proof_id_col()),
684 )?;
685 inner.mapping.insert(res, var.into());
686 inner.mapping.insert(term_var, term.into());
687 Ok(())
688 })
689 } else {
690 Box::new(move |inner, rb| {
691 let dst_vars = inner.convert_all(&entries);
692 let var = rb.lookup_with_fallback(
693 table,
694 &dst_vars,
695 ColumnId::from_usize(schema_math.ret_val_col()),
696 panic_func,
697 &[],
698 )?;
699 inner.mapping.insert(res, var.into());
700 Ok(())
701 })
702 }
703 }
704 };
705 self.query.add_rule.push(cb);
706 res
707 }
708
709 pub fn lookup(
715 &mut self,
716 func: FunctionId,
717 entries: &[QueryEntry],
718 panic_msg: impl FnOnce() -> String + Send + 'static,
719 ) -> Variable {
720 self.lookup_with_subsumed(
721 func,
722 entries,
723 QueryEntry::Const {
724 val: NOT_SUBSUMED,
725 ty: ColumnTy::Id,
726 },
727 panic_msg,
728 )
729 }
730
731 pub fn union(&mut self, mut l: QueryEntry, mut r: QueryEntry) {
733 let cb: BuildRuleCallback = if self.query.tracing {
734 for entry in [&mut l, &mut r] {
737 if let QueryEntry::Var { id, .. } = entry {
738 *id = self.query.vars[*id].term_var;
739 }
740 }
741 Box::new(move |inner, rb| {
742 let l = inner.convert(&l);
743 let r = inner.convert(&r);
744 let proof = inner.lhs_reason.expect("reason must be set");
745 rb.insert(inner.uf_table, &[l, r, inner.next_ts(), proof])
746 .context("union")
747 })
748 } else {
749 Box::new(move |inner, rb| {
750 let l = inner.convert(&l);
751 let r = inner.convert(&r);
752 rb.insert(inner.uf_table, &[l, r, inner.next_ts()])
753 .context("union")
754 })
755 };
756 self.query.add_rule.push(cb);
757 }
758
759 pub(crate) fn rebuild_row(
768 &mut self,
769 func: FunctionId,
770 before: &[QueryEntry],
771 after: &[QueryEntry],
772 subsume_var: Option<Variable>,
775 ) {
776 assert_eq!(before.len(), after.len());
777 self.remove(func, &before[..before.len() - 1]);
778 if !self.egraph.tracing {
779 if let Some(subsume_var) = subsume_var {
780 self.set_with_subsume(
781 func,
782 after,
783 QueryEntry::Var {
784 id: subsume_var,
785 name: None,
786 },
787 );
788 } else {
789 self.set(func, after);
790 }
791 return;
792 }
793 let table = self.egraph.funcs[func].table;
794 let term_var = self.new_var(ColumnTy::Id);
795 let reason_var = self.new_var(ColumnTy::Id);
796 let before_id = before.last().unwrap().var();
797 let before_term = self.query.vars[before_id].term_var;
798 debug_assert_ne!(before_term, before_id);
799 let add_proof = self.proof_builder.rebuild_proof(
800 func,
801 after,
802 RebuildVars {
803 before_term,
804 new_term: term_var,
805 reason: reason_var,
806 },
807 self.egraph,
808 );
809 let after = SmallVec::<[_; 4]>::from_iter(after.iter().cloned());
810 let uf_table = self.query.uf_table;
811 let info = &self.egraph.funcs[func];
812 let schema_math = SchemaMath {
813 tracing: self.egraph.tracing,
814 subsume: info.can_subsume,
815 func_cols: info.schema.len(),
816 };
817 self.query.add_rule.push(Box::new(move |inner, rb| {
818 add_proof(inner, rb)?;
819 let mut dst_vars = inner.convert_all(&after);
820 schema_math.write_table_row(
821 &mut dst_vars,
822 RowVals {
823 timestamp: inner.next_ts(),
824 proof: Some(inner.mapping[term_var]),
825 subsume: subsume_var.map(|v| inner.mapping[v]),
826 ret_val: None, },
828 );
829 rb.insert(
832 uf_table,
833 &[
834 inner.mapping[before_term],
835 inner.mapping[term_var],
836 inner.next_ts(),
837 inner.mapping[reason_var],
838 ],
839 )
840 .context("rebuild_row_uf")?;
841 rb.insert(table, &dst_vars).context("rebuild_row_table")
842 }));
843 }
844
845 pub fn set(&mut self, func: FunctionId, entries: &[QueryEntry]) {
847 self.set_with_subsume(
848 func,
849 entries,
850 QueryEntry::Const {
851 val: NOT_SUBSUMED,
852 ty: ColumnTy::Id,
853 },
854 );
855 }
856
857 pub(crate) fn set_with_subsume(
858 &mut self,
859 func: FunctionId,
860 entries: &[QueryEntry],
861 subsume_entry: QueryEntry,
862 ) {
863 let info = &self.egraph.funcs[func];
864 let table = info.table;
865 let entries = entries.to_vec();
866 let schema_math = SchemaMath {
867 tracing: self.egraph.tracing,
868 subsume: info.can_subsume,
869 func_cols: info.schema.len(),
870 };
871 if self.egraph.tracing {
872 let res = self.lookup(func, &entries[0..entries.len() - 1], || {
873 "lookup failed during proof-enabled set; this is an internal proofs bug".to_string()
874 });
875 self.union(res.into(), entries.last().unwrap().clone());
876 if schema_math.subsume {
877 self.add_callback(move |inner, rb| {
879 let mut dst_vars = inner.convert_all(&entries);
880 let proof_var = rb.lookup(
881 table,
882 &dst_vars[0..schema_math.num_keys()],
883 ColumnId::from_usize(schema_math.proof_id_col()),
884 )?;
885 schema_math.write_table_row(
886 &mut dst_vars,
887 RowVals {
888 timestamp: inner.next_ts(),
889 proof: Some(proof_var.into()),
890 subsume: Some(inner.convert(&subsume_entry)),
891 ret_val: Some(inner.convert(&res.into())),
892 },
893 );
894 rb.insert(table, &dst_vars).context("set")
895 });
896 }
897 } else {
898 self.query.add_rule.push(Box::new(move |inner, rb| {
899 let mut dst_vars = inner.convert_all(&entries);
900 schema_math.write_table_row(
901 &mut dst_vars,
902 RowVals {
903 timestamp: inner.next_ts(),
904 proof: None, subsume: schema_math.subsume.then(|| inner.convert(&subsume_entry)),
906 ret_val: None, },
908 );
909 rb.insert(table, &dst_vars).context("set")
910 }));
911 };
912 }
913
914 pub fn remove(&mut self, table: FunctionId, entries: &[QueryEntry]) {
916 let table = self.egraph.funcs[table].table;
917 let entries = entries.to_vec();
918 let cb: BuildRuleCallback = Box::new(move |inner, rb| {
919 let dst_vars = inner.convert_all(&entries);
920 rb.remove(table, &dst_vars).context("remove")
921 });
922 self.query.add_rule.push(cb);
923 }
924
925 pub fn panic(&mut self, message: String) {
927 let panic = self.egraph.new_panic(message.clone());
928 let ret_ty = ColumnTy::Id;
929 let res = self.new_var(ret_ty);
930 self.query.add_rule.push(Box::new(move |inner, rb| {
931 let var = rb.call_external(panic, &[])?;
932 inner.mapping.insert(res, var.into());
933 Ok(())
934 }));
935 }
936}
937
938impl Query {
939 fn query_state<'a, 'outer>(
940 &self,
941 rsb: &'a mut RuleSetBuilder<'outer>,
942 ) -> (QueryBuilder<'outer, 'a>, Bindings) {
943 let mut qb = rsb.new_rule();
944 qb.set_plan_strategy(self.plan_strategy);
945 let mut inner = Bindings {
946 uf_table: self.uf_table,
947 next_ts: None,
948 lhs_reason: None,
949 mapping: Default::default(),
950 grounded: Default::default(),
951 };
952 for (var, _) in self.vars.iter() {
953 inner.mapping.insert(var, DstVar::Var(qb.new_var()));
954 }
955 (qb, inner)
956 }
957
958 fn run_rules_and_build(
959 &self,
960 qb: QueryBuilder,
961 mut inner: Bindings,
962 desc: &str,
963 ) -> Result<core_relations::RuleId> {
964 let mut rb = qb.build();
965 inner.next_ts = Some(rb.read_counter(self.ts_counter).into());
966 if let Some(build_reason) = &self.build_reason {
968 build_reason(&mut inner, &mut rb)?;
969 }
970 self.add_rule
971 .iter()
972 .try_for_each(|f| f(&mut inner, &mut rb))?;
973 Ok(rb.build_with_description(desc))
974 }
975
976 pub(crate) fn build_cached_plan(
977 &self,
978 db: &mut core_relations::Database,
979 desc: &str,
980 ) -> Result<CachedPlanInfo> {
981 let mut rsb = RuleSetBuilder::new(db);
982 let (mut qb, mut inner) = self.query_state(&mut rsb);
983 let mut atom_mapping = Vec::with_capacity(self.atoms.len());
984 for (table, entries, _schema_info) in &self.atoms {
985 atom_mapping.push(add_atom(&mut qb, *table, entries, &[], &mut inner)?);
986 }
987 let rule_id = self.run_rules_and_build(qb, inner, desc)?;
988 let rs = rsb.build();
989 let plan = Arc::new(rs.build_cached_plan(rule_id));
990 Ok(CachedPlanInfo { plan, atom_mapping })
991 }
992
993 pub(crate) fn add_rules_from_cached(
1000 &self,
1001 rsb: &mut RuleSetBuilder,
1002 mid_ts: Timestamp,
1003 cached_plan: &CachedPlanInfo,
1004 ) -> Result<()> {
1005 if !self.seminaive || (self.atoms.is_empty() && mid_ts == Timestamp::new(0)) {
1008 rsb.add_rule_from_cached_plan(&cached_plan.plan, &[]);
1009 return Ok(());
1010 }
1011 if let Some(focus_atom) = self.sole_focus {
1012 let (_, _, schema_info) = &self.atoms[focus_atom];
1014 let ts_col = ColumnId::from_usize(schema_info.ts_col());
1015 rsb.add_rule_from_cached_plan(
1016 &cached_plan.plan,
1017 &[(
1018 cached_plan.atom_mapping[focus_atom],
1019 Constraint::GeConst {
1020 col: ts_col,
1021 val: mid_ts.to_value(),
1022 },
1023 )],
1024 );
1025 return Ok(());
1026 }
1027 let mut constraints: Vec<(core_relations::AtomId, Constraint)> =
1029 Vec::with_capacity(self.atoms.len());
1030 'outer: for focus_atom in 0..self.atoms.len() {
1031 for (i, (_, _, schema_info)) in self.atoms.iter().enumerate() {
1032 let ts_col = ColumnId::from_usize(schema_info.ts_col());
1033 match i.cmp(&focus_atom) {
1034 Ordering::Less => {
1035 if mid_ts == Timestamp::new(0) {
1036 continue 'outer;
1037 }
1038 constraints.push((
1039 cached_plan.atom_mapping[i],
1040 Constraint::LtConst {
1041 col: ts_col,
1042 val: mid_ts.to_value(),
1043 },
1044 ));
1045 }
1046 Ordering::Equal => constraints.push((
1047 cached_plan.atom_mapping[i],
1048 Constraint::GeConst {
1049 col: ts_col,
1050 val: mid_ts.to_value(),
1051 },
1052 )),
1053 Ordering::Greater => {}
1054 };
1055 }
1056 rsb.add_rule_from_cached_plan(&cached_plan.plan, &constraints);
1057 constraints.clear();
1058 }
1059 Ok(())
1060 }
1061}
1062
1063pub(crate) struct Bindings {
1066 uf_table: TableId,
1067 next_ts: Option<DstVar>,
1068 pub(crate) lhs_reason: Option<DstVar>,
1071 pub(crate) mapping: DenseIdMap<Variable, DstVar>,
1072 grounded: HashSet<Variable>,
1073}
1074
1075impl Bindings {
1076 pub(crate) fn next_ts(&self) -> DstVar {
1077 self.next_ts
1078 .expect("ts_var should only be used in RHS of the rule")
1079 }
1080 pub(crate) fn convert(&self, entry: &QueryEntry) -> DstVar {
1081 match entry {
1082 QueryEntry::Var { id: v, .. } => self.mapping[*v],
1083 QueryEntry::Const { val, .. } => DstVar::Const(*val),
1084 }
1085 }
1086 pub(crate) fn convert_all(&self, entries: &[QueryEntry]) -> SmallVec<[DstVar; 4]> {
1087 entries.iter().map(|e| self.convert(e)).collect()
1088 }
1089}
1090
1091fn add_atom(
1092 qb: &mut QueryBuilder,
1093 table: TableId,
1094 entries: &[QueryEntry],
1095 constraints: &[Constraint],
1096 inner: &mut Bindings,
1097) -> Result<core_relations::AtomId> {
1098 for entry in entries {
1099 if let QueryEntry::Var { id, .. } = entry {
1100 inner.grounded.insert(*id);
1101 }
1102 }
1103 let vars = inner.convert_all(entries);
1104 Ok(qb.add_atom(table, &vars, constraints)?)
1105}