Skip to main content

arrow_view_state/
sort_builder.rs

1//! Streaming sort index builder.
2//!
3//! Accepts `RecordBatch`es one at a time via [`SortBuilder::push`], encodes sort keys with
4//! `RowConverter`, and produces a [`PermutationIndex`] on [`SortBuilder::finish`].
5
6use arrow_array::{RecordBatch, UInt32Array};
7use arrow_row::{RowConverter, SortField};
8use tracing::debug;
9
10use crate::error::IndexError;
11use crate::permutation::PermutationIndex;
12use crate::storage::PermutationStorage;
13
14// Re-export `SortField` for caller convenience (it's from `arrow_row`).
15pub use arrow_row::SortField as ArrowSortField;
16
17/// Streaming sort index builder.
18///
19/// Create with [`SortBuilder::new`], feed batches via [`SortBuilder::push`],
20/// then call [`SortBuilder::finish`] to produce the sorted [`PermutationIndex`].
21///
22/// # Example
23///
24/// ```ignore
25/// let mut builder = SortBuilder::new(vec![
26///     SortField::new(DataType::Int64),
27///     SortField::new_with_options(DataType::Utf8, SortOptions { descending: true, .. }),
28/// ])?;
29///
30/// for batch in record_batch_stream {
31///     builder.push(&batch, &[0, 2])?;  // sort by columns 0 and 2
32/// }
33///
34/// let index = builder.finish()?;
35/// ```
36#[derive(Debug)]
37pub struct SortBuilder {
38    converter: RowConverter,
39    /// `(encoded_row_bytes, global_physical_id)` pairs accumulated across pushes.
40    entries: Vec<(Box<[u8]>, u32)>,
41    fields_len: usize,
42    global_offset: u64,
43}
44
45impl SortBuilder {
46    /// Create a new builder for a sort with the given fields.
47    ///
48    /// `fields` defines the sort key schema and direction. One [`SortField`] per
49    /// sort column, in priority order (first = primary, second = tiebreaker, …).
50    ///
51    /// # Errors
52    ///
53    /// - [`IndexError::EmptyColumns`] if `fields` is empty.
54    /// - [`IndexError::RowEncodingFailed`] if the `RowConverter` cannot be created.
55    pub fn new(fields: Vec<SortField>) -> Result<Self, IndexError> {
56        if fields.is_empty() {
57            return Err(IndexError::EmptyColumns);
58        }
59        let fields_len = fields.len();
60        let converter = RowConverter::new(fields)?;
61        Ok(Self {
62            converter,
63            entries: Vec::new(),
64            fields_len,
65            global_offset: 0,
66        })
67    }
68
69    /// Ingest one [`RecordBatch`].
70    ///
71    /// `sort_columns` are the column indices within `batch` that correspond
72    /// to the [`SortField`]s passed to [`SortBuilder::new`]. Must have the same length.
73    ///
74    /// The batch's sort columns are encoded into fixed-width row keys; the
75    /// batch itself is NOT retained.
76    ///
77    /// # Errors
78    ///
79    /// - [`IndexError::LengthMismatch`] if `sort_columns.len() != fields.len()`.
80    /// - [`IndexError::TooManyRows`] if cumulative rows exceed `u32::MAX`.
81    /// - [`IndexError::RowEncodingFailed`] on encoding failure.
82    pub fn push(&mut self, batch: &RecordBatch, sort_columns: &[usize]) -> Result<(), IndexError> {
83        if sort_columns.len() != self.fields_len {
84            return Err(IndexError::LengthMismatch {
85                expected: self.fields_len as u64,
86                actual: sort_columns.len() as u64,
87            });
88        }
89
90        let n = batch.num_rows();
91        if n == 0 {
92            return Ok(());
93        }
94
95        let new_total = self.global_offset + n as u64;
96        if new_total > u64::from(u32::MAX) {
97            return Err(IndexError::TooManyRows(new_total));
98        }
99
100        let columns: Vec<_> = sort_columns
101            .iter()
102            .map(|&idx| batch.column(idx).clone())
103            .collect();
104
105        let rows = self.converter.convert_columns(&columns)?;
106
107        self.entries.reserve(n);
108        for i in 0..n {
109            let row = rows.row(i);
110            let bytes: Box<[u8]> = row.as_ref().into();
111            let global_id = u32::try_from(self.global_offset + i as u64)
112                .map_err(|_| IndexError::TooManyRows(self.global_offset + i as u64))?;
113            self.entries.push((bytes, global_id));
114        }
115
116        self.global_offset = new_total;
117        Ok(())
118    }
119
120    /// Total rows ingested so far.
121    pub fn rows_ingested(&self) -> u64 {
122        self.global_offset
123    }
124
125    /// Consume the builder, run parallel argsort, and produce a [`PermutationIndex`].
126    ///
127    /// # Errors
128    ///
129    /// - [`IndexError::EmptyColumns`] if no rows were pushed.
130    /// - [`IndexError::MmapError`] if mmap storage creation fails (when `mmap` feature is enabled).
131    pub fn finish(self) -> Result<PermutationIndex, IndexError> {
132        if self.entries.is_empty() {
133            return Err(IndexError::EmptyColumns);
134        }
135
136        let n = self.entries.len();
137        debug!(rows = n, "finalising sort index");
138
139        let mut ids: Vec<u32> = self.entries.iter().map(|(_, id)| *id).collect();
140        let entries = &self.entries;
141
142        #[cfg(feature = "parallel")]
143        {
144            use rayon::prelude::*;
145            ids.par_sort_unstable_by(|&a_id, &b_id| {
146                entries[a_id as usize].0.cmp(&entries[b_id as usize].0)
147            });
148        }
149        #[cfg(not(feature = "parallel"))]
150        {
151            ids.sort_unstable_by(|&a_id, &b_id| {
152                entries[a_id as usize].0.cmp(&entries[b_id as usize].0)
153            });
154        }
155
156        let storage = make_storage(ids, n as u64)?;
157        Ok(PermutationIndex::from_storage(storage))
158    }
159}
160
161/// Route to mmap or in-memory storage based on row count and feature flags.
162fn make_storage(ids: Vec<u32>, n: u64) -> Result<PermutationStorage, IndexError> {
163    #[cfg(feature = "mmap")]
164    if n > crate::mmap_builder::MMAP_THRESHOLD {
165        return crate::mmap_builder::write_mmap(&ids);
166    }
167    let _ = n;
168    Ok(PermutationStorage::InMemory(UInt32Array::from(ids)))
169}