1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::sync::Arc;
4
5use arrow::array::{Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array};
6use arrow::array::{Float32Array, Float64Array, UInt32Builder};
7use arrow::array::{StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
8use arrow::datatypes::{DataType, Schema};
9use arrow::record_batch::RecordBatch;
10
11use crate::ops::SortOptions;
12use crate::{DataFrameError, Result};
13
14#[derive(Clone)]
15struct RowKey {
16 index: usize,
17 values: Vec<Option<SortValue>>,
18}
19
20#[derive(Clone, Debug, PartialEq)]
21enum SortValue {
22 Boolean(bool),
23 Signed(i128),
24 Unsigned(u128),
25 Float64(f64),
26 Utf8(String),
27}
28
29pub fn sort_batches(input: Vec<RecordBatch>, options: &SortOptions) -> Result<Vec<RecordBatch>> {
30 let batch = concat_batches(&input)?;
31 if batch.num_rows() == 0 {
32 return Ok(vec![batch]);
33 }
34 if options.by.is_empty() {
35 return Err(DataFrameError::invalid_operation("sort requires columns"));
36 }
37 if options.by.len() != options.descending.len() {
38 return Err(DataFrameError::invalid_operation(
39 "descending length must match sort columns",
40 ));
41 }
42
43 let columns = build_sort_columns(&batch, &options.by)?;
44
45 let mut keys = Vec::with_capacity(batch.num_rows());
46 for row in 0..batch.num_rows() {
47 let mut values = Vec::with_capacity(columns.len());
48 for col in &columns {
49 values.push(col.value(row)?);
50 }
51 keys.push(RowKey { index: row, values });
52 }
53
54 keys.sort_by(|a, b| compare_keys(a, b, &options.descending));
55
56 let index_array = build_indices(keys.iter().map(|k| k.index))?;
57 let mut arrays = Vec::with_capacity(batch.num_columns());
58 for col in batch.columns() {
59 let array = arrow::compute::take(col.as_ref(), &index_array, None)
60 .map_err(|source| DataFrameError::Arrow { source })?;
61 arrays.push(array);
62 }
63
64 let batch = RecordBatch::try_new(batch.schema(), arrays).map_err(|e| {
65 DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
66 })?;
67 Ok(vec![batch])
68}
69
70pub fn slice_batches(
71 input: Vec<RecordBatch>,
72 offset: usize,
73 len: usize,
74 from_end: bool,
75) -> Result<Vec<RecordBatch>> {
76 let batch = concat_batches(&input)?;
77 let total = batch.num_rows();
78 if total == 0 || len == 0 {
79 return Ok(vec![batch.slice(0, 0)]);
80 }
81
82 let start = if from_end {
83 total.saturating_sub(offset + len)
84 } else {
85 offset
86 };
87 if start >= total {
88 return Ok(vec![batch.slice(0, 0)]);
89 }
90 let end = std::cmp::min(start + len, total);
91 Ok(vec![batch.slice(start, end - start)])
92}
93
94fn concat_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
95 if batches.is_empty() {
96 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
97 }
98 let schema = batches[0].schema();
99 if batches.len() == 1 {
100 return Ok(batches[0].clone());
101 }
102 arrow::compute::concat_batches(&schema, batches)
103 .map_err(|source| DataFrameError::Arrow { source })
104}
105
106struct SortColumn {
107 name: String,
108 data: SortColumnData,
109}
110
111enum SortColumnData {
112 Boolean(Arc<BooleanArray>),
113 Int8(Arc<Int8Array>),
114 Int16(Arc<Int16Array>),
115 Int32(Arc<Int32Array>),
116 Int64(Arc<Int64Array>),
117 UInt8(Arc<UInt8Array>),
118 UInt16(Arc<UInt16Array>),
119 UInt32(Arc<UInt32Array>),
120 UInt64(Arc<UInt64Array>),
121 Float32(Arc<Float32Array>),
122 Float64(Arc<Float64Array>),
123 Utf8(Arc<StringArray>),
124}
125
126impl SortColumn {
127 fn value(&self, row: usize) -> Result<Option<SortValue>> {
128 match &self.data {
129 SortColumnData::Boolean(array) => {
130 if array.is_null(row) {
131 Ok(None)
132 } else {
133 Ok(Some(SortValue::Boolean(array.value(row))))
134 }
135 }
136 SortColumnData::Int8(array) => {
137 if array.is_null(row) {
138 Ok(None)
139 } else {
140 Ok(Some(SortValue::Signed(array.value(row) as i128)))
141 }
142 }
143 SortColumnData::Int16(array) => {
144 if array.is_null(row) {
145 Ok(None)
146 } else {
147 Ok(Some(SortValue::Signed(array.value(row) as i128)))
148 }
149 }
150 SortColumnData::Int32(array) => {
151 if array.is_null(row) {
152 Ok(None)
153 } else {
154 Ok(Some(SortValue::Signed(array.value(row) as i128)))
155 }
156 }
157 SortColumnData::Int64(array) => {
158 if array.is_null(row) {
159 Ok(None)
160 } else {
161 Ok(Some(SortValue::Signed(array.value(row) as i128)))
162 }
163 }
164 SortColumnData::UInt8(array) => {
165 if array.is_null(row) {
166 Ok(None)
167 } else {
168 Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
169 }
170 }
171 SortColumnData::UInt16(array) => {
172 if array.is_null(row) {
173 Ok(None)
174 } else {
175 Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
176 }
177 }
178 SortColumnData::UInt32(array) => {
179 if array.is_null(row) {
180 Ok(None)
181 } else {
182 Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
183 }
184 }
185 SortColumnData::UInt64(array) => {
186 if array.is_null(row) {
187 Ok(None)
188 } else {
189 Ok(Some(SortValue::Unsigned(array.value(row) as u128)))
190 }
191 }
192 SortColumnData::Float32(array) => {
193 if array.is_null(row) {
194 Ok(None)
195 } else {
196 Ok(Some(SortValue::Float64(array.value(row) as f64)))
197 }
198 }
199 SortColumnData::Float64(array) => {
200 if array.is_null(row) {
201 Ok(None)
202 } else {
203 Ok(Some(SortValue::Float64(array.value(row))))
204 }
205 }
206 SortColumnData::Utf8(array) => {
207 if array.is_null(row) {
208 Ok(None)
209 } else {
210 Ok(Some(SortValue::Utf8(array.value(row).to_string())))
211 }
212 }
213 }
214 }
215}
216
217fn build_sort_columns(batch: &RecordBatch, by: &[String]) -> Result<Vec<SortColumn>> {
218 let mut columns = Vec::with_capacity(by.len());
219
220 for name in by {
221 let idx = batch
222 .schema()
223 .fields()
224 .iter()
225 .position(|f| f.name() == name)
226 .ok_or_else(|| DataFrameError::column_not_found(name.clone()))?;
227 let array = batch.column(idx);
228 let data = match array.data_type() {
229 DataType::Boolean => SortColumnData::Boolean(Arc::new(
230 array
231 .as_any()
232 .downcast_ref::<BooleanArray>()
233 .ok_or_else(|| DataFrameError::invalid_operation("bad BooleanArray"))?
234 .clone(),
235 )),
236 DataType::Int8 => SortColumnData::Int8(Arc::new(
237 array
238 .as_any()
239 .downcast_ref::<Int8Array>()
240 .ok_or_else(|| DataFrameError::invalid_operation("bad Int8Array"))?
241 .clone(),
242 )),
243 DataType::Int16 => SortColumnData::Int16(Arc::new(
244 array
245 .as_any()
246 .downcast_ref::<Int16Array>()
247 .ok_or_else(|| DataFrameError::invalid_operation("bad Int16Array"))?
248 .clone(),
249 )),
250 DataType::Int32 => SortColumnData::Int32(Arc::new(
251 array
252 .as_any()
253 .downcast_ref::<Int32Array>()
254 .ok_or_else(|| DataFrameError::invalid_operation("bad Int32Array"))?
255 .clone(),
256 )),
257 DataType::Int64 => SortColumnData::Int64(Arc::new(
258 array
259 .as_any()
260 .downcast_ref::<Int64Array>()
261 .ok_or_else(|| DataFrameError::invalid_operation("bad Int64Array"))?
262 .clone(),
263 )),
264 DataType::UInt8 => SortColumnData::UInt8(Arc::new(
265 array
266 .as_any()
267 .downcast_ref::<UInt8Array>()
268 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt8Array"))?
269 .clone(),
270 )),
271 DataType::UInt16 => SortColumnData::UInt16(Arc::new(
272 array
273 .as_any()
274 .downcast_ref::<UInt16Array>()
275 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt16Array"))?
276 .clone(),
277 )),
278 DataType::UInt32 => SortColumnData::UInt32(Arc::new(
279 array
280 .as_any()
281 .downcast_ref::<UInt32Array>()
282 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt32Array"))?
283 .clone(),
284 )),
285 DataType::UInt64 => SortColumnData::UInt64(Arc::new(
286 array
287 .as_any()
288 .downcast_ref::<UInt64Array>()
289 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt64Array"))?
290 .clone(),
291 )),
292 DataType::Float32 => SortColumnData::Float32(Arc::new(
293 array
294 .as_any()
295 .downcast_ref::<Float32Array>()
296 .ok_or_else(|| DataFrameError::invalid_operation("bad Float32Array"))?
297 .clone(),
298 )),
299 DataType::Float64 => SortColumnData::Float64(Arc::new(
300 array
301 .as_any()
302 .downcast_ref::<Float64Array>()
303 .ok_or_else(|| DataFrameError::invalid_operation("bad Float64Array"))?
304 .clone(),
305 )),
306 DataType::Utf8 => SortColumnData::Utf8(Arc::new(
307 array
308 .as_any()
309 .downcast_ref::<StringArray>()
310 .ok_or_else(|| DataFrameError::invalid_operation("bad StringArray"))?
311 .clone(),
312 )),
313 other => {
314 return Err(DataFrameError::invalid_operation(format!(
315 "unsupported sort type {other:?}",
316 )))
317 }
318 };
319
320 columns.push(SortColumn {
321 name: name.clone(),
322 data,
323 });
324 }
325
326 let mut seen = HashSet::new();
327 for col in &columns {
328 if !seen.insert(col.name.clone()) {
329 return Err(DataFrameError::invalid_operation("duplicate sort column"));
330 }
331 }
332
333 Ok(columns)
334}
335
336fn compare_keys(a: &RowKey, b: &RowKey, descending: &[bool]) -> Ordering {
337 for (idx, (av, bv)) in a.values.iter().zip(b.values.iter()).enumerate() {
338 match (av, bv) {
339 (None, None) => continue,
340 (None, Some(_)) => return Ordering::Greater,
341 (Some(_), None) => return Ordering::Less,
342 (Some(av), Some(bv)) => {
343 let mut ord = compare_value(av, bv);
344 if descending[idx] {
345 ord = ord.reverse();
346 }
347 if ord != Ordering::Equal {
348 return ord;
349 }
350 }
351 }
352 }
353 Ordering::Equal
354}
355
356fn compare_value(a: &SortValue, b: &SortValue) -> Ordering {
357 match (a, b) {
358 (SortValue::Boolean(a), SortValue::Boolean(b)) => a.cmp(b),
359 (SortValue::Signed(a), SortValue::Signed(b)) => a.cmp(b),
360 (SortValue::Unsigned(a), SortValue::Unsigned(b)) => a.cmp(b),
361 (SortValue::Float64(a), SortValue::Float64(b)) => a.total_cmp(b),
362 (SortValue::Utf8(a), SortValue::Utf8(b)) => a.cmp(b),
363 _ => Ordering::Equal,
364 }
365}
366
367fn build_indices<I>(indices: I) -> Result<arrow::array::UInt32Array>
368where
369 I: IntoIterator<Item = usize>,
370{
371 let iter = indices.into_iter();
372 let (lower, _) = iter.size_hint();
373 let mut builder = UInt32Builder::with_capacity(lower);
374 for idx in iter {
375 let value = u32::try_from(idx)
376 .map_err(|_| DataFrameError::invalid_operation("row index exceeds u32 range"))?;
377 builder.append_value(value);
378 }
379 Ok(builder.finish())
380}