Skip to main content

kyu_executor/operators/
sort.rs

1//! OrderBy operator — materializes all rows, sorts, then emits.
2
3use kyu_common::KyuResult;
4use kyu_expression::{BoundExpression, evaluate};
5use kyu_parser::ast::SortOrder;
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct OrderByOp {
13    pub child: Box<PhysicalOperator>,
14    pub order_by: Vec<(BoundExpression, SortOrder)>,
15    result: Option<DataChunk>,
16}
17
18impl OrderByOp {
19    pub fn new(child: PhysicalOperator, order_by: Vec<(BoundExpression, SortOrder)>) -> Self {
20        Self {
21            child: Box::new(child),
22            order_by,
23            result: None,
24        }
25    }
26
27    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
28        if self.result.is_some() {
29            return Ok(None);
30        }
31
32        // Collect all child chunks without merging (avoids ensure_owned).
33        let mut all_chunks: Vec<DataChunk> = Vec::new();
34        while let Some(chunk) = self.child.next(ctx)? {
35            all_chunks.push(chunk);
36        }
37
38        let total_rows: usize = all_chunks.iter().map(|c| c.num_rows()).sum();
39        if total_rows == 0 {
40            self.result = Some(DataChunk::empty(0));
41            return Ok(None);
42        }
43
44        let num_cols = all_chunks[0].num_columns();
45
46        // Build a (chunk_idx, local_row) index + evaluate sort keys.
47        let num_keys = self.order_by.len();
48        let mut row_locs: Vec<(usize, usize)> = Vec::with_capacity(total_rows);
49        let mut sort_keys: Vec<Vec<TypedValue>> = Vec::with_capacity(total_rows);
50        for (ci, chunk) in all_chunks.iter().enumerate() {
51            for row_idx in 0..chunk.num_rows() {
52                row_locs.push((ci, row_idx));
53                let row_ref = chunk.row_ref(row_idx);
54                let keys: Vec<TypedValue> = self
55                    .order_by
56                    .iter()
57                    .map(|(expr, _)| evaluate(expr, &row_ref))
58                    .collect::<KyuResult<_>>()?;
59                sort_keys.push(keys);
60            }
61        }
62
63        // Sort a permutation array — swaps move 8-byte indices, not full rows.
64        let mut indices: Vec<usize> = (0..total_rows).collect();
65        let order_specs: Vec<SortOrder> = self.order_by.iter().map(|(_, order)| *order).collect();
66        indices.sort_by(|&a, &b| {
67            #[allow(clippy::needless_range_loop)]
68            for i in 0..num_keys {
69                let cmp = compare_values(&sort_keys[a][i], &sort_keys[b][i]);
70                let cmp = match order_specs.get(i) {
71                    Some(SortOrder::Descending) => cmp.reverse(),
72                    _ => cmp,
73                };
74                if cmp != std::cmp::Ordering::Equal {
75                    return cmp;
76                }
77            }
78            std::cmp::Ordering::Equal
79        });
80
81        // Build output by gathering from source chunks via (chunk_idx, local_row).
82        let mut result_chunk = DataChunk::with_capacity(num_cols, total_rows);
83        for &idx in &indices {
84            let (ci, ri) = row_locs[idx];
85            result_chunk.append_row_from_chunk(&all_chunks[ci], ri);
86        }
87
88        self.result = Some(DataChunk::empty(0));
89        Ok(Some(result_chunk))
90    }
91}
92
93fn compare_values(a: &TypedValue, b: &TypedValue) -> std::cmp::Ordering {
94    match (a, b) {
95        (TypedValue::Null, TypedValue::Null) => std::cmp::Ordering::Equal,
96        (TypedValue::Null, _) => std::cmp::Ordering::Greater, // NULLs sort last
97        (_, TypedValue::Null) => std::cmp::Ordering::Less,
98        (TypedValue::Int64(a), TypedValue::Int64(b)) => a.cmp(b),
99        (TypedValue::Int32(a), TypedValue::Int32(b)) => a.cmp(b),
100        (TypedValue::Double(a), TypedValue::Double(b)) => {
101            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
102        }
103        (TypedValue::Float(a), TypedValue::Float(b)) => {
104            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
105        }
106        (TypedValue::String(a), TypedValue::String(b)) => a.cmp(b),
107        (TypedValue::Bool(a), TypedValue::Bool(b)) => a.cmp(b),
108        _ => std::cmp::Ordering::Equal,
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use crate::context::MockStorage;
116    use kyu_types::LogicalType;
117
118    #[test]
119    fn sort_ascending() {
120        let mut storage = MockStorage::new();
121        storage.insert_table(
122            kyu_common::id::TableId(0),
123            vec![
124                vec![TypedValue::Int64(30)],
125                vec![TypedValue::Int64(10)],
126                vec![TypedValue::Int64(20)],
127            ],
128        );
129        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
130
131        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
132            kyu_common::id::TableId(0),
133        ));
134        let mut sort = OrderByOp::new(
135            scan,
136            vec![(
137                BoundExpression::Variable {
138                    index: 0,
139                    result_type: LogicalType::Int64,
140                },
141                SortOrder::Ascending,
142            )],
143        );
144        let chunk = sort.next(&ctx).unwrap().unwrap();
145        assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(10));
146        assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
147        assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(30));
148    }
149
150    #[test]
151    fn sort_descending() {
152        let mut storage = MockStorage::new();
153        storage.insert_table(
154            kyu_common::id::TableId(0),
155            vec![
156                vec![TypedValue::Int64(10)],
157                vec![TypedValue::Int64(30)],
158                vec![TypedValue::Int64(20)],
159            ],
160        );
161        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
162
163        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
164            kyu_common::id::TableId(0),
165        ));
166        let mut sort = OrderByOp::new(
167            scan,
168            vec![(
169                BoundExpression::Variable {
170                    index: 0,
171                    result_type: LogicalType::Int64,
172                },
173                SortOrder::Descending,
174            )],
175        );
176        let chunk = sort.next(&ctx).unwrap().unwrap();
177        assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(30));
178        assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
179        assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(10));
180    }
181}