use std::{iter::once, sync::Arc};
use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id};
use smallvec::SmallVec;
use thiserror::Error;
use crate::{
BaseValueId, CounterId, ExternalFunctionId, PoolSet,
action::{Instr, QueryEntry, WriteVal},
common::HashMap,
free_join::{
ActionId, AtomId, Database, ProcessedConstraints, SubAtom, TableId, TableInfo, VarInfo,
Variable,
plan::{JoinHeader, JoinStages, Plan, PlanStrategy},
},
pool::{Pooled, with_pool_set},
table_spec::{ColumnId, Constraint},
};
define_id!(pub RuleId, u32, "An identifier for a rule in a rule set");
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct SymbolMap {
pub atoms: HashMap<AtomId, Arc<str>>,
pub vars: HashMap<Variable, Arc<str>>,
}
pub struct CachedPlan {
plan: Plan,
desc: Arc<str>,
symbol_map: SymbolMap,
actions: ActionInfo,
}
#[derive(Debug, Clone)]
pub(crate) struct ActionInfo {
pub(crate) used_vars: SmallVec<[Variable; 4]>,
pub(crate) instrs: Arc<Pooled<Vec<Instr>>>,
}
#[derive(Default)]
pub struct RuleSet {
pub(crate) plans: IdVec<RuleId, (Plan, Arc<str> /* description */, SymbolMap, ActionId)>,
pub(crate) actions: DenseIdMap<ActionId, ActionInfo>,
}
impl RuleSet {
pub fn build_cached_plan(&self, rule_id: RuleId) -> CachedPlan {
let (plan, desc, symbol_map, action_id) = self.plans.get(rule_id).expect("rule must exist");
let actions = self
.actions
.get(*action_id)
.expect("action must exist")
.clone();
CachedPlan {
plan: plan.clone(),
desc: desc.clone(),
symbol_map: symbol_map.clone(),
actions,
}
}
}
pub struct RuleSetBuilder<'outer> {
rule_set: RuleSet,
db: &'outer mut Database,
}
impl<'outer> RuleSetBuilder<'outer> {
pub fn new(db: &'outer mut Database) -> Self {
Self {
rule_set: Default::default(),
db,
}
}
pub fn estimate_size(&self, table: TableId, c: Option<Constraint>) -> usize {
self.db.estimate_size(table, c)
}
pub fn new_rule<'a>(&'a mut self) -> QueryBuilder<'outer, 'a> {
let instrs = with_pool_set(PoolSet::get);
QueryBuilder {
rsb: self,
instrs,
query: Query {
var_info: Default::default(),
atoms: Default::default(),
action: ActionId::new(u32::MAX),
plan_strategy: Default::default(),
},
}
}
pub fn add_rule_from_cached_plan(
&mut self,
cached: &CachedPlan,
extra_constraints: &[(AtomId, Constraint)],
) -> RuleId {
let action_id = self.rule_set.actions.push(cached.actions.clone());
let mut plan = Plan {
atoms: cached.plan.atoms.clone(),
stages: JoinStages {
header: Default::default(),
instrs: cached.plan.stages.instrs.clone(),
actions: action_id,
},
};
for (atom_id, constraint) in extra_constraints {
let atom_info = plan.atoms.get(*atom_id).expect("atom must exist in plan");
let table = atom_info.table;
let processed = self
.db
.process_constraints(table, std::slice::from_ref(constraint));
if !processed.slow.is_empty() {
panic!(
"Cached plans only support constraints with a fast pushdown. Got: {constraint:?} for table {table:?}",
);
}
plan.stages.header.push(JoinHeader {
atom: *atom_id,
constraints: processed.fast,
subset: processed.subset,
});
}
for JoinHeader {
atom, constraints, ..
} in &cached.plan.stages.header
{
let atom_info = plan.atoms.get(*atom).expect("atom must exist in plan");
let table = atom_info.table;
let processed = self.db.process_constraints(table, constraints);
if !processed.slow.is_empty() {
panic!(
"Cached plans only support constraints with a fast pushdown. Got: {constraints:?} for table {table:?}",
);
}
plan.stages.header.push(JoinHeader {
atom: *atom,
constraints: processed.fast,
subset: processed.subset,
});
}
self.rule_set.plans.push((
plan,
cached.desc.clone(),
cached.symbol_map.clone(),
action_id,
))
}
pub fn build(self) -> RuleSet {
self.rule_set
}
}
pub struct QueryBuilder<'outer, 'a> {
rsb: &'a mut RuleSetBuilder<'outer>,
query: Query,
instrs: Pooled<Vec<Instr>>,
}
impl<'outer, 'a> QueryBuilder<'outer, 'a> {
pub fn build(self) -> RuleBuilder<'outer, 'a> {
RuleBuilder { qb: self }
}
pub fn set_plan_strategy(&mut self, strategy: PlanStrategy) {
self.query.plan_strategy = strategy;
}
pub fn new_var(&mut self) -> Variable {
self.query.var_info.push(VarInfo {
occurrences: Default::default(),
used_in_rhs: false,
defined_in_rhs: false,
name: None,
})
}
pub fn new_var_named(&mut self, name: &str) -> Variable {
self.query.var_info.push(VarInfo {
occurrences: Default::default(),
used_in_rhs: false,
defined_in_rhs: false,
name: Some(name.into()),
})
}
fn mark_used<'b>(&mut self, entries: impl IntoIterator<Item = &'b QueryEntry>) {
for entry in entries {
if let QueryEntry::Var(v) = entry {
self.query.var_info[*v].used_in_rhs = true;
}
}
}
fn mark_defined(&mut self, entry: &QueryEntry) {
if let QueryEntry::Var(v) = entry {
self.query.var_info[*v].defined_in_rhs = true;
}
}
pub fn add_atom<'b>(
&mut self,
table_id: TableId,
vars: &[QueryEntry],
cs: impl IntoIterator<Item = &'b Constraint>,
) -> Result<AtomId, QueryError> {
let info = &self.rsb.db.tables[table_id];
let arity = info.spec.arity();
let check_constraint = |c: &Constraint| {
let process_col = |col: &ColumnId| -> Result<(), QueryError> {
if col.index() >= arity {
Err(QueryError::InvalidConstraint {
constraint: c.clone(),
column: col.index(),
table: table_id,
arity,
})
} else {
Ok(())
}
};
match c {
Constraint::Eq { l_col, r_col } => {
process_col(l_col)?;
process_col(r_col)
}
Constraint::EqConst { col, .. }
| Constraint::LtConst { col, .. }
| Constraint::GtConst { col, .. }
| Constraint::LeConst { col, .. }
| Constraint::GeConst { col, .. } => process_col(col),
}
};
if arity != vars.len() {
return Err(QueryError::BadArity {
table: table_id,
expected: arity,
got: vars.len(),
});
}
let cs = Vec::from_iter(
cs.into_iter()
.cloned()
.chain(vars.iter().enumerate().filter_map(|(i, qe)| match qe {
QueryEntry::Var(_) => None,
QueryEntry::Const(c) => Some(Constraint::EqConst {
col: ColumnId::from_usize(i),
val: *c,
}),
})),
);
cs.iter().try_fold((), |_, c| check_constraint(c))?;
let processed = self.rsb.db.process_constraints(table_id, &cs);
let mut atom = Atom {
table: table_id,
var_to_column: Default::default(),
column_to_var: Default::default(),
constraints: processed,
};
let next_atom = AtomId::from_usize(self.query.atoms.n_ids());
let mut subatoms = HashMap::<Variable, SubAtom>::default();
for (i, qe) in vars.iter().enumerate() {
let var = match qe {
QueryEntry::Var(var) => *var,
QueryEntry::Const(_) => {
continue;
}
};
if var == Variable::placeholder() {
continue;
}
let col = ColumnId::from_usize(i);
if let Some(prev) = atom.var_to_column.insert(var, col) {
atom.constraints.slow.push(Constraint::Eq {
l_col: col,
r_col: prev,
})
};
atom.column_to_var.insert(col, var);
subatoms
.entry(var)
.or_insert_with(|| SubAtom::new(next_atom))
.vars
.push(col);
}
for (var, subatom) in subatoms {
self.query
.var_info
.get_mut(var)
.expect("all variables must be bound in current query")
.occurrences
.push(subatom);
}
Ok(self.query.atoms.push(atom))
}
}
#[derive(Debug, Error)]
pub enum QueryError {
#[error("table {table:?} has {expected:?} keys but got {got:?}")]
KeyArityMismatch {
table: TableId,
expected: usize,
got: usize,
},
#[error("table {table:?} has {expected:?} columns but got {got:?}")]
TableArityMismatch {
table: TableId,
expected: usize,
got: usize,
},
#[error(
"counter used in column {column_id:?} of table {table:?}, which is declared as a base value"
)]
CounterUsedInBaseColumn {
table: TableId,
column_id: ColumnId,
base: BaseValueId,
},
#[error("attempt to compare two groups of values, one of length {l}, another of length {r}")]
MultiComparisonMismatch { l: usize, r: usize },
#[error("table {table:?} expected {expected:?} columns but got {got:?}")]
BadArity {
table: TableId,
expected: usize,
got: usize,
},
#[error("expected {expected:?} columns in schema but got {got:?}")]
InvalidSchema { expected: usize, got: usize },
#[error(
"constraint {constraint:?} on table {table:?} references column {column:?}, but the table has arity {arity:?}"
)]
InvalidConstraint {
constraint: Constraint,
column: usize,
table: TableId,
arity: usize,
},
}
pub struct RuleBuilder<'outer, 'a> {
qb: QueryBuilder<'outer, 'a>,
}
impl RuleBuilder<'_, '_> {
fn table_info(&self, table: TableId) -> &TableInfo {
self.qb.rsb.db.get_table_info(table)
}
pub fn build(self) -> RuleId {
self.build_with_description("")
}
fn build_symbol_map(&self) -> SymbolMap {
let var_info = &self.qb.query.var_info;
SymbolMap {
atoms: self
.qb
.query
.atoms
.iter()
.filter_map(|(id, atom)| {
let name = self.table_info(atom.table).name.clone();
name.map(|name| (id, name))
})
.collect(),
vars: var_info
.iter()
.filter_map(|(id, info)| info.name.as_ref().map(|name| (id, name.clone())))
.collect(),
}
}
pub fn build_with_description(mut self, desc: impl Into<String>) -> RuleId {
let var_info = &self.qb.query.var_info;
let symbol_map = self.build_symbol_map();
let used_vars = SmallVec::from_iter(var_info.iter().filter_map(|(v, info)| {
if info.used_in_rhs && !info.defined_in_rhs {
Some(v)
} else {
None
}
}));
let action_id = self.qb.rsb.rule_set.actions.push(ActionInfo {
instrs: Arc::new(self.qb.instrs),
used_vars,
});
self.qb.query.action = action_id;
let plan = self.qb.rsb.db.plan_query(self.qb.query);
let desc: String = desc.into();
self.qb
.rsb
.rule_set
.plans
.push((plan, desc.into(), symbol_map, action_id))
}
pub fn read_counter(&mut self, counter: CounterId) -> Variable {
let dst = self.qb.new_var();
self.qb.instrs.push(Instr::ReadCounter { counter, dst });
self.qb.mark_defined(&dst.into());
dst
}
pub fn lookup_or_insert(
&mut self,
table: TableId,
args: &[QueryEntry],
default_vals: &[WriteVal],
dst_col: ColumnId,
) -> Result<Variable, QueryError> {
let table_info = self.table_info(table);
self.validate_keys(table, table_info, args)?;
self.validate_vals(table, table_info, default_vals.iter())?;
let res = self.qb.new_var();
self.qb.instrs.push(Instr::LookupOrInsertDefault {
table,
args: args.to_vec(),
default: default_vals.to_vec(),
dst_col,
dst_var: res,
});
self.qb.mark_used(args);
self.qb
.mark_used(default_vals.iter().filter_map(|x| match x {
WriteVal::QueryEntry(qe) => Some(qe),
WriteVal::IncCounter(_) | WriteVal::CurrentVal(_) => None,
}));
self.qb.mark_defined(&res.into());
Ok(res)
}
pub fn lookup_with_default(
&mut self,
table: TableId,
args: &[QueryEntry],
default: QueryEntry,
dst_col: ColumnId,
) -> Result<Variable, QueryError> {
let table_info = self.table_info(table);
self.validate_keys(table, table_info, args)?;
let res = self.qb.new_var();
self.qb.instrs.push(Instr::LookupWithDefault {
table,
args: args.to_vec(),
dst_col,
dst_var: res,
default,
});
self.qb.mark_used(args);
self.qb.mark_used(&[default]);
self.qb.mark_defined(&res.into());
Ok(res)
}
pub fn lookup(
&mut self,
table: TableId,
args: &[QueryEntry],
dst_col: ColumnId,
) -> Result<Variable, QueryError> {
let table_info = self.table_info(table);
self.validate_keys(table, table_info, args)?;
let res = self.qb.new_var();
self.qb.instrs.push(Instr::Lookup {
table,
args: args.to_vec(),
dst_col,
dst_var: res,
});
self.qb.mark_used(args);
self.qb.mark_defined(&res.into());
Ok(res)
}
pub fn insert(&mut self, table: TableId, vals: &[QueryEntry]) -> Result<(), QueryError> {
let table_info = self.table_info(table);
self.validate_row(table, table_info, vals)?;
self.qb.instrs.push(Instr::Insert {
table,
vals: vals.to_vec(),
});
self.qb.mark_used(vals);
Ok(())
}
pub fn insert_if_eq(
&mut self,
table: TableId,
l: QueryEntry,
r: QueryEntry,
vals: &[QueryEntry],
) -> Result<(), QueryError> {
let table_info = self.table_info(table);
self.validate_row(table, table_info, vals)?;
self.qb.instrs.push(Instr::InsertIfEq {
table,
l,
r,
vals: vals.to_vec(),
});
self.qb
.mark_used(vals.iter().chain(once(&l)).chain(once(&r)));
Ok(())
}
pub fn remove(&mut self, table: TableId, args: &[QueryEntry]) -> Result<(), QueryError> {
let table_info = self.table_info(table);
self.validate_keys(table, table_info, args)?;
self.qb.instrs.push(Instr::Remove {
table,
args: args.to_vec(),
});
self.qb.mark_used(args);
Ok(())
}
pub fn call_external(
&mut self,
func: ExternalFunctionId,
args: &[QueryEntry],
) -> Result<Variable, QueryError> {
let res = self.qb.new_var();
self.qb.instrs.push(Instr::External {
func,
args: args.to_vec(),
dst: res,
});
self.qb.mark_used(args);
self.qb.mark_defined(&res.into());
Ok(res)
}
pub fn lookup_with_fallback(
&mut self,
table: TableId,
key: &[QueryEntry],
dst_col: ColumnId,
func: ExternalFunctionId,
func_args: &[QueryEntry],
) -> Result<Variable, QueryError> {
let table_info = self.table_info(table);
self.validate_keys(table, table_info, key)?;
let res = self.qb.new_var();
self.qb.instrs.push(Instr::LookupWithFallback {
table,
table_key: key.to_vec(),
func,
func_args: func_args.to_vec(),
dst_var: res,
dst_col,
});
self.qb.mark_used(key);
self.qb.mark_used(func_args);
self.qb.mark_defined(&res.into());
Ok(res)
}
pub fn call_external_with_fallback(
&mut self,
f1: ExternalFunctionId,
args1: &[QueryEntry],
f2: ExternalFunctionId,
args2: &[QueryEntry],
) -> Result<Variable, QueryError> {
let res = self.qb.new_var();
self.qb.instrs.push(Instr::ExternalWithFallback {
f1,
args1: args1.to_vec(),
f2,
args2: args2.to_vec(),
dst: res,
});
self.qb.mark_used(args1);
self.qb.mark_used(args2);
self.qb.mark_defined(&res.into());
Ok(res)
}
pub fn assert_eq(&mut self, l: QueryEntry, r: QueryEntry) {
self.qb.instrs.push(Instr::AssertEq(l, r));
self.qb.mark_used(&[l, r]);
}
pub fn assert_ne(&mut self, l: QueryEntry, r: QueryEntry) -> Result<(), QueryError> {
self.qb.instrs.push(Instr::AssertNe(l, r));
self.qb.mark_used(&[l, r]);
Ok(())
}
pub fn assert_any_ne(&mut self, l: &[QueryEntry], r: &[QueryEntry]) -> Result<(), QueryError> {
if l.len() != r.len() {
return Err(QueryError::MultiComparisonMismatch {
l: l.len(),
r: r.len(),
});
}
let mut ops = Vec::with_capacity(l.len() + r.len());
ops.extend_from_slice(l);
ops.extend_from_slice(r);
self.qb.instrs.push(Instr::AssertAnyNe {
ops,
divider: l.len(),
});
self.qb.mark_used(l);
self.qb.mark_used(r);
Ok(())
}
fn validate_row(
&self,
table: TableId,
info: &TableInfo,
vals: &[QueryEntry],
) -> Result<(), QueryError> {
if vals.len() != info.spec.arity() {
Err(QueryError::TableArityMismatch {
table,
expected: info.spec.arity(),
got: vals.len(),
})
} else {
Ok(())
}
}
fn validate_keys(
&self,
table: TableId,
info: &TableInfo,
keys: &[QueryEntry],
) -> Result<(), QueryError> {
if keys.len() != info.spec.n_keys {
Err(QueryError::KeyArityMismatch {
table,
expected: info.spec.n_keys,
got: keys.len(),
})
} else {
Ok(())
}
}
fn validate_vals<'b>(
&self,
table: TableId,
info: &TableInfo,
vals: impl Iterator<Item = &'b WriteVal>,
) -> Result<(), QueryError> {
for (i, _) in vals.enumerate() {
let col = i + info.spec.n_keys;
if col >= info.spec.arity() {
return Err(QueryError::TableArityMismatch {
table,
expected: info.spec.arity(),
got: col,
});
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub(crate) struct Atom {
pub(crate) table: TableId,
pub(crate) var_to_column: HashMap<Variable, ColumnId>,
pub(crate) column_to_var: DenseIdMap<ColumnId, Variable>,
pub(crate) constraints: ProcessedConstraints,
}
pub(crate) struct Query {
pub(crate) var_info: DenseIdMap<Variable, VarInfo>,
pub(crate) atoms: DenseIdMap<AtomId, Atom>,
pub(crate) action: ActionId,
pub(crate) plan_strategy: PlanStrategy,
}