Skip to main content

grafeo_core/execution/operators/
accumulator.rs

1//! Shared accumulator types for both pull-based and push-based aggregate operators.
2//!
3//! Provides the canonical definitions of [`AggregateFunction`], [`AggregateExpr`],
4//! [`AggregateState`], and [`HashableValue`] used by both `aggregate.rs` (pull)
5//! and `push/aggregate.rs`.
6
7// Re-export AggregateState so both pull and push operators import from one place.
8pub use super::aggregate::AggregateState;
9
10use grafeo_common::types::Value;
11
12/// Aggregation function types.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[non_exhaustive]
15pub enum AggregateFunction {
16    /// Count of rows (COUNT(*)).
17    Count,
18    /// Count of non-null values (COUNT(column)).
19    CountNonNull,
20    /// Sum of values.
21    Sum,
22    /// Average of values.
23    Avg,
24    /// Minimum value.
25    Min,
26    /// Maximum value.
27    Max,
28    /// First value in the group.
29    First,
30    /// Last value in the group.
31    Last,
32    /// Collect values into a list.
33    Collect,
34    /// Sample standard deviation (STDEV).
35    StdDev,
36    /// Population standard deviation (STDEVP).
37    StdDevPop,
38    /// Sample variance (VAR_SAMP / VARIANCE).
39    Variance,
40    /// Population variance (VAR_POP).
41    VariancePop,
42    /// Discrete percentile (PERCENTILE_DISC).
43    PercentileDisc,
44    /// Continuous percentile (PERCENTILE_CONT).
45    PercentileCont,
46    /// Concatenate values with separator (GROUP_CONCAT).
47    GroupConcat,
48    /// Return an arbitrary value from the group (SAMPLE).
49    Sample,
50    /// Sample covariance (COVAR_SAMP(y, x)).
51    CovarSamp,
52    /// Population covariance (COVAR_POP(y, x)).
53    CovarPop,
54    /// Pearson correlation coefficient (CORR(y, x)).
55    Corr,
56    /// Regression slope (REGR_SLOPE(y, x)).
57    RegrSlope,
58    /// Regression intercept (REGR_INTERCEPT(y, x)).
59    RegrIntercept,
60    /// Coefficient of determination (REGR_R2(y, x)).
61    RegrR2,
62    /// Regression count of non-null pairs (REGR_COUNT(y, x)).
63    RegrCount,
64    /// Regression sum of squares for x (REGR_SXX(y, x)).
65    RegrSxx,
66    /// Regression sum of squares for y (REGR_SYY(y, x)).
67    RegrSyy,
68    /// Regression sum of cross-products (REGR_SXY(y, x)).
69    RegrSxy,
70    /// Regression average of x (REGR_AVGX(y, x)).
71    RegrAvgx,
72    /// Regression average of y (REGR_AVGY(y, x)).
73    RegrAvgy,
74}
75
76/// An aggregation expression.
77#[derive(Debug, Clone)]
78pub struct AggregateExpr {
79    /// The aggregation function.
80    pub function: AggregateFunction,
81    /// Column index to aggregate (None for COUNT(*), y column for binary set functions).
82    pub column: Option<usize>,
83    /// Second column index for binary set functions (x column for COVAR, CORR, REGR_*).
84    pub column2: Option<usize>,
85    /// Whether to aggregate distinct values only.
86    pub distinct: bool,
87    /// Output alias (for naming the result column).
88    pub alias: Option<String>,
89    /// Percentile parameter for PERCENTILE_DISC/PERCENTILE_CONT (0.0 to 1.0).
90    pub percentile: Option<f64>,
91    /// Separator string for GROUP_CONCAT / LISTAGG.
92    pub separator: Option<String>,
93}
94
95impl AggregateExpr {
96    /// Creates a COUNT(*) expression.
97    pub fn count_star() -> Self {
98        Self {
99            function: AggregateFunction::Count,
100            column: None,
101            column2: None,
102            distinct: false,
103            alias: None,
104            percentile: None,
105            separator: None,
106        }
107    }
108
109    /// Creates a COUNT(column) expression.
110    pub fn count(column: usize) -> Self {
111        Self {
112            function: AggregateFunction::CountNonNull,
113            column: Some(column),
114            column2: None,
115            distinct: false,
116            alias: None,
117            percentile: None,
118            separator: None,
119        }
120    }
121
122    /// Creates a SUM(column) expression.
123    pub fn sum(column: usize) -> Self {
124        Self {
125            function: AggregateFunction::Sum,
126            column: Some(column),
127            column2: None,
128            distinct: false,
129            alias: None,
130            percentile: None,
131            separator: None,
132        }
133    }
134
135    /// Creates an AVG(column) expression.
136    pub fn avg(column: usize) -> Self {
137        Self {
138            function: AggregateFunction::Avg,
139            column: Some(column),
140            column2: None,
141            distinct: false,
142            alias: None,
143            percentile: None,
144            separator: None,
145        }
146    }
147
148    /// Creates a MIN(column) expression.
149    pub fn min(column: usize) -> Self {
150        Self {
151            function: AggregateFunction::Min,
152            column: Some(column),
153            column2: None,
154            distinct: false,
155            alias: None,
156            percentile: None,
157            separator: None,
158        }
159    }
160
161    /// Creates a MAX(column) expression.
162    pub fn max(column: usize) -> Self {
163        Self {
164            function: AggregateFunction::Max,
165            column: Some(column),
166            column2: None,
167            distinct: false,
168            alias: None,
169            percentile: None,
170            separator: None,
171        }
172    }
173
174    /// Creates a FIRST(column) expression.
175    pub fn first(column: usize) -> Self {
176        Self {
177            function: AggregateFunction::First,
178            column: Some(column),
179            column2: None,
180            distinct: false,
181            alias: None,
182            percentile: None,
183            separator: None,
184        }
185    }
186
187    /// Creates a LAST(column) expression.
188    pub fn last(column: usize) -> Self {
189        Self {
190            function: AggregateFunction::Last,
191            column: Some(column),
192            column2: None,
193            distinct: false,
194            alias: None,
195            percentile: None,
196            separator: None,
197        }
198    }
199
200    /// Creates a COLLECT(column) expression.
201    pub fn collect(column: usize) -> Self {
202        Self {
203            function: AggregateFunction::Collect,
204            column: Some(column),
205            column2: None,
206            distinct: false,
207            alias: None,
208            percentile: None,
209            separator: None,
210        }
211    }
212
213    /// Creates a STDEV(column) expression (sample standard deviation).
214    pub fn stdev(column: usize) -> Self {
215        Self {
216            function: AggregateFunction::StdDev,
217            column: Some(column),
218            column2: None,
219            distinct: false,
220            alias: None,
221            percentile: None,
222            separator: None,
223        }
224    }
225
226    /// Creates a STDEVP(column) expression (population standard deviation).
227    pub fn stdev_pop(column: usize) -> Self {
228        Self {
229            function: AggregateFunction::StdDevPop,
230            column: Some(column),
231            column2: None,
232            distinct: false,
233            alias: None,
234            percentile: None,
235            separator: None,
236        }
237    }
238
239    /// Creates a PERCENTILE_DISC(column, percentile) expression.
240    ///
241    /// # Arguments
242    /// * `column` - Column index to aggregate
243    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
244    pub fn percentile_disc(column: usize, percentile: f64) -> Self {
245        Self {
246            function: AggregateFunction::PercentileDisc,
247            column: Some(column),
248            column2: None,
249            distinct: false,
250            alias: None,
251            percentile: Some(percentile.clamp(0.0, 1.0)),
252            separator: None,
253        }
254    }
255
256    /// Creates a PERCENTILE_CONT(column, percentile) expression.
257    ///
258    /// # Arguments
259    /// * `column` - Column index to aggregate
260    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
261    pub fn percentile_cont(column: usize, percentile: f64) -> Self {
262        Self {
263            function: AggregateFunction::PercentileCont,
264            column: Some(column),
265            column2: None,
266            distinct: false,
267            alias: None,
268            percentile: Some(percentile.clamp(0.0, 1.0)),
269            separator: None,
270        }
271    }
272
273    /// Sets the distinct flag.
274    pub fn with_distinct(mut self) -> Self {
275        self.distinct = true;
276        self
277    }
278
279    /// Sets the output alias.
280    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
281        self.alias = Some(alias.into());
282        self
283    }
284}
285
286/// A wrapper for [`Value`] that can be hashed (for DISTINCT tracking).
287#[derive(Debug, Clone, PartialEq, Eq, Hash)]
288#[non_exhaustive]
289pub enum HashableValue {
290    /// Null value.
291    Null,
292    /// Boolean value.
293    Bool(bool),
294    /// Integer value.
295    Int64(i64),
296    /// Float as raw bits (for deterministic hashing).
297    Float64Bits(u64),
298    /// String value.
299    String(String),
300    /// Fallback for other types (uses Debug representation).
301    Other(String),
302}
303
304impl From<&Value> for HashableValue {
305    fn from(v: &Value) -> Self {
306        match v {
307            Value::Null => HashableValue::Null,
308            Value::Bool(b) => HashableValue::Bool(*b),
309            Value::Int64(i) => HashableValue::Int64(*i),
310            Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
311            Value::String(s) => HashableValue::String(s.to_string()),
312            other => HashableValue::Other(format!("{other:?}")),
313        }
314    }
315}
316
317impl From<Value> for HashableValue {
318    fn from(v: Value) -> Self {
319        Self::from(&v)
320    }
321}