Skip to main content

egglog_core_relations/row_buffer/
mod.rs

1//! A basic data-structure encapsulating a batch of rows.
2
3use core::slice;
4use std::{cell::Cell, mem, ops::Deref};
5
6use crate::numeric_id::NumericId;
7use egglog_concurrency::{ParallelVecWriter, parallel_writer::write_cell_slice};
8use rayon::iter::ParallelIterator;
9use smallvec::SmallVec;
10
11use crate::{
12    common::Value,
13    offsets::RowId,
14    pool::{Pooled, with_pool_set},
15};
16
17#[cfg(test)]
18mod tests;
19
20/// A batch of rows. This is a common enough pattern that it makes sense to make
21/// it its own data-structure. The advantage of this abstraction is that it
22/// allows us to store multiple rows in a single allocation.
23///
24/// RowBuffer stores data in row-major order.
25pub struct RowBuffer {
26    n_columns: usize,
27    total_rows: usize,
28    data: Pooled<Vec<Cell<Value>>>,
29}
30
31// Safety constraints for RowBuffer.
32//
33// All of the unsafe code in RowBuffer is due to the use of `Cell<Value>` for
34// the backing `data`. We do not want to expose raw `Cell`s to users (they
35// complicate the API), but every use-case for RowBuffer uses entries in data
36// like normal values _but one_: that is the `set_stale_shared` method. See the
37// documentation for that method for more context.
38//
39// This method enabled multiple threads to write to exclusive rows in the table
40// without performing any additional synchronization, or slowing down future
41// readers by requiring atomic operations for every read.
42unsafe impl Send for RowBuffer {}
43unsafe impl Sync for RowBuffer {}
44
45impl Clone for RowBuffer {
46    fn clone(&self) -> Self {
47        RowBuffer {
48            n_columns: self.n_columns,
49            total_rows: self.total_rows,
50            data: Pooled::cloned(&self.data),
51        }
52    }
53}
54
55impl RowBuffer {
56    /// Create a new RowBuffer with the given arity.
57    pub(crate) fn new(n_columns: usize) -> RowBuffer {
58        assert_ne!(
59            n_columns, 0,
60            "attempting to create a row batch with no columns"
61        );
62        RowBuffer {
63            n_columns,
64            total_rows: 0,
65            data: with_pool_set(|ps| ps.get()),
66        }
67    }
68
69    pub(crate) fn parallel_writer(&mut self) -> ParallelRowBufWriter {
70        let data = mem::take(&mut self.data);
71        ParallelRowBufWriter {
72            buf: RowBuffer {
73                n_columns: self.n_columns,
74                total_rows: self.total_rows,
75                data: Default::default(),
76            },
77            vec: ParallelVecWriter::new(Pooled::into_inner(data)),
78        }
79    }
80
81    /// Reserve space for `additional` rows.
82    pub(crate) fn reserve(&mut self, additional: usize) {
83        self.data.reserve(additional * self.n_columns);
84    }
85
86    /// The size of the rows accepted by this buffer.
87    pub(crate) fn arity(&self) -> usize {
88        self.n_columns
89    }
90
91    pub(crate) fn raw_rows(&self) -> *const Value {
92        self.data.as_ptr() as *const Value
93    }
94
95    /// Blindly set the length of the RowBuffer to the given number of rows.
96    ///
97    /// # Safety
98    /// `count` must be within the capacity of the RowBuffer and the resized buffer must point to
99    /// initialized memory. (Analogous to [`Vec::set_len`]).
100    pub(crate) unsafe fn set_len(&mut self, count: usize) {
101        unsafe {
102            self.data.set_len(count * self.n_columns);
103        }
104        self.total_rows = count;
105    }
106
107    /// Return an iterator over the non-stale rows in the buffer.
108    pub(crate) fn non_stale(&self) -> impl Iterator<Item = &[Value]> {
109        self.data
110            .chunks(self.n_columns)
111            .filter(|row| !row[0].get().is_stale())
112            // SAFETY: This kind of transmutation is safe so long as no one
113            // modifies any of the values behind the `Cell` while this value is
114            // borrowed.
115            //
116            // The only time we modify these values is in safe methods requiring
117            // a mutable reference (`set_stale`, `get_row_mut`), or in the
118            // unsafe `set_stale_shared` method whose safety requirements imply
119            // that no call will overlap with borrowing such a row.
120            .map(|row| unsafe { mem::transmute::<&[Cell<Value>], &[Value]>(row) })
121    }
122
123    pub(crate) fn non_stale_mut(&mut self) -> impl Iterator<Item = &mut [Value]> {
124        self.data
125            .chunks_mut(self.n_columns)
126            .filter(|row| !row[0].get().is_stale())
127            // SAFETY: This kind of transmutation is safe so long as no one
128            // modifies any of the values behind the `Cell` while this value is
129            // borrowed.
130            //
131            // The only time we modify these values is in safe methods requiring
132            // a mutable reference (`set_stale`, `get_row_mut`), or in the
133            // unsafe `set_stale_shared` method whose safety requirements imply
134            // that no call will overlap with borrowing such a row.
135            .map(|row| unsafe { mem::transmute::<&mut [Cell<Value>], &mut [Value]>(row) })
136    }
137
138    /// A parallel version of [`RowBuffer::iter`].
139    pub(crate) fn parallel_iter(&self) -> impl ParallelIterator<Item = &[Value]> {
140        use rayon::prelude::*;
141        // SAFETY: This kind of transmutation is safe so long as no one
142        // modifies any of the values behind the `Cell` while this value is
143        // borrowed.
144        //
145        // The only time we modify these values is in safe methods requiring
146        // a mutable reference (`set_stale`, `get_row_mut`), or in the
147        // unsafe `set_stale_shared` method whose safety requirements imply
148        // that no call will overlap with borrowing such a row.
149        unsafe { mem::transmute::<&[Cell<Value>], &[Value]>(&self.data) }.par_chunks(self.n_columns)
150    }
151
152    /// Return an iterator over all rows in the buffer.
153    pub(crate) fn iter(&self) -> impl Iterator<Item = &[Value]> {
154        self.data
155            .chunks(self.n_columns)
156            // SAFETY: see comment in `non_stale`.
157            .map(|row| unsafe { mem::transmute::<&[Cell<Value>], &[Value]>(row) })
158    }
159
160    /// Clear the contents of the buffer.
161    pub(crate) fn clear(&mut self) {
162        self.data.clear();
163        self.total_rows = 0;
164    }
165
166    /// The number of rows in the buffer.
167    pub(crate) fn len(&self) -> usize {
168        self.total_rows
169    }
170
171    /// Mark a row as stale in the buffer with shared access to it. Returns
172    /// whether the row was already stale.
173    ///
174    /// # Safety
175    /// This method is unsafe because we implement `Send` and `Sync` for the
176    /// `RowBuffer` type. That means that you can call `set_stale_shared(row)`
177    /// and `get_row(row)` concurrently, which would be a data race.
178    ///
179    /// To safely use this method, you must ensure that there are no concurrent reads or writes to
180    /// `row`. Indeed, that is what this method is for: parallel writes to exclusive rows in a
181    /// shared `RowBuffer`. Any other use-case should use the [`RowBuffer::set_stale`] method,
182    /// which requires a mutable reference.
183    pub(crate) unsafe fn set_stale_shared(&self, row: RowId) -> bool {
184        let cells = &self.data[row.index() * self.n_columns..(row.index() + 1) * self.n_columns];
185        let was_stale = cells[0].get().is_stale();
186        cells[0].set(Value::stale());
187        was_stale
188    }
189
190    /// Get the row corresponding to the given RowId.
191    ///
192    /// # Panics
193    /// This method panics if `row` is out of bounds.
194    pub(crate) fn get_row(&self, row: RowId) -> &[Value] {
195        // SAFETY: see the comment in `non_stale`.
196        unsafe { get_row(&self.data, self.n_columns, row) }
197    }
198
199    /// Get the row corresponding to the given RowId without bounds checking.
200    pub(crate) unsafe fn get_row_unchecked(&self, row: RowId) -> &[Value] {
201        unsafe {
202            slice::from_raw_parts(
203                self.data.as_ptr().add(row.index() * self.n_columns) as *const Value,
204                self.n_columns,
205            )
206        }
207    }
208
209    /// Get a mutable reference to the row corresponding to the given RowId.
210    ///
211    /// # Panics
212    /// This method panics if `row` is out of bounds.
213    pub(crate) fn get_row_mut(&mut self, row: RowId) -> &mut [Value] {
214        // SAFETY: see the comment in `non_stale`.
215        unsafe {
216            mem::transmute::<&mut [Cell<Value>], &mut [Value]>(
217                &mut self.data[row.index() * self.n_columns..(row.index() + 1) * self.n_columns],
218            )
219        }
220    }
221
222    /// Set the given row to be stale. By convention, this calls `set_stale` on
223    /// the first column in the row. Returns whether the row was already stale.
224    ///
225    /// # Panics
226    /// This method panics if `row` is out of bounds.
227    pub(crate) fn set_stale(&mut self, row: RowId) -> bool {
228        let row = self.get_row_mut(row);
229        let res = row[0].is_stale();
230        row[0].set_stale();
231        res
232    }
233
234    /// Insert a row into a buffer, returning the RowId for this row.
235    ///
236    /// # Panics
237    /// This method panics if the length of `row` does not match the arity of
238    /// the RowBuffer.
239    pub(crate) fn add_row(&mut self, row: &[Value]) -> RowId {
240        assert_eq!(
241            row.len(),
242            self.n_columns,
243            "attempting to add a row with mismatched arity to table"
244        );
245        if self.total_rows == 0 {
246            Pooled::refresh(&mut self.data);
247        }
248        let res = RowId::from_usize(self.total_rows);
249        self.data.extend(row.iter().copied().map(Cell::new));
250        self.total_rows += 1;
251        res
252    }
253
254    /// Remove any stale entries in the buffer. This invalidates existing
255    /// RowIds. This method calls `remap` with the old and new RowIds for all
256    /// non-stale rows.
257    pub(crate) fn remove_stale(&mut self, mut remap: impl FnMut(&[Value], RowId, RowId)) {
258        let mut within_row = 0;
259        let mut row_in = 0;
260        let mut row_out = 0;
261        let mut keep_row = true;
262        let mut scratch = SmallVec::<[Value; 8]>::new();
263        self.data.retain(|entry| {
264            if within_row == 0 {
265                keep_row = !entry.get().is_stale();
266                if keep_row {
267                    scratch.push(entry.get());
268                    row_out += 1;
269                }
270                row_in += 1;
271            } else if keep_row {
272                scratch.push(entry.get());
273            }
274            within_row += 1;
275            if within_row == self.n_columns {
276                within_row = 0;
277                if keep_row {
278                    remap(&scratch, RowId::new(row_in - 1), RowId::new(row_out - 1));
279                    scratch.clear();
280                }
281            }
282            keep_row
283        });
284        self.total_rows = row_out as usize;
285    }
286}
287
288/// A `TaggedRowBuffer` wraps a `RowBuffer` but also keeps track of a _source_
289/// `RowId` for the row it contains. This makes it useful for materializing
290/// the contents of a `Subset` of a table.
291pub struct TaggedRowBuffer {
292    inner: RowBuffer,
293}
294
295impl TaggedRowBuffer {
296    /// Create a new buffer with the given arity.
297    pub fn new(n_columns: usize) -> TaggedRowBuffer {
298        TaggedRowBuffer {
299            inner: RowBuffer::new(n_columns + 1),
300        }
301    }
302
303    /// Clear the contents of the buffer.
304    pub fn clear(&mut self) {
305        self.inner.clear()
306    }
307
308    /// Whether the buffer is empty.
309    pub fn is_empty(&self) -> bool {
310        self.inner.len() == 0
311    }
312
313    /// The number of rows in the buffer.
314    pub fn len(&self) -> usize {
315        self.inner.len()
316    }
317
318    fn base_arity(&self) -> usize {
319        self.inner.n_columns - 1
320    }
321
322    /// Add the given row and RowId to the buffer, returning the RowId (in
323    /// `self`) for the new row.
324    pub fn add_row(&mut self, row_id: RowId, row: &[Value]) -> RowId {
325        // Variant of `RowBuffer::add_row` that also stores the given `RowId` inline.
326        //
327        // Changes to the implementation of one method should probably also
328        // change the other.
329        assert_eq!(
330            row.len(),
331            self.base_arity(),
332            "attempting to add a row with mismatched arity to table"
333        );
334        if self.inner.total_rows == 0 {
335            Pooled::refresh(&mut self.inner.data);
336        }
337        let res = RowId::from_usize(self.inner.total_rows);
338        self.inner.data.extend(row.iter().copied().map(Cell::new));
339        self.inner.data.push(Cell::new(Value::new(row_id.rep())));
340        self.inner.total_rows += 1;
341        res
342    }
343
344    /// Get the row (and the id it was associated with at insertion time) at the
345    /// offset associated with `row`.
346    pub fn get_row(&self, row: RowId) -> (RowId, &[Value]) {
347        self.unwrap_row(self.inner.get_row(row))
348    }
349
350    pub fn get_row_mut(&mut self, row: RowId) -> (RowId, &mut [Value]) {
351        let base_arity = self.base_arity();
352        let row = self.inner.get_row_mut(row);
353        let row_id = row[base_arity];
354        let row = &mut row[..base_arity];
355        (RowId::new(row_id.rep()), row)
356    }
357
358    /// Iterate over the contents of the buffer.
359    pub fn iter(&self) -> impl Iterator<Item = (RowId, &[Value])> {
360        self.inner.iter().map(|row| self.unwrap_row(row))
361    }
362
363    /// Iterate over the contents of the buffer in parallel.
364    pub fn par_iter(&self) -> impl ParallelIterator<Item = (RowId, &[Value])> {
365        self.inner.parallel_iter().map(|row| self.unwrap_row(row))
366    }
367
368    /// Iterate over all rows in the buffer, except for the stale ones.
369    pub fn non_stale(&self) -> impl Iterator<Item = (RowId, &[Value])> {
370        self.inner.non_stale().map(|row| self.unwrap_row(row))
371    }
372
373    /// Iterate over all rows in the buffer, except for the stale ones.
374    pub fn non_stale_mut(&mut self) -> impl Iterator<Item = (RowId, &mut [Value])> {
375        let base_arity = self.base_arity();
376        self.inner
377            .non_stale_mut()
378            .map(move |row| Self::unwrap_row_mut(base_arity, row))
379    }
380
381    pub fn set_stale(&mut self, row: RowId) -> bool {
382        self.inner.set_stale(row)
383    }
384
385    fn unwrap_row<'a>(&self, row: &'a [Value]) -> (RowId, &'a [Value]) {
386        let row_id = row[self.base_arity()];
387        let row = &row[..self.base_arity()];
388        (RowId::new(row_id.rep()), row)
389    }
390    fn unwrap_row_mut(base_arity: usize, row: &mut [Value]) -> (RowId, &mut [Value]) {
391        let row_id = row[base_arity];
392        let row = &mut row[..base_arity];
393        (RowId::new(row_id.rep()), row)
394    }
395}
396
397/// # Safety
398/// This function is safe so long as there are no concurrent writes to the given
399/// row.
400unsafe fn get_row(data: &[Cell<Value>], n_columns: usize, row: RowId) -> &[Value] {
401    unsafe {
402        mem::transmute::<&[Cell<Value>], &[Value]>(
403            &data[row.index() * n_columns..(row.index() + 1) * n_columns],
404        )
405    }
406}
407
408/// A wrapper for a RowBuffer that allows it to be written to in parallel, based
409/// on [`ParallelVecWriter`].
410///
411/// This is a type that is used to speed up parallel `merge` operations on
412/// `SortedWritesTable`. It uses a low-level interface that should be avoided in
413/// most cases.
414pub(crate) struct ParallelRowBufWriter {
415    buf: RowBuffer,
416    vec: ParallelVecWriter<Cell<Value>>,
417}
418
419impl ParallelRowBufWriter {
420    pub(crate) fn read_handle(&self) -> ReadHandle<'_, impl Deref<Target = [Cell<Value>]> + '_> {
421        ReadHandle {
422            buf: &self.buf,
423            data: self.vec.read_access(),
424        }
425    }
426
427    pub(crate) fn append_contents(&self, rows: &RowBuffer) -> RowId {
428        assert_eq!(rows.n_columns, self.buf.n_columns);
429        let start_off = write_cell_slice(&self.vec, rows.data.as_slice());
430        debug_assert_eq!(start_off % self.buf.n_columns, 0);
431        RowId::from_usize(start_off / self.buf.n_columns)
432    }
433
434    pub(crate) fn finish(mut self) -> RowBuffer {
435        self.buf.data = Pooled::new(self.vec.finish());
436        self.buf.total_rows = self.buf.data.len() / self.buf.n_columns;
437        self.buf
438    }
439}
440
441/// A handle granting read access to a row buffer's contents.
442pub(crate) struct ReadHandle<'a, T> {
443    buf: &'a RowBuffer,
444    data: T,
445}
446
447impl<T: Deref<Target = [Cell<Value>]>> ReadHandle<'_, T> {
448    /// Get the row corresponding to the given RowId without bounds checking.
449    ///
450    /// # Safety
451    /// The caller must ensure that either `row` is within bounds of the buffer at the creation of
452    /// this handle, or that the row was successfully written to the buffer before it was called.
453    ///
454    /// Furthermore, no calls to `set_stale_shared` may overlap with this call.
455    pub(crate) unsafe fn get_row_unchecked(&self, row: RowId) -> &[Value] {
456        // SAFETY: ParallelVecWriter guarantees that data within bounds is not
457        // being modified concurrently.
458        unsafe {
459            std::slice::from_raw_parts(
460                self.data.as_ptr().add(row.index() * self.buf.n_columns) as *const Value,
461                self.buf.n_columns,
462            )
463        }
464    }
465
466    /// See the documentation for [`RowBuffer::set_stale_shared`].
467    ///
468    /// In addition to the requirements there, `row` is allowed to be out of bounds of the initial
469    /// length of the wrapped vector, but any out-of-bounds row must be in bounds of a (previously
470    /// completed) write.
471    pub(crate) unsafe fn set_stale_shared(&self, row: RowId) -> bool {
472        let cells: &[Cell<Value>] = &self.data;
473        let cell_ptr: *const Cell<Value> = cells.as_ptr();
474        let to_set: &Cell<Value> = unsafe { &*cell_ptr.add(row.index() * self.buf.n_columns) };
475        let was_stale = to_set.get().is_stale();
476        to_set.set(Value::stale());
477        was_stale
478    }
479}