Skip to main content

oxigdal_query/executor/
mod.rs

1//! Query execution engine.
2
3pub mod aggregate;
4pub mod filter;
5pub mod join;
6pub mod scan;
7pub mod sort;
8
9use crate::error::{QueryError, Result};
10use crate::parser::ast::*;
11use aggregate::{Aggregate, AggregateFunc, AggregateFunction};
12use filter::Filter;
13use join::Join;
14use scan::{DataSource, RecordBatch, TableScan};
15use sort::Sort;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// Query executor.
20pub struct Executor {
21    /// Data sources registry.
22    data_sources: HashMap<String, Arc<dyn DataSource>>,
23}
24
25impl Executor {
26    /// Create a new executor.
27    pub fn new() -> Self {
28        Self {
29            data_sources: HashMap::new(),
30        }
31    }
32
33    /// Register a data source.
34    pub fn register_data_source(&mut self, name: String, source: Arc<dyn DataSource>) {
35        self.data_sources.insert(name, source);
36    }
37
38    /// Execute a query.
39    pub async fn execute(&self, stmt: &Statement) -> Result<Vec<RecordBatch>> {
40        match stmt {
41            Statement::Select(select) => self.execute_select(select).await,
42        }
43    }
44
45    /// Execute a SELECT statement.
46    async fn execute_select(&self, select: &SelectStatement) -> Result<Vec<RecordBatch>> {
47        // Execute FROM clause
48        let mut batches = if let Some(ref table_ref) = select.from {
49            self.execute_table_reference(table_ref).await?
50        } else {
51            return Err(QueryError::semantic("SELECT without FROM not supported"));
52        };
53
54        // Execute WHERE clause
55        if let Some(ref selection) = select.selection {
56            batches = self.execute_filter(batches, selection)?;
57        }
58
59        // Execute GROUP BY / aggregation
60        if !select.group_by.is_empty() || self.has_aggregates(&select.projection) {
61            batches = self.execute_aggregate(batches, select)?;
62        }
63
64        // Execute ORDER BY
65        if !select.order_by.is_empty() {
66            batches = self.execute_sort(batches, &select.order_by)?;
67        }
68
69        // Execute LIMIT and OFFSET
70        if select.limit.is_some() || select.offset.is_some() {
71            batches = self.execute_limit_offset(batches, select.limit, select.offset)?;
72        }
73
74        Ok(batches)
75    }
76
77    /// Execute a table reference.
78    async fn execute_table_reference(
79        &self,
80        table_ref: &TableReference,
81    ) -> Result<Vec<RecordBatch>> {
82        match table_ref {
83            TableReference::Table { name, .. } => {
84                let source = self
85                    .data_sources
86                    .get(name)
87                    .ok_or_else(|| QueryError::TableNotFound(name.clone()))?;
88
89                let scan = TableScan::new(name.clone(), source.clone());
90                scan.execute().await
91            }
92            TableReference::Join {
93                left,
94                right,
95                join_type,
96                on,
97            } => {
98                // Use Box::pin to avoid infinite size for recursive async fn
99                let left_batches = Box::pin(self.execute_table_reference(left)).await?;
100                let right_batches = Box::pin(self.execute_table_reference(right)).await?;
101
102                let join = Join::new(*join_type, on.clone());
103                let mut result = Vec::new();
104
105                for left_batch in &left_batches {
106                    for right_batch in &right_batches {
107                        result.push(join.execute(left_batch, right_batch)?);
108                    }
109                }
110
111                Ok(result)
112            }
113            TableReference::Subquery { query, .. } => Box::pin(self.execute_select(query)).await,
114        }
115    }
116
117    /// Execute filter operation.
118    fn execute_filter(
119        &self,
120        batches: Vec<RecordBatch>,
121        predicate: &Expr,
122    ) -> Result<Vec<RecordBatch>> {
123        let filter = Filter::new(predicate.clone());
124        let mut result = Vec::new();
125
126        for batch in batches {
127            result.push(filter.execute(&batch)?);
128        }
129
130        Ok(result)
131    }
132
133    /// Execute aggregation.
134    fn execute_aggregate(
135        &self,
136        batches: Vec<RecordBatch>,
137        select: &SelectStatement,
138    ) -> Result<Vec<RecordBatch>> {
139        // Extract aggregate functions from projection
140        let mut agg_funcs = Vec::new();
141
142        for item in &select.projection {
143            if let SelectItem::Expr { expr, alias } = item {
144                if let Some(agg_func) = self.extract_aggregate(expr) {
145                    let func_alias = alias.clone().or_else(|| Some("agg".to_string()));
146                    agg_funcs.push(AggregateFunction {
147                        func: agg_func.0,
148                        column: agg_func.1,
149                        alias: func_alias,
150                    });
151                }
152            }
153        }
154
155        let aggregate = Aggregate::new(select.group_by.clone(), agg_funcs);
156        let mut result = Vec::new();
157
158        for batch in batches {
159            result.push(aggregate.execute(&batch)?);
160        }
161
162        Ok(result)
163    }
164
165    /// Extract aggregate function from expression.
166    fn extract_aggregate(&self, expr: &Expr) -> Option<(AggregateFunc, String)> {
167        if let Expr::Function { name, args } = expr {
168            let func = match name.to_uppercase().as_str() {
169                "COUNT" => Some(AggregateFunc::Count),
170                "SUM" => Some(AggregateFunc::Sum),
171                "AVG" => Some(AggregateFunc::Avg),
172                "MIN" => Some(AggregateFunc::Min),
173                "MAX" => Some(AggregateFunc::Max),
174                _ => None,
175            }?;
176
177            if let Some(arg) = args.first() {
178                match arg {
179                    Expr::Column { name, .. } => {
180                        return Some((func, name.clone()));
181                    }
182                    Expr::Wildcard => {
183                        // COUNT(*) uses any column
184                        return Some((func, "*".to_string()));
185                    }
186                    _ => {}
187                }
188            } else if matches!(func, AggregateFunc::Count) {
189                // COUNT(*) with no args
190                return Some((func, "*".to_string()));
191            }
192        }
193        None
194    }
195
196    /// Check if projection has aggregates.
197    fn has_aggregates(&self, projection: &[SelectItem]) -> bool {
198        for item in projection {
199            if let SelectItem::Expr { expr, .. } = item {
200                if self.extract_aggregate(expr).is_some() {
201                    return true;
202                }
203            }
204        }
205        false
206    }
207
208    /// Execute sort operation.
209    fn execute_sort(
210        &self,
211        batches: Vec<RecordBatch>,
212        order_by: &[OrderByExpr],
213    ) -> Result<Vec<RecordBatch>> {
214        let sort = Sort::new(order_by.to_vec());
215        let mut result = Vec::new();
216
217        for batch in batches {
218            result.push(sort.execute(&batch)?);
219        }
220
221        Ok(result)
222    }
223
224    /// Execute LIMIT and OFFSET.
225    fn execute_limit_offset(
226        &self,
227        batches: Vec<RecordBatch>,
228        limit: Option<usize>,
229        offset: Option<usize>,
230    ) -> Result<Vec<RecordBatch>> {
231        let offset = offset.unwrap_or(0);
232        let mut current_row = 0;
233        let mut result = Vec::new();
234        let mut remaining = limit;
235
236        for batch in batches {
237            if let Some(rem) = remaining {
238                if rem == 0 {
239                    break;
240                }
241            }
242
243            let start = if current_row < offset {
244                let skip = (offset - current_row).min(batch.num_rows);
245                current_row += skip;
246                skip
247            } else {
248                0
249            };
250
251            let end = if let Some(rem) = remaining {
252                (start + rem).min(batch.num_rows)
253            } else {
254                batch.num_rows
255            };
256
257            if start < end {
258                let slice_batch = self.slice_batch(&batch, start, end)?;
259                let slice_rows = slice_batch.num_rows;
260                result.push(slice_batch);
261
262                if let Some(rem) = &mut remaining {
263                    *rem = rem.saturating_sub(slice_rows);
264                }
265            }
266
267            current_row += batch.num_rows;
268        }
269
270        Ok(result)
271    }
272
273    /// Slice a record batch.
274    fn slice_batch(&self, batch: &RecordBatch, start: usize, end: usize) -> Result<RecordBatch> {
275        let mut sliced_columns = Vec::new();
276
277        for column in &batch.columns {
278            sliced_columns.push(self.slice_column(column, start, end));
279        }
280
281        RecordBatch::new(batch.schema.clone(), sliced_columns, end - start)
282    }
283
284    /// Slice a column.
285    fn slice_column(
286        &self,
287        column: &scan::ColumnData,
288        start: usize,
289        end: usize,
290    ) -> scan::ColumnData {
291        use scan::ColumnData;
292
293        match column {
294            ColumnData::Boolean(data) => ColumnData::Boolean(data[start..end].to_vec()),
295            ColumnData::Int32(data) => ColumnData::Int32(data[start..end].to_vec()),
296            ColumnData::Int64(data) => ColumnData::Int64(data[start..end].to_vec()),
297            ColumnData::Float32(data) => ColumnData::Float32(data[start..end].to_vec()),
298            ColumnData::Float64(data) => ColumnData::Float64(data[start..end].to_vec()),
299            ColumnData::String(data) => ColumnData::String(data[start..end].to_vec()),
300            ColumnData::Binary(data) => ColumnData::Binary(data[start..end].to_vec()),
301        }
302    }
303}
304
305impl Default for Executor {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::executor::scan::{DataType, Field, MemoryDataSource, Schema};
315    use crate::parser::sql::parse_sql;
316
317    #[tokio::test]
318    async fn test_executor_simple_query() -> Result<()> {
319        let schema = Arc::new(Schema::new(vec![
320            Field::new("id".to_string(), DataType::Int64, false),
321            Field::new("value".to_string(), DataType::Int64, false),
322        ]));
323
324        let columns = vec![
325            scan::ColumnData::Int64(vec![Some(1), Some(2), Some(3)]),
326            scan::ColumnData::Int64(vec![Some(10), Some(20), Some(30)]),
327        ];
328
329        let batch = RecordBatch::new(schema.clone(), columns, 3)?;
330        let source = Arc::new(MemoryDataSource::new(schema, vec![batch]));
331
332        let mut executor = Executor::new();
333        executor.register_data_source("test_table".to_string(), source);
334
335        let sql = "SELECT * FROM test_table";
336        let stmt = parse_sql(sql)?;
337
338        let result = executor.execute(&stmt).await?;
339        assert!(!result.is_empty());
340
341        Ok(())
342    }
343}