Skip to main content

grafeo_core/execution/operators/
accumulator.rs

1//! Shared accumulator types for both pull-based and push-based aggregate operators.
2//!
3//! Provides the canonical definitions of [`AggregateFunction`], [`AggregateExpr`],
4//! and [`HashableValue`] used by both `aggregate.rs` (pull) and `push/aggregate.rs`.
5
6use grafeo_common::types::Value;
7
8/// Aggregation function types.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10#[non_exhaustive]
11pub enum AggregateFunction {
12    /// Count of rows (COUNT(*)).
13    Count,
14    /// Count of non-null values (COUNT(column)).
15    CountNonNull,
16    /// Sum of values.
17    Sum,
18    /// Average of values.
19    Avg,
20    /// Minimum value.
21    Min,
22    /// Maximum value.
23    Max,
24    /// First value in the group.
25    First,
26    /// Last value in the group.
27    Last,
28    /// Collect values into a list.
29    Collect,
30    /// Sample standard deviation (STDEV).
31    StdDev,
32    /// Population standard deviation (STDEVP).
33    StdDevPop,
34    /// Sample variance (VAR_SAMP / VARIANCE).
35    Variance,
36    /// Population variance (VAR_POP).
37    VariancePop,
38    /// Discrete percentile (PERCENTILE_DISC).
39    PercentileDisc,
40    /// Continuous percentile (PERCENTILE_CONT).
41    PercentileCont,
42    /// Concatenate values with separator (GROUP_CONCAT).
43    GroupConcat,
44    /// Return an arbitrary value from the group (SAMPLE).
45    Sample,
46    /// Sample covariance (COVAR_SAMP(y, x)).
47    CovarSamp,
48    /// Population covariance (COVAR_POP(y, x)).
49    CovarPop,
50    /// Pearson correlation coefficient (CORR(y, x)).
51    Corr,
52    /// Regression slope (REGR_SLOPE(y, x)).
53    RegrSlope,
54    /// Regression intercept (REGR_INTERCEPT(y, x)).
55    RegrIntercept,
56    /// Coefficient of determination (REGR_R2(y, x)).
57    RegrR2,
58    /// Regression count of non-null pairs (REGR_COUNT(y, x)).
59    RegrCount,
60    /// Regression sum of squares for x (REGR_SXX(y, x)).
61    RegrSxx,
62    /// Regression sum of squares for y (REGR_SYY(y, x)).
63    RegrSyy,
64    /// Regression sum of cross-products (REGR_SXY(y, x)).
65    RegrSxy,
66    /// Regression average of x (REGR_AVGX(y, x)).
67    RegrAvgx,
68    /// Regression average of y (REGR_AVGY(y, x)).
69    RegrAvgy,
70}
71
72/// An aggregation expression.
73#[derive(Debug, Clone)]
74pub struct AggregateExpr {
75    /// The aggregation function.
76    pub function: AggregateFunction,
77    /// Column index to aggregate (None for COUNT(*), y column for binary set functions).
78    pub column: Option<usize>,
79    /// Second column index for binary set functions (x column for COVAR, CORR, REGR_*).
80    pub column2: Option<usize>,
81    /// Whether to aggregate distinct values only.
82    pub distinct: bool,
83    /// Output alias (for naming the result column).
84    pub alias: Option<String>,
85    /// Percentile parameter for PERCENTILE_DISC/PERCENTILE_CONT (0.0 to 1.0).
86    pub percentile: Option<f64>,
87    /// Separator string for GROUP_CONCAT / LISTAGG.
88    pub separator: Option<String>,
89}
90
91impl AggregateExpr {
92    /// Creates a COUNT(*) expression.
93    pub fn count_star() -> Self {
94        Self {
95            function: AggregateFunction::Count,
96            column: None,
97            column2: None,
98            distinct: false,
99            alias: None,
100            percentile: None,
101            separator: None,
102        }
103    }
104
105    /// Creates a COUNT(column) expression.
106    pub fn count(column: usize) -> Self {
107        Self {
108            function: AggregateFunction::CountNonNull,
109            column: Some(column),
110            column2: None,
111            distinct: false,
112            alias: None,
113            percentile: None,
114            separator: None,
115        }
116    }
117
118    /// Creates a SUM(column) expression.
119    pub fn sum(column: usize) -> Self {
120        Self {
121            function: AggregateFunction::Sum,
122            column: Some(column),
123            column2: None,
124            distinct: false,
125            alias: None,
126            percentile: None,
127            separator: None,
128        }
129    }
130
131    /// Creates an AVG(column) expression.
132    pub fn avg(column: usize) -> Self {
133        Self {
134            function: AggregateFunction::Avg,
135            column: Some(column),
136            column2: None,
137            distinct: false,
138            alias: None,
139            percentile: None,
140            separator: None,
141        }
142    }
143
144    /// Creates a MIN(column) expression.
145    pub fn min(column: usize) -> Self {
146        Self {
147            function: AggregateFunction::Min,
148            column: Some(column),
149            column2: None,
150            distinct: false,
151            alias: None,
152            percentile: None,
153            separator: None,
154        }
155    }
156
157    /// Creates a MAX(column) expression.
158    pub fn max(column: usize) -> Self {
159        Self {
160            function: AggregateFunction::Max,
161            column: Some(column),
162            column2: None,
163            distinct: false,
164            alias: None,
165            percentile: None,
166            separator: None,
167        }
168    }
169
170    /// Creates a FIRST(column) expression.
171    pub fn first(column: usize) -> Self {
172        Self {
173            function: AggregateFunction::First,
174            column: Some(column),
175            column2: None,
176            distinct: false,
177            alias: None,
178            percentile: None,
179            separator: None,
180        }
181    }
182
183    /// Creates a LAST(column) expression.
184    pub fn last(column: usize) -> Self {
185        Self {
186            function: AggregateFunction::Last,
187            column: Some(column),
188            column2: None,
189            distinct: false,
190            alias: None,
191            percentile: None,
192            separator: None,
193        }
194    }
195
196    /// Creates a COLLECT(column) expression.
197    pub fn collect(column: usize) -> Self {
198        Self {
199            function: AggregateFunction::Collect,
200            column: Some(column),
201            column2: None,
202            distinct: false,
203            alias: None,
204            percentile: None,
205            separator: None,
206        }
207    }
208
209    /// Creates a STDEV(column) expression (sample standard deviation).
210    pub fn stdev(column: usize) -> Self {
211        Self {
212            function: AggregateFunction::StdDev,
213            column: Some(column),
214            column2: None,
215            distinct: false,
216            alias: None,
217            percentile: None,
218            separator: None,
219        }
220    }
221
222    /// Creates a STDEVP(column) expression (population standard deviation).
223    pub fn stdev_pop(column: usize) -> Self {
224        Self {
225            function: AggregateFunction::StdDevPop,
226            column: Some(column),
227            column2: None,
228            distinct: false,
229            alias: None,
230            percentile: None,
231            separator: None,
232        }
233    }
234
235    /// Creates a PERCENTILE_DISC(column, percentile) expression.
236    ///
237    /// # Arguments
238    /// * `column` - Column index to aggregate
239    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
240    pub fn percentile_disc(column: usize, percentile: f64) -> Self {
241        Self {
242            function: AggregateFunction::PercentileDisc,
243            column: Some(column),
244            column2: None,
245            distinct: false,
246            alias: None,
247            percentile: Some(percentile.clamp(0.0, 1.0)),
248            separator: None,
249        }
250    }
251
252    /// Creates a PERCENTILE_CONT(column, percentile) expression.
253    ///
254    /// # Arguments
255    /// * `column` - Column index to aggregate
256    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
257    pub fn percentile_cont(column: usize, percentile: f64) -> Self {
258        Self {
259            function: AggregateFunction::PercentileCont,
260            column: Some(column),
261            column2: None,
262            distinct: false,
263            alias: None,
264            percentile: Some(percentile.clamp(0.0, 1.0)),
265            separator: None,
266        }
267    }
268
269    /// Sets the distinct flag.
270    pub fn with_distinct(mut self) -> Self {
271        self.distinct = true;
272        self
273    }
274
275    /// Sets the output alias.
276    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
277        self.alias = Some(alias.into());
278        self
279    }
280}
281
282/// A wrapper for [`Value`] that can be hashed (for DISTINCT tracking).
283#[derive(Debug, Clone, PartialEq, Eq, Hash)]
284#[non_exhaustive]
285pub enum HashableValue {
286    /// Null value.
287    Null,
288    /// Boolean value.
289    Bool(bool),
290    /// Integer value.
291    Int64(i64),
292    /// Float as raw bits (for deterministic hashing).
293    Float64Bits(u64),
294    /// String value.
295    String(String),
296    /// Fallback for other types (uses Debug representation).
297    Other(String),
298}
299
300impl From<&Value> for HashableValue {
301    fn from(v: &Value) -> Self {
302        match v {
303            Value::Null => HashableValue::Null,
304            Value::Bool(b) => HashableValue::Bool(*b),
305            Value::Int64(i) => HashableValue::Int64(*i),
306            Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
307            Value::String(s) => HashableValue::String(s.to_string()),
308            other => HashableValue::Other(format!("{other:?}")),
309        }
310    }
311}
312
313impl From<Value> for HashableValue {
314    fn from(v: Value) -> Self {
315        Self::from(&v)
316    }
317}