oxigdal_query/executor/
aggregate.rs1use crate::error::{QueryError, Result};
4use crate::executor::scan::{ColumnData, Field, RecordBatch, Schema};
5use crate::parser::ast::Expr;
6use std::sync::Arc;
7
8pub struct Aggregate {
10 pub group_by: Vec<Expr>,
12 pub aggregates: Vec<AggregateFunction>,
14}
15
16impl Aggregate {
17 pub fn new(group_by: Vec<Expr>, aggregates: Vec<AggregateFunction>) -> Self {
19 Self {
20 group_by,
21 aggregates,
22 }
23 }
24
25 pub fn execute(&self, batch: &RecordBatch) -> Result<RecordBatch> {
27 if self.group_by.is_empty() {
28 self.execute_global_aggregate(batch)
30 } else {
31 self.execute_grouped_aggregate(batch)
33 }
34 }
35
36 fn execute_global_aggregate(&self, batch: &RecordBatch) -> Result<RecordBatch> {
38 let mut result_fields = Vec::new();
39 let mut result_columns = Vec::new();
40
41 for agg in &self.aggregates {
42 let value = if agg.column == "*" {
43 if matches!(agg.func, AggregateFunc::Count) {
45 Some(batch.num_rows as f64)
46 } else {
47 return Err(QueryError::semantic(
48 "Wildcard (*) can only be used with COUNT function",
49 ));
50 }
51 } else {
52 let column = batch
53 .column_by_name(&agg.column)
54 .ok_or_else(|| QueryError::ColumnNotFound(agg.column.clone()))?;
55 self.compute_aggregate(agg.func, column)?
56 };
57
58 result_fields.push(Field::new(
59 agg.alias.clone().unwrap_or_else(|| {
60 if agg.column == "*" {
61 "count".to_string()
62 } else {
63 agg.column.clone()
64 }
65 }),
66 crate::executor::scan::DataType::Float64,
67 true,
68 ));
69 result_columns.push(ColumnData::Float64(vec![value]));
70 }
71
72 let schema = Arc::new(Schema::new(result_fields));
73 RecordBatch::new(schema, result_columns, 1)
74 }
75
76 fn execute_grouped_aggregate(&self, _batch: &RecordBatch) -> Result<RecordBatch> {
78 Err(QueryError::unsupported(
80 "Grouped aggregation not implemented",
81 ))
82 }
83
84 fn compute_aggregate(&self, func: AggregateFunc, column: &ColumnData) -> Result<Option<f64>> {
86 match func {
87 AggregateFunc::Count => Ok(Some(self.count(column))),
88 AggregateFunc::Sum => self.sum(column),
89 AggregateFunc::Avg => self.avg(column),
90 AggregateFunc::Min => self.min(column),
91 AggregateFunc::Max => self.max(column),
92 }
93 }
94
95 fn count(&self, column: &ColumnData) -> f64 {
97 let non_null_count = match column {
98 ColumnData::Boolean(data) => data.iter().filter(|v| v.is_some()).count(),
99 ColumnData::Int32(data) => data.iter().filter(|v| v.is_some()).count(),
100 ColumnData::Int64(data) => data.iter().filter(|v| v.is_some()).count(),
101 ColumnData::Float32(data) => data.iter().filter(|v| v.is_some()).count(),
102 ColumnData::Float64(data) => data.iter().filter(|v| v.is_some()).count(),
103 ColumnData::String(data) => data.iter().filter(|v| v.is_some()).count(),
104 ColumnData::Binary(data) => data.iter().filter(|v| v.is_some()).count(),
105 };
106 non_null_count as f64
107 }
108
109 fn sum(&self, column: &ColumnData) -> Result<Option<f64>> {
111 match column {
112 ColumnData::Int32(data) => {
113 let sum: i64 = data.iter().filter_map(|v| v.map(|i| i as i64)).sum();
114 Ok(Some(sum as f64))
115 }
116 ColumnData::Int64(data) => {
117 let sum: i64 = data.iter().filter_map(|v| *v).sum();
118 Ok(Some(sum as f64))
119 }
120 ColumnData::Float32(data) => {
121 let sum: f32 = data.iter().filter_map(|v| *v).sum();
122 Ok(Some(sum as f64))
123 }
124 ColumnData::Float64(data) => {
125 let sum: f64 = data.iter().filter_map(|v| *v).sum();
126 Ok(Some(sum))
127 }
128 _ => Err(QueryError::type_mismatch("numeric", "non-numeric")),
129 }
130 }
131
132 fn avg(&self, column: &ColumnData) -> Result<Option<f64>> {
134 let sum = self.sum(column)?;
135 let count = self.count(column);
136 if count > 0.0 {
137 Ok(sum.map(|s| s / count))
138 } else {
139 Ok(None)
140 }
141 }
142
143 fn min(&self, column: &ColumnData) -> Result<Option<f64>> {
145 match column {
146 ColumnData::Int32(data) => {
147 let min = data.iter().filter_map(|v| *v).min();
148 Ok(min.map(|m| m as f64))
149 }
150 ColumnData::Int64(data) => {
151 let min = data.iter().filter_map(|v| *v).min();
152 Ok(min.map(|m| m as f64))
153 }
154 ColumnData::Float32(data) => {
155 let min = data
156 .iter()
157 .filter_map(|v| *v)
158 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
159 Ok(min.map(|m| m as f64))
160 }
161 ColumnData::Float64(data) => {
162 let min = data
163 .iter()
164 .filter_map(|v| *v)
165 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
166 Ok(min)
167 }
168 _ => Err(QueryError::type_mismatch("numeric", "non-numeric")),
169 }
170 }
171
172 fn max(&self, column: &ColumnData) -> Result<Option<f64>> {
174 match column {
175 ColumnData::Int32(data) => {
176 let max = data.iter().filter_map(|v| *v).max();
177 Ok(max.map(|m| m as f64))
178 }
179 ColumnData::Int64(data) => {
180 let max = data.iter().filter_map(|v| *v).max();
181 Ok(max.map(|m| m as f64))
182 }
183 ColumnData::Float32(data) => {
184 let max = data
185 .iter()
186 .filter_map(|v| *v)
187 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
188 Ok(max.map(|m| m as f64))
189 }
190 ColumnData::Float64(data) => {
191 let max = data
192 .iter()
193 .filter_map(|v| *v)
194 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
195 Ok(max)
196 }
197 _ => Err(QueryError::type_mismatch("numeric", "non-numeric")),
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct AggregateFunction {
205 pub func: AggregateFunc,
207 pub column: String,
209 pub alias: Option<String>,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum AggregateFunc {
216 Count,
218 Sum,
220 Avg,
222 Min,
224 Max,
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::executor::scan::DataType;
232
233 #[test]
234 fn test_global_aggregate() -> Result<()> {
235 let schema = Arc::new(Schema::new(vec![Field::new(
236 "value".to_string(),
237 DataType::Int64,
238 false,
239 )]));
240
241 let columns = vec![ColumnData::Int64(vec![
242 Some(10),
243 Some(20),
244 Some(30),
245 Some(40),
246 Some(50),
247 ])];
248
249 let batch = RecordBatch::new(schema, columns, 5)?;
250
251 let agg = Aggregate::new(
252 vec![],
253 vec![
254 AggregateFunction {
255 func: AggregateFunc::Sum,
256 column: "value".to_string(),
257 alias: Some("sum".to_string()),
258 },
259 AggregateFunction {
260 func: AggregateFunc::Avg,
261 column: "value".to_string(),
262 alias: Some("avg".to_string()),
263 },
264 ],
265 );
266
267 let result = agg.execute(&batch)?;
268 assert_eq!(result.num_rows, 1);
269 assert_eq!(result.columns.len(), 2);
270
271 Ok(())
272 }
273}