1use std::{iter::once, sync::Arc};
4
5use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
6use smallvec::SmallVec;
7use thiserror::Error;
8
9use crate::{
10 BaseValueId, CounterId, ExternalFunctionId, PoolSet,
11 action::{Instr, QueryEntry, WriteVal},
12 common::HashMap,
13 free_join::{
14 ActionId, AtomId, Database, ProcessedConstraints, SubAtom, TableId, TableInfo, VarInfo,
15 Variable,
16 plan::{JoinHeader, JoinStages, Plan, PlanStrategy},
17 },
18 pool::{Pooled, with_pool_set},
19 table_spec::{ColumnId, Constraint},
20};
21
22define_id!(pub RuleId, u32, "An identifier for a rule in a rule set");
23
24pub struct CachedPlan {
26 plan: Plan,
27 desc: String,
28 actions: ActionInfo,
29}
30
31#[derive(Debug, Clone)]
32pub(crate) struct ActionInfo {
33 pub(crate) used_vars: SmallVec<[Variable; 4]>,
34 pub(crate) instrs: Arc<Pooled<Vec<Instr>>>,
35}
36
37#[derive(Default)]
41pub struct RuleSet {
42 pub(crate) plans: IdVec<RuleId, (Plan, String , ActionId)>,
50 pub(crate) actions: DenseIdMap<ActionId, ActionInfo>,
51}
52
53impl RuleSet {
54 pub fn build_cached_plan(&self, rule_id: RuleId) -> CachedPlan {
55 let (plan, desc, action_id) = self.plans.get(rule_id).expect("rule must exist");
56 let actions = self
57 .actions
58 .get(*action_id)
59 .expect("action must exist")
60 .clone();
61 CachedPlan {
62 plan: plan.clone(),
63 desc: desc.clone(),
64 actions,
65 }
66 }
67}
68
69pub struct RuleSetBuilder<'outer> {
73 rule_set: RuleSet,
74 db: &'outer mut Database,
75}
76
77impl<'outer> RuleSetBuilder<'outer> {
78 pub fn new(db: &'outer mut Database) -> Self {
79 Self {
80 rule_set: Default::default(),
81 db,
82 }
83 }
84
85 pub fn estimate_size(&self, table: TableId, c: Option<Constraint>) -> usize {
90 self.db.estimate_size(table, c)
91 }
92
93 pub fn new_rule<'a>(&'a mut self) -> QueryBuilder<'outer, 'a> {
95 let instrs = with_pool_set(PoolSet::get);
96 QueryBuilder {
97 rsb: self,
98 instrs,
99 query: Query {
100 var_info: Default::default(),
101 atoms: Default::default(),
102 action: ActionId::new(u32::MAX),
104 plan_strategy: Default::default(),
105 },
106 }
107 }
108
109 pub fn add_rule_from_cached_plan(
110 &mut self,
111 cached: &CachedPlan,
112 extra_constraints: &[(AtomId, Constraint)],
113 ) -> RuleId {
114 let action_id = self.rule_set.actions.push(cached.actions.clone());
116 let mut plan = Plan {
117 atoms: cached.plan.atoms.clone(),
118 stages: JoinStages {
119 header: Default::default(),
120 instrs: cached.plan.stages.instrs.clone(),
121 actions: action_id,
122 },
123 };
124
125 for (atom_id, constraint) in extra_constraints {
127 let atom_info = plan.atoms.get(*atom_id).expect("atom must exist in plan");
128 let table = atom_info.table;
129 let processed = self
130 .db
131 .process_constraints(table, std::slice::from_ref(constraint));
132 if !processed.slow.is_empty() {
133 panic!(
134 "Cached plans only support constraints with a fast pushdown. Got: {constraint:?} for table {table:?}",
135 );
136 }
137 plan.stages.header.push(JoinHeader {
138 atom: *atom_id,
139 constraints: processed.fast,
140 subset: processed.subset,
141 });
142 }
143
144 for JoinHeader {
147 atom, constraints, ..
148 } in &cached.plan.stages.header
149 {
150 let atom_info = plan.atoms.get(*atom).expect("atom must exist in plan");
151 let table = atom_info.table;
152 let processed = self.db.process_constraints(table, constraints);
153 if !processed.slow.is_empty() {
154 panic!(
155 "Cached plans only support constraints with a fast pushdown. Got: {constraints:?} for table {table:?}",
156 );
157 }
158 plan.stages.header.push(JoinHeader {
159 atom: *atom,
160 constraints: processed.fast,
161 subset: processed.subset,
162 });
163 }
164
165 self.rule_set
166 .plans
167 .push((plan, cached.desc.clone(), action_id))
168 }
169
170 pub fn build(self) -> RuleSet {
172 self.rule_set
173 }
174}
175
176pub struct QueryBuilder<'outer, 'a> {
181 rsb: &'a mut RuleSetBuilder<'outer>,
182 query: Query,
183 instrs: Pooled<Vec<Instr>>,
184}
185
186impl<'outer, 'a> QueryBuilder<'outer, 'a> {
187 pub fn build(self) -> RuleBuilder<'outer, 'a> {
189 RuleBuilder { qb: self }
190 }
191
192 pub fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
194 self.query.plan_strategy = strategy;
195 }
196
197 pub fn new_var(&mut self) -> Variable {
199 self.query.var_info.push(VarInfo {
200 occurrences: Default::default(),
201 used_in_rhs: false,
202 defined_in_rhs: false,
203 })
204 }
205
206 fn mark_used<'b>(&mut self, entries: impl IntoIterator<Item = &'b QueryEntry>) {
207 for entry in entries {
208 if let QueryEntry::Var(v) = entry {
209 self.query.var_info[*v].used_in_rhs = true;
210 }
211 }
212 }
213
214 fn mark_defined(&mut self, entry: &QueryEntry) {
215 if let QueryEntry::Var(v) = entry {
217 self.query.var_info[*v].defined_in_rhs = true;
218 }
219 }
220
221 pub fn add_atom<'b>(
234 &mut self,
235 table_id: TableId,
236 vars: &[QueryEntry],
237 cs: impl IntoIterator<Item = &'b Constraint>,
238 ) -> Result<AtomId, QueryError> {
239 let info = &self.rsb.db.tables[table_id];
240 let arity = info.spec.arity();
241 let check_constraint = |c: &Constraint| {
242 let process_col = |col: &ColumnId| -> Result<(), QueryError> {
243 if col.index() >= arity {
244 Err(QueryError::InvalidConstraint {
245 constraint: c.clone(),
246 column: col.index(),
247 table: table_id,
248 arity,
249 })
250 } else {
251 Ok(())
252 }
253 };
254 match c {
255 Constraint::Eq { l_col, r_col } => {
256 process_col(l_col)?;
257 process_col(r_col)
258 }
259 Constraint::EqConst { col, .. }
260 | Constraint::LtConst { col, .. }
261 | Constraint::GtConst { col, .. }
262 | Constraint::LeConst { col, .. }
263 | Constraint::GeConst { col, .. } => process_col(col),
264 }
265 };
266 if arity != vars.len() {
267 return Err(QueryError::BadArity {
268 table: table_id,
269 expected: arity,
270 got: vars.len(),
271 });
272 }
273 let cs = Vec::from_iter(
274 cs.into_iter()
275 .cloned()
276 .chain(vars.iter().enumerate().filter_map(|(i, qe)| match qe {
277 QueryEntry::Var(_) => None,
278 QueryEntry::Const(c) => Some(Constraint::EqConst {
279 col: ColumnId::from_usize(i),
280 val: *c,
281 }),
282 })),
283 );
284 cs.iter().try_fold((), |_, c| check_constraint(c))?;
285 let processed = self.rsb.db.process_constraints(table_id, &cs);
286 let mut atom = Atom {
287 table: table_id,
288 var_to_column: Default::default(),
289 column_to_var: Default::default(),
290 constraints: processed,
291 };
292 let next_atom = AtomId::from_usize(self.query.atoms.n_ids());
293 let mut subatoms = HashMap::<Variable, SubAtom>::default();
294 for (i, qe) in vars.iter().enumerate() {
295 let var = match qe {
296 QueryEntry::Var(var) => *var,
297 QueryEntry::Const(_) => {
298 continue;
299 }
300 };
301 if var == Variable::placeholder() {
302 continue;
303 }
304 let col = ColumnId::from_usize(i);
305 if let Some(prev) = atom.var_to_column.insert(var, col) {
306 atom.constraints.slow.push(Constraint::Eq {
307 l_col: col,
308 r_col: prev,
309 })
310 };
311 atom.column_to_var.insert(col, var);
312 subatoms
313 .entry(var)
314 .or_insert_with(|| SubAtom::new(next_atom))
315 .vars
316 .push(col);
317 }
318 for (var, subatom) in subatoms {
319 self.query
320 .var_info
321 .get_mut(var)
322 .expect("all variables must be bound in current query")
323 .occurrences
324 .push(subatom);
325 }
326 Ok(self.query.atoms.push(atom))
327 }
328}
329
330#[derive(Debug, Error)]
331pub enum QueryError {
332 #[error("table {table:?} has {expected:?} keys but got {got:?}")]
333 KeyArityMismatch {
334 table: TableId,
335 expected: usize,
336 got: usize,
337 },
338 #[error("table {table:?} has {expected:?} columns but got {got:?}")]
339 TableArityMismatch {
340 table: TableId,
341 expected: usize,
342 got: usize,
343 },
344
345 #[error(
346 "counter used in column {column_id:?} of table {table:?}, which is declared as a base value"
347 )]
348 CounterUsedInBaseColumn {
349 table: TableId,
350 column_id: ColumnId,
351 base: BaseValueId,
352 },
353
354 #[error("attempt to compare two groups of values, one of length {l}, another of length {r}")]
355 MultiComparisonMismatch { l: usize, r: usize },
356
357 #[error("table {table:?} expected {expected:?} columns but got {got:?}")]
358 BadArity {
359 table: TableId,
360 expected: usize,
361 got: usize,
362 },
363
364 #[error("expected {expected:?} columns in schema but got {got:?}")]
365 InvalidSchema { expected: usize, got: usize },
366
367 #[error(
368 "constraint {constraint:?} on table {table:?} references column {column:?}, but the table has arity {arity:?}"
369 )]
370 InvalidConstraint {
371 constraint: Constraint,
372 column: usize,
373 table: TableId,
374 arity: usize,
375 },
376}
377
378pub struct RuleBuilder<'outer, 'a> {
382 qb: QueryBuilder<'outer, 'a>,
383}
384
385impl RuleBuilder<'_, '_> {
386 pub fn build(self) -> RuleId {
388 self.build_with_description("")
389 }
390 pub fn build_with_description(mut self, desc: impl Into<String>) -> RuleId {
391 let used_vars =
393 SmallVec::from_iter(self.qb.query.var_info.iter().filter_map(|(v, info)| {
394 if info.used_in_rhs && !info.defined_in_rhs {
395 Some(v)
396 } else {
397 None
398 }
399 }));
400 let action_id = self.qb.rsb.rule_set.actions.push(ActionInfo {
401 instrs: Arc::new(self.qb.instrs),
402 used_vars,
403 });
404 self.qb.query.action = action_id;
405 let plan = self.qb.rsb.db.plan_query(self.qb.query);
407 self.qb
409 .rsb
410 .rule_set
411 .plans
412 .push((plan, desc.into(), action_id))
413 }
414
415 pub fn read_counter(&mut self, counter: CounterId) -> Variable {
417 let dst = self.qb.new_var();
418 self.qb.instrs.push(Instr::ReadCounter { counter, dst });
419 self.qb.mark_defined(&dst.into());
420 dst
421 }
422
423 pub fn lookup_or_insert(
430 &mut self,
431 table: TableId,
432 args: &[QueryEntry],
433 default_vals: &[WriteVal],
434 dst_col: ColumnId,
435 ) -> Result<Variable, QueryError> {
436 let table_info = self
437 .qb
438 .rsb
439 .db
440 .tables
441 .get(table)
442 .expect("table must be declared in the current database");
443 self.validate_keys(table, table_info, args)?;
444 self.validate_vals(table, table_info, default_vals.iter())?;
445 let res = self.qb.new_var();
446 self.qb.instrs.push(Instr::LookupOrInsertDefault {
447 table,
448 args: args.to_vec(),
449 default: default_vals.to_vec(),
450 dst_col,
451 dst_var: res,
452 });
453 self.qb.mark_used(args);
454 self.qb
455 .mark_used(default_vals.iter().filter_map(|x| match x {
456 WriteVal::QueryEntry(qe) => Some(qe),
457 WriteVal::IncCounter(_) | WriteVal::CurrentVal(_) => None,
458 }));
459 self.qb.mark_defined(&res.into());
460 Ok(res)
461 }
462
463 pub fn lookup_with_default(
470 &mut self,
471 table: TableId,
472 args: &[QueryEntry],
473 default: QueryEntry,
474 dst_col: ColumnId,
475 ) -> Result<Variable, QueryError> {
476 let table_info = self
477 .qb
478 .rsb
479 .db
480 .tables
481 .get(table)
482 .expect("table must be declared in the current database");
483 self.validate_keys(table, table_info, args)?;
484 let res = self.qb.new_var();
485 self.qb.instrs.push(Instr::LookupWithDefault {
486 table,
487 args: args.to_vec(),
488 dst_col,
489 dst_var: res,
490 default,
491 });
492 self.qb.mark_used(args);
493 self.qb.mark_used(&[default]);
494 self.qb.mark_defined(&res.into());
495 Ok(res)
496 }
497
498 pub fn lookup(
505 &mut self,
506 table: TableId,
507 args: &[QueryEntry],
508 dst_col: ColumnId,
509 ) -> Result<Variable, QueryError> {
510 let table_info = self
511 .qb
512 .rsb
513 .db
514 .tables
515 .get(table)
516 .expect("table must be declared in the current database");
517 self.validate_keys(table, table_info, args)?;
518 let res = self.qb.new_var();
519 self.qb.instrs.push(Instr::Lookup {
520 table,
521 args: args.to_vec(),
522 dst_col,
523 dst_var: res,
524 });
525 self.qb.mark_used(args);
526 self.qb.mark_defined(&res.into());
527 Ok(res)
528 }
529
530 pub fn insert(&mut self, table: TableId, vals: &[QueryEntry]) -> Result<(), QueryError> {
532 let table_info = self
533 .qb
534 .rsb
535 .db
536 .tables
537 .get(table)
538 .expect("table must be declared in the current database");
539 self.validate_row(table, table_info, vals)?;
540 self.qb.instrs.push(Instr::Insert {
541 table,
542 vals: vals.to_vec(),
543 });
544 self.qb.mark_used(vals);
545 Ok(())
546 }
547
548 pub fn insert_if_eq(
550 &mut self,
551 table: TableId,
552 l: QueryEntry,
553 r: QueryEntry,
554 vals: &[QueryEntry],
555 ) -> Result<(), QueryError> {
556 let table_info = self
557 .qb
558 .rsb
559 .db
560 .tables
561 .get(table)
562 .expect("table must be declared in the current database");
563 self.validate_row(table, table_info, vals)?;
564 self.qb.instrs.push(Instr::InsertIfEq {
565 table,
566 l,
567 r,
568 vals: vals.to_vec(),
569 });
570 self.qb
571 .mark_used(vals.iter().chain(once(&l)).chain(once(&r)));
572 Ok(())
573 }
574
575 pub fn remove(&mut self, table: TableId, args: &[QueryEntry]) -> Result<(), QueryError> {
577 let table_info = self
578 .qb
579 .rsb
580 .db
581 .tables
582 .get(table)
583 .expect("table must be declared in the current database");
584 self.validate_keys(table, table_info, args)?;
585 self.qb.instrs.push(Instr::Remove {
586 table,
587 args: args.to_vec(),
588 });
589 self.qb.mark_used(args);
590 Ok(())
591 }
592
593 pub fn call_external(
595 &mut self,
596 func: ExternalFunctionId,
597 args: &[QueryEntry],
598 ) -> Result<Variable, QueryError> {
599 let res = self.qb.new_var();
600 self.qb.instrs.push(Instr::External {
601 func,
602 args: args.to_vec(),
603 dst: res,
604 });
605 self.qb.mark_used(args);
606 self.qb.mark_defined(&res.into());
607 Ok(res)
608 }
609
610 pub fn lookup_with_fallback(
614 &mut self,
615 table: TableId,
616 key: &[QueryEntry],
617 dst_col: ColumnId,
618 func: ExternalFunctionId,
619 func_args: &[QueryEntry],
620 ) -> Result<Variable, QueryError> {
621 let table_info = self
622 .qb
623 .rsb
624 .db
625 .tables
626 .get(table)
627 .expect("table must be declared in the current database");
628 self.validate_keys(table, table_info, key)?;
629 let res = self.qb.new_var();
630 self.qb.instrs.push(Instr::LookupWithFallback {
631 table,
632 table_key: key.to_vec(),
633 func,
634 func_args: func_args.to_vec(),
635 dst_var: res,
636 dst_col,
637 });
638 self.qb.mark_used(key);
639 self.qb.mark_used(func_args);
640 self.qb.mark_defined(&res.into());
641 Ok(res)
642 }
643
644 pub fn call_external_with_fallback(
645 &mut self,
646 f1: ExternalFunctionId,
647 args1: &[QueryEntry],
648 f2: ExternalFunctionId,
649 args2: &[QueryEntry],
650 ) -> Result<Variable, QueryError> {
651 let res = self.qb.new_var();
652 self.qb.instrs.push(Instr::ExternalWithFallback {
653 f1,
654 args1: args1.to_vec(),
655 f2,
656 args2: args2.to_vec(),
657 dst: res,
658 });
659 self.qb.mark_used(args1);
660 self.qb.mark_used(args2);
661 self.qb.mark_defined(&res.into());
662 Ok(res)
663 }
664
665 pub fn assert_eq(&mut self, l: QueryEntry, r: QueryEntry) {
667 self.qb.instrs.push(Instr::AssertEq(l, r));
668 self.qb.mark_used(&[l, r]);
669 }
670
671 pub fn assert_ne(&mut self, l: QueryEntry, r: QueryEntry) -> Result<(), QueryError> {
673 self.qb.instrs.push(Instr::AssertNe(l, r));
674 self.qb.mark_used(&[l, r]);
675 Ok(())
676 }
677
678 pub fn assert_any_ne(&mut self, l: &[QueryEntry], r: &[QueryEntry]) -> Result<(), QueryError> {
682 if l.len() != r.len() {
683 return Err(QueryError::MultiComparisonMismatch {
684 l: l.len(),
685 r: r.len(),
686 });
687 }
688
689 let mut ops = Vec::with_capacity(l.len() + r.len());
690 ops.extend_from_slice(l);
691 ops.extend_from_slice(r);
692 self.qb.instrs.push(Instr::AssertAnyNe {
693 ops,
694 divider: l.len(),
695 });
696 self.qb.mark_used(l);
697 self.qb.mark_used(r);
698 Ok(())
699 }
700
701 fn validate_row(
702 &self,
703 table: TableId,
704 info: &TableInfo,
705 vals: &[QueryEntry],
706 ) -> Result<(), QueryError> {
707 if vals.len() != info.spec.arity() {
708 Err(QueryError::TableArityMismatch {
709 table,
710 expected: info.spec.arity(),
711 got: vals.len(),
712 })
713 } else {
714 Ok(())
715 }
716 }
717
718 fn validate_keys(
719 &self,
720 table: TableId,
721 info: &TableInfo,
722 keys: &[QueryEntry],
723 ) -> Result<(), QueryError> {
724 if keys.len() != info.spec.n_keys {
725 Err(QueryError::KeyArityMismatch {
726 table,
727 expected: info.spec.n_keys,
728 got: keys.len(),
729 })
730 } else {
731 Ok(())
732 }
733 }
734
735 fn validate_vals<'b>(
736 &self,
737 table: TableId,
738 info: &TableInfo,
739 vals: impl Iterator<Item = &'b WriteVal>,
740 ) -> Result<(), QueryError> {
741 for (i, _) in vals.enumerate() {
742 let col = i + info.spec.n_keys;
743 if col >= info.spec.arity() {
744 return Err(QueryError::TableArityMismatch {
745 table,
746 expected: info.spec.arity(),
747 got: col,
748 });
749 }
750 }
751 Ok(())
752 }
753}
754
755#[derive(Debug, Clone)]
756pub(crate) struct Atom {
757 pub(crate) table: TableId,
758 pub(crate) var_to_column: HashMap<Variable, ColumnId>,
759 pub(crate) column_to_var: DenseIdMap<ColumnId, Variable>,
760 pub(crate) constraints: ProcessedConstraints,
766}
767
768pub(crate) struct Query {
769 pub(crate) var_info: DenseIdMap<Variable, VarInfo>,
770 pub(crate) atoms: DenseIdMap<AtomId, Atom>,
771 pub(crate) action: ActionId,
772 pub(crate) plan_strategy: PlanStrategy,
773}