use std::{
ops::Deref,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
};
use crate::{
common::HashMap,
free_join::{invoke_batch, invoke_batch_assign},
numeric_id::{DenseIdMap, NumericId},
};
use egglog_concurrency::NotificationList;
use smallvec::SmallVec;
use crate::{
BaseValues, ContainerValues, ExternalFunctionId, WrappedTable,
common::Value,
free_join::{CounterId, Counters, ExternalFunctions, TableId, TableInfo, Variable},
pool::{Clear, Pooled, with_pool_set},
table_spec::{ColumnId, MutationBuffer},
};
use self::mask::{Mask, MaskIter, ValueSource};
#[macro_use]
pub(crate) mod mask;
#[cfg(test)]
mod tests;
#[derive(Copy, Clone, Debug)]
pub enum QueryEntry {
Var(Variable),
Const(Value),
}
impl From<Variable> for QueryEntry {
fn from(var: Variable) -> Self {
QueryEntry::Var(var)
}
}
impl From<Value> for QueryEntry {
fn from(val: Value) -> Self {
QueryEntry::Const(val)
}
}
#[derive(Debug, Clone, Copy)]
pub enum WriteVal {
QueryEntry(QueryEntry),
IncCounter(CounterId),
CurrentVal(usize),
}
impl<T> From<T> for WriteVal
where
T: Into<QueryEntry>,
{
fn from(val: T) -> Self {
WriteVal::QueryEntry(val.into())
}
}
impl From<CounterId> for WriteVal {
fn from(ctr: CounterId) -> Self {
WriteVal::IncCounter(ctr)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum MergeVal {
Counter(CounterId),
Constant(Value),
}
impl From<CounterId> for MergeVal {
fn from(ctr: CounterId) -> Self {
MergeVal::Counter(ctr)
}
}
impl From<Value> for MergeVal {
fn from(val: Value) -> Self {
MergeVal::Constant(val)
}
}
pub(crate) struct Bindings {
matches: usize,
max_batch_size: usize,
data: Pooled<Vec<Value>>,
var_offsets: DenseIdMap<Variable, usize>,
}
impl std::ops::Index<Variable> for Bindings {
type Output = [Value];
fn index(&self, var: Variable) -> &[Value] {
self.get(var).unwrap()
}
}
impl std::fmt::Debug for Bindings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut table = f.debug_map();
for (var, start) in self.var_offsets.iter() {
table.entry(&var, &&self.data[*start..*start + self.matches]);
}
table.finish()
}
}
impl Bindings {
pub(crate) fn new(max_batch_size: usize) -> Self {
Bindings {
matches: 0,
max_batch_size,
data: Default::default(),
var_offsets: DenseIdMap::new(),
}
}
fn assert_invariant(&self) {
#[cfg(debug_assertions)]
{
assert!(self.matches <= self.max_batch_size);
for (var, start) in self.var_offsets.iter() {
assert!(
start + self.matches <= self.data.len(),
"Variable {:?} starts at {}, but data only has {} elements",
var,
start,
self.data.len()
);
}
}
}
pub(crate) fn clear(&mut self) {
self.matches = 0;
self.var_offsets.clear();
self.data.clear();
self.assert_invariant();
}
fn get(&self, var: Variable) -> Option<&[Value]> {
let start = self.var_offsets.get(var)?;
Some(&self.data[*start..*start + self.matches])
}
fn add_mapping(&mut self, var: Variable, vals: &[Value]) {
let start = self.data.len();
self.data.extend_from_slice(vals);
debug_assert!(vals.len() <= self.max_batch_size);
if vals.len() < self.max_batch_size {
let target_len = self.data.len() + self.max_batch_size - vals.len();
self.data.resize(target_len, Value::stale());
}
self.var_offsets.insert(var, start);
}
pub(crate) fn insert(&mut self, var: Variable, vals: &[Value]) {
if self.var_offsets.n_ids() == 0 {
self.matches = vals.len();
} else {
assert_eq!(self.matches, vals.len());
}
self.add_mapping(var, vals);
self.assert_invariant();
}
pub(crate) unsafe fn push(
&mut self,
map: &DenseIdMap<Variable, Value>,
used_vars: &[Variable],
) {
if self.matches != 0 {
assert!(self.matches < self.max_batch_size);
#[cfg(debug_assertions)]
{
for var in used_vars {
assert!(
self.var_offsets.get(*var).is_some(),
"Variable {:?} not found in bindings {:?}",
var,
self.var_offsets
);
}
}
for var in used_vars {
let var = var.index();
unsafe {
let start = self.var_offsets.raw().get_unchecked(var).unwrap_unchecked();
*self.data.get_unchecked_mut(start + self.matches) =
map.raw().get_unchecked(var).unwrap_unchecked();
}
}
} else {
for (var, val) in map.iter() {
self.add_mapping(var, &[*val]);
}
}
self.matches += 1;
self.assert_invariant();
}
pub(crate) fn take(&mut self, var: Variable) -> Option<ExtractedBinding> {
let mut vals: Pooled<Vec<Value>> = with_pool_set(|ps| ps.get());
vals.extend_from_slice(self.get(var)?);
let start = self.var_offsets.take(var)?;
Some(ExtractedBinding {
var,
offset: start,
vals,
})
}
pub(crate) fn replace(&mut self, bdg: ExtractedBinding) {
let ExtractedBinding {
var,
offset,
mut vals,
} = bdg;
assert_eq!(vals.len(), self.matches);
self.data
.splice(offset..offset + self.matches, vals.drain(..));
self.var_offsets.insert(var, offset);
}
}
pub(crate) struct ExtractedBinding {
var: Variable,
offset: usize,
pub(crate) vals: Pooled<Vec<Value>>,
}
#[derive(Default)]
pub(crate) struct PredictedVals {
#[allow(clippy::type_complexity)]
data: HashMap<(TableId, SmallVec<[Value; 3]>), Pooled<Vec<Value>>>,
}
impl Clear for PredictedVals {
fn reuse(&self) -> bool {
self.data.capacity() > 0
}
fn clear(&mut self) {
self.data.clear()
}
fn bytes(&self) -> usize {
self.data.capacity()
* (std::mem::size_of::<(TableId, SmallVec<[Value; 3]>)>()
+ std::mem::size_of::<Pooled<Vec<Value>>>())
}
}
impl PredictedVals {
pub(crate) fn get_val(
&mut self,
table: TableId,
key: &[Value],
default: impl FnOnce() -> Pooled<Vec<Value>>,
) -> impl Deref<Target = Pooled<Vec<Value>>> + '_ {
self.data
.entry((table, SmallVec::from_slice(key)))
.or_insert_with(default)
}
}
#[derive(Copy, Clone)]
pub(crate) struct DbView<'a> {
pub(crate) table_info: &'a DenseIdMap<TableId, TableInfo>,
pub(crate) counters: &'a Counters,
pub(crate) external_funcs: &'a ExternalFunctions,
pub(crate) bases: &'a BaseValues,
pub(crate) containers: &'a ContainerValues,
pub(crate) notification_list: &'a NotificationList<TableId>,
}
pub struct ExecutionState<'a> {
pub(crate) predicted: PredictedVals,
pub(crate) db: DbView<'a>,
buffers: MutationBuffers<'a>,
pub(crate) changed: bool,
stop_match: Arc<AtomicBool>,
}
struct MutationBuffers<'a> {
notify_list: &'a NotificationList<TableId>,
buffers: DenseIdMap<TableId, Box<dyn MutationBuffer>>,
}
impl Clone for MutationBuffers<'_> {
fn clone(&self) -> Self {
let mut res = MutationBuffers::new(self.notify_list, Default::default());
for (id, buf) in self.buffers.iter() {
res.buffers.insert(id, buf.fresh_handle());
}
res
}
}
impl<'a> MutationBuffers<'a> {
fn new(
notify_list: &'a NotificationList<TableId>,
buffers: DenseIdMap<TableId, Box<dyn MutationBuffer>>,
) -> MutationBuffers<'a> {
MutationBuffers {
notify_list,
buffers,
}
}
fn lazy_init(&mut self, table_id: TableId, f: impl FnOnce() -> Box<dyn MutationBuffer>) {
self.buffers.get_or_insert(table_id, f);
}
fn stage_insert(&mut self, table_id: TableId, row: &[Value]) {
self.buffers[table_id].stage_insert(row);
self.notify_list.notify(table_id);
}
fn stage_remove(&mut self, table_id: TableId, key: &[Value]) {
self.buffers[table_id].stage_remove(key);
self.notify_list.notify(table_id);
}
}
impl Clone for ExecutionState<'_> {
fn clone(&self) -> Self {
ExecutionState {
predicted: Default::default(),
db: self.db,
buffers: self.buffers.clone(),
changed: false,
stop_match: Arc::clone(&self.stop_match),
}
}
}
impl<'a> ExecutionState<'a> {
pub(crate) fn new(
db: DbView<'a>,
buffers: DenseIdMap<TableId, Box<dyn MutationBuffer>>,
) -> Self {
ExecutionState {
predicted: Default::default(),
db,
buffers: MutationBuffers::new(db.notification_list, buffers),
changed: false,
stop_match: Arc::new(AtomicBool::new(false)),
}
}
pub fn stage_insert(&mut self, table: TableId, row: &[Value]) {
self.buffers
.lazy_init(table, || self.db.table_info[table].table.new_buffer());
self.buffers.stage_insert(table, row);
self.changed = true;
}
pub fn stage_remove(&mut self, table: TableId, key: &[Value]) {
self.buffers
.lazy_init(table, || self.db.table_info[table].table.new_buffer());
self.buffers.stage_remove(table, key);
self.changed = true;
}
pub fn call_external_func(
&mut self,
func: ExternalFunctionId,
args: &[Value],
) -> Option<Value> {
self.db.external_funcs[func].invoke(self, args)
}
pub fn inc_counter(&self, ctr: CounterId) -> usize {
self.db.counters.inc(ctr)
}
pub fn read_counter(&self, ctr: CounterId) -> usize {
self.db.counters.read(ctr)
}
pub fn table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
self.db.table_info.iter().map(|(id, _)| id)
}
pub fn get_table(&self, table: TableId) -> &'a WrappedTable {
&self.db.table_info[table].table
}
pub fn table_name(&self, table: TableId) -> Option<&'a str> {
self.db.table_info[table].name()
}
pub fn base_values(&self) -> &BaseValues {
self.db.bases
}
pub fn container_values(&self) -> &'a ContainerValues {
self.db.containers
}
pub fn predict_val(
&mut self,
table: TableId,
key: &[Value],
vals: impl ExactSizeIterator<Item = MergeVal>,
) -> Pooled<Vec<Value>> {
if let Some(row) = self.db.table_info[table].table.get_row(key) {
return row.vals;
}
Pooled::cloned(
self.predicted
.get_val(table, key, || {
Self::construct_new_row(
&self.db,
&mut self.buffers,
&mut self.changed,
table,
key,
vals,
)
})
.deref(),
)
}
fn construct_new_row(
db: &DbView,
buffers: &mut MutationBuffers,
changed: &mut bool,
table: TableId,
key: &[Value],
vals: impl ExactSizeIterator<Item = MergeVal>,
) -> Pooled<Vec<Value>> {
with_pool_set(|ps| {
let mut new = ps.get::<Vec<Value>>();
new.reserve(key.len() + vals.len());
new.extend_from_slice(key);
for val in vals {
new.push(match val {
MergeVal::Counter(ctr) => Value::from_usize(db.counters.inc(ctr)),
MergeVal::Constant(c) => c,
})
}
buffers.lazy_init(table, || db.table_info[table].table.new_buffer());
buffers.stage_insert(table, &new);
*changed = true;
new
})
}
pub fn predict_col(
&mut self,
table: TableId,
key: &[Value],
vals: impl ExactSizeIterator<Item = MergeVal>,
col: ColumnId,
) -> Value {
if let Some(val) = self.db.table_info[table].table.get_row_column(key, col) {
return val;
}
self.predicted.get_val(table, key, || {
Self::construct_new_row(
&self.db,
&mut self.buffers,
&mut self.changed,
table,
key,
vals,
)
})[col.index()]
}
pub fn trigger_early_stop(&self) {
self.stop_match.store(true, Ordering::Release);
}
pub fn should_stop(&self) -> bool {
self.stop_match.load(Ordering::Acquire)
}
}
impl ExecutionState<'_> {
pub(crate) fn run_instrs(&mut self, instrs: &[Instr], bindings: &mut Bindings) -> usize {
if bindings.var_offsets.next_id().rep() == 0 {
bindings.matches = 1;
}
let mut mask = with_pool_set(|ps| Mask::new(0..bindings.matches, ps));
for instr in instrs {
if mask.is_empty() {
return 0;
}
self.run_instr(&mut mask, instr, bindings);
}
mask.count_ones()
}
fn run_instr(&mut self, mask: &mut Mask, inst: &Instr, bindings: &mut Bindings) {
fn assert_impl(
bindings: &mut Bindings,
mask: &mut Mask,
l: &QueryEntry,
r: &QueryEntry,
op: impl Fn(Value, Value) -> bool,
) {
match (l, r) {
(QueryEntry::Var(v1), QueryEntry::Var(v2)) => {
mask.iter(&bindings[*v1])
.zip(&bindings[*v2])
.retain(|(v1, v2)| op(*v1, *v2));
}
(QueryEntry::Var(v), QueryEntry::Const(c))
| (QueryEntry::Const(c), QueryEntry::Var(v)) => {
mask.iter(&bindings[*v]).retain(|v| op(*v, *c));
}
(QueryEntry::Const(c1), QueryEntry::Const(c2)) => {
if !op(*c1, *c2) {
mask.clear();
}
}
}
}
match inst {
Instr::LookupOrInsertDefault {
table: table_id,
args,
default,
dst_col,
dst_var,
} => {
let pool = with_pool_set(|ps| ps.get_pool::<Vec<Value>>().clone());
self.buffers.lazy_init(*table_id, || {
self.db.table_info[*table_id].table.new_buffer()
});
let table = &self.db.table_info[*table_id].table;
let mut mask_copy = mask.clone();
table.lookup_row_vectorized(&mut mask_copy, bindings, args, *dst_col, *dst_var);
mask_copy.symmetric_difference(mask);
if mask_copy.is_empty() {
return;
}
let mut out = bindings.take(*dst_var).unwrap();
for_each_binding_with_mask!(mask_copy, args.as_slice(), bindings, |iter| {
iter.assign_vec(&mut out.vals, |offset, key| {
let prediction_key = (
*table_id,
SmallVec::<[Value; 3]>::from_slice(key.as_slice()),
);
let buffers = &mut self.buffers;
let ctrs = &self.db.counters;
let bindings = &bindings;
let pool = pool.clone();
let row =
self.predicted
.data
.entry(prediction_key)
.or_insert_with(move || {
let mut row = pool.get();
row.extend_from_slice(key.as_slice());
row.reserve(default.len());
for val in default {
let val = match val {
WriteVal::QueryEntry(QueryEntry::Const(c)) => *c,
WriteVal::QueryEntry(QueryEntry::Var(v)) => {
bindings[*v][offset]
}
WriteVal::IncCounter(ctr) => {
Value::from_usize(ctrs.inc(*ctr))
}
WriteVal::CurrentVal(ix) => row[*ix],
};
row.push(val)
}
buffers.stage_insert(*table_id, &row);
row
});
row[dst_col.index()]
});
});
bindings.replace(out);
}
Instr::LookupWithDefault {
table,
args,
dst_col,
dst_var,
default,
} => {
let table = &self.db.table_info[*table].table;
table.lookup_with_default_vectorized(
mask, bindings, args, *dst_col, *default, *dst_var,
);
}
Instr::Lookup {
table,
args,
dst_col,
dst_var,
} => {
let table = &self.db.table_info[*table].table;
table.lookup_row_vectorized(mask, bindings, args, *dst_col, *dst_var);
}
Instr::LookupWithFallback {
table: table_id,
table_key,
func,
func_args,
dst_col,
dst_var,
} => {
let table = &self.db.table_info[*table_id].table;
let mut lookup_result = mask.clone();
table.lookup_row_vectorized(
&mut lookup_result,
bindings,
table_key,
*dst_col,
*dst_var,
);
let mut to_call_func = lookup_result.clone();
to_call_func.symmetric_difference(mask);
if to_call_func.is_empty() {
return;
}
invoke_batch_assign(
self.db.external_funcs[*func].as_ref(),
self,
&mut to_call_func,
bindings,
func_args,
*dst_var,
);
lookup_result.union(&to_call_func);
*mask = lookup_result;
}
Instr::Insert { table, vals } => {
for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| {
iter.for_each(|vals| {
self.stage_insert(*table, vals.as_slice());
})
});
}
Instr::InsertIfEq { table, l, r, vals } => match (l, r) {
(QueryEntry::Var(v1), QueryEntry::Var(v2)) => {
for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| {
iter.zip(&bindings[*v1])
.zip(&bindings[*v2])
.for_each(|((vals, v1), v2)| {
if v1 == v2 {
self.stage_insert(*table, &vals);
}
})
})
}
(QueryEntry::Var(v), QueryEntry::Const(c))
| (QueryEntry::Const(c), QueryEntry::Var(v)) => {
for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| {
iter.zip(&bindings[*v]).for_each(|(vals, cond)| {
if cond == c {
self.stage_insert(*table, &vals);
}
})
})
}
(QueryEntry::Const(c1), QueryEntry::Const(c2)) => {
if c1 == c2 {
for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| iter
.for_each(|vals| {
self.stage_insert(*table, &vals);
}))
}
}
},
Instr::Remove { table, args } => {
for_each_binding_with_mask!(mask, args.as_slice(), bindings, |iter| {
iter.for_each(|args| {
self.stage_remove(*table, args.as_slice());
})
});
}
Instr::External { func, args, dst } => {
invoke_batch(
self.db.external_funcs[*func].as_ref(),
self,
mask,
bindings,
args,
*dst,
);
}
Instr::ExternalWithFallback {
f1,
args1,
f2,
args2,
dst,
} => {
let mut f1_result = mask.clone();
invoke_batch(
self.db.external_funcs[*f1].as_ref(),
self,
&mut f1_result,
bindings,
args1,
*dst,
);
let mut to_call_f2 = f1_result.clone();
to_call_f2.symmetric_difference(mask);
if to_call_f2.is_empty() {
return;
}
invoke_batch_assign(
self.db.external_funcs[*f2].as_ref(),
self,
&mut to_call_f2,
bindings,
args2,
*dst,
);
f1_result.union(&to_call_f2);
*mask = f1_result;
}
Instr::AssertAnyNe { ops, divider } => {
for_each_binding_with_mask!(mask, ops.as_slice(), bindings, |iter| {
iter.retain(|vals| {
vals[0..*divider]
.iter()
.zip(&vals[*divider..])
.any(|(l, r)| l != r)
})
})
}
Instr::AssertEq(l, r) => assert_impl(bindings, mask, l, r, |l, r| l == r),
Instr::AssertNe(l, r) => assert_impl(bindings, mask, l, r, |l, r| l != r),
Instr::ReadCounter { counter, dst } => {
let mut vals = with_pool_set(|ps| ps.get::<Vec<Value>>());
let ctr_val = Value::from_usize(self.read_counter(*counter));
vals.resize(bindings.matches, ctr_val);
bindings.insert(*dst, &vals);
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum Instr {
LookupOrInsertDefault {
table: TableId,
args: Vec<QueryEntry>,
default: Vec<WriteVal>,
dst_col: ColumnId,
dst_var: Variable,
},
LookupWithDefault {
table: TableId,
args: Vec<QueryEntry>,
dst_col: ColumnId,
dst_var: Variable,
default: QueryEntry,
},
Lookup {
table: TableId,
args: Vec<QueryEntry>,
dst_col: ColumnId,
dst_var: Variable,
},
LookupWithFallback {
table: TableId,
table_key: Vec<QueryEntry>,
func: ExternalFunctionId,
func_args: Vec<QueryEntry>,
dst_col: ColumnId,
dst_var: Variable,
},
Insert {
table: TableId,
vals: Vec<QueryEntry>,
},
InsertIfEq {
table: TableId,
l: QueryEntry,
r: QueryEntry,
vals: Vec<QueryEntry>,
},
Remove {
table: TableId,
args: Vec<QueryEntry>,
},
External {
func: ExternalFunctionId,
args: Vec<QueryEntry>,
dst: Variable,
},
ExternalWithFallback {
f1: ExternalFunctionId,
args1: Vec<QueryEntry>,
f2: ExternalFunctionId,
args2: Vec<QueryEntry>,
dst: Variable,
},
AssertEq(QueryEntry, QueryEntry),
AssertNe(QueryEntry, QueryEntry),
AssertAnyNe {
ops: Vec<QueryEntry>,
divider: usize,
},
ReadCounter {
counter: CounterId,
dst: Variable,
},
}