Skip to main content

grafeo_core/execution/operators/
accumulator.rs

1//! Shared accumulator types for both pull-based and push-based aggregate operators.
2//!
3//! Provides the canonical definitions of [`AggregateFunction`], [`AggregateExpr`],
4//! and [`HashableValue`] used by both `aggregate.rs` (pull) and `push/aggregate.rs`.
5
6use grafeo_common::types::Value;
7
8/// Aggregation function types.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AggregateFunction {
11    /// Count of rows (COUNT(*)).
12    Count,
13    /// Count of non-null values (COUNT(column)).
14    CountNonNull,
15    /// Sum of values.
16    Sum,
17    /// Average of values.
18    Avg,
19    /// Minimum value.
20    Min,
21    /// Maximum value.
22    Max,
23    /// First value in the group.
24    First,
25    /// Last value in the group.
26    Last,
27    /// Collect values into a list.
28    Collect,
29    /// Sample standard deviation (STDEV).
30    StdDev,
31    /// Population standard deviation (STDEVP).
32    StdDevPop,
33    /// Sample variance (VAR_SAMP / VARIANCE).
34    Variance,
35    /// Population variance (VAR_POP).
36    VariancePop,
37    /// Discrete percentile (PERCENTILE_DISC).
38    PercentileDisc,
39    /// Continuous percentile (PERCENTILE_CONT).
40    PercentileCont,
41    /// Concatenate values with separator (GROUP_CONCAT).
42    GroupConcat,
43    /// Return an arbitrary value from the group (SAMPLE).
44    Sample,
45    /// Sample covariance (COVAR_SAMP(y, x)).
46    CovarSamp,
47    /// Population covariance (COVAR_POP(y, x)).
48    CovarPop,
49    /// Pearson correlation coefficient (CORR(y, x)).
50    Corr,
51    /// Regression slope (REGR_SLOPE(y, x)).
52    RegrSlope,
53    /// Regression intercept (REGR_INTERCEPT(y, x)).
54    RegrIntercept,
55    /// Coefficient of determination (REGR_R2(y, x)).
56    RegrR2,
57    /// Regression count of non-null pairs (REGR_COUNT(y, x)).
58    RegrCount,
59    /// Regression sum of squares for x (REGR_SXX(y, x)).
60    RegrSxx,
61    /// Regression sum of squares for y (REGR_SYY(y, x)).
62    RegrSyy,
63    /// Regression sum of cross-products (REGR_SXY(y, x)).
64    RegrSxy,
65    /// Regression average of x (REGR_AVGX(y, x)).
66    RegrAvgx,
67    /// Regression average of y (REGR_AVGY(y, x)).
68    RegrAvgy,
69}
70
71/// An aggregation expression.
72#[derive(Debug, Clone)]
73pub struct AggregateExpr {
74    /// The aggregation function.
75    pub function: AggregateFunction,
76    /// Column index to aggregate (None for COUNT(*), y column for binary set functions).
77    pub column: Option<usize>,
78    /// Second column index for binary set functions (x column for COVAR, CORR, REGR_*).
79    pub column2: Option<usize>,
80    /// Whether to aggregate distinct values only.
81    pub distinct: bool,
82    /// Output alias (for naming the result column).
83    pub alias: Option<String>,
84    /// Percentile parameter for PERCENTILE_DISC/PERCENTILE_CONT (0.0 to 1.0).
85    pub percentile: Option<f64>,
86    /// Separator string for GROUP_CONCAT / LISTAGG.
87    pub separator: Option<String>,
88}
89
90impl AggregateExpr {
91    /// Creates a COUNT(*) expression.
92    pub fn count_star() -> Self {
93        Self {
94            function: AggregateFunction::Count,
95            column: None,
96            column2: None,
97            distinct: false,
98            alias: None,
99            percentile: None,
100            separator: None,
101        }
102    }
103
104    /// Creates a COUNT(column) expression.
105    pub fn count(column: usize) -> Self {
106        Self {
107            function: AggregateFunction::CountNonNull,
108            column: Some(column),
109            column2: None,
110            distinct: false,
111            alias: None,
112            percentile: None,
113            separator: None,
114        }
115    }
116
117    /// Creates a SUM(column) expression.
118    pub fn sum(column: usize) -> Self {
119        Self {
120            function: AggregateFunction::Sum,
121            column: Some(column),
122            column2: None,
123            distinct: false,
124            alias: None,
125            percentile: None,
126            separator: None,
127        }
128    }
129
130    /// Creates an AVG(column) expression.
131    pub fn avg(column: usize) -> Self {
132        Self {
133            function: AggregateFunction::Avg,
134            column: Some(column),
135            column2: None,
136            distinct: false,
137            alias: None,
138            percentile: None,
139            separator: None,
140        }
141    }
142
143    /// Creates a MIN(column) expression.
144    pub fn min(column: usize) -> Self {
145        Self {
146            function: AggregateFunction::Min,
147            column: Some(column),
148            column2: None,
149            distinct: false,
150            alias: None,
151            percentile: None,
152            separator: None,
153        }
154    }
155
156    /// Creates a MAX(column) expression.
157    pub fn max(column: usize) -> Self {
158        Self {
159            function: AggregateFunction::Max,
160            column: Some(column),
161            column2: None,
162            distinct: false,
163            alias: None,
164            percentile: None,
165            separator: None,
166        }
167    }
168
169    /// Creates a FIRST(column) expression.
170    pub fn first(column: usize) -> Self {
171        Self {
172            function: AggregateFunction::First,
173            column: Some(column),
174            column2: None,
175            distinct: false,
176            alias: None,
177            percentile: None,
178            separator: None,
179        }
180    }
181
182    /// Creates a LAST(column) expression.
183    pub fn last(column: usize) -> Self {
184        Self {
185            function: AggregateFunction::Last,
186            column: Some(column),
187            column2: None,
188            distinct: false,
189            alias: None,
190            percentile: None,
191            separator: None,
192        }
193    }
194
195    /// Creates a COLLECT(column) expression.
196    pub fn collect(column: usize) -> Self {
197        Self {
198            function: AggregateFunction::Collect,
199            column: Some(column),
200            column2: None,
201            distinct: false,
202            alias: None,
203            percentile: None,
204            separator: None,
205        }
206    }
207
208    /// Creates a STDEV(column) expression (sample standard deviation).
209    pub fn stdev(column: usize) -> Self {
210        Self {
211            function: AggregateFunction::StdDev,
212            column: Some(column),
213            column2: None,
214            distinct: false,
215            alias: None,
216            percentile: None,
217            separator: None,
218        }
219    }
220
221    /// Creates a STDEVP(column) expression (population standard deviation).
222    pub fn stdev_pop(column: usize) -> Self {
223        Self {
224            function: AggregateFunction::StdDevPop,
225            column: Some(column),
226            column2: None,
227            distinct: false,
228            alias: None,
229            percentile: None,
230            separator: None,
231        }
232    }
233
234    /// Creates a PERCENTILE_DISC(column, percentile) expression.
235    ///
236    /// # Arguments
237    /// * `column` - Column index to aggregate
238    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
239    pub fn percentile_disc(column: usize, percentile: f64) -> Self {
240        Self {
241            function: AggregateFunction::PercentileDisc,
242            column: Some(column),
243            column2: None,
244            distinct: false,
245            alias: None,
246            percentile: Some(percentile.clamp(0.0, 1.0)),
247            separator: None,
248        }
249    }
250
251    /// Creates a PERCENTILE_CONT(column, percentile) expression.
252    ///
253    /// # Arguments
254    /// * `column` - Column index to aggregate
255    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
256    pub fn percentile_cont(column: usize, percentile: f64) -> Self {
257        Self {
258            function: AggregateFunction::PercentileCont,
259            column: Some(column),
260            column2: None,
261            distinct: false,
262            alias: None,
263            percentile: Some(percentile.clamp(0.0, 1.0)),
264            separator: None,
265        }
266    }
267
268    /// Sets the distinct flag.
269    pub fn with_distinct(mut self) -> Self {
270        self.distinct = true;
271        self
272    }
273
274    /// Sets the output alias.
275    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
276        self.alias = Some(alias.into());
277        self
278    }
279}
280
281/// A wrapper for [`Value`] that can be hashed (for DISTINCT tracking).
282#[derive(Debug, Clone, PartialEq, Eq, Hash)]
283pub enum HashableValue {
284    /// Null value.
285    Null,
286    /// Boolean value.
287    Bool(bool),
288    /// Integer value.
289    Int64(i64),
290    /// Float as raw bits (for deterministic hashing).
291    Float64Bits(u64),
292    /// String value.
293    String(String),
294    /// Fallback for other types (uses Debug representation).
295    Other(String),
296}
297
298impl From<&Value> for HashableValue {
299    fn from(v: &Value) -> Self {
300        match v {
301            Value::Null => HashableValue::Null,
302            Value::Bool(b) => HashableValue::Bool(*b),
303            Value::Int64(i) => HashableValue::Int64(*i),
304            Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
305            Value::String(s) => HashableValue::String(s.to_string()),
306            other => HashableValue::Other(format!("{other:?}")),
307        }
308    }
309}
310
311impl From<Value> for HashableValue {
312    fn from(v: Value) -> Self {
313        Self::from(&v)
314    }
315}