Skip to main content

grafeo_core/execution/operators/
accumulator.rs

1//! Shared accumulator types for both pull-based and push-based aggregate operators.
2//!
3//! Provides the canonical definitions of [`AggregateFunction`], [`AggregateExpr`],
4//! and [`HashableValue`] used by both `aggregate.rs` (pull) and `push/aggregate.rs`.
5
6use grafeo_common::types::Value;
7
8/// Aggregation function types.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AggregateFunction {
11    /// Count of rows (COUNT(*)).
12    Count,
13    /// Count of non-null values (COUNT(column)).
14    CountNonNull,
15    /// Sum of values.
16    Sum,
17    /// Average of values.
18    Avg,
19    /// Minimum value.
20    Min,
21    /// Maximum value.
22    Max,
23    /// First value in the group.
24    First,
25    /// Last value in the group.
26    Last,
27    /// Collect values into a list.
28    Collect,
29    /// Sample standard deviation (STDEV).
30    StdDev,
31    /// Population standard deviation (STDEVP).
32    StdDevPop,
33    /// Sample variance (VAR_SAMP / VARIANCE).
34    Variance,
35    /// Population variance (VAR_POP).
36    VariancePop,
37    /// Discrete percentile (PERCENTILE_DISC).
38    PercentileDisc,
39    /// Continuous percentile (PERCENTILE_CONT).
40    PercentileCont,
41    /// Concatenate values with separator (GROUP_CONCAT).
42    GroupConcat,
43    /// Return an arbitrary value from the group (SAMPLE).
44    Sample,
45    /// Sample covariance (COVAR_SAMP(y, x)).
46    CovarSamp,
47    /// Population covariance (COVAR_POP(y, x)).
48    CovarPop,
49    /// Pearson correlation coefficient (CORR(y, x)).
50    Corr,
51    /// Regression slope (REGR_SLOPE(y, x)).
52    RegrSlope,
53    /// Regression intercept (REGR_INTERCEPT(y, x)).
54    RegrIntercept,
55    /// Coefficient of determination (REGR_R2(y, x)).
56    RegrR2,
57    /// Regression count of non-null pairs (REGR_COUNT(y, x)).
58    RegrCount,
59    /// Regression sum of squares for x (REGR_SXX(y, x)).
60    RegrSxx,
61    /// Regression sum of squares for y (REGR_SYY(y, x)).
62    RegrSyy,
63    /// Regression sum of cross-products (REGR_SXY(y, x)).
64    RegrSxy,
65    /// Regression average of x (REGR_AVGX(y, x)).
66    RegrAvgx,
67    /// Regression average of y (REGR_AVGY(y, x)).
68    RegrAvgy,
69}
70
71/// An aggregation expression.
72#[derive(Debug, Clone)]
73pub struct AggregateExpr {
74    /// The aggregation function.
75    pub function: AggregateFunction,
76    /// Column index to aggregate (None for COUNT(*), y column for binary set functions).
77    pub column: Option<usize>,
78    /// Second column index for binary set functions (x column for COVAR, CORR, REGR_*).
79    pub column2: Option<usize>,
80    /// Whether to aggregate distinct values only.
81    pub distinct: bool,
82    /// Output alias (for naming the result column).
83    pub alias: Option<String>,
84    /// Percentile parameter for PERCENTILE_DISC/PERCENTILE_CONT (0.0 to 1.0).
85    pub percentile: Option<f64>,
86}
87
88impl AggregateExpr {
89    /// Creates a COUNT(*) expression.
90    pub fn count_star() -> Self {
91        Self {
92            function: AggregateFunction::Count,
93            column: None,
94            column2: None,
95            distinct: false,
96            alias: None,
97            percentile: None,
98        }
99    }
100
101    /// Creates a COUNT(column) expression.
102    pub fn count(column: usize) -> Self {
103        Self {
104            function: AggregateFunction::CountNonNull,
105            column: Some(column),
106            column2: None,
107            distinct: false,
108            alias: None,
109            percentile: None,
110        }
111    }
112
113    /// Creates a SUM(column) expression.
114    pub fn sum(column: usize) -> Self {
115        Self {
116            function: AggregateFunction::Sum,
117            column: Some(column),
118            column2: None,
119            distinct: false,
120            alias: None,
121            percentile: None,
122        }
123    }
124
125    /// Creates an AVG(column) expression.
126    pub fn avg(column: usize) -> Self {
127        Self {
128            function: AggregateFunction::Avg,
129            column: Some(column),
130            column2: None,
131            distinct: false,
132            alias: None,
133            percentile: None,
134        }
135    }
136
137    /// Creates a MIN(column) expression.
138    pub fn min(column: usize) -> Self {
139        Self {
140            function: AggregateFunction::Min,
141            column: Some(column),
142            column2: None,
143            distinct: false,
144            alias: None,
145            percentile: None,
146        }
147    }
148
149    /// Creates a MAX(column) expression.
150    pub fn max(column: usize) -> Self {
151        Self {
152            function: AggregateFunction::Max,
153            column: Some(column),
154            column2: None,
155            distinct: false,
156            alias: None,
157            percentile: None,
158        }
159    }
160
161    /// Creates a FIRST(column) expression.
162    pub fn first(column: usize) -> Self {
163        Self {
164            function: AggregateFunction::First,
165            column: Some(column),
166            column2: None,
167            distinct: false,
168            alias: None,
169            percentile: None,
170        }
171    }
172
173    /// Creates a LAST(column) expression.
174    pub fn last(column: usize) -> Self {
175        Self {
176            function: AggregateFunction::Last,
177            column: Some(column),
178            column2: None,
179            distinct: false,
180            alias: None,
181            percentile: None,
182        }
183    }
184
185    /// Creates a COLLECT(column) expression.
186    pub fn collect(column: usize) -> Self {
187        Self {
188            function: AggregateFunction::Collect,
189            column: Some(column),
190            column2: None,
191            distinct: false,
192            alias: None,
193            percentile: None,
194        }
195    }
196
197    /// Creates a STDEV(column) expression (sample standard deviation).
198    pub fn stdev(column: usize) -> Self {
199        Self {
200            function: AggregateFunction::StdDev,
201            column: Some(column),
202            column2: None,
203            distinct: false,
204            alias: None,
205            percentile: None,
206        }
207    }
208
209    /// Creates a STDEVP(column) expression (population standard deviation).
210    pub fn stdev_pop(column: usize) -> Self {
211        Self {
212            function: AggregateFunction::StdDevPop,
213            column: Some(column),
214            column2: None,
215            distinct: false,
216            alias: None,
217            percentile: None,
218        }
219    }
220
221    /// Creates a PERCENTILE_DISC(column, percentile) expression.
222    ///
223    /// # Arguments
224    /// * `column` - Column index to aggregate
225    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
226    pub fn percentile_disc(column: usize, percentile: f64) -> Self {
227        Self {
228            function: AggregateFunction::PercentileDisc,
229            column: Some(column),
230            column2: None,
231            distinct: false,
232            alias: None,
233            percentile: Some(percentile.clamp(0.0, 1.0)),
234        }
235    }
236
237    /// Creates a PERCENTILE_CONT(column, percentile) expression.
238    ///
239    /// # Arguments
240    /// * `column` - Column index to aggregate
241    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
242    pub fn percentile_cont(column: usize, percentile: f64) -> Self {
243        Self {
244            function: AggregateFunction::PercentileCont,
245            column: Some(column),
246            column2: None,
247            distinct: false,
248            alias: None,
249            percentile: Some(percentile.clamp(0.0, 1.0)),
250        }
251    }
252
253    /// Sets the distinct flag.
254    pub fn with_distinct(mut self) -> Self {
255        self.distinct = true;
256        self
257    }
258
259    /// Sets the output alias.
260    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
261        self.alias = Some(alias.into());
262        self
263    }
264}
265
266/// A wrapper for [`Value`] that can be hashed (for DISTINCT tracking).
267#[derive(Debug, Clone, PartialEq, Eq, Hash)]
268pub enum HashableValue {
269    /// Null value.
270    Null,
271    /// Boolean value.
272    Bool(bool),
273    /// Integer value.
274    Int64(i64),
275    /// Float as raw bits (for deterministic hashing).
276    Float64Bits(u64),
277    /// String value.
278    String(String),
279    /// Fallback for other types (uses Debug representation).
280    Other(String),
281}
282
283impl From<&Value> for HashableValue {
284    fn from(v: &Value) -> Self {
285        match v {
286            Value::Null => HashableValue::Null,
287            Value::Bool(b) => HashableValue::Bool(*b),
288            Value::Int64(i) => HashableValue::Int64(*i),
289            Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
290            Value::String(s) => HashableValue::String(s.to_string()),
291            other => HashableValue::Other(format!("{other:?}")),
292        }
293    }
294}
295
296impl From<Value> for HashableValue {
297    fn from(v: Value) -> Self {
298        Self::from(&v)
299    }
300}