Skip to main content

kyu_executor/operators/
sort.rs

1//! OrderBy operator — materializes all rows, sorts, then emits.
2
3use kyu_common::KyuResult;
4use kyu_expression::{evaluate, BoundExpression};
5use kyu_parser::ast::SortOrder;
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct OrderByOp {
13    pub child: Box<PhysicalOperator>,
14    pub order_by: Vec<(BoundExpression, SortOrder)>,
15    result: Option<DataChunk>,
16}
17
18impl OrderByOp {
19    pub fn new(child: PhysicalOperator, order_by: Vec<(BoundExpression, SortOrder)>) -> Self {
20        Self {
21            child: Box::new(child),
22            order_by,
23            result: None,
24        }
25    }
26
27    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
28        if self.result.is_some() {
29            return Ok(None);
30        }
31
32        // Collect all child chunks without merging (avoids ensure_owned).
33        let mut all_chunks: Vec<DataChunk> = Vec::new();
34        while let Some(chunk) = self.child.next(ctx)? {
35            all_chunks.push(chunk);
36        }
37
38        let total_rows: usize = all_chunks.iter().map(|c| c.num_rows()).sum();
39        if total_rows == 0 {
40            self.result = Some(DataChunk::empty(0));
41            return Ok(None);
42        }
43
44        let num_cols = all_chunks[0].num_columns();
45
46        // Build a (chunk_idx, local_row) index + evaluate sort keys.
47        let num_keys = self.order_by.len();
48        let mut row_locs: Vec<(usize, usize)> = Vec::with_capacity(total_rows);
49        let mut sort_keys: Vec<Vec<TypedValue>> = Vec::with_capacity(total_rows);
50        for (ci, chunk) in all_chunks.iter().enumerate() {
51            for row_idx in 0..chunk.num_rows() {
52                row_locs.push((ci, row_idx));
53                let row_ref = chunk.row_ref(row_idx);
54                let keys: Vec<TypedValue> = self
55                    .order_by
56                    .iter()
57                    .map(|(expr, _)| evaluate(expr, &row_ref))
58                    .collect::<KyuResult<_>>()?;
59                sort_keys.push(keys);
60            }
61        }
62
63        // Sort a permutation array — swaps move 8-byte indices, not full rows.
64        let mut indices: Vec<usize> = (0..total_rows).collect();
65        let order_specs: Vec<SortOrder> =
66            self.order_by.iter().map(|(_, order)| *order).collect();
67        indices.sort_by(|&a, &b| {
68            #[allow(clippy::needless_range_loop)]
69            for i in 0..num_keys {
70                let cmp = compare_values(&sort_keys[a][i], &sort_keys[b][i]);
71                let cmp = match order_specs.get(i) {
72                    Some(SortOrder::Descending) => cmp.reverse(),
73                    _ => cmp,
74                };
75                if cmp != std::cmp::Ordering::Equal {
76                    return cmp;
77                }
78            }
79            std::cmp::Ordering::Equal
80        });
81
82        // Build output by gathering from source chunks via (chunk_idx, local_row).
83        let mut result_chunk = DataChunk::with_capacity(num_cols, total_rows);
84        for &idx in &indices {
85            let (ci, ri) = row_locs[idx];
86            result_chunk.append_row_from_chunk(&all_chunks[ci], ri);
87        }
88
89        self.result = Some(DataChunk::empty(0));
90        Ok(Some(result_chunk))
91    }
92}
93
94fn compare_values(a: &TypedValue, b: &TypedValue) -> std::cmp::Ordering {
95    match (a, b) {
96        (TypedValue::Null, TypedValue::Null) => std::cmp::Ordering::Equal,
97        (TypedValue::Null, _) => std::cmp::Ordering::Greater, // NULLs sort last
98        (_, TypedValue::Null) => std::cmp::Ordering::Less,
99        (TypedValue::Int64(a), TypedValue::Int64(b)) => a.cmp(b),
100        (TypedValue::Int32(a), TypedValue::Int32(b)) => a.cmp(b),
101        (TypedValue::Double(a), TypedValue::Double(b)) => a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal),
102        (TypedValue::Float(a), TypedValue::Float(b)) => a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal),
103        (TypedValue::String(a), TypedValue::String(b)) => a.cmp(b),
104        (TypedValue::Bool(a), TypedValue::Bool(b)) => a.cmp(b),
105        _ => std::cmp::Ordering::Equal,
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::context::MockStorage;
113    use kyu_types::LogicalType;
114
115    #[test]
116    fn sort_ascending() {
117        let mut storage = MockStorage::new();
118        storage.insert_table(
119            kyu_common::id::TableId(0),
120            vec![
121                vec![TypedValue::Int64(30)],
122                vec![TypedValue::Int64(10)],
123                vec![TypedValue::Int64(20)],
124            ],
125        );
126        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
127
128        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
129            kyu_common::id::TableId(0),
130        ));
131        let mut sort = OrderByOp::new(
132            scan,
133            vec![(
134                BoundExpression::Variable {
135                    index: 0,
136                    result_type: LogicalType::Int64,
137                },
138                SortOrder::Ascending,
139            )],
140        );
141        let chunk = sort.next(&ctx).unwrap().unwrap();
142        assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(10));
143        assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
144        assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(30));
145    }
146
147    #[test]
148    fn sort_descending() {
149        let mut storage = MockStorage::new();
150        storage.insert_table(
151            kyu_common::id::TableId(0),
152            vec![
153                vec![TypedValue::Int64(10)],
154                vec![TypedValue::Int64(30)],
155                vec![TypedValue::Int64(20)],
156            ],
157        );
158        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
159
160        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
161            kyu_common::id::TableId(0),
162        ));
163        let mut sort = OrderByOp::new(
164            scan,
165            vec![(
166                BoundExpression::Variable {
167                    index: 0,
168                    result_type: LogicalType::Int64,
169                },
170                SortOrder::Descending,
171            )],
172        );
173        let chunk = sort.next(&ctx).unwrap().unwrap();
174        assert_eq!(chunk.get_value(0, 0), TypedValue::Int64(30));
175        assert_eq!(chunk.get_value(1, 0), TypedValue::Int64(20));
176        assert_eq!(chunk.get_value(2, 0), TypedValue::Int64(10));
177    }
178}