kyu_executor/operators/
sort.rs1use kyu_common::KyuResult;
4use kyu_expression::{evaluate, BoundExpression};
5use kyu_parser::ast::SortOrder;
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct OrderByOp {
13 pub child: Box<PhysicalOperator>,
14 pub order_by: Vec<(BoundExpression, SortOrder)>,
15 result: Option<DataChunk>,
16}
17
18impl OrderByOp {
19 pub fn new(child: PhysicalOperator, order_by: Vec<(BoundExpression, SortOrder)>) -> Self {
20 Self {
21 child: Box::new(child),
22 order_by,
23 result: None,
24 }
25 }
26
27 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
28 if self.result.is_some() {
29 return Ok(None);
30 }
31
32 let mut all_chunks: Vec<DataChunk> = Vec::new();
34 while let Some(chunk) = self.child.next(ctx)? {
35 all_chunks.push(chunk);
36 }
37
38 let total_rows: usize = all_chunks.iter().map(|c| c.num_rows()).sum();
39 if total_rows == 0 {
40 self.result = Some(DataChunk::empty(0));
41 return Ok(None);
42 }
43
44 let num_cols = all_chunks[0].num_columns();
45
46 let num_keys = self.order_by.len();
48 let mut row_locs: Vec<(usize, usize)> = Vec::with_capacity(total_rows);
49 let mut sort_keys: Vec<Vec<TypedValue>> = Vec::with_capacity(total_rows);
50 for (ci, chunk) in all_chunks.iter().enumerate() {
51 for row_idx in 0..chunk.num_rows() {
52 row_locs.push((ci, row_idx));
53 let row_ref = chunk.row_ref(row_idx);
54 let keys: Vec<TypedValue> = self
55 .order_by
56 .iter()
57 .map(|(expr, _)| evaluate(expr, &row_ref))
58 .collect::<KyuResult<_>>()?;
59 sort_keys.push(keys);
60 }
61 }
62
63 let mut indices: Vec<usize> = (0..total_rows).collect();
65 let order_specs: Vec<SortOrder> =
66 self.order_by.iter().map(|(_, order)| *order).collect();
67 indices.sort_by(|&a, &b| {
68 #[allow(clippy::needless_range_loop)]
69 for i in 0..num_keys {
70 let cmp = compare_values(&sort_keys[a][i], &sort_keys[b][i]);
71 let cmp = match order_specs.get(i) {
72 Some(SortOrder::Descending) => cmp.reverse(),
73 _ => cmp,
74 };
75 if cmp != std::cmp::Ordering::Equal {
76 return cmp;
77 }
78 }
79 std::cmp::Ordering::Equal
80 });
81
82 let mut result_chunk = DataChunk::with_capacity(num_cols, total_rows);
84 for &idx in &indices {
85 let (ci, ri) = row_locs[idx];
86 result_chunk.append_row_from_chunk(&all_chunks[ci], ri);
87 }
88
89 self.result = Some(DataChunk::empty(0));
90 Ok(Some(result_chunk))
91 }
92}
93
94fn compare_values(a: &TypedValue, b: &TypedValue) -> std::cmp::Ordering {
95 match (a, b) {
96 (TypedValue::Null, TypedValue::Null) => std::cmp::Ordering::Equal,
97 (TypedValue::Null, _) => std::cmp::Ordering::Greater, (_, TypedValue::Null) => std::cmp::Ordering::Less,
99 (TypedValue::Int64(a), TypedValue::Int64(b)) => a.cmp(b),
100 (TypedValue::Int32(a), TypedValue::Int32(b)) => a.cmp(b),
101 (TypedValue::Double(a), TypedValue::Double(b)) => a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal),
102 (TypedValue::Float(a), TypedValue::Float(b)) => a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal),
103 (TypedValue::String(a), TypedValue::String(b)) => a.cmp(b),
104 (TypedValue::Bool(a), TypedValue::Bool(b)) => a.cmp(b),
105 _ => std::cmp::Ordering::Equal,
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::context::MockStorage;
113 use kyu_types::LogicalType;
114
115 #[test]
116 fn sort_ascending() {
117 let mut storage = MockStorage::new();
118 storage.insert_table(
119 kyu_common::id::TableId(0),
120 vec![
121 vec![TypedValue::Int64(30)],
122 vec![TypedValue::Int64(10)],
123 vec![TypedValue::Int64(20)],
124 ],
125 );
126 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
127
128 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
129 kyu_common::id::TableId(0),
130 ));
131 let mut sort = OrderByOp::new(
132 scan,
133 vec![(
134 BoundExpression::Variable {
135 index: 0,
136 result_type: LogicalType::Int64,
137 },
138 SortOrder::Ascending,
139 )],
140 );
141 let chunk = sort.next(&ctx).unwrap().unwrap();
142 assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(10));
143 assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
144 assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(30));
145 }
146
147 #[test]
148 fn sort_descending() {
149 let mut storage = MockStorage::new();
150 storage.insert_table(
151 kyu_common::id::TableId(0),
152 vec![
153 vec![TypedValue::Int64(10)],
154 vec![TypedValue::Int64(30)],
155 vec![TypedValue::Int64(20)],
156 ],
157 );
158 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
159
160 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
161 kyu_common::id::TableId(0),
162 ));
163 let mut sort = OrderByOp::new(
164 scan,
165 vec![(
166 BoundExpression::Variable {
167 index: 0,
168 result_type: LogicalType::Int64,
169 },
170 SortOrder::Descending,
171 )],
172 );
173 let chunk = sort.next(&ctx).unwrap().unwrap();
174 assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(30));
175 assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
176 assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(10));
177 }
178}