1pub mod aggregate;
4pub mod filter;
5pub mod join;
6pub mod scan;
7pub mod sort;
8pub mod spatial_funcs;
9
10pub use spatial_funcs::evaluate_spatial_function;
11
12pub mod window;
13
14pub use window::{OrderKey, WindowFunction, WindowSpec, evaluate_window, evaluate_window_batch};
15
16use crate::error::{QueryError, Result};
17use crate::parser::ast::*;
18use aggregate::{Aggregate, AggregateFunc, AggregateFunction};
19use filter::Filter;
20use join::Join;
21use scan::{DataSource, RecordBatch, TableScan};
22use sort::Sort;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26pub struct Executor {
28 data_sources: HashMap<String, Arc<dyn DataSource>>,
30}
31
32impl Executor {
33 pub fn new() -> Self {
35 Self {
36 data_sources: HashMap::new(),
37 }
38 }
39
40 pub fn register_data_source(&mut self, name: String, source: Arc<dyn DataSource>) {
42 self.data_sources.insert(name, source);
43 }
44
45 pub async fn execute(&self, stmt: &Statement) -> Result<Vec<RecordBatch>> {
47 match stmt {
48 Statement::Select(select) => self.execute_select(select).await,
49 }
50 }
51
52 async fn execute_select(&self, select: &SelectStatement) -> Result<Vec<RecordBatch>> {
54 let mut batches = if let Some(ref table_ref) = select.from {
56 self.execute_table_reference(table_ref).await?
57 } else {
58 return Err(QueryError::semantic("SELECT without FROM not supported"));
59 };
60
61 if let Some(ref selection) = select.selection {
63 batches = self.execute_filter(batches, selection)?;
64 }
65
66 if !select.group_by.is_empty() || self.has_aggregates(&select.projection) {
68 batches = self.execute_aggregate(batches, select)?;
69 }
70
71 if !select.order_by.is_empty() {
73 batches = self.execute_sort(batches, &select.order_by)?;
74 }
75
76 if select.limit.is_some() || select.offset.is_some() {
78 batches = self.execute_limit_offset(batches, select.limit, select.offset)?;
79 }
80
81 Ok(batches)
82 }
83
84 async fn execute_table_reference(
86 &self,
87 table_ref: &TableReference,
88 ) -> Result<Vec<RecordBatch>> {
89 match table_ref {
90 TableReference::Table { name, .. } => {
91 let source = self
92 .data_sources
93 .get(name)
94 .ok_or_else(|| QueryError::TableNotFound(name.clone()))?;
95
96 let scan = TableScan::new(name.clone(), source.clone());
97 scan.execute().await
98 }
99 TableReference::Join {
100 left,
101 right,
102 join_type,
103 on,
104 } => {
105 let left_batches = Box::pin(self.execute_table_reference(left)).await?;
107 let right_batches = Box::pin(self.execute_table_reference(right)).await?;
108
109 let join = Join::new(*join_type, on.clone());
110 let mut result = Vec::new();
111
112 for left_batch in &left_batches {
113 for right_batch in &right_batches {
114 result.push(join.execute(left_batch, right_batch)?);
115 }
116 }
117
118 Ok(result)
119 }
120 TableReference::Subquery { query, .. } => Box::pin(self.execute_select(query)).await,
121 }
122 }
123
124 fn execute_filter(
126 &self,
127 batches: Vec<RecordBatch>,
128 predicate: &Expr,
129 ) -> Result<Vec<RecordBatch>> {
130 let filter = Filter::new(predicate.clone());
131 let mut result = Vec::new();
132
133 for batch in batches {
134 result.push(filter.execute(&batch)?);
135 }
136
137 Ok(result)
138 }
139
140 fn execute_aggregate(
142 &self,
143 batches: Vec<RecordBatch>,
144 select: &SelectStatement,
145 ) -> Result<Vec<RecordBatch>> {
146 let mut agg_funcs = Vec::new();
148
149 for item in &select.projection {
150 if let SelectItem::Expr { expr, alias } = item {
151 if let Some(agg_func) = self.extract_aggregate(expr) {
152 let func_alias = alias.clone().or_else(|| Some("agg".to_string()));
153 agg_funcs.push(AggregateFunction {
154 func: agg_func.0,
155 column: agg_func.1,
156 alias: func_alias,
157 });
158 }
159 }
160 }
161
162 let aggregate = Aggregate::new(select.group_by.clone(), agg_funcs);
163 let mut result = Vec::new();
164
165 for batch in batches {
166 result.push(aggregate.execute(&batch)?);
167 }
168
169 Ok(result)
170 }
171
172 fn extract_aggregate(&self, expr: &Expr) -> Option<(AggregateFunc, String)> {
174 if let Expr::Function { name, args } = expr {
175 let func = match name.to_uppercase().as_str() {
176 "COUNT" => Some(AggregateFunc::Count),
177 "SUM" => Some(AggregateFunc::Sum),
178 "AVG" => Some(AggregateFunc::Avg),
179 "MIN" => Some(AggregateFunc::Min),
180 "MAX" => Some(AggregateFunc::Max),
181 _ => None,
182 }?;
183
184 if let Some(arg) = args.first() {
185 match arg {
186 Expr::Column { name, .. } => {
187 return Some((func, name.clone()));
188 }
189 Expr::Wildcard => {
190 return Some((func, "*".to_string()));
192 }
193 _ => {}
194 }
195 } else if matches!(func, AggregateFunc::Count) {
196 return Some((func, "*".to_string()));
198 }
199 }
200 None
201 }
202
203 fn has_aggregates(&self, projection: &[SelectItem]) -> bool {
205 for item in projection {
206 if let SelectItem::Expr { expr, .. } = item {
207 if self.extract_aggregate(expr).is_some() {
208 return true;
209 }
210 }
211 }
212 false
213 }
214
215 fn execute_sort(
217 &self,
218 batches: Vec<RecordBatch>,
219 order_by: &[OrderByExpr],
220 ) -> Result<Vec<RecordBatch>> {
221 let sort = Sort::new(order_by.to_vec());
222 let mut result = Vec::new();
223
224 for batch in batches {
225 result.push(sort.execute(&batch)?);
226 }
227
228 Ok(result)
229 }
230
231 fn execute_limit_offset(
233 &self,
234 batches: Vec<RecordBatch>,
235 limit: Option<usize>,
236 offset: Option<usize>,
237 ) -> Result<Vec<RecordBatch>> {
238 let offset = offset.unwrap_or(0);
239 let mut current_row = 0;
240 let mut result = Vec::new();
241 let mut remaining = limit;
242
243 for batch in batches {
244 if let Some(rem) = remaining {
245 if rem == 0 {
246 break;
247 }
248 }
249
250 let start = if current_row < offset {
251 let skip = (offset - current_row).min(batch.num_rows);
252 current_row += skip;
253 skip
254 } else {
255 0
256 };
257
258 let end = if let Some(rem) = remaining {
259 (start + rem).min(batch.num_rows)
260 } else {
261 batch.num_rows
262 };
263
264 if start < end {
265 let slice_batch = self.slice_batch(&batch, start, end)?;
266 let slice_rows = slice_batch.num_rows;
267 result.push(slice_batch);
268
269 if let Some(rem) = &mut remaining {
270 *rem = rem.saturating_sub(slice_rows);
271 }
272 }
273
274 current_row += batch.num_rows;
275 }
276
277 Ok(result)
278 }
279
280 fn slice_batch(&self, batch: &RecordBatch, start: usize, end: usize) -> Result<RecordBatch> {
282 let mut sliced_columns = Vec::new();
283
284 for column in &batch.columns {
285 sliced_columns.push(self.slice_column(column, start, end));
286 }
287
288 RecordBatch::new(batch.schema.clone(), sliced_columns, end - start)
289 }
290
291 fn slice_column(
293 &self,
294 column: &scan::ColumnData,
295 start: usize,
296 end: usize,
297 ) -> scan::ColumnData {
298 use scan::ColumnData;
299
300 match column {
301 ColumnData::Boolean(data) => ColumnData::Boolean(data[start..end].to_vec()),
302 ColumnData::Int32(data) => ColumnData::Int32(data[start..end].to_vec()),
303 ColumnData::Int64(data) => ColumnData::Int64(data[start..end].to_vec()),
304 ColumnData::Float32(data) => ColumnData::Float32(data[start..end].to_vec()),
305 ColumnData::Float64(data) => ColumnData::Float64(data[start..end].to_vec()),
306 ColumnData::String(data) => ColumnData::String(data[start..end].to_vec()),
307 ColumnData::Binary(data) => ColumnData::Binary(data[start..end].to_vec()),
308 }
309 }
310}
311
312impl Default for Executor {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::executor::scan::{DataType, Field, MemoryDataSource, Schema};
322 use crate::parser::sql::parse_sql;
323
324 #[tokio::test]
325 async fn test_executor_simple_query() -> Result<()> {
326 let schema = Arc::new(Schema::new(vec![
327 Field::new("id".to_string(), DataType::Int64, false),
328 Field::new("value".to_string(), DataType::Int64, false),
329 ]));
330
331 let columns = vec![
332 scan::ColumnData::Int64(vec![Some(1), Some(2), Some(3)]),
333 scan::ColumnData::Int64(vec![Some(10), Some(20), Some(30)]),
334 ];
335
336 let batch = RecordBatch::new(schema.clone(), columns, 3)?;
337 let source = Arc::new(MemoryDataSource::new(schema, vec![batch]));
338
339 let mut executor = Executor::new();
340 executor.register_data_source("test_table".to_string(), source);
341
342 let sql = "SELECT * FROM test_table";
343 let stmt = parse_sql(sql)?;
344
345 let result = executor.execute(&stmt).await?;
346 assert!(!result.is_empty());
347
348 Ok(())
349 }
350}