Skip to main content

bench_rolling_var/
bench_rolling_var.rs

1//! Bench + golden digest for Series::rolling().{sum,mean,var,std} on all-valid
2//! Float64 input.
3//!
4//! Run: cargo run -p fp-frame --example bench_rolling_var --release
5//!
6//! `apply_rolling` previously rebuilt a fresh `Vec<f64>` for every window via
7//! `window_values` (a `filter_map` over `Scalar::to_f64`). For an all-valid
8//! Float64 column the window IS a contiguous slice of the typed buffer, so the
9//! per-window allocation and per-element Scalar dispatch are pure overhead. The
10//! typed fast path feeds `&data[start..end]` straight to the aggregation
11//! closure. var/std are two-pass per window (mean, then sum of squares) so the
12//! win compounds with window size.
13//!
14//! The golden battery pins the exact f64 bits across windows, min_periods, and
15//! centered/trailing variants; the randomized cross-check compares the API
16//! against an independent verbatim per-window reference (bit-for-bit).
17
18use std::time::Instant;
19
20use fp_frame::Series;
21use fp_index::IndexLabel;
22use fp_types::Scalar;
23
24fn s_from(vals: &[f64]) -> Series {
25    let idx: Vec<IndexLabel> = (0..vals.len() as i64).map(IndexLabel::Int64).collect();
26    let sc: Vec<Scalar> = vals.iter().map(|&v| Scalar::Float64(v)).collect();
27    Series::from_values("s", idx, sc).unwrap()
28}
29
30// Only the aggregations whose back end is `apply_rolling` — the function this
31// lever touches. (Rolling sum/mean route to slt1p's O(n) online sweep and are
32// deliberately NOT bit-identical to the naive fold, so they are excluded from
33// this golden.)
34#[derive(Clone, Copy)]
35enum Agg {
36    Var,
37    Std,
38    Skew,
39    Kurt,
40    Sem,
41    Prod,
42}
43
44fn agg_name(a: Agg) -> &'static str {
45    match a {
46        Agg::Var => "var",
47        Agg::Std => "std",
48        Agg::Skew => "skew",
49        Agg::Kurt => "kurt",
50        Agg::Sem => "sem",
51        Agg::Prod => "prod",
52    }
53}
54
55const ALL_AGGS: [Agg; 6] = [
56    Agg::Var,
57    Agg::Std,
58    Agg::Skew,
59    Agg::Kurt,
60    Agg::Sem,
61    Agg::Prod,
62];
63
64fn run_api(s: &Series, window: usize, min_periods: Option<usize>, center: bool, a: Agg) -> Series {
65    let r = if center {
66        s.rolling_center(window, min_periods)
67    } else {
68        s.rolling(window, min_periods)
69    };
70    match a {
71        Agg::Var => r.var(),
72        Agg::Std => r.std(),
73        Agg::Skew => r.skew(),
74        Agg::Kurt => r.kurt(),
75        Agg::Sem => r.sem(),
76        Agg::Prod => r.prod(),
77    }
78    .unwrap()
79}
80
81/// Independent verbatim reference for the all-valid window aggregation. Mirrors
82/// the original generic `apply_rolling` semantics exactly (collect window, gate
83/// on min_periods by length, apply the same closures in the same fold order).
84fn ref_values(vals: &[f64], window: usize, min_periods: usize, center: bool, a: Agg) -> Vec<f64> {
85    let len = vals.len();
86    let mut out = Vec::with_capacity(len);
87    for i in 0..len {
88        let (start, end) = if center {
89            let half = window / 2;
90            let start = i.saturating_sub(half);
91            let end = (i + half + window % 2).min(len);
92            (start, end)
93        } else {
94            ((i + 1).saturating_sub(window), i + 1)
95        };
96        let nums = &vals[start..end];
97        if nums.len() < min_periods {
98            out.push(f64::NAN); // sentinel "null"; compared as null below
99            continue;
100        }
101        // Verbatim reference only for var/std (formulas the bench owns); other
102        // aggregations are proven by the before==after FNV comparison instead.
103        let v = match a {
104            Agg::Var => {
105                if nums.len() < 2 {
106                    f64::NAN
107                } else {
108                    let mean = nums.iter().sum::<f64>() / nums.len() as f64;
109                    nums.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (nums.len() - 1) as f64
110                }
111            }
112            Agg::Std => {
113                if nums.len() < 2 {
114                    f64::NAN
115                } else {
116                    let mean = nums.iter().sum::<f64>() / nums.len() as f64;
117                    (nums.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
118                        / (nums.len() - 1) as f64)
119                        .sqrt()
120                }
121            }
122            _ => unreachable!("ref_values only covers var/std"),
123        };
124        out.push(v);
125    }
126    out
127}
128
129fn fmt_series(s: &Series) -> String {
130    let mut out = String::new();
131    for v in s.values() {
132        match v {
133            Scalar::Float64(f) => {
134                if f.is_nan() {
135                    out.push_str("nan ");
136                } else {
137                    out.push_str(&format!("{:016x} ", f.to_bits()));
138                }
139            }
140            Scalar::Null(_) => out.push_str("null "),
141            other => out.push_str(&format!("?{other:?} ")),
142        }
143    }
144    out
145}
146
147fn golden() -> String {
148    let mut out = String::new();
149    let data: &[&[f64]] = &[
150        &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
151        &[3.5, -1.0, 2.25, 0.0, 9.9, -4.4, 8.1, 2.2, 7.7, -3.3],
152        &[1e8, 1.0, 1e-8, -1e8, 42.0],
153    ];
154    for (di, d) in data.iter().enumerate() {
155        let s = s_from(d);
156        for &window in &[2usize, 3, 5] {
157            for &mp in &[None, Some(1usize), Some(window)] {
158                for &center in &[false, true] {
159                    for a in ALL_AGGS {
160                        let res = run_api(&s, window, mp, center, a);
161                        out.push_str(&format!(
162                            "d{di} w{window} mp{mp:?} c{center} {}: {}\n",
163                            agg_name(a),
164                            fmt_series(&res)
165                        ));
166                    }
167                }
168            }
169        }
170    }
171    out
172}
173
174fn cross_check() -> (usize, usize) {
175    let mut state: u64 = 0x9E3779B97F4A7C15;
176    let mut next = || {
177        state = state
178            .wrapping_mul(6364136223846793005)
179            .wrapping_add(1442695040888963407);
180        state
181    };
182    let (mut ok, mut bad) = (0usize, 0usize);
183    for _ in 0..3000 {
184        let n = (next() % 60) as usize + 1;
185        let xs: Vec<f64> = (0..n)
186            .map(|_| (next() % 20000) as f64 / 100.0 - 100.0)
187            .collect();
188        let window = (next() % 8) as usize + 1;
189        let mp_pick = next() % 3;
190        let mp = match mp_pick {
191            0 => None,
192            1 => Some(1usize),
193            _ => Some(window),
194        };
195        let center = next() % 2 == 0;
196        let a = if next() % 2 == 0 { Agg::Var } else { Agg::Std };
197        let s = s_from(&xs);
198        let got = run_api(&s, window, mp, center, a);
199        let mp_eff = mp.unwrap_or(window);
200        let want = ref_values(&xs, window, mp_eff, center, a);
201        let got_vals = got.values();
202        let mut row_ok = got_vals.len() == want.len();
203        if row_ok {
204            for (g, w) in got_vals.iter().zip(want.iter()) {
205                let matches = match g {
206                    Scalar::Null(_) => w.is_nan() && mp_eff > 0, // min_periods-null sentinel
207                    Scalar::Float64(f) => {
208                        f.to_bits() == w.to_bits() || (f.is_nan() && w.is_nan())
209                    }
210                    _ => false,
211                };
212                if !matches {
213                    row_ok = false;
214                    break;
215                }
216            }
217        }
218        if row_ok {
219            ok += 1;
220        } else {
221            bad += 1;
222            if bad <= 3 {
223                eprintln!("MISMATCH n={n} w={window} mp={mp:?} c={center} agg={}", agg_name(a));
224            }
225        }
226    }
227    (ok, bad)
228}
229
230/// FNV-1a 64-bit over the full golden battery — a stable parity fingerprint
231/// that survives rch's tail-truncated output (the verbose per-row dump can
232/// scroll off; this single line cannot).
233fn fnv1a64(s: &str) -> u64 {
234    let mut h: u64 = 0xcbf29ce484222325;
235    for b in s.as_bytes() {
236        h ^= *b as u64;
237        h = h.wrapping_mul(0x100000001b3);
238    }
239    h
240}
241
242fn main() {
243    let g = golden();
244    println!("GOLDEN_FNV1A64 {:016x} len={}", fnv1a64(&g), g.len());
245
246    let (ok, bad) = cross_check();
247    println!("CROSSCHECK ok={ok} bad={bad}");
248
249    // Large all-valid Float64 series; moderate window so the per-window
250    // two-pass var/std dominates. Deterministic data.
251    let n: usize = 200_000;
252    let window: usize = 250;
253    let xs: Vec<f64> = (0..n)
254        .map(|i| ((i as f64) * 0.31).sin() * 100.0 + ((i % 997) as f64))
255        .collect();
256    let s = s_from(&xs);
257
258    for a in [Agg::Var, Agg::Std, Agg::Skew, Agg::Sem] {
259        let t = Instant::now();
260        let res = run_api(&s, window, None, false, a);
261        let d = t.elapsed();
262        // touch output so it is not optimized away
263        let last = res.values().last().cloned();
264        println!(
265            "TIMING n={n} window={window} {}={:.3}ms last={:?}",
266            agg_name(a),
267            d.as_secs_f64() * 1e3,
268            last
269        );
270    }
271}