use core::slice;
use std::{cell::Cell, mem, ops::Deref};
use crate::numeric_id::NumericId;
use egglog_concurrency::{ParallelVecWriter, parallel_writer::write_cell_slice};
use rayon::iter::ParallelIterator;
use smallvec::SmallVec;
use crate::{
common::Value,
offsets::RowId,
pool::{Pooled, with_pool_set},
};
#[cfg(test)]
mod tests;
pub struct RowBuffer {
n_columns: usize,
total_rows: usize,
data: Pooled<Vec<Cell<Value>>>,
}
unsafe impl Send for RowBuffer {}
unsafe impl Sync for RowBuffer {}
impl Clone for RowBuffer {
fn clone(&self) -> Self {
RowBuffer {
n_columns: self.n_columns,
total_rows: self.total_rows,
data: Pooled::cloned(&self.data),
}
}
}
impl RowBuffer {
pub(crate) fn new(n_columns: usize) -> RowBuffer {
assert_ne!(
n_columns, 0,
"attempting to create a row batch with no columns"
);
RowBuffer {
n_columns,
total_rows: 0,
data: with_pool_set(|ps| ps.get()),
}
}
pub(crate) fn parallel_writer(&mut self) -> ParallelRowBufWriter {
let data = mem::take(&mut self.data);
ParallelRowBufWriter {
buf: RowBuffer {
n_columns: self.n_columns,
total_rows: self.total_rows,
data: Default::default(),
},
vec: ParallelVecWriter::new(Pooled::into_inner(data)),
}
}
pub(crate) fn reserve(&mut self, additional: usize) {
self.data.reserve(additional * self.n_columns);
}
pub(crate) fn arity(&self) -> usize {
self.n_columns
}
pub(crate) fn raw_rows(&self) -> *const Value {
self.data.as_ptr() as *const Value
}
pub(crate) unsafe fn set_len(&mut self, count: usize) {
unsafe {
self.data.set_len(count * self.n_columns);
}
self.total_rows = count;
}
pub(crate) fn non_stale(&self) -> impl Iterator<Item = &[Value]> {
self.data
.chunks(self.n_columns)
.filter(|row| !row[0].get().is_stale())
.map(|row| unsafe { mem::transmute::<&[Cell<Value>], &[Value]>(row) })
}
pub(crate) fn non_stale_mut(&mut self) -> impl Iterator<Item = &mut [Value]> {
self.data
.chunks_mut(self.n_columns)
.filter(|row| !row[0].get().is_stale())
.map(|row| unsafe { mem::transmute::<&mut [Cell<Value>], &mut [Value]>(row) })
}
pub(crate) fn parallel_iter(&self) -> impl ParallelIterator<Item = &[Value]> {
use rayon::prelude::*;
unsafe { mem::transmute::<&[Cell<Value>], &[Value]>(&self.data) }.par_chunks(self.n_columns)
}
pub(crate) fn iter(&self) -> impl Iterator<Item = &[Value]> {
self.data
.chunks(self.n_columns)
.map(|row| unsafe { mem::transmute::<&[Cell<Value>], &[Value]>(row) })
}
pub(crate) fn clear(&mut self) {
self.data.clear();
self.total_rows = 0;
}
pub(crate) fn len(&self) -> usize {
self.total_rows
}
pub(crate) unsafe fn set_stale_shared(&self, row: RowId) -> bool {
let cells = &self.data[row.index() * self.n_columns..(row.index() + 1) * self.n_columns];
let was_stale = cells[0].get().is_stale();
cells[0].set(Value::stale());
was_stale
}
pub(crate) fn get_row(&self, row: RowId) -> &[Value] {
unsafe { get_row(&self.data, self.n_columns, row) }
}
pub(crate) unsafe fn get_row_unchecked(&self, row: RowId) -> &[Value] {
unsafe {
slice::from_raw_parts(
self.data.as_ptr().add(row.index() * self.n_columns) as *const Value,
self.n_columns,
)
}
}
pub(crate) fn get_row_mut(&mut self, row: RowId) -> &mut [Value] {
unsafe {
mem::transmute::<&mut [Cell<Value>], &mut [Value]>(
&mut self.data[row.index() * self.n_columns..(row.index() + 1) * self.n_columns],
)
}
}
pub(crate) fn set_stale(&mut self, row: RowId) -> bool {
let row = self.get_row_mut(row);
let res = row[0].is_stale();
row[0].set_stale();
res
}
pub(crate) fn add_row(&mut self, row: &[Value]) -> RowId {
assert_eq!(
row.len(),
self.n_columns,
"attempting to add a row with mismatched arity to table"
);
if self.total_rows == 0 {
Pooled::refresh(&mut self.data);
}
let res = RowId::from_usize(self.total_rows);
self.data.extend(row.iter().copied().map(Cell::new));
self.total_rows += 1;
res
}
pub(crate) fn remove_stale(&mut self, mut remap: impl FnMut(&[Value], RowId, RowId)) {
let mut within_row = 0;
let mut row_in = 0;
let mut row_out = 0;
let mut keep_row = true;
let mut scratch = SmallVec::<[Value; 8]>::new();
self.data.retain(|entry| {
if within_row == 0 {
keep_row = !entry.get().is_stale();
if keep_row {
scratch.push(entry.get());
row_out += 1;
}
row_in += 1;
} else if keep_row {
scratch.push(entry.get());
}
within_row += 1;
if within_row == self.n_columns {
within_row = 0;
if keep_row {
remap(&scratch, RowId::new(row_in - 1), RowId::new(row_out - 1));
scratch.clear();
}
}
keep_row
});
self.total_rows = row_out as usize;
}
}
pub struct TaggedRowBuffer {
inner: RowBuffer,
}
impl TaggedRowBuffer {
pub fn new(n_columns: usize) -> TaggedRowBuffer {
TaggedRowBuffer {
inner: RowBuffer::new(n_columns + 1),
}
}
pub fn clear(&mut self) {
self.inner.clear()
}
pub fn is_empty(&self) -> bool {
self.inner.len() == 0
}
pub fn len(&self) -> usize {
self.inner.len()
}
fn base_arity(&self) -> usize {
self.inner.n_columns - 1
}
pub fn add_row(&mut self, row_id: RowId, row: &[Value]) -> RowId {
assert_eq!(
row.len(),
self.base_arity(),
"attempting to add a row with mismatched arity to table"
);
if self.inner.total_rows == 0 {
Pooled::refresh(&mut self.inner.data);
}
let res = RowId::from_usize(self.inner.total_rows);
self.inner.data.extend(row.iter().copied().map(Cell::new));
self.inner.data.push(Cell::new(Value::new(row_id.rep())));
self.inner.total_rows += 1;
res
}
pub fn get_row(&self, row: RowId) -> (RowId, &[Value]) {
self.unwrap_row(self.inner.get_row(row))
}
pub fn get_row_mut(&mut self, row: RowId) -> (RowId, &mut [Value]) {
let base_arity = self.base_arity();
let row = self.inner.get_row_mut(row);
let row_id = row[base_arity];
let row = &mut row[..base_arity];
(RowId::new(row_id.rep()), row)
}
pub fn iter(&self) -> impl Iterator<Item = (RowId, &[Value])> {
self.inner.iter().map(|row| self.unwrap_row(row))
}
pub fn par_iter(&self) -> impl ParallelIterator<Item = (RowId, &[Value])> {
self.inner.parallel_iter().map(|row| self.unwrap_row(row))
}
pub fn non_stale(&self) -> impl Iterator<Item = (RowId, &[Value])> {
self.inner.non_stale().map(|row| self.unwrap_row(row))
}
pub fn non_stale_mut(&mut self) -> impl Iterator<Item = (RowId, &mut [Value])> {
let base_arity = self.base_arity();
self.inner
.non_stale_mut()
.map(move |row| Self::unwrap_row_mut(base_arity, row))
}
pub fn set_stale(&mut self, row: RowId) -> bool {
self.inner.set_stale(row)
}
fn unwrap_row<'a>(&self, row: &'a [Value]) -> (RowId, &'a [Value]) {
let row_id = row[self.base_arity()];
let row = &row[..self.base_arity()];
(RowId::new(row_id.rep()), row)
}
fn unwrap_row_mut(base_arity: usize, row: &mut [Value]) -> (RowId, &mut [Value]) {
let row_id = row[base_arity];
let row = &mut row[..base_arity];
(RowId::new(row_id.rep()), row)
}
}
unsafe fn get_row(data: &[Cell<Value>], n_columns: usize, row: RowId) -> &[Value] {
unsafe {
mem::transmute::<&[Cell<Value>], &[Value]>(
&data[row.index() * n_columns..(row.index() + 1) * n_columns],
)
}
}
pub(crate) struct ParallelRowBufWriter {
buf: RowBuffer,
vec: ParallelVecWriter<Cell<Value>>,
}
impl ParallelRowBufWriter {
pub(crate) fn read_handle(&self) -> ReadHandle<'_, impl Deref<Target = [Cell<Value>]> + '_> {
ReadHandle {
buf: &self.buf,
data: self.vec.read_access(),
}
}
pub(crate) fn append_contents(&self, rows: &RowBuffer) -> RowId {
assert_eq!(rows.n_columns, self.buf.n_columns);
let start_off = write_cell_slice(&self.vec, rows.data.as_slice());
debug_assert_eq!(start_off % self.buf.n_columns, 0);
RowId::from_usize(start_off / self.buf.n_columns)
}
pub(crate) fn finish(mut self) -> RowBuffer {
self.buf.data = Pooled::new(self.vec.finish());
self.buf.total_rows = self.buf.data.len() / self.buf.n_columns;
self.buf
}
}
pub(crate) struct ReadHandle<'a, T> {
buf: &'a RowBuffer,
data: T,
}
impl<T: Deref<Target = [Cell<Value>]>> ReadHandle<'_, T> {
pub(crate) unsafe fn get_row_unchecked(&self, row: RowId) -> &[Value] {
unsafe {
std::slice::from_raw_parts(
self.data.as_ptr().add(row.index() * self.buf.n_columns) as *const Value,
self.buf.n_columns,
)
}
}
pub(crate) unsafe fn set_stale_shared(&self, row: RowId) -> bool {
let cells: &[Cell<Value>] = &self.data;
let cell_ptr: *const Cell<Value> = cells.as_ptr();
let to_set: &Cell<Value> = unsafe { &*cell_ptr.add(row.index() * self.buf.n_columns) };
let was_stale = to_set.get().is_stale();
to_set.set(Value::stale());
was_stale
}
}