Skip to main content

oxigdal_query/executor/
aggregate.rs

1//! Aggregation executor.
2
3use crate::error::{QueryError, Result};
4use crate::executor::scan::{ColumnData, Field, RecordBatch, Schema};
5use crate::parser::ast::Expr;
6use std::sync::Arc;
7
8/// Aggregate operator.
9pub struct Aggregate {
10    /// GROUP BY expressions.
11    pub group_by: Vec<Expr>,
12    /// Aggregate functions.
13    pub aggregates: Vec<AggregateFunction>,
14}
15
16impl Aggregate {
17    /// Create a new aggregate operator.
18    pub fn new(group_by: Vec<Expr>, aggregates: Vec<AggregateFunction>) -> Self {
19        Self {
20            group_by,
21            aggregates,
22        }
23    }
24
25    /// Execute aggregation.
26    pub fn execute(&self, batch: &RecordBatch) -> Result<RecordBatch> {
27        if self.group_by.is_empty() {
28            // Global aggregation
29            self.execute_global_aggregate(batch)
30        } else {
31            // Grouped aggregation
32            self.execute_grouped_aggregate(batch)
33        }
34    }
35
36    /// Execute global aggregation (no GROUP BY).
37    fn execute_global_aggregate(&self, batch: &RecordBatch) -> Result<RecordBatch> {
38        let mut result_fields = Vec::new();
39        let mut result_columns = Vec::new();
40
41        for agg in &self.aggregates {
42            let value = if agg.column == "*" {
43                // COUNT(*) - count all rows regardless of NULL values
44                if matches!(agg.func, AggregateFunc::Count) {
45                    Some(batch.num_rows as f64)
46                } else {
47                    return Err(QueryError::semantic(
48                        "Wildcard (*) can only be used with COUNT function",
49                    ));
50                }
51            } else {
52                let column = batch
53                    .column_by_name(&agg.column)
54                    .ok_or_else(|| QueryError::ColumnNotFound(agg.column.clone()))?;
55                self.compute_aggregate(agg.func, column)?
56            };
57
58            result_fields.push(Field::new(
59                agg.alias.clone().unwrap_or_else(|| {
60                    if agg.column == "*" {
61                        "count".to_string()
62                    } else {
63                        agg.column.clone()
64                    }
65                }),
66                crate::executor::scan::DataType::Float64,
67                true,
68            ));
69            result_columns.push(ColumnData::Float64(vec![value]));
70        }
71
72        let schema = Arc::new(Schema::new(result_fields));
73        RecordBatch::new(schema, result_columns, 1)
74    }
75
76    /// Execute grouped aggregation.
77    fn execute_grouped_aggregate(&self, _batch: &RecordBatch) -> Result<RecordBatch> {
78        // Simplified implementation
79        Err(QueryError::unsupported(
80            "Grouped aggregation not implemented",
81        ))
82    }
83
84    /// Compute aggregate function.
85    fn compute_aggregate(&self, func: AggregateFunc, column: &ColumnData) -> Result<Option<f64>> {
86        match func {
87            AggregateFunc::Count => Ok(Some(self.count(column))),
88            AggregateFunc::Sum => self.sum(column),
89            AggregateFunc::Avg => self.avg(column),
90            AggregateFunc::Min => self.min(column),
91            AggregateFunc::Max => self.max(column),
92        }
93    }
94
95    /// Count aggregate.
96    fn count(&self, column: &ColumnData) -> f64 {
97        let non_null_count = match column {
98            ColumnData::Boolean(data) => data.iter().filter(|v| v.is_some()).count(),
99            ColumnData::Int32(data) => data.iter().filter(|v| v.is_some()).count(),
100            ColumnData::Int64(data) => data.iter().filter(|v| v.is_some()).count(),
101            ColumnData::Float32(data) => data.iter().filter(|v| v.is_some()).count(),
102            ColumnData::Float64(data) => data.iter().filter(|v| v.is_some()).count(),
103            ColumnData::String(data) => data.iter().filter(|v| v.is_some()).count(),
104            ColumnData::Binary(data) => data.iter().filter(|v| v.is_some()).count(),
105        };
106        non_null_count as f64
107    }
108
109    /// Sum aggregate.
110    fn sum(&self, column: &ColumnData) -> Result<Option<f64>> {
111        match column {
112            ColumnData::Int32(data) => {
113                let sum: i64 = data.iter().filter_map(|v| v.map(|i| i as i64)).sum();
114                Ok(Some(sum as f64))
115            }
116            ColumnData::Int64(data) => {
117                let sum: i64 = data.iter().filter_map(|v| *v).sum();
118                Ok(Some(sum as f64))
119            }
120            ColumnData::Float32(data) => {
121                let sum: f32 = data.iter().filter_map(|v| *v).sum();
122                Ok(Some(sum as f64))
123            }
124            ColumnData::Float64(data) => {
125                let sum: f64 = data.iter().filter_map(|v| *v).sum();
126                Ok(Some(sum))
127            }
128            _ => Err(QueryError::type_mismatch("numeric", "non-numeric")),
129        }
130    }
131
132    /// Average aggregate.
133    fn avg(&self, column: &ColumnData) -> Result<Option<f64>> {
134        let sum = self.sum(column)?;
135        let count = self.count(column);
136        if count > 0.0 {
137            Ok(sum.map(|s| s / count))
138        } else {
139            Ok(None)
140        }
141    }
142
143    /// Minimum aggregate.
144    fn min(&self, column: &ColumnData) -> Result<Option<f64>> {
145        match column {
146            ColumnData::Int32(data) => {
147                let min = data.iter().filter_map(|v| *v).min();
148                Ok(min.map(|m| m as f64))
149            }
150            ColumnData::Int64(data) => {
151                let min = data.iter().filter_map(|v| *v).min();
152                Ok(min.map(|m| m as f64))
153            }
154            ColumnData::Float32(data) => {
155                let min = data
156                    .iter()
157                    .filter_map(|v| *v)
158                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
159                Ok(min.map(|m| m as f64))
160            }
161            ColumnData::Float64(data) => {
162                let min = data
163                    .iter()
164                    .filter_map(|v| *v)
165                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
166                Ok(min)
167            }
168            _ => Err(QueryError::type_mismatch("numeric", "non-numeric")),
169        }
170    }
171
172    /// Maximum aggregate.
173    fn max(&self, column: &ColumnData) -> Result<Option<f64>> {
174        match column {
175            ColumnData::Int32(data) => {
176                let max = data.iter().filter_map(|v| *v).max();
177                Ok(max.map(|m| m as f64))
178            }
179            ColumnData::Int64(data) => {
180                let max = data.iter().filter_map(|v| *v).max();
181                Ok(max.map(|m| m as f64))
182            }
183            ColumnData::Float32(data) => {
184                let max = data
185                    .iter()
186                    .filter_map(|v| *v)
187                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
188                Ok(max.map(|m| m as f64))
189            }
190            ColumnData::Float64(data) => {
191                let max = data
192                    .iter()
193                    .filter_map(|v| *v)
194                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
195                Ok(max)
196            }
197            _ => Err(QueryError::type_mismatch("numeric", "non-numeric")),
198        }
199    }
200}
201
202/// Aggregate function.
203#[derive(Debug, Clone)]
204pub struct AggregateFunction {
205    /// Function type.
206    pub func: AggregateFunc,
207    /// Column to aggregate.
208    pub column: String,
209    /// Output alias.
210    pub alias: Option<String>,
211}
212
213/// Aggregate function type.
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum AggregateFunc {
216    /// COUNT function.
217    Count,
218    /// SUM function.
219    Sum,
220    /// AVG function.
221    Avg,
222    /// MIN function.
223    Min,
224    /// MAX function.
225    Max,
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::executor::scan::DataType;
232
233    #[test]
234    fn test_global_aggregate() -> Result<()> {
235        let schema = Arc::new(Schema::new(vec![Field::new(
236            "value".to_string(),
237            DataType::Int64,
238            false,
239        )]));
240
241        let columns = vec![ColumnData::Int64(vec![
242            Some(10),
243            Some(20),
244            Some(30),
245            Some(40),
246            Some(50),
247        ])];
248
249        let batch = RecordBatch::new(schema, columns, 5)?;
250
251        let agg = Aggregate::new(
252            vec![],
253            vec![
254                AggregateFunction {
255                    func: AggregateFunc::Sum,
256                    column: "value".to_string(),
257                    alias: Some("sum".to_string()),
258                },
259                AggregateFunction {
260                    func: AggregateFunc::Avg,
261                    column: "value".to_string(),
262                    alias: Some("avg".to_string()),
263                },
264            ],
265        );
266
267        let result = agg.execute(&batch)?;
268        assert_eq!(result.num_rows, 1);
269        assert_eq!(result.columns.len(), 2);
270
271        Ok(())
272    }
273}