Skip to main content

alopex_dataframe/ops/
sort.rs

1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::sync::Arc;
4
5use arrow::array::{Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array};
6use arrow::array::{Float32Array, Float64Array, UInt32Builder};
7use arrow::array::{StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
8use arrow::datatypes::{DataType, Schema};
9use arrow::record_batch::RecordBatch;
10
11use crate::ops::SortOptions;
12use crate::{DataFrameError, Result};
13
14#[derive(Clone)]
15struct RowKey {
16    index: usize,
17    values: Vec<Option<SortValue>>,
18}
19
20#[derive(Clone, Debug, PartialEq)]
21enum SortValue {
22    Boolean(bool),
23    Signed(i128),
24    Unsigned(u128),
25    Float64(f64),
26    Utf8(String),
27}
28
29pub fn sort_batches(input: Vec<RecordBatch>, options: &SortOptions) -> Result<Vec<RecordBatch>> {
30    let batch = concat_batches(&input)?;
31    if batch.num_rows() == 0 {
32        return Ok(vec![batch]);
33    }
34    if options.by.is_empty() {
35        return Err(DataFrameError::invalid_operation("sort requires columns"));
36    }
37    if options.by.len() != options.descending.len() {
38        return Err(DataFrameError::invalid_operation(
39            "descending length must match sort columns",
40        ));
41    }
42
43    let columns = build_sort_columns(&batch, &options.by)?;
44
45    let mut keys = Vec::with_capacity(batch.num_rows());
46    for row in 0..batch.num_rows() {
47        let mut values = Vec::with_capacity(columns.len());
48        for col in &columns {
49            values.push(col.value(row)?);
50        }
51        keys.push(RowKey { index: row, values });
52    }
53
54    keys.sort_by(|a, b| compare_keys(a, b, &options.descending));
55
56    let index_array = build_indices(keys.iter().map(|k| k.index))?;
57    let mut arrays = Vec::with_capacity(batch.num_columns());
58    for col in batch.columns() {
59        let array = arrow::compute::take(col.as_ref(), &index_array, None)
60            .map_err(|source| DataFrameError::Arrow { source })?;
61        arrays.push(array);
62    }
63
64    let batch = RecordBatch::try_new(batch.schema(), arrays).map_err(|e| {
65        DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
66    })?;
67    Ok(vec![batch])
68}
69
70pub fn slice_batches(
71    input: Vec<RecordBatch>,
72    offset: usize,
73    len: usize,
74    from_end: bool,
75) -> Result<Vec<RecordBatch>> {
76    let batch = concat_batches(&input)?;
77    let total = batch.num_rows();
78    if total == 0 || len == 0 {
79        return Ok(vec![batch.slice(0, 0)]);
80    }
81
82    let start = if from_end {
83        total.saturating_sub(offset + len)
84    } else {
85        offset
86    };
87    if start >= total {
88        return Ok(vec![batch.slice(0, 0)]);
89    }
90    let end = std::cmp::min(start + len, total);
91    Ok(vec![batch.slice(start, end - start)])
92}
93
94fn concat_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
95    if batches.is_empty() {
96        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
97    }
98    let schema = batches[0].schema();
99    if batches.len() == 1 {
100        return Ok(batches[0].clone());
101    }
102    arrow::compute::concat_batches(&schema, batches)
103        .map_err(|source| DataFrameError::Arrow { source })
104}
105
106struct SortColumn {
107    name: String,
108    data: SortColumnData,
109}
110
111enum SortColumnData {
112    Boolean(Arc<BooleanArray>),
113    Int8(Arc<Int8Array>),
114    Int16(Arc<Int16Array>),
115    Int32(Arc<Int32Array>),
116    Int64(Arc<Int64Array>),
117    UInt8(Arc<UInt8Array>),
118    UInt16(Arc<UInt16Array>),
119    UInt32(Arc<UInt32Array>),
120    UInt64(Arc<UInt64Array>),
121    Float32(Arc<Float32Array>),
122    Float64(Arc<Float64Array>),
123    Utf8(Arc<StringArray>),
124}
125
126impl SortColumn {
127    fn value(&self, row: usize) -> Result<Option<SortValue>> {
128        match &self.data {
129            SortColumnData::Boolean(array) => {
130                if array.is_null(row) {
131                    Ok(None)
132                } else {
133                    Ok(Some(SortValue::Boolean(array.value(row))))
134                }
135            }
136            SortColumnData::Int8(array) => {
137                if array.is_null(row) {
138                    Ok(None)
139                } else {
140                    Ok(Some(SortValue::Signed(array.value(row) as i128)))
141                }
142            }
143            SortColumnData::Int16(array) => {
144                if array.is_null(row) {
145                    Ok(None)
146                } else {
147                    Ok(Some(SortValue::Signed(array.value(row) as i128)))
148                }
149            }
150            SortColumnData::Int32(array) => {
151                if array.is_null(row) {
152                    Ok(None)
153                } else {
154                    Ok(Some(SortValue::Signed(array.value(row) as i128)))
155                }
156            }
157            SortColumnData::Int64(array) => {
158                if array.is_null(row) {
159                    Ok(None)
160                } else {
161                    Ok(Some(SortValue::Signed(array.value(row) as i128)))
162                }
163            }
164            SortColumnData::UInt8(array) => {
165                if array.is_null(row) {
166                    Ok(None)
167                } else {
168                    Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
169                }
170            }
171            SortColumnData::UInt16(array) => {
172                if array.is_null(row) {
173                    Ok(None)
174                } else {
175                    Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
176                }
177            }
178            SortColumnData::UInt32(array) => {
179                if array.is_null(row) {
180                    Ok(None)
181                } else {
182                    Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
183                }
184            }
185            SortColumnData::UInt64(array) => {
186                if array.is_null(row) {
187                    Ok(None)
188                } else {
189                    Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
190                }
191            }
192            SortColumnData::Float32(array) => {
193                if array.is_null(row) {
194                    Ok(None)
195                } else {
196                    Ok(Some(SortValue::Float64(array.value(row) as f64)))
197                }
198            }
199            SortColumnData::Float64(array) => {
200                if array.is_null(row) {
201                    Ok(None)
202                } else {
203                    Ok(Some(SortValue::Float64(array.value(row))))
204                }
205            }
206            SortColumnData::Utf8(array) => {
207                if array.is_null(row) {
208                    Ok(None)
209                } else {
210                    Ok(Some(SortValue::Utf8(array.value(row).to_string())))
211                }
212            }
213        }
214    }
215}
216
217fn build_sort_columns(batch: &RecordBatch, by: &[String]) -> Result<Vec<SortColumn>> {
218    let mut columns = Vec::with_capacity(by.len());
219
220    for name in by {
221        let idx = batch
222            .schema()
223            .fields()
224            .iter()
225            .position(|f| f.name() == name)
226            .ok_or_else(|| DataFrameError::column_not_found(name.clone()))?;
227        let array = batch.column(idx);
228        let data = match array.data_type() {
229            DataType::Boolean => SortColumnData::Boolean(Arc::new(
230                array
231                    .as_any()
232                    .downcast_ref::<BooleanArray>()
233                    .ok_or_else(|| DataFrameError::invalid_operation("bad BooleanArray"))?
234                    .clone(),
235            )),
236            DataType::Int8 => SortColumnData::Int8(Arc::new(
237                array
238                    .as_any()
239                    .downcast_ref::<Int8Array>()
240                    .ok_or_else(|| DataFrameError::invalid_operation("bad Int8Array"))?
241                    .clone(),
242            )),
243            DataType::Int16 => SortColumnData::Int16(Arc::new(
244                array
245                    .as_any()
246                    .downcast_ref::<Int16Array>()
247                    .ok_or_else(|| DataFrameError::invalid_operation("bad Int16Array"))?
248                    .clone(),
249            )),
250            DataType::Int32 => SortColumnData::Int32(Arc::new(
251                array
252                    .as_any()
253                    .downcast_ref::<Int32Array>()
254                    .ok_or_else(|| DataFrameError::invalid_operation("bad Int32Array"))?
255                    .clone(),
256            )),
257            DataType::Int64 => SortColumnData::Int64(Arc::new(
258                array
259                    .as_any()
260                    .downcast_ref::<Int64Array>()
261                    .ok_or_else(|| DataFrameError::invalid_operation("bad Int64Array"))?
262                    .clone(),
263            )),
264            DataType::UInt8 => SortColumnData::UInt8(Arc::new(
265                array
266                    .as_any()
267                    .downcast_ref::<UInt8Array>()
268                    .ok_or_else(|| DataFrameError::invalid_operation("bad UInt8Array"))?
269                    .clone(),
270            )),
271            DataType::UInt16 => SortColumnData::UInt16(Arc::new(
272                array
273                    .as_any()
274                    .downcast_ref::<UInt16Array>()
275                    .ok_or_else(|| DataFrameError::invalid_operation("bad UInt16Array"))?
276                    .clone(),
277            )),
278            DataType::UInt32 => SortColumnData::UInt32(Arc::new(
279                array
280                    .as_any()
281                    .downcast_ref::<UInt32Array>()
282                    .ok_or_else(|| DataFrameError::invalid_operation("bad UInt32Array"))?
283                    .clone(),
284            )),
285            DataType::UInt64 => SortColumnData::UInt64(Arc::new(
286                array
287                    .as_any()
288                    .downcast_ref::<UInt64Array>()
289                    .ok_or_else(|| DataFrameError::invalid_operation("bad UInt64Array"))?
290                    .clone(),
291            )),
292            DataType::Float32 => SortColumnData::Float32(Arc::new(
293                array
294                    .as_any()
295                    .downcast_ref::<Float32Array>()
296                    .ok_or_else(|| DataFrameError::invalid_operation("bad Float32Array"))?
297                    .clone(),
298            )),
299            DataType::Float64 => SortColumnData::Float64(Arc::new(
300                array
301                    .as_any()
302                    .downcast_ref::<Float64Array>()
303                    .ok_or_else(|| DataFrameError::invalid_operation("bad Float64Array"))?
304                    .clone(),
305            )),
306            DataType::Utf8 => SortColumnData::Utf8(Arc::new(
307                array
308                    .as_any()
309                    .downcast_ref::<StringArray>()
310                    .ok_or_else(|| DataFrameError::invalid_operation("bad StringArray"))?
311                    .clone(),
312            )),
313            other => {
314                return Err(DataFrameError::invalid_operation(format!(
315                    "unsupported sort type {other:?}",
316                )))
317            }
318        };
319
320        columns.push(SortColumn {
321            name: name.clone(),
322            data,
323        });
324    }
325
326    let mut seen = HashSet::new();
327    for col in &columns {
328        if !seen.insert(col.name.clone()) {
329            return Err(DataFrameError::invalid_operation("duplicate sort column"));
330        }
331    }
332
333    Ok(columns)
334}
335
336fn compare_keys(a: &RowKey, b: &RowKey, descending: &[bool]) -> Ordering {
337    for (idx, (av, bv)) in a.values.iter().zip(b.values.iter()).enumerate() {
338        match (av, bv) {
339            (None, None) => continue,
340            (None, Some(_)) => return Ordering::Greater,
341            (Some(_), None) => return Ordering::Less,
342            (Some(av), Some(bv)) => {
343                let mut ord = compare_value(av, bv);
344                if descending[idx] {
345                    ord = ord.reverse();
346                }
347                if ord != Ordering::Equal {
348                    return ord;
349                }
350            }
351        }
352    }
353    Ordering::Equal
354}
355
356fn compare_value(a: &SortValue, b: &SortValue) -> Ordering {
357    match (a, b) {
358        (SortValue::Boolean(a), SortValue::Boolean(b)) => a.cmp(b),
359        (SortValue::Signed(a), SortValue::Signed(b)) => a.cmp(b),
360        (SortValue::Unsigned(a), SortValue::Unsigned(b)) => a.cmp(b),
361        (SortValue::Float64(a), SortValue::Float64(b)) => a.total_cmp(b),
362        (SortValue::Utf8(a), SortValue::Utf8(b)) => a.cmp(b),
363        _ => Ordering::Equal,
364    }
365}
366
367fn build_indices<I>(indices: I) -> Result<arrow::array::UInt32Array>
368where
369    I: IntoIterator<Item = usize>,
370{
371    let iter = indices.into_iter();
372    let (lower, _) = iter.size_hint();
373    let mut builder = UInt32Builder::with_capacity(lower);
374    for idx in iter {
375        let value = u32::try_from(idx)
376            .map_err(|_| DataFrameError::invalid_operation("row index exceeds u32 range"))?;
377        builder.append_value(value);
378    }
379    Ok(builder.finish())
380}