egglog_core_relations/table/
mod.rs

1//! A generic table implementation supporting sorted writes.
2//!
3//! The primary difference between this table and the `Function` implementation
4//! in egglog is that high level concepts like "timestamp" and "merge function"
5//! are abstracted away from the core functionality of the table.
6
7use std::{
8    any::Any,
9    cmp,
10    hash::Hasher,
11    mem,
12    sync::{
13        Arc, Weak,
14        atomic::{AtomicUsize, Ordering},
15    },
16};
17
18use crate::numeric_id::{DenseIdMap, NumericId};
19use crossbeam_queue::SegQueue;
20use hashbrown::HashTable;
21use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
22use rustc_hash::FxHasher;
23use sharded_hash_table::ShardedHashTable;
24
25use crate::{
26    Pooled, TableChange, TableId,
27    action::ExecutionState,
28    common::{HashMap, ShardData, ShardId, SubsetTracker, Value},
29    hash_index::{ColumnIndex, Index},
30    offsets::{OffsetRange, Offsets, RowId, Subset, SubsetRef},
31    parallel_heuristics::parallelize_table_op,
32    pool::with_pool_set,
33    row_buffer::{ParallelRowBufWriter, RowBuffer},
34    table_spec::{
35        ColumnId, Constraint, Generation, MutationBuffer, Offset, Row, Table, TableSpec,
36        TableVersion,
37    },
38};
39
40mod rebuild;
41mod sharded_hash_table;
42#[cfg(test)]
43mod tests;
44
45// NB: Having this type def lets us switch between 64 and 32 bits of hashcode.
46//
47// We should consider just using u64 everywhere though. Hashbrown doesn't play nicely with 32-bit
48// hashcodes because it uses both the high and low bits of a 64-bit code.
49
50type HashCode = u64;
51
52/// A pointer to a row in the table.
53#[derive(Clone, Debug)]
54pub(crate) struct TableEntry {
55    hashcode: HashCode,
56    row: RowId,
57}
58
59impl TableEntry {
60    fn hashcode(&self) -> u64 {
61        // We keep the cast here to make it easy to switch to HashCode=u32.
62        #[allow(clippy::unnecessary_cast)]
63        {
64            self.hashcode as u64
65        }
66    }
67}
68
69/// The core data for a table.
70///
71/// This type is a thin wrapper around `RowBuffer`. The big difference is that
72/// it keeps track of how many stale rows are present.
73#[derive(Clone)]
74struct Rows {
75    data: RowBuffer,
76    scratch: RowBuffer,
77    stale_rows: usize,
78}
79
80impl Rows {
81    fn new(data: RowBuffer) -> Rows {
82        let arity = data.arity();
83        Rows {
84            data,
85            scratch: RowBuffer::new(arity),
86            stale_rows: 0,
87        }
88    }
89    fn clear(&mut self) {
90        self.data.clear();
91        self.stale_rows = 0;
92    }
93    fn next_row(&self) -> RowId {
94        RowId::from_usize(self.data.len())
95    }
96    fn set_stale(&mut self, row: RowId) {
97        if !self.data.set_stale(row) {
98            self.stale_rows += 1;
99        }
100    }
101
102    fn get_row(&self, row: RowId) -> Option<&[Value]> {
103        let row = self.data.get_row(row);
104        if row[0].is_stale() { None } else { Some(row) }
105    }
106
107    /// A variant of `get_row` without bounds-checking on `row`.
108    unsafe fn get_row_unchecked(&self, row: RowId) -> Option<&[Value]> {
109        let row = unsafe { self.data.get_row_unchecked(row) };
110        if row[0].is_stale() { None } else { Some(row) }
111    }
112
113    fn add_row(&mut self, row: &[Value]) -> RowId {
114        if row[0].is_stale() {
115            self.stale_rows += 1;
116        }
117        self.data.add_row(row)
118    }
119
120    fn remove_stale(&mut self, remap: impl FnMut(&[Value], RowId, RowId)) {
121        self.data.remove_stale(remap);
122        self.stale_rows = 0;
123    }
124}
125
126/// The type of closures that are used to merge values in a [`SortedWritesTable`].
127///
128/// The first argument grants access to database using an [`ExecutionState`], the second argument
129/// is the current value of the tuple. The third argument is the new, or "incoming" value of the
130/// tuple. The fourth argument is a mutable reference to a vector that will be used to store the
131/// output of the merge function _if_ it changes the value of the tuple. If it does not, then the
132/// merge function should return `false`.
133pub type MergeFn =
134    dyn Fn(&mut ExecutionState, &[Value], &[Value], &mut Vec<Value>) -> bool + Send + Sync;
135
136pub struct SortedWritesTable {
137    generation: Generation,
138    data: Rows,
139    hash: ShardedHashTable<TableEntry>,
140
141    n_keys: usize,
142    n_columns: usize,
143    sort_by: Option<ColumnId>,
144    offsets: Vec<(Value, RowId)>,
145
146    pending_state: Arc<PendingState>,
147    merge: Arc<MergeFn>,
148    to_rebuild: Vec<ColumnId>,
149    rebuild_index: Index<ColumnIndex>,
150    // Used to manage incremental rebuilds.
151    subset_tracker: SubsetTracker,
152}
153
154impl Clone for SortedWritesTable {
155    fn clone(&self) -> SortedWritesTable {
156        SortedWritesTable {
157            generation: self.generation,
158            data: self.data.clone(),
159            hash: self.hash.clone(),
160            n_keys: self.n_keys,
161            n_columns: self.n_columns,
162            sort_by: self.sort_by,
163            offsets: self.offsets.clone(),
164            pending_state: Arc::new(self.pending_state.deep_copy()),
165            merge: self.merge.clone(),
166            to_rebuild: self.to_rebuild.clone(),
167            rebuild_index: Index::new(self.to_rebuild.clone(), ColumnIndex::new()),
168            subset_tracker: Default::default(),
169        }
170    }
171}
172
173/// A variant of [`RowBuffer`] that can handle arity 0.
174///
175/// We use this to handle empty keys, where the deletion API needs to handle "row buffers of empty
176/// rows". The goal here is to keep most of the API RowBuffer-centric and avoid complicating the
177/// code too much: actual code that was optimized to handle arity 0 would look a bit different.
178#[derive(Clone)]
179enum ArbitraryRowBuffer {
180    NonEmpty(RowBuffer),
181    Empty { rows: usize },
182}
183
184impl ArbitraryRowBuffer {
185    fn new(arity: usize) -> ArbitraryRowBuffer {
186        if arity == 0 {
187            ArbitraryRowBuffer::Empty { rows: 0 }
188        } else {
189            ArbitraryRowBuffer::NonEmpty(RowBuffer::new(arity))
190        }
191    }
192
193    fn add_row(&mut self, row: &[Value]) {
194        match self {
195            ArbitraryRowBuffer::NonEmpty(buf) => {
196                buf.add_row(row);
197            }
198            ArbitraryRowBuffer::Empty { rows } => {
199                *rows += 1;
200            }
201        }
202    }
203
204    fn len(&self) -> usize {
205        match self {
206            ArbitraryRowBuffer::NonEmpty(buf) => buf.len(),
207            ArbitraryRowBuffer::Empty { rows } => *rows,
208        }
209    }
210
211    fn for_each(&self, mut f: impl FnMut(&[Value])) {
212        match self {
213            ArbitraryRowBuffer::NonEmpty(buf) => {
214                for row in buf.iter() {
215                    f(row);
216                }
217            }
218            ArbitraryRowBuffer::Empty { rows } => {
219                for _ in 0..*rows {
220                    f(&[]);
221                }
222            }
223        }
224    }
225}
226
227struct Buffer {
228    pending_rows: DenseIdMap<ShardId, RowBuffer>,
229    pending_removals: DenseIdMap<ShardId, ArbitraryRowBuffer>,
230    state: Weak<PendingState>,
231    n_cols: u32,
232    n_keys: u32,
233    shard_data: ShardData,
234}
235
236impl MutationBuffer for Buffer {
237    fn stage_insert(&mut self, row: &[Value]) {
238        let (shard, _) = hash_code(self.shard_data, row, self.n_keys as _);
239        self.pending_rows
240            .get_or_insert(shard, || RowBuffer::new(self.n_cols as _))
241            .add_row(row);
242    }
243    fn stage_remove(&mut self, key: &[Value]) {
244        let (shard, _) = hash_code(self.shard_data, key, self.n_keys as _);
245        self.pending_removals
246            .get_or_insert(shard, || ArbitraryRowBuffer::new(self.n_keys as _))
247            .add_row(key);
248    }
249    fn fresh_handle(&self) -> Box<dyn MutationBuffer> {
250        Box::new(Buffer {
251            pending_rows: Default::default(),
252            pending_removals: Default::default(),
253            state: self.state.clone(),
254            n_cols: self.n_cols,
255            n_keys: self.n_keys,
256            shard_data: self.shard_data,
257        })
258    }
259}
260
261impl Drop for Buffer {
262    fn drop(&mut self) {
263        if let Some(state) = self.state.upgrade() {
264            let mut rows = 0;
265            for shard_id in 0..self.pending_rows.n_ids() {
266                let shard = ShardId::from_usize(shard_id);
267                let Some(buf) = self.pending_rows.take(shard) else {
268                    continue;
269                };
270                rows += buf.len();
271                state.pending_rows[shard].push(buf);
272            }
273            state.total_rows.fetch_add(rows, Ordering::Relaxed);
274
275            let mut rows = 0;
276            for shard_id in 0..self.pending_removals.n_ids() {
277                let shard = ShardId::from_usize(shard_id);
278                let Some(buf) = self.pending_removals.take(shard) else {
279                    continue;
280                };
281                rows += buf.len();
282                state.pending_removals[shard].push(buf);
283            }
284            state.total_removals.fetch_add(rows, Ordering::Relaxed);
285        }
286    }
287}
288
289impl Table for SortedWritesTable {
290    fn dyn_clone(&self) -> Box<dyn Table> {
291        Box::new(self.clone())
292    }
293    fn as_any(&self) -> &dyn Any {
294        self
295    }
296    fn clear(&mut self) {
297        self.pending_state.clear();
298        if self.data.data.len() == 0 {
299            return;
300        }
301        self.offsets.clear();
302        self.data.clear();
303        self.hash.clear();
304        self.generation = Generation::from_usize(self.version().major.index() + 1);
305    }
306
307    fn spec(&self) -> TableSpec {
308        TableSpec {
309            n_keys: self.n_keys,
310            n_vals: self.n_columns - self.n_keys,
311            uncacheable_columns: Default::default(),
312            allows_delete: true,
313        }
314    }
315
316    fn apply_rebuild(
317        &mut self,
318        table_id: TableId,
319        table: &crate::WrappedTable,
320        next_ts: Value,
321        exec_state: &mut ExecutionState,
322    ) {
323        self.do_rebuild(table_id, table, next_ts, exec_state);
324    }
325
326    fn version(&self) -> TableVersion {
327        TableVersion {
328            major: self.generation,
329            minor: Offset::from_usize(self.data.next_row().index()),
330        }
331    }
332
333    fn updates_since(&self, offset: Offset) -> Subset {
334        Subset::Dense(OffsetRange::new(
335            RowId::from_usize(offset.index()),
336            self.data.next_row(),
337        ))
338    }
339
340    fn all(&self) -> Subset {
341        Subset::Dense(OffsetRange::new(RowId::new(0), self.data.next_row()))
342    }
343
344    fn len(&self) -> usize {
345        self.data.data.len() - self.data.stale_rows
346    }
347
348    fn scan_generic(&self, subset: SubsetRef, mut f: impl FnMut(RowId, &[Value]))
349    where
350        Self: Sized,
351    {
352        let Some((_low, hi)) = subset.bounds() else {
353            // Empty subset
354            return;
355        };
356        assert!(
357            hi.index() <= self.data.data.len(),
358            "{} vs. {}",
359            hi.index(),
360            self.data.data.len()
361        );
362        // SAFETY: subsets are sorted, low must be at most hi, and hi is less
363        // than the length of the table.
364        subset.offsets(|row| unsafe {
365            if let Some(vals) = self.data.get_row_unchecked(row) {
366                f(row, vals)
367            }
368        })
369    }
370
371    fn scan_generic_bounded(
372        &self,
373        subset: SubsetRef,
374        start: Offset,
375        n: usize,
376        cs: &[Constraint],
377        mut f: impl FnMut(RowId, &[Value]),
378    ) -> Option<Offset>
379    where
380        Self: Sized,
381    {
382        if cs.is_empty() {
383            subset
384                .iter_bounded(start.index(), start.index() + n, |row| {
385                    let Some(entry) = self.data.get_row(row) else {
386                        return;
387                    };
388                    f(row, entry);
389                })
390                .map(Offset::from_usize)
391        } else {
392            subset
393                .iter_bounded(start.index(), start.index() + n, |row| {
394                    let Some(entry) = self.get_if(cs, row) else {
395                        return;
396                    };
397                    f(row, entry);
398                })
399                .map(Offset::from_usize)
400        }
401    }
402
403    fn fast_subset(&self, constraint: &Constraint) -> Option<Subset> {
404        let sort_by = self.sort_by?;
405        match constraint {
406            Constraint::Eq { .. } => None,
407            Constraint::EqConst { col, val } => {
408                if col == &sort_by {
409                    match self.binary_search_sort_val(*val) {
410                        Ok((found, bound)) => Some(Subset::Dense(OffsetRange::new(found, bound))),
411                        Err(_) => Some(Subset::empty()),
412                    }
413                } else {
414                    None
415                }
416            }
417            Constraint::LtConst { col, val } => {
418                if col == &sort_by {
419                    match self.binary_search_sort_val(*val) {
420                        Ok((found, _)) => {
421                            Some(Subset::Dense(OffsetRange::new(RowId::new(0), found)))
422                        }
423                        Err(next) => Some(Subset::Dense(OffsetRange::new(RowId::new(0), next))),
424                    }
425                } else {
426                    None
427                }
428            }
429            Constraint::GtConst { col, val } => {
430                if col == &sort_by {
431                    match self.binary_search_sort_val(*val) {
432                        Ok((_, bound)) => {
433                            Some(Subset::Dense(OffsetRange::new(bound, self.data.next_row())))
434                        }
435                        Err(next) => {
436                            Some(Subset::Dense(OffsetRange::new(next, self.data.next_row())))
437                        }
438                    }
439                } else {
440                    None
441                }
442            }
443            Constraint::LeConst { col, val } => {
444                if col == &sort_by {
445                    match self.binary_search_sort_val(*val) {
446                        Ok((_, bound)) => {
447                            Some(Subset::Dense(OffsetRange::new(RowId::new(0), bound)))
448                        }
449                        Err(next) => Some(Subset::Dense(OffsetRange::new(RowId::new(0), next))),
450                    }
451                } else {
452                    None
453                }
454            }
455            Constraint::GeConst { col, val } => {
456                if col == &sort_by {
457                    match self.binary_search_sort_val(*val) {
458                        Ok((found, _)) => {
459                            Some(Subset::Dense(OffsetRange::new(found, self.data.next_row())))
460                        }
461                        Err(next) => {
462                            Some(Subset::Dense(OffsetRange::new(next, self.data.next_row())))
463                        }
464                    }
465                } else {
466                    None
467                }
468            }
469        }
470    }
471
472    fn refine_one(&self, mut subset: Subset, c: &Constraint) -> Subset {
473        // NB: we aren't using any of the `fast_subset` tricks here. We may want
474        // to if the higher-level implementations end up using it directly.
475        subset.retain(|row| self.eval(std::slice::from_ref(c), row));
476        subset
477    }
478
479    fn new_buffer(&self) -> Box<dyn MutationBuffer> {
480        let n_shards = self.hash.shard_data().n_shards();
481        Box::new(Buffer {
482            pending_rows: DenseIdMap::with_capacity(n_shards),
483            pending_removals: DenseIdMap::with_capacity(n_shards),
484            state: Arc::downgrade(&self.pending_state),
485            n_keys: u32::try_from(self.n_keys).expect("n_keys should fit in u32"),
486            n_cols: u32::try_from(self.n_columns).expect("n_columns should fit in u32"),
487            shard_data: self.hash.shard_data(),
488        })
489    }
490
491    fn merge(&mut self, exec_state: &mut ExecutionState) -> TableChange {
492        let removed = self.do_delete();
493        let added = self.do_insert(exec_state);
494        self.maybe_rehash();
495        TableChange { removed, added }
496    }
497
498    fn get_row(&self, key: &[Value]) -> Option<Row> {
499        let id = get_entry(key, self.n_keys, &self.hash, |row| {
500            &self.data.get_row(row).unwrap()[0..self.n_keys] == key
501        })?;
502        let mut vals = with_pool_set(|ps| ps.get::<Vec<Value>>());
503        vals.extend_from_slice(self.data.get_row(id).unwrap());
504        Some(Row { id, vals })
505    }
506
507    fn get_row_column(&self, key: &[Value], col: ColumnId) -> Option<Value> {
508        let id = get_entry(key, self.n_keys, &self.hash, |row| {
509            &self.data.get_row(row).unwrap()[0..self.n_keys] == key
510        })?;
511        Some(self.data.get_row(id).unwrap()[col.index()])
512    }
513}
514
515impl SortedWritesTable {
516    /// Create a new [`SortedWritesTable`] with the given number of keys,
517    /// columns, and an optional sort column.
518    ///
519    /// The `merge_fn` is used to evaluate conflicts when more than one row is
520    /// inserted with the same primary key. The old and new proposed values are
521    /// passed as the second and third arguments, respectively, with the
522    /// function filling the final argument with the contents of the new row.
523    /// The return value indicates whether or not the contents of the vector
524    /// should be used.
525    ///
526    /// Merge functions can access the database via [`ExecutionState`].
527    pub fn new(
528        n_keys: usize,
529        n_columns: usize,
530        sort_by: Option<ColumnId>,
531        to_rebuild: Vec<ColumnId>,
532        merge_fn: Box<MergeFn>,
533    ) -> Self {
534        let hash = ShardedHashTable::<TableEntry>::default();
535        let shard_data = hash.shard_data();
536        let rebuild_index = Index::new(to_rebuild.clone(), ColumnIndex::new());
537        SortedWritesTable {
538            generation: Generation::new(0),
539            data: Rows::new(RowBuffer::new(n_columns)),
540            hash,
541            n_keys,
542            n_columns,
543            sort_by,
544            offsets: Default::default(),
545            pending_state: Arc::new(PendingState::new(shard_data)),
546            merge: merge_fn.into(),
547            to_rebuild,
548            rebuild_index,
549            subset_tracker: Default::default(),
550        }
551    }
552
553    /// Flush all pending removals, in parallel.
554    fn parallel_delete(&mut self) -> bool {
555        let shard_data = self.hash.shard_data();
556        let stale_delta: usize = self
557            .hash
558            .mut_shards()
559            .par_iter_mut()
560            .enumerate()
561            .filter_map(|(shard_id, shard)| {
562                let shard_id = ShardId::from_usize(shard_id);
563                if self.pending_state.pending_removals[shard_id].is_empty() {
564                    return None;
565                }
566                Some((shard_id, shard))
567            })
568            .map(|(shard_id, shard)| {
569                let queue = &self.pending_state.pending_removals[shard_id];
570                let mut marked_stale = 0;
571                while let Some(buf) = queue.pop() {
572                    buf.for_each(|to_remove| {
573                        let (actual_shard, hc) = hash_code(shard_data, to_remove, self.n_keys);
574                        assert_eq!(actual_shard, shard_id);
575                        if let Ok(entry) = shard.find_entry(hc, |entry| {
576                            entry.hashcode == (hc as _)
577                                && &self.data.get_row(entry.row).unwrap()[0..self.n_keys]
578                                    == to_remove
579                        }) {
580                            let (ent, _) = entry.remove();
581                            // SAFETY: The safety requirements of
582                            // `set_stale_shared` are that there are no
583                            // concurrent accesses to `row`. No other threads
584                            // can access this row within this method because
585                            // different `shards` partition the space
586                            // (guaranteed by the assertion above), and we
587                            // launch at most one thread per shard.
588                            marked_stale +=
589                                unsafe { !self.data.data.set_stale_shared(ent.row) } as usize;
590                        }
591                    });
592                }
593                marked_stale
594            })
595            .sum();
596        // Update the stale count with the total marked stale.
597        self.data.stale_rows += stale_delta;
598        stale_delta > 0
599    }
600    fn serial_delete(&mut self) -> bool {
601        let shard_data = self.hash.shard_data();
602        let mut changed = false;
603        self.hash
604            .mut_shards()
605            .iter_mut()
606            .enumerate()
607            .for_each(|(shard_id, shard)| {
608                let shard_id = ShardId::from_usize(shard_id);
609                let queue = &self.pending_state.pending_removals[shard_id];
610                while let Some(buf) = queue.pop() {
611                    buf.for_each(|to_remove| {
612                        let (actual_shard, hc) = hash_code(shard_data, to_remove, self.n_keys);
613                        assert_eq!(actual_shard, shard_id);
614                        if let Ok(entry) = shard.find_entry(hc, |entry| {
615                            entry.hashcode == (hc as _)
616                                && &self.data.get_row(entry.row).unwrap()[0..self.n_keys]
617                                    == to_remove
618                        }) {
619                            let (ent, _) = entry.remove();
620                            self.data.set_stale(ent.row);
621                            changed = true;
622                        }
623                    })
624                }
625            });
626        changed
627    }
628
629    fn do_delete(&mut self) -> bool {
630        let total = self.pending_state.total_removals.swap(0, Ordering::Relaxed);
631
632        if parallelize_table_op(total) {
633            self.parallel_delete()
634        } else {
635            self.serial_delete()
636        }
637    }
638
639    fn do_insert(&mut self, exec_state: &mut ExecutionState) -> bool {
640        let total = self.pending_state.total_rows.swap(0, Ordering::Relaxed);
641        self.data.data.reserve(total);
642        if parallelize_table_op(total) {
643            if let Some(col) = self.sort_by {
644                self.parallel_insert(
645                    exec_state,
646                    SortChecker {
647                        col,
648                        current: None,
649                        baseline: self.offsets.last().map(|(v, _)| *v),
650                    },
651                )
652            } else {
653                self.parallel_insert(exec_state, ())
654            }
655        } else {
656            self.serial_insert(exec_state)
657        }
658    }
659
660    fn serial_insert(&mut self, exec_state: &mut ExecutionState) -> bool {
661        let mut changed = false;
662        let n_keys = self.n_keys;
663        let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
664        for (_outer_shard, queue) in self.pending_state.pending_rows.iter() {
665            if let Some(sort_by) = self.sort_by {
666                while let Some(buf) = queue.pop() {
667                    for query in buf.non_stale() {
668                        let key = &query[0..n_keys];
669                        let entry = get_entry_mut(query, n_keys, &mut self.hash, |row| {
670                            let Some(row) = self.data.get_row(row) else {
671                                return false;
672                            };
673                            &row[0..n_keys] == key
674                        });
675
676                        if let Some(row) = entry {
677                            // First case: overwriting an existing value. Apply merge
678                            // function. Insert new row and update hash table if merge
679                            // changes anything.
680                            let cur = self
681                                .data
682                                .get_row(*row)
683                                .expect("table should not point to stale entry");
684                            if (self.merge)(exec_state, cur, query, &mut scratch) {
685                                let sort_val = query[sort_by.index()];
686                                let new = self.data.add_row(&scratch);
687                                if let Some(largest) = self.offsets.last().map(|(v, _)| *v) {
688                                    assert!(
689                                        sort_val >= largest,
690                                        "inserting row that violates sort order ({sort_val:?} vs. {largest:?})"
691                                    );
692                                    if sort_val > largest {
693                                        self.offsets.push((sort_val, new));
694                                    }
695                                } else {
696                                    self.offsets.push((sort_val, new));
697                                }
698                                self.data.set_stale(*row);
699                                *row = new;
700                                changed = true;
701                            }
702                            scratch.clear();
703                        } else {
704                            let sort_val = query[sort_by.index()];
705                            // New value: update invariants.
706                            let new = self.data.add_row(query);
707                            if let Some(largest) = self.offsets.last().map(|(v, _)| *v) {
708                                assert!(
709                                    sort_val >= largest,
710                                    "inserting row that violates sort order {sort_val:?} vs. {largest:?}"
711                                );
712                                if sort_val > largest {
713                                    self.offsets.push((sort_val, new));
714                                }
715                            } else {
716                                self.offsets.push((sort_val, new));
717                            }
718                            let (shard, hc) = hash_code(self.hash.shard_data(), query, self.n_keys);
719                            debug_assert_eq!(shard, _outer_shard);
720                            self.hash.mut_shards()[shard.index()].insert_unique(
721                                hc as _,
722                                TableEntry {
723                                    hashcode: hc as _,
724                                    row: new,
725                                },
726                                TableEntry::hashcode,
727                            );
728                            changed = true;
729                        }
730                    }
731                }
732            } else {
733                // Simplified variant without the sorting constraint.
734                while let Some(buf) = queue.pop() {
735                    for query in buf.non_stale() {
736                        let key = &query[0..n_keys];
737                        let entry = get_entry_mut(query, n_keys, &mut self.hash, |row| {
738                            let Some(row) = self.data.get_row(row) else {
739                                return false;
740                            };
741                            &row[0..n_keys] == key
742                        });
743
744                        if let Some(row) = entry {
745                            let cur = self
746                                .data
747                                .get_row(*row)
748                                .expect("table should not point to stale entry");
749                            if (self.merge)(exec_state, cur, query, &mut scratch) {
750                                let new = self.data.add_row(&scratch);
751                                self.data.set_stale(*row);
752                                *row = new;
753                                changed = true;
754                            }
755                            scratch.clear();
756                        } else {
757                            // New value: update invariants.
758                            let new = self.data.add_row(query);
759                            let (shard, hc) = hash_code(self.hash.shard_data(), query, self.n_keys);
760                            debug_assert_eq!(shard, _outer_shard);
761                            self.hash.mut_shards()[shard.index()].insert_unique(
762                                hc as _,
763                                TableEntry {
764                                    hashcode: hc as _,
765                                    row: new,
766                                },
767                                TableEntry::hashcode,
768                            );
769                            changed = true;
770                        }
771                    }
772                }
773            };
774        }
775        changed
776    }
777
778    fn parallel_insert<C: OrderingChecker>(
779        &mut self,
780        exec_state: &ExecutionState,
781        checker: C,
782    ) -> bool {
783        const BATCH_SIZE: usize = 1 << 18;
784        // Parallel insert uses one giant parallel foreach. We have updates
785        // pre-sharded, and one logical thread can process updates for each
786        // shard independently. Updates happen in three phases, which comments
787        // describe below.
788        let shard_data = self.hash.shard_data();
789        let n_keys = self.n_keys;
790        let n_cols = self.n_columns;
791        let next_offset = RowId::from_usize(self.data.data.len());
792        let row_writer = self.data.data.parallel_writer();
793        let pending_adds = self
794            .hash
795            .mut_shards()
796            .par_iter_mut()
797            .enumerate()
798            .map(|(shard_id, shard)| {
799                let shard_id = ShardId::from_usize(shard_id);
800                let mut checker = checker.clone();
801                let mut exec_state = exec_state.clone();
802                let mut scratch = with_pool_set(|ps| ps.get::<Vec<Value>>());
803                let queue = &self.pending_state.pending_rows[shard_id];
804                let mut marked_stale = 0usize;
805                let mut staged = StagedOutputs::new(n_keys, n_cols, BATCH_SIZE);
806                let mut changed = false;
807                // The core flush loop: We call once `staged` reaches `BATCH_SIZE` or
808                // when we're done.
809                macro_rules! flush_staged_outputs {
810                    () => {{
811                        // Phase 2: Write the staged rows to the row writer. This only
812                        // works due to the `ParallelRowBufWriter` machinery.
813                        let start_row = staged.write_output(&row_writer);
814                        // Phase 3: With the values buffered in the row buffer, we can
815                        // write them back to the shard, pointed to the correct rows.
816
817                        // In the serial implementation, we do phases 2 and 3 inline with
818                        // processing the incoming mutation, but separating them out
819                        // this way allows us to do a single write to the shared row
820                        // buffer, rather than one per row, which would cause
821                        // contention.
822                        let mut cur_row = start_row;
823                        let read_handle = row_writer.read_handle();
824                        for row in staged.rows() {
825                            use hashbrown::hash_table::Entry;
826                            checker.check_local(row);
827                            changed = true;
828                            let key = &row[0..n_keys];
829                            let (_actual_shard, hc) = hash_code(shard_data, row, n_keys);
830                            #[cfg(any(debug_assertions, test))]
831                            {
832                                unsafe {
833                                    // read the value we wrote at this row and
834                                    // check that it matches.
835                                    assert_eq!(read_handle.get_row_unchecked(cur_row), row);
836                                }
837                            }
838                            debug_assert_eq!(_actual_shard, shard_id);
839                            match shard.entry(
840                                hc,
841                                // SAFETY: `ent` must point to a valid row
842                                |ent| unsafe {
843                                    ent.hashcode == hc as HashCode
844                                        && &read_handle.get_row_unchecked(ent.row)[0..n_keys] == key
845                                },
846                                TableEntry::hashcode,
847                            ) {
848                                Entry::Occupied(mut occ) => {
849                                    // SAFETY: `occ` must point to a valid row: we only insert valid rows
850                                    // into the map.
851                                    let cur = unsafe { read_handle.get_row_unchecked(occ.get().row) };
852
853                                    // SAFETY: The safety requirements of
854                                    // `set_stale_shared` are that there are no
855                                    // concurrent accesses to `row`. We have
856                                    // exclusive access to any row whose hash matches this
857                                    // shard.
858                                    if (self.merge)(&mut exec_state, cur, row, &mut scratch) {
859                                        unsafe {
860                                            let _was_stale = read_handle.set_stale_shared(occ.get().row);
861                                            debug_assert!(!_was_stale);
862                                        }
863                                        occ.get_mut().row = cur_row;
864                                        changed = true;
865                                    } else {
866                                        // Mark the new row as stale: we didn't end up needing it.
867                                        unsafe {
868                                            let _was_stale = read_handle.set_stale_shared(cur_row);
869                                            debug_assert!(!_was_stale);
870                                        }
871                                    }
872                                    marked_stale += 1;
873                                    scratch.clear();
874                                }
875                                Entry::Vacant(v) => {
876                                    changed = true;
877                                    v.insert(TableEntry {
878                                        hashcode: hc as HashCode,
879                                        row: cur_row,
880                                    });
881                                }
882                            }
883
884                            cur_row = cur_row.inc();
885                        }
886                        staged.clear();
887                    }};
888                }
889                // Phase 1: process all incoming updates:
890                // * Add new values to `staged`
891                // * Removing entries in `shard` and mark them as stale in
892                // `data` if they will be overwritten.
893                while let Some(buf) = queue.pop() {
894                    // We create a read_handle once per batch to avoid blocking
895                    // too many threads if someone needs to resize the row
896                    // writer.
897                    for row in buf.non_stale() {
898                        staged.insert(row, |cur, new, out| {
899                            (self.merge)(&mut exec_state, cur, new, out)
900                        });
901                        if staged.len() >= BATCH_SIZE {
902                            flush_staged_outputs!();
903                        }
904                    }
905                }
906                flush_staged_outputs!();
907                (checker, marked_stale, changed)
908            })
909            .collect_vec_list();
910        self.data.data = row_writer.finish();
911        // Now we just need to reset our invariants.
912
913        // Confirm none of the writes violated sort order and update the
914        // `offsets` vector.
915        let checker = C::check_global(pending_adds.iter().flatten().map(|(checker, _, _)| checker));
916        checker.update_offsets(next_offset, &mut self.offsets);
917
918        // Update the staleness counters.
919        self.data.stale_rows += pending_adds
920            .iter()
921            .flatten()
922            .map(|(_, stale, _)| *stale)
923            .sum::<usize>();
924
925        // Register any changes.
926        pending_adds
927            .iter()
928            .flatten()
929            .any(|(_, _, changed)| *changed)
930    }
931
932    fn binary_search_sort_val(&self, val: Value) -> Result<(RowId, RowId), RowId> {
933        debug_assert!(
934            self.offsets.windows(2).all(|x| x[0].1 < x[1].1),
935            "{:?}",
936            self.offsets
937        );
938
939        debug_assert!(
940            self.offsets.windows(2).all(|x| x[0].0 < x[1].0),
941            "{:?}",
942            self.offsets
943        );
944        match self.offsets.binary_search_by_key(&val, |(v, _)| *v) {
945            Ok(got) => Ok((
946                self.offsets[got].1,
947                self.offsets
948                    .get(got + 1)
949                    .map(|(_, r)| *r)
950                    .unwrap_or(self.data.next_row()),
951            )),
952            Err(next) => Err(self
953                .offsets
954                .get(next)
955                .map(|(_, id)| *id)
956                .unwrap_or(self.data.next_row())),
957        }
958    }
959    fn eval(&self, cs: &[Constraint], row: RowId) -> bool {
960        self.get_if(cs, row).is_some()
961    }
962
963    fn get_if(&self, cs: &[Constraint], row: RowId) -> Option<&[Value]> {
964        let row = self.data.get_row(row)?;
965        let mut res = true;
966        for constraint in cs {
967            match constraint {
968                Constraint::Eq { l_col, r_col } => res &= row[l_col.index()] == row[r_col.index()],
969                Constraint::EqConst { col, val } => res &= row[col.index()] == *val,
970                Constraint::LtConst { col, val } => res &= row[col.index()] < *val,
971                Constraint::GtConst { col, val } => res &= row[col.index()] > *val,
972                Constraint::LeConst { col, val } => res &= row[col.index()] <= *val,
973                Constraint::GeConst { col, val } => res &= row[col.index()] >= *val,
974            }
975        }
976        if res { Some(row) } else { None }
977    }
978
979    fn maybe_rehash(&mut self) {
980        if self.data.stale_rows <= cmp::max(16, self.data.data.len() / 2) {
981            return;
982        }
983
984        if parallelize_table_op(self.data.data.len()) {
985            self.parallel_rehash();
986        } else {
987            self.rehash();
988        }
989    }
990    fn parallel_rehash(&mut self) {
991        use rayon::prelude::*;
992        // Parallel rehashes go "hash-first" rather than "rows-first".
993        //
994        // We iterate over each shard and then write out new contents to a fresh row, in parallel.
995        let Some(sort_by) = self.sort_by else {
996            // Just do a serial rehash for now. We currently do not have a use-case for parallel
997            // compaction of unsorted tables.
998            //
999            // Implementing parallel compaction for an unsorted table is much easier: each shard
1000            // can write to a contiguous chunk of the `scratch` buffer, with the offsets being
1001            // pre-chunked based on the size of each shard.
1002            self.rehash();
1003            return;
1004        };
1005        self.generation = self.generation.inc();
1006        assert!(!self.offsets.is_empty());
1007        struct TimestampStats {
1008            value: Value,
1009            count: usize,
1010            histogram: Pooled<DenseIdMap<ShardId, usize>>,
1011        }
1012        impl Default for TimestampStats {
1013            fn default() -> TimestampStats {
1014                TimestampStats {
1015                    value: Value::stale(),
1016                    count: 0,
1017                    histogram: with_pool_set(|ps| ps.get()),
1018                }
1019            }
1020        }
1021        let mut results = Vec::<TimestampStats>::with_capacity(self.offsets.len());
1022        results.resize_with(self.offsets.len() - 1, Default::default);
1023        // Use a macro rather than a lambda to avoid borrow issues.
1024        macro_rules! compute_hist {
1025            ($start_val: expr, $start_row: expr, $end_row: expr) => {{
1026                let mut histogram: Pooled<DenseIdMap<ShardId, usize>> =
1027                    with_pool_set(|ps| ps.get());
1028                let mut cur_row = $start_row;
1029                let mut count = 0;
1030                while cur_row < $end_row {
1031                    if let Some(row) = self.data.get_row(cur_row) {
1032                        count += 1;
1033                        let (shard, _) = hash_code(self.hash.shard_data(), row, self.n_keys);
1034                        *histogram.get_or_default(shard) += 1;
1035                    }
1036                    cur_row = cur_row.inc();
1037                }
1038                TimestampStats {
1039                    value: $start_val,
1040                    count,
1041                    histogram,
1042                }
1043            }};
1044        }
1045        let mut last: TimestampStats = Default::default();
1046        rayon::join(
1047            || {
1048                // This closure handles computing all timestamps but the last one.
1049                self.offsets
1050                    .windows(2)
1051                    .zip(results.iter_mut())
1052                    .par_bridge()
1053                    .for_each(|(xs, res)| {
1054                        let [(start_val, start_row), (_, end_row)] = xs else {
1055                            unreachable!()
1056                        };
1057                        *res = compute_hist!(*start_val, *start_row, *end_row);
1058                    })
1059            },
1060            || {
1061                // And here we handle the final one.
1062                let (start_val, start_row) = self.offsets.last().unwrap();
1063                let end_row = self.data.next_row();
1064                last = compute_hist!(*start_val, *start_row, end_row);
1065            },
1066        );
1067        results.push(last);
1068        // Now we need to compute cumulative statistics on the row layouts here.
1069        // We do this serially a we currently don't have a ton of use for cases with thousands
1070        // of timestamps or more. There are well-known parallel algorithms for computing these
1071        // cumulative statistics in parallel, but they aren't currently all that well-suited
1072        // for rayon at the moment.
1073        let mut prev_count = 0;
1074        self.offsets.clear();
1075        for stats in results.iter_mut() {
1076            if stats.count == 0 {
1077                continue;
1078            }
1079            self.offsets
1080                .push((stats.value, RowId::from_usize(prev_count)));
1081            let mut inner = prev_count;
1082            for (_, count) in stats.histogram.iter_mut() {
1083                // Each entry in the histogram now points to the start row for that shard's
1084                // rows for a given timestamp.
1085                let tmp = *count;
1086                *count = inner;
1087                inner += tmp;
1088            }
1089            prev_count += stats.count;
1090            debug_assert_eq!(inner, prev_count)
1091        }
1092
1093        // Now the part with some unsafe code.
1094        // We will iterate over each shard and use the statistics in `results` to guide where
1095        // each row will go.
1096        //
1097        // This involves doing unsynchronized writes to the table (ptr::copy_nonoverlapping)
1098        // followed by a set_len. The safety of these operations relies on the fact that:
1099        // * No one grabs a reference to the interior of `scratch` until these operations have
1100        //   finished.
1101        // * `scratch` does not overlap `data`.
1102        // * The sharding function completely partitions the set of objects in the table: one
1103        //   shard's writes will never stomp on those of another.
1104
1105        self.data.scratch.clear();
1106        self.data.scratch.reserve(prev_count);
1107        self.hash
1108            .mut_shards()
1109            .par_iter_mut()
1110            .with_max_len(1)
1111            .enumerate()
1112            .for_each(|(shard_id, shard)| {
1113                let shard_id = ShardId::from_usize(shard_id);
1114                let scratch_ptr = self.data.scratch.raw_rows();
1115                let mut progress =
1116                    HashMap::<Value /* timestamp */, RowId /* next row */>::default();
1117                progress.reserve(results.len());
1118                for stats in &results {
1119                    let Some(start) = stats.histogram.get(shard_id) else {
1120                        continue;
1121                    };
1122                    progress.insert(stats.value, RowId::from_usize(*start));
1123                }
1124                for TableEntry { row: row_id, .. } in shard.iter_mut() {
1125                    let row = self
1126                        .data
1127                        .get_row(*row_id)
1128                        .expect("shard should not map to a stale value");
1129                    let val = row[sort_by.index()];
1130                    let next = progress[&val];
1131                    // SAFETY: see above longer comment.
1132                    unsafe {
1133                        std::ptr::copy_nonoverlapping(
1134                            row.as_ptr(),
1135                            scratch_ptr.add(next.index() * self.n_columns) as *mut Value,
1136                            self.n_columns,
1137                        )
1138                    }
1139                    *row_id = next;
1140                    progress.insert(val, next.inc());
1141                }
1142            });
1143        // SAFETY: see above longer comment.
1144        unsafe { self.data.scratch.set_len(prev_count) };
1145        mem::swap(&mut self.data.data, &mut self.data.scratch);
1146        self.data.stale_rows = 0;
1147    }
1148    fn rehash_impl(
1149        sort_by: Option<ColumnId>,
1150        n_keys: usize,
1151        rows: &mut Rows,
1152        offsets: &mut Vec<(Value, RowId)>,
1153        hash: &mut ShardedHashTable<TableEntry>,
1154    ) {
1155        if let Some(sort_by) = sort_by {
1156            offsets.clear();
1157            rows.remove_stale(|row, old, new| {
1158                let stale_entry = get_entry_mut(row, n_keys, hash, |x| x == old)
1159                    .expect("non-stale entry not mapped in hash");
1160                *stale_entry = new;
1161                let sort_col = row[sort_by.index()];
1162                if let Some((max, _)) = offsets.last() {
1163                    if sort_col > *max {
1164                        offsets.push((sort_col, new));
1165                    }
1166                } else {
1167                    offsets.push((sort_col, new));
1168                }
1169            })
1170        } else {
1171            rows.remove_stale(|row, old, new| {
1172                let stale_entry = get_entry_mut(row, n_keys, hash, |x| x == old)
1173                    .expect("non-stale entry not mapped in hash");
1174                *stale_entry = new;
1175            })
1176        }
1177    }
1178
1179    fn rehash(&mut self) {
1180        self.generation = self.generation.inc();
1181        Self::rehash_impl(
1182            self.sort_by,
1183            self.n_keys,
1184            &mut self.data,
1185            &mut self.offsets,
1186            &mut self.hash,
1187        )
1188    }
1189}
1190
1191fn get_entry(
1192    row: &[Value],
1193    n_keys: usize,
1194    table: &ShardedHashTable<TableEntry>,
1195    test: impl Fn(RowId) -> bool,
1196) -> Option<RowId> {
1197    let (shard, hash) = hash_code(table.shard_data(), row, n_keys);
1198    table
1199        .get_shard(shard)
1200        .find(hash, |ent| {
1201            ent.hashcode == hash as HashCode && test(ent.row)
1202        })
1203        .map(|ent| ent.row)
1204}
1205
1206fn get_entry_mut<'a>(
1207    row: &[Value],
1208    n_keys: usize,
1209    table: &'a mut ShardedHashTable<TableEntry>,
1210    test: impl Fn(RowId) -> bool,
1211) -> Option<&'a mut RowId> {
1212    let (shard, hash) = hash_code(table.shard_data(), row, n_keys);
1213    table.mut_shards()[shard.index()]
1214        .find_mut(hash, |ent| {
1215            ent.hashcode == hash as HashCode && test(ent.row)
1216        })
1217        .map(|ent| &mut ent.row)
1218}
1219
1220fn hash_code(shard_data: ShardData, row: &[Value], n_keys: usize) -> (ShardId, u64) {
1221    let mut hasher = FxHasher::default();
1222    for val in &row[0..n_keys] {
1223        hasher.write_usize(val.index());
1224    }
1225    let full_code = hasher.finish();
1226    // We keep this cast here to allow for experimenting with HashCode=u32.
1227    #[allow(clippy::unnecessary_cast)]
1228    (shard_data.shard_id(full_code), full_code as HashCode as u64)
1229}
1230
1231/// A simple struct for packaging up pending mutations to a `SortedWritesTable`.
1232struct PendingState {
1233    pending_rows: DenseIdMap<ShardId, SegQueue<RowBuffer>>,
1234    pending_removals: DenseIdMap<ShardId, SegQueue<ArbitraryRowBuffer>>,
1235    total_removals: AtomicUsize,
1236    total_rows: AtomicUsize,
1237}
1238
1239impl PendingState {
1240    fn new(shard_data: ShardData) -> PendingState {
1241        let n_shards = shard_data.n_shards();
1242        let mut pending_rows = DenseIdMap::with_capacity(n_shards);
1243        let mut pending_removals = DenseIdMap::with_capacity(n_shards);
1244        for i in 0..n_shards {
1245            pending_rows.insert(ShardId::from_usize(i), SegQueue::default());
1246            pending_removals.insert(ShardId::from_usize(i), SegQueue::default());
1247        }
1248
1249        PendingState {
1250            pending_rows,
1251            pending_removals,
1252            total_removals: AtomicUsize::new(0),
1253            total_rows: AtomicUsize::new(0),
1254        }
1255    }
1256    fn clear(&self) {
1257        for (_, queue) in self.pending_rows.iter() {
1258            while queue.pop().is_some() {}
1259        }
1260
1261        for (_, queue) in self.pending_removals.iter() {
1262            while queue.pop().is_some() {}
1263        }
1264    }
1265
1266    /// This is only really used in debugging, but it's annoying enough to write
1267    /// that it may help to have around.
1268    ///
1269    /// We also, however, use it in the clone impl (which should only be called when pending state
1270    /// is empty).
1271    fn deep_copy(&self) -> PendingState {
1272        let mut pending_rows = DenseIdMap::new();
1273        let mut pending_removals = DenseIdMap::new();
1274        fn drain_queue<T>(queue: &SegQueue<T>) -> Vec<T> {
1275            let mut res = Vec::new();
1276            while let Some(x) = queue.pop() {
1277                res.push(x);
1278            }
1279            res
1280        }
1281        for (shard, queue) in self.pending_rows.iter() {
1282            let contents = drain_queue(queue);
1283            let new_queue = SegQueue::default();
1284            for x in contents {
1285                new_queue.push(x.clone());
1286                queue.push(x);
1287            }
1288            pending_rows.insert(shard, new_queue);
1289        }
1290
1291        for (shard, queue) in self.pending_removals.iter() {
1292            let contents = drain_queue(queue);
1293            let new_queue = SegQueue::default();
1294            for x in contents {
1295                new_queue.push(x.clone());
1296                queue.push(x);
1297            }
1298            pending_removals.insert(shard, new_queue);
1299        }
1300
1301        PendingState {
1302            pending_rows,
1303            pending_removals,
1304            total_removals: AtomicUsize::new(self.total_removals.load(Ordering::Acquire)),
1305            total_rows: AtomicUsize::new(self.total_rows.load(Ordering::Acquire)),
1306        }
1307    }
1308}
1309
1310/// A trait that encapsulates the logic of potentially checking that written
1311/// columns appear in sorted order.
1312///
1313/// For rows that are sorted by a column, an OrderingChecker asserts that all
1314/// new rows have the same value in that column, and that the column is greater
1315/// than or equal to the column value coming in. For rows not sorted, these
1316/// checks become no-ops.
1317trait OrderingChecker: Clone + Send + Sync {
1318    /// Check any invariants locally, updating the state of the checker when
1319    /// doing so.
1320    fn check_local(&mut self, row: &[Value]);
1321    /// Combine the states of multiple checkers, returning a new checker with
1322    /// all information assimilated. This is the checker that is suitable for
1323    /// calling `update_offsets` with.
1324    fn check_global<'a>(checkers: impl Iterator<Item = &'a Self>) -> Self
1325    where
1326        Self: 'a;
1327    /// Update the sorted offset vector with the current state of the checker.
1328    fn update_offsets(&self, start: RowId, offsets: &mut Vec<(Value, RowId)>);
1329}
1330
1331impl OrderingChecker for () {
1332    fn check_local(&mut self, _: &[Value]) {}
1333    fn check_global<'a>(_: impl Iterator<Item = &'a ()>) {}
1334    fn update_offsets(&self, _: RowId, _: &mut Vec<(Value, RowId)>) {}
1335}
1336
1337#[derive(Copy, Clone)]
1338struct SortChecker {
1339    col: ColumnId,
1340    baseline: Option<Value>,
1341    current: Option<Value>,
1342}
1343
1344impl OrderingChecker for SortChecker {
1345    fn check_local(&mut self, row: &[Value]) {
1346        let val = row[self.col.index()];
1347        if let Some(cur) = self.current {
1348            assert_eq!(
1349                cur, val,
1350                "concurrently inserting rows with different sort keys"
1351            );
1352        } else {
1353            self.current = Some(val);
1354            if let Some(baseline) = self.baseline {
1355                assert!(val >= baseline, "inserted row violates sort order");
1356            }
1357        }
1358    }
1359
1360    fn check_global<'a>(mut checkers: impl Iterator<Item = &'a Self>) -> Self {
1361        let Some(start) = checkers.next() else {
1362            return SortChecker {
1363                col: ColumnId::new(!0),
1364                baseline: None,
1365                current: None,
1366            };
1367        };
1368        let mut expected = start.current;
1369        for checker in checkers {
1370            assert_eq!(checker.baseline, start.baseline);
1371            match (&mut expected, checker.current) {
1372                (None, None) => {}
1373                (cur @ None, Some(x)) => {
1374                    *cur = Some(x);
1375                }
1376                (Some(_), None) => {}
1377                (Some(x), Some(y)) => {
1378                    assert_eq!(
1379                        *x, y,
1380                        "concurrently inserting rows with different sort keys"
1381                    );
1382                }
1383            }
1384        }
1385        SortChecker {
1386            col: start.col,
1387            baseline: start.baseline,
1388            current: expected,
1389        }
1390    }
1391
1392    fn update_offsets(&self, start: RowId, offsets: &mut Vec<(Value, RowId)>) {
1393        if let Some(cur) = self.current {
1394            if let Some((max, _)) = offsets.last() {
1395                if cur > *max {
1396                    offsets.push((cur, start));
1397                }
1398            } else {
1399                offsets.push((cur, start));
1400            }
1401        }
1402    }
1403}
1404
1405/// A type similar to a SortedWritesTable used to buffer outputs. The main thing
1406/// that StagedOutputs handles is running the merge function for a table on
1407/// multiple updates to the same key that show up in the same round of
1408/// insertions.
1409struct StagedOutputs {
1410    shard_data: ShardData,
1411    n_keys: usize,
1412    hash: Pooled<HashTable<TableEntry>>,
1413    rows: RowBuffer,
1414    n_stale: usize,
1415    scratch: Pooled<Vec<Value>>,
1416}
1417
1418impl StagedOutputs {
1419    fn rows(&self) -> impl Iterator<Item = &[Value]> {
1420        self.rows.non_stale()
1421    }
1422    fn new(n_keys: usize, n_cols: usize, capacity: usize) -> Self {
1423        let mut res = with_pool_set(|ps| StagedOutputs {
1424            shard_data: ShardData::new(1),
1425            n_keys,
1426            n_stale: 0,
1427            hash: ps.get(),
1428            rows: RowBuffer::new(n_cols),
1429            scratch: ps.get(),
1430        });
1431        res.hash.reserve(capacity, TableEntry::hashcode);
1432        res.rows.reserve(capacity);
1433        res
1434    }
1435    fn clear(&mut self) {
1436        self.hash.clear();
1437        self.rows.clear();
1438        self.n_stale = 0;
1439    }
1440    fn len(&self) -> usize {
1441        self.rows.len() - self.n_stale
1442    }
1443
1444    fn insert(
1445        &mut self,
1446        row: &[Value],
1447        mut merge_fn: impl FnMut(&[Value], &[Value], &mut Vec<Value>) -> bool,
1448    ) {
1449        if row[0].is_stale() {
1450            return;
1451        }
1452        use hashbrown::hash_table::Entry;
1453        let (_, hc) = hash_code(self.shard_data, row, self.n_keys);
1454        let entry = self.hash.entry(
1455            hc,
1456            |te| {
1457                te.hashcode() == hc
1458                    && self.rows.get_row(te.row)[0..self.n_keys] == row[0..self.n_keys]
1459            },
1460            TableEntry::hashcode,
1461        );
1462        match entry {
1463            Entry::Occupied(mut occupied_entry) => {
1464                let cur = self.rows.get_row(occupied_entry.get().row);
1465                if merge_fn(cur, row, &mut self.scratch) {
1466                    let new = self.rows.add_row(&self.scratch);
1467                    self.rows.set_stale(occupied_entry.get().row);
1468                    self.n_stale += 1;
1469                    occupied_entry.get_mut().row = new;
1470                }
1471                self.scratch.clear();
1472            }
1473            Entry::Vacant(vacant_entry) => {
1474                let next = self.rows.add_row(row);
1475                vacant_entry.insert(TableEntry {
1476                    hashcode: hc as _,
1477                    row: next,
1478                });
1479            }
1480        }
1481    }
1482
1483    /// Write the contents of the staged outputs to the given writer, returning
1484    /// the initial RowId of the new output.
1485    fn write_output(&self, output: &ParallelRowBufWriter) -> RowId {
1486        let n_rows = self.rows.len() - self.n_stale;
1487        let n_vals = n_rows * self.rows.arity();
1488        output.write_raw_values(
1489            WithExactSize {
1490                iter: self.rows.non_stale().flatten().copied(),
1491                size: n_vals,
1492            },
1493            n_rows,
1494        )
1495    }
1496}
1497
1498/// A simple type used to attach a known size to an arbitrary iterator.
1499struct WithExactSize<I> {
1500    iter: I,
1501    size: usize,
1502}
1503
1504impl<I: Iterator> Iterator for WithExactSize<I> {
1505    type Item = I::Item;
1506
1507    fn next(&mut self) -> Option<Self::Item> {
1508        self.iter.next()
1509    }
1510
1511    fn size_hint(&self) -> (usize, Option<usize>) {
1512        self.iter.size_hint()
1513    }
1514}
1515
1516impl<I: Iterator> ExactSizeIterator for WithExactSize<I> {
1517    fn len(&self) -> usize {
1518        self.size
1519    }
1520}