1use crate::error::Result;
4use crate::executor::scan::{ColumnData, RecordBatch};
5use crate::parser::ast::{Expr, OrderByExpr};
6use std::cmp::Ordering;
7
8pub struct Sort {
10 pub order_by: Vec<OrderByExpr>,
12}
13
14impl Sort {
15 pub fn new(order_by: Vec<OrderByExpr>) -> Self {
17 Self { order_by }
18 }
19
20 pub fn execute(&self, batch: &RecordBatch) -> Result<RecordBatch> {
22 if self.order_by.is_empty() {
23 return Ok(batch.clone());
24 }
25
26 let mut indices: Vec<usize> = (0..batch.num_rows).collect();
28
29 indices.sort_by(|&a, &b| self.compare_rows(batch, a, b));
31
32 let mut sorted_columns = Vec::new();
34 for column in &batch.columns {
35 sorted_columns.push(self.reorder_column(column, &indices));
36 }
37
38 RecordBatch::new(batch.schema.clone(), sorted_columns, batch.num_rows)
39 }
40
41 fn compare_rows(&self, batch: &RecordBatch, a: usize, b: usize) -> Ordering {
43 for order in &self.order_by {
44 let ordering = self.compare_values(batch, &order.expr, a, b);
45 let ordering = if order.asc {
46 ordering
47 } else {
48 ordering.reverse()
49 };
50
51 if ordering != Ordering::Equal {
52 return ordering;
53 }
54 }
55 Ordering::Equal
56 }
57
58 fn compare_values(&self, batch: &RecordBatch, expr: &Expr, a: usize, b: usize) -> Ordering {
60 if let Expr::Column { name, .. } = expr {
62 if let Some(column) = batch.column_by_name(name) {
63 return self.compare_column_values(column, a, b);
64 }
65 }
66 Ordering::Equal
67 }
68
69 fn compare_column_values(&self, column: &ColumnData, a: usize, b: usize) -> Ordering {
71 match column {
72 ColumnData::Boolean(data) => Self::compare_optional(&data[a], &data[b]),
73 ColumnData::Int32(data) => Self::compare_optional(&data[a], &data[b]),
74 ColumnData::Int64(data) => Self::compare_optional(&data[a], &data[b]),
75 ColumnData::Float32(data) => {
76 let val_a = data[a];
77 let val_b = data[b];
78 match (val_a, val_b) {
79 (Some(a), Some(b)) => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
80 (Some(_), None) => Ordering::Less,
81 (None, Some(_)) => Ordering::Greater,
82 (None, None) => Ordering::Equal,
83 }
84 }
85 ColumnData::Float64(data) => {
86 let val_a = data[a];
87 let val_b = data[b];
88 match (val_a, val_b) {
89 (Some(a), Some(b)) => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
90 (Some(_), None) => Ordering::Less,
91 (None, Some(_)) => Ordering::Greater,
92 (None, None) => Ordering::Equal,
93 }
94 }
95 ColumnData::String(data) => Self::compare_optional(&data[a], &data[b]),
96 ColumnData::Binary(_) => Ordering::Equal,
97 }
98 }
99
100 fn compare_optional<T: Ord>(a: &Option<T>, b: &Option<T>) -> Ordering {
102 match (a, b) {
103 (Some(a), Some(b)) => a.cmp(b),
104 (Some(_), None) => Ordering::Less,
105 (None, Some(_)) => Ordering::Greater,
106 (None, None) => Ordering::Equal,
107 }
108 }
109
110 fn reorder_column(&self, column: &ColumnData, indices: &[usize]) -> ColumnData {
112 match column {
113 ColumnData::Boolean(data) => {
114 let reordered = indices.iter().map(|&i| data[i]).collect();
115 ColumnData::Boolean(reordered)
116 }
117 ColumnData::Int32(data) => {
118 let reordered = indices.iter().map(|&i| data[i]).collect();
119 ColumnData::Int32(reordered)
120 }
121 ColumnData::Int64(data) => {
122 let reordered = indices.iter().map(|&i| data[i]).collect();
123 ColumnData::Int64(reordered)
124 }
125 ColumnData::Float32(data) => {
126 let reordered = indices.iter().map(|&i| data[i]).collect();
127 ColumnData::Float32(reordered)
128 }
129 ColumnData::Float64(data) => {
130 let reordered = indices.iter().map(|&i| data[i]).collect();
131 ColumnData::Float64(reordered)
132 }
133 ColumnData::String(data) => {
134 let reordered = indices.iter().map(|&i| data[i].clone()).collect();
135 ColumnData::String(reordered)
136 }
137 ColumnData::Binary(data) => {
138 let reordered = indices.iter().map(|&i| data[i].clone()).collect();
139 ColumnData::Binary(reordered)
140 }
141 }
142 }
143}
144
145#[cfg(test)]
146#[allow(clippy::panic)]
147mod tests {
148 use super::*;
149 use crate::executor::scan::{DataType, Field, Schema};
150 use std::sync::Arc;
151
152 #[test]
153 fn test_sort_execution() -> Result<()> {
154 let schema = Arc::new(Schema::new(vec![
155 Field::new("id".to_string(), DataType::Int64, false),
156 Field::new("value".to_string(), DataType::Int64, false),
157 ]));
158
159 let columns = vec![
160 ColumnData::Int64(vec![Some(3), Some(1), Some(4), Some(1), Some(5)]),
161 ColumnData::Int64(vec![Some(30), Some(10), Some(40), Some(15), Some(50)]),
162 ];
163
164 let batch = RecordBatch::new(schema, columns, 5)?;
165
166 let order_by = vec![OrderByExpr {
168 expr: Expr::Column {
169 table: None,
170 name: "id".to_string(),
171 },
172 asc: true,
173 nulls_first: false,
174 }];
175
176 let sort = Sort::new(order_by);
177 let sorted = sort.execute(&batch)?;
178
179 let ColumnData::Int64(data) = &sorted.columns[0] else {
181 panic!("Expected Int64 column");
182 };
183 assert_eq!(data[0], Some(1));
184 assert_eq!(data[1], Some(1));
185 assert_eq!(data[2], Some(3));
186 assert_eq!(data[3], Some(4));
187 assert_eq!(data[4], Some(5));
188
189 Ok(())
190 }
191}