kyu_executor/operators/
sort.rs1use kyu_common::KyuResult;
4use kyu_expression::{BoundExpression, evaluate};
5use kyu_parser::ast::SortOrder;
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct OrderByOp {
13 pub child: Box<PhysicalOperator>,
14 pub order_by: Vec<(BoundExpression, SortOrder)>,
15 result: Option<DataChunk>,
16}
17
18impl OrderByOp {
19 pub fn new(child: PhysicalOperator, order_by: Vec<(BoundExpression, SortOrder)>) -> Self {
20 Self {
21 child: Box::new(child),
22 order_by,
23 result: None,
24 }
25 }
26
27 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
28 if self.result.is_some() {
29 return Ok(None);
30 }
31
32 let mut all_chunks: Vec<DataChunk> = Vec::new();
34 while let Some(chunk) = self.child.next(ctx)? {
35 all_chunks.push(chunk);
36 }
37
38 let total_rows: usize = all_chunks.iter().map(|c| c.num_rows()).sum();
39 if total_rows == 0 {
40 self.result = Some(DataChunk::empty(0));
41 return Ok(None);
42 }
43
44 let num_cols = all_chunks[0].num_columns();
45
46 let num_keys = self.order_by.len();
48 let mut row_locs: Vec<(usize, usize)> = Vec::with_capacity(total_rows);
49 let mut sort_keys: Vec<Vec<TypedValue>> = Vec::with_capacity(total_rows);
50 for (ci, chunk) in all_chunks.iter().enumerate() {
51 for row_idx in 0..chunk.num_rows() {
52 row_locs.push((ci, row_idx));
53 let row_ref = chunk.row_ref(row_idx);
54 let keys: Vec<TypedValue> = self
55 .order_by
56 .iter()
57 .map(|(expr, _)| evaluate(expr, &row_ref))
58 .collect::<KyuResult<_>>()?;
59 sort_keys.push(keys);
60 }
61 }
62
63 let mut indices: Vec<usize> = (0..total_rows).collect();
65 let order_specs: Vec<SortOrder> = self.order_by.iter().map(|(_, order)| *order).collect();
66 indices.sort_by(|&a, &b| {
67 #[allow(clippy::needless_range_loop)]
68 for i in 0..num_keys {
69 let cmp = compare_values(&sort_keys[a][i], &sort_keys[b][i]);
70 let cmp = match order_specs.get(i) {
71 Some(SortOrder::Descending) => cmp.reverse(),
72 _ => cmp,
73 };
74 if cmp != std::cmp::Ordering::Equal {
75 return cmp;
76 }
77 }
78 std::cmp::Ordering::Equal
79 });
80
81 let mut result_chunk = DataChunk::with_capacity(num_cols, total_rows);
83 for &idx in &indices {
84 let (ci, ri) = row_locs[idx];
85 result_chunk.append_row_from_chunk(&all_chunks[ci], ri);
86 }
87
88 self.result = Some(DataChunk::empty(0));
89 Ok(Some(result_chunk))
90 }
91}
92
93fn compare_values(a: &TypedValue, b: &TypedValue) -> std::cmp::Ordering {
94 match (a, b) {
95 (TypedValue::Null, TypedValue::Null) => std::cmp::Ordering::Equal,
96 (TypedValue::Null, _) => std::cmp::Ordering::Greater, (_, TypedValue::Null) => std::cmp::Ordering::Less,
98 (TypedValue::Int64(a), TypedValue::Int64(b)) => a.cmp(b),
99 (TypedValue::Int32(a), TypedValue::Int32(b)) => a.cmp(b),
100 (TypedValue::Double(a), TypedValue::Double(b)) => {
101 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
102 }
103 (TypedValue::Float(a), TypedValue::Float(b)) => {
104 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
105 }
106 (TypedValue::String(a), TypedValue::String(b)) => a.cmp(b),
107 (TypedValue::Bool(a), TypedValue::Bool(b)) => a.cmp(b),
108 _ => std::cmp::Ordering::Equal,
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use crate::context::MockStorage;
116 use kyu_types::LogicalType;
117
118 #[test]
119 fn sort_ascending() {
120 let mut storage = MockStorage::new();
121 storage.insert_table(
122 kyu_common::id::TableId(0),
123 vec![
124 vec![TypedValue::Int64(30)],
125 vec![TypedValue::Int64(10)],
126 vec![TypedValue::Int64(20)],
127 ],
128 );
129 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
130
131 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
132 kyu_common::id::TableId(0),
133 ));
134 let mut sort = OrderByOp::new(
135 scan,
136 vec![(
137 BoundExpression::Variable {
138 index: 0,
139 result_type: LogicalType::Int64,
140 },
141 SortOrder::Ascending,
142 )],
143 );
144 let chunk = sort.next(&ctx).unwrap().unwrap();
145 assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(10));
146 assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
147 assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(30));
148 }
149
150 #[test]
151 fn sort_descending() {
152 let mut storage = MockStorage::new();
153 storage.insert_table(
154 kyu_common::id::TableId(0),
155 vec![
156 vec![TypedValue::Int64(10)],
157 vec![TypedValue::Int64(30)],
158 vec![TypedValue::Int64(20)],
159 ],
160 );
161 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
162
163 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
164 kyu_common::id::TableId(0),
165 ));
166 let mut sort = OrderByOp::new(
167 scan,
168 vec![(
169 BoundExpression::Variable {
170 index: 0,
171 result_type: LogicalType::Int64,
172 },
173 SortOrder::Descending,
174 )],
175 );
176 let chunk = sort.next(&ctx).unwrap().unwrap();
177 assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(30));
178 assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
179 assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(10));
180 }
181}