1pub mod aggregate;
4pub mod filter;
5pub mod join;
6pub mod scan;
7pub mod sort;
8
9use crate::error::{QueryError, Result};
10use crate::parser::ast::*;
11use aggregate::{Aggregate, AggregateFunc, AggregateFunction};
12use filter::Filter;
13use join::Join;
14use scan::{DataSource, RecordBatch, TableScan};
15use sort::Sort;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19pub struct Executor {
21 data_sources: HashMap<String, Arc<dyn DataSource>>,
23}
24
25impl Executor {
26 pub fn new() -> Self {
28 Self {
29 data_sources: HashMap::new(),
30 }
31 }
32
33 pub fn register_data_source(&mut self, name: String, source: Arc<dyn DataSource>) {
35 self.data_sources.insert(name, source);
36 }
37
38 pub async fn execute(&self, stmt: &Statement) -> Result<Vec<RecordBatch>> {
40 match stmt {
41 Statement::Select(select) => self.execute_select(select).await,
42 }
43 }
44
45 async fn execute_select(&self, select: &SelectStatement) -> Result<Vec<RecordBatch>> {
47 let mut batches = if let Some(ref table_ref) = select.from {
49 self.execute_table_reference(table_ref).await?
50 } else {
51 return Err(QueryError::semantic("SELECT without FROM not supported"));
52 };
53
54 if let Some(ref selection) = select.selection {
56 batches = self.execute_filter(batches, selection)?;
57 }
58
59 if !select.group_by.is_empty() || self.has_aggregates(&select.projection) {
61 batches = self.execute_aggregate(batches, select)?;
62 }
63
64 if !select.order_by.is_empty() {
66 batches = self.execute_sort(batches, &select.order_by)?;
67 }
68
69 if select.limit.is_some() || select.offset.is_some() {
71 batches = self.execute_limit_offset(batches, select.limit, select.offset)?;
72 }
73
74 Ok(batches)
75 }
76
77 async fn execute_table_reference(
79 &self,
80 table_ref: &TableReference,
81 ) -> Result<Vec<RecordBatch>> {
82 match table_ref {
83 TableReference::Table { name, .. } => {
84 let source = self
85 .data_sources
86 .get(name)
87 .ok_or_else(|| QueryError::TableNotFound(name.clone()))?;
88
89 let scan = TableScan::new(name.clone(), source.clone());
90 scan.execute().await
91 }
92 TableReference::Join {
93 left,
94 right,
95 join_type,
96 on,
97 } => {
98 let left_batches = Box::pin(self.execute_table_reference(left)).await?;
100 let right_batches = Box::pin(self.execute_table_reference(right)).await?;
101
102 let join = Join::new(*join_type, on.clone());
103 let mut result = Vec::new();
104
105 for left_batch in &left_batches {
106 for right_batch in &right_batches {
107 result.push(join.execute(left_batch, right_batch)?);
108 }
109 }
110
111 Ok(result)
112 }
113 TableReference::Subquery { query, .. } => Box::pin(self.execute_select(query)).await,
114 }
115 }
116
117 fn execute_filter(
119 &self,
120 batches: Vec<RecordBatch>,
121 predicate: &Expr,
122 ) -> Result<Vec<RecordBatch>> {
123 let filter = Filter::new(predicate.clone());
124 let mut result = Vec::new();
125
126 for batch in batches {
127 result.push(filter.execute(&batch)?);
128 }
129
130 Ok(result)
131 }
132
133 fn execute_aggregate(
135 &self,
136 batches: Vec<RecordBatch>,
137 select: &SelectStatement,
138 ) -> Result<Vec<RecordBatch>> {
139 let mut agg_funcs = Vec::new();
141
142 for item in &select.projection {
143 if let SelectItem::Expr { expr, alias } = item {
144 if let Some(agg_func) = self.extract_aggregate(expr) {
145 let func_alias = alias.clone().or_else(|| Some("agg".to_string()));
146 agg_funcs.push(AggregateFunction {
147 func: agg_func.0,
148 column: agg_func.1,
149 alias: func_alias,
150 });
151 }
152 }
153 }
154
155 let aggregate = Aggregate::new(select.group_by.clone(), agg_funcs);
156 let mut result = Vec::new();
157
158 for batch in batches {
159 result.push(aggregate.execute(&batch)?);
160 }
161
162 Ok(result)
163 }
164
165 fn extract_aggregate(&self, expr: &Expr) -> Option<(AggregateFunc, String)> {
167 if let Expr::Function { name, args } = expr {
168 let func = match name.to_uppercase().as_str() {
169 "COUNT" => Some(AggregateFunc::Count),
170 "SUM" => Some(AggregateFunc::Sum),
171 "AVG" => Some(AggregateFunc::Avg),
172 "MIN" => Some(AggregateFunc::Min),
173 "MAX" => Some(AggregateFunc::Max),
174 _ => None,
175 }?;
176
177 if let Some(arg) = args.first() {
178 match arg {
179 Expr::Column { name, .. } => {
180 return Some((func, name.clone()));
181 }
182 Expr::Wildcard => {
183 return Some((func, "*".to_string()));
185 }
186 _ => {}
187 }
188 } else if matches!(func, AggregateFunc::Count) {
189 return Some((func, "*".to_string()));
191 }
192 }
193 None
194 }
195
196 fn has_aggregates(&self, projection: &[SelectItem]) -> bool {
198 for item in projection {
199 if let SelectItem::Expr { expr, .. } = item {
200 if self.extract_aggregate(expr).is_some() {
201 return true;
202 }
203 }
204 }
205 false
206 }
207
208 fn execute_sort(
210 &self,
211 batches: Vec<RecordBatch>,
212 order_by: &[OrderByExpr],
213 ) -> Result<Vec<RecordBatch>> {
214 let sort = Sort::new(order_by.to_vec());
215 let mut result = Vec::new();
216
217 for batch in batches {
218 result.push(sort.execute(&batch)?);
219 }
220
221 Ok(result)
222 }
223
224 fn execute_limit_offset(
226 &self,
227 batches: Vec<RecordBatch>,
228 limit: Option<usize>,
229 offset: Option<usize>,
230 ) -> Result<Vec<RecordBatch>> {
231 let offset = offset.unwrap_or(0);
232 let mut current_row = 0;
233 let mut result = Vec::new();
234 let mut remaining = limit;
235
236 for batch in batches {
237 if let Some(rem) = remaining {
238 if rem == 0 {
239 break;
240 }
241 }
242
243 let start = if current_row < offset {
244 let skip = (offset - current_row).min(batch.num_rows);
245 current_row += skip;
246 skip
247 } else {
248 0
249 };
250
251 let end = if let Some(rem) = remaining {
252 (start + rem).min(batch.num_rows)
253 } else {
254 batch.num_rows
255 };
256
257 if start < end {
258 let slice_batch = self.slice_batch(&batch, start, end)?;
259 let slice_rows = slice_batch.num_rows;
260 result.push(slice_batch);
261
262 if let Some(rem) = &mut remaining {
263 *rem = rem.saturating_sub(slice_rows);
264 }
265 }
266
267 current_row += batch.num_rows;
268 }
269
270 Ok(result)
271 }
272
273 fn slice_batch(&self, batch: &RecordBatch, start: usize, end: usize) -> Result<RecordBatch> {
275 let mut sliced_columns = Vec::new();
276
277 for column in &batch.columns {
278 sliced_columns.push(self.slice_column(column, start, end));
279 }
280
281 RecordBatch::new(batch.schema.clone(), sliced_columns, end - start)
282 }
283
284 fn slice_column(
286 &self,
287 column: &scan::ColumnData,
288 start: usize,
289 end: usize,
290 ) -> scan::ColumnData {
291 use scan::ColumnData;
292
293 match column {
294 ColumnData::Boolean(data) => ColumnData::Boolean(data[start..end].to_vec()),
295 ColumnData::Int32(data) => ColumnData::Int32(data[start..end].to_vec()),
296 ColumnData::Int64(data) => ColumnData::Int64(data[start..end].to_vec()),
297 ColumnData::Float32(data) => ColumnData::Float32(data[start..end].to_vec()),
298 ColumnData::Float64(data) => ColumnData::Float64(data[start..end].to_vec()),
299 ColumnData::String(data) => ColumnData::String(data[start..end].to_vec()),
300 ColumnData::Binary(data) => ColumnData::Binary(data[start..end].to_vec()),
301 }
302 }
303}
304
305impl Default for Executor {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::executor::scan::{DataType, Field, MemoryDataSource, Schema};
315 use crate::parser::sql::parse_sql;
316
317 #[tokio::test]
318 async fn test_executor_simple_query() -> Result<()> {
319 let schema = Arc::new(Schema::new(vec![
320 Field::new("id".to_string(), DataType::Int64, false),
321 Field::new("value".to_string(), DataType::Int64, false),
322 ]));
323
324 let columns = vec![
325 scan::ColumnData::Int64(vec![Some(1), Some(2), Some(3)]),
326 scan::ColumnData::Int64(vec![Some(10), Some(20), Some(30)]),
327 ];
328
329 let batch = RecordBatch::new(schema.clone(), columns, 3)?;
330 let source = Arc::new(MemoryDataSource::new(schema, vec![batch]));
331
332 let mut executor = Executor::new();
333 executor.register_data_source("test_table".to_string(), source);
334
335 let sql = "SELECT * FROM test_table";
336 let stmt = parse_sql(sql)?;
337
338 let result = executor.execute(&stmt).await?;
339 assert!(!result.is_empty());
340
341 Ok(())
342 }
343}