Skip to main content

cjc_data/
agg_kernels.rs

1//! Specialized aggregate kernels for grouped operations.
2//!
3//! These functions operate on gathered data slices with segment boundaries,
4//! or on arbitrary row-index groups. All f64 reductions use Kahan summation
5//! for deterministic, numerically stable results.
6
7use cjc_repro::kahan::KahanAccumulatorF64;
8use cjc_repro::kahan_sum_f64;
9use std::collections::BTreeSet;
10
11// ── Segment-based kernels ────────────────────────────────────────────────────
12// Segments are (start, end) ranges into a contiguous data slice.
13
14/// Kahan-stable sum over contiguous segments.
15pub fn agg_sum_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
16    segments
17        .iter()
18        .map(|&(start, end)| kahan_sum_f64(&data[start..end]))
19        .collect()
20}
21
22/// Kahan-stable mean over contiguous segments.
23pub fn agg_mean_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
24    segments
25        .iter()
26        .map(|&(start, end)| {
27            let n = end - start;
28            if n == 0 {
29                return f64::NAN;
30            }
31            kahan_sum_f64(&data[start..end]) / n as f64
32        })
33        .collect()
34}
35
36/// Count per segment.
37pub fn agg_count(segments: &[(usize, usize)]) -> Vec<i64> {
38    segments
39        .iter()
40        .map(|&(start, end)| (end - start) as i64)
41        .collect()
42}
43
44/// Minimum f64 per segment. Returns NAN for empty segments.
45pub fn agg_min_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
46    segments
47        .iter()
48        .map(|&(start, end)| {
49            if start == end {
50                return f64::NAN;
51            }
52            data[start..end]
53                .iter()
54                .cloned()
55                .fold(f64::INFINITY, |a, b| if b.is_nan() || b < a { b } else { a })
56        })
57        .collect()
58}
59
60/// Maximum f64 per segment. Returns NAN for empty segments.
61pub fn agg_max_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
62    segments
63        .iter()
64        .map(|&(start, end)| {
65            if start == end {
66                return f64::NAN;
67            }
68            data[start..end]
69                .iter()
70                .cloned()
71                .fold(f64::NEG_INFINITY, |a, b| if b.is_nan() || b > a { b } else { a })
72        })
73        .collect()
74}
75
76/// Variance via Welford's online algorithm (numerically stable).
77/// Returns population variance (divide by N, not N-1).
78pub fn agg_var_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
79    segments
80        .iter()
81        .map(|&(start, end)| welford_variance(&data[start..end]))
82        .collect()
83}
84
85/// Standard deviation via Welford's algorithm.
86pub fn agg_sd_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
87    segments
88        .iter()
89        .map(|&(start, end)| {
90            let var = welford_variance(&data[start..end]);
91            if var.is_nan() { f64::NAN } else { var.sqrt() }
92        })
93        .collect()
94}
95
96/// Median via sort per segment. Returns NAN for empty segments.
97pub fn agg_median_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
98    segments
99        .iter()
100        .map(|&(start, end)| {
101            let n = end - start;
102            if n == 0 {
103                return f64::NAN;
104            }
105            let mut buf: Vec<f64> = data[start..end].to_vec();
106            buf.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
107            if n % 2 == 1 {
108                buf[n / 2]
109            } else {
110                (buf[n / 2 - 1] + buf[n / 2]) / 2.0
111            }
112        })
113        .collect()
114}
115
116/// Quantile via sort + linear interpolation. Returns NAN for empty segments.
117pub fn agg_quantile_f64(data: &[f64], p: f64, segments: &[(usize, usize)]) -> Vec<f64> {
118    segments
119        .iter()
120        .map(|&(start, end)| {
121            let n = end - start;
122            if n == 0 {
123                return f64::NAN;
124            }
125            let mut buf: Vec<f64> = data[start..end].to_vec();
126            buf.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
127            let p = p.clamp(0.0, 1.0);
128            let idx = p * (n - 1) as f64;
129            let lo = idx.floor() as usize;
130            let hi = idx.ceil() as usize;
131            if lo == hi {
132                buf[lo]
133            } else {
134                let frac = idx - lo as f64;
135                buf[lo] * (1.0 - frac) + buf[hi] * frac
136            }
137        })
138        .collect()
139}
140
141/// Count distinct strings per segment using BTreeSet (deterministic).
142pub fn agg_n_distinct_str(data: &[String], segments: &[(usize, usize)]) -> Vec<i64> {
143    segments
144        .iter()
145        .map(|&(start, end)| {
146            let set: BTreeSet<&String> = data[start..end].iter().collect();
147            set.len() as i64
148        })
149        .collect()
150}
151
152/// Count distinct i64 values per segment using BTreeSet (deterministic).
153pub fn agg_n_distinct_i64(data: &[i64], segments: &[(usize, usize)]) -> Vec<i64> {
154    segments
155        .iter()
156        .map(|&(start, end)| {
157            let set: BTreeSet<&i64> = data[start..end].iter().collect();
158            set.len() as i64
159        })
160        .collect()
161}
162
163/// First f64 per segment.
164pub fn agg_first_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
165    segments
166        .iter()
167        .map(|&(start, end)| {
168            if start == end { f64::NAN } else { data[start] }
169        })
170        .collect()
171}
172
173/// Last f64 per segment.
174pub fn agg_last_f64(data: &[f64], segments: &[(usize, usize)]) -> Vec<f64> {
175    segments
176        .iter()
177        .map(|&(start, end)| {
178            if start == end { f64::NAN } else { data[end - 1] }
179        })
180        .collect()
181}
182
183/// Sum for i64 (wrapping on overflow).
184pub fn agg_sum_i64(data: &[i64], segments: &[(usize, usize)]) -> Vec<i64> {
185    segments
186        .iter()
187        .map(|&(start, end)| {
188            data[start..end]
189                .iter()
190                .fold(0i64, |acc, &x| acc.wrapping_add(x))
191        })
192        .collect()
193}
194
195/// Minimum i64 per segment.
196pub fn agg_min_i64(data: &[i64], segments: &[(usize, usize)]) -> Vec<i64> {
197    segments
198        .iter()
199        .map(|&(start, end)| {
200            if start == end {
201                i64::MAX
202            } else {
203                data[start..end].iter().cloned().min().unwrap()
204            }
205        })
206        .collect()
207}
208
209/// Maximum i64 per segment.
210pub fn agg_max_i64(data: &[i64], segments: &[(usize, usize)]) -> Vec<i64> {
211    segments
212        .iter()
213        .map(|&(start, end)| {
214            if start == end {
215                i64::MIN
216            } else {
217                data[start..end].iter().cloned().max().unwrap()
218            }
219        })
220        .collect()
221}
222
223// ── Gather-based kernels ─────────────────────────────────────────────────────
224// These work with arbitrary (non-contiguous) row indices, as produced by
225// GroupIndex. They gather values first, then aggregate.
226
227/// Kahan-stable sum over gathered f64 rows.
228pub fn gather_agg_sum_f64(data: &[f64], groups: &[Vec<usize>]) -> Vec<f64> {
229    groups
230        .iter()
231        .map(|indices| {
232            let mut acc = KahanAccumulatorF64::new();
233            for &i in indices {
234                acc.add(data[i]);
235            }
236            acc.finalize()
237        })
238        .collect()
239}
240
241/// Kahan-stable mean over gathered f64 rows.
242pub fn gather_agg_mean_f64(data: &[f64], groups: &[Vec<usize>]) -> Vec<f64> {
243    groups
244        .iter()
245        .map(|indices| {
246            if indices.is_empty() {
247                return f64::NAN;
248            }
249            let mut acc = KahanAccumulatorF64::new();
250            for &i in indices {
251                acc.add(data[i]);
252            }
253            acc.finalize() / indices.len() as f64
254        })
255        .collect()
256}
257
258/// Welford variance over gathered f64 rows.
259pub fn gather_agg_var_f64(data: &[f64], groups: &[Vec<usize>]) -> Vec<f64> {
260    groups
261        .iter()
262        .map(|indices| {
263            let gathered: Vec<f64> = indices.iter().map(|&i| data[i]).collect();
264            welford_variance(&gathered)
265        })
266        .collect()
267}
268
269/// Count distinct strings over gathered rows.
270pub fn gather_agg_n_distinct_str(data: &[String], groups: &[Vec<usize>]) -> Vec<i64> {
271    groups
272        .iter()
273        .map(|indices| {
274            let set: BTreeSet<&String> = indices.iter().map(|&i| &data[i]).collect();
275            set.len() as i64
276        })
277        .collect()
278}
279
280// ── Internal helpers ─────────────────────────────────────────────────────────
281
282/// Welford's online algorithm for population variance.
283/// Returns NAN for empty slices, 0.0 for single-element slices.
284fn welford_variance(values: &[f64]) -> f64 {
285    let n = values.len();
286    if n == 0 {
287        return f64::NAN;
288    }
289    if n == 1 {
290        return 0.0;
291    }
292    let mut mean = 0.0f64;
293    let mut m2 = 0.0f64;
294    for (i, &x) in values.iter().enumerate() {
295        let delta = x - mean;
296        mean += delta / (i + 1) as f64;
297        let delta2 = x - mean;
298        m2 += delta * delta2;
299    }
300    m2 / n as f64
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_welford_known() {
309        // Variance of [2, 4, 4, 4, 5, 5, 7, 9] = 4.0 (population)
310        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
311        let var = welford_variance(&data);
312        assert!((var - 4.0).abs() < 1e-12, "got {}", var);
313    }
314
315    #[test]
316    fn test_welford_empty() {
317        assert!(welford_variance(&[]).is_nan());
318    }
319
320    #[test]
321    fn test_welford_single() {
322        assert_eq!(welford_variance(&[42.0]), 0.0);
323    }
324}