Skip to main content

cjc_runtime/
timeseries.rs

1//! Time series analysis functions.
2//!
3//! Provides autocorrelation (ACF), partial autocorrelation (PACF),
4//! exponential smoothing (EWMA, EMA), seasonal decomposition (additive
5//! and multiplicative), differencing, and AR model fitting/forecasting.
6//!
7//! # Determinism
8//!
9//! All floating-point reductions use [`BinnedAccumulatorF64`] for
10//! order-insensitive, bit-identical results across runs and platforms.
11//! The Yule-Walker AR fit delegates to [`Tensor::solve`](crate::tensor::Tensor::solve)
12//! which uses deterministic LU decomposition.
13
14use crate::accumulator::BinnedAccumulatorF64;
15
16// ---------------------------------------------------------------------------
17// Helper: deterministic sum of a slice
18// ---------------------------------------------------------------------------
19
20/// Compute the sum of `data` using [`BinnedAccumulatorF64`] for determinism.
21fn binned_sum(data: &[f64]) -> f64 {
22    let mut acc = BinnedAccumulatorF64::new();
23    acc.add_slice(data);
24    acc.finalize()
25}
26
27// ---------------------------------------------------------------------------
28// Helper: deterministic mean
29// ---------------------------------------------------------------------------
30
31/// Compute the arithmetic mean of `data` using [`BinnedAccumulatorF64`].
32///
33/// Returns `0.0` for empty slices.
34fn binned_mean(data: &[f64]) -> f64 {
35    if data.is_empty() {
36        return 0.0;
37    }
38    binned_sum(data) / data.len() as f64
39}
40
41// ---------------------------------------------------------------------------
42// ACF — Autocorrelation function
43// ---------------------------------------------------------------------------
44
45/// Compute the autocorrelation function for lags 0..=max_lag.
46///
47/// Returns `Vec<f64>` of length `max_lag + 1` where `result[0] = 1.0`.
48/// Uses the standard formula: ACF(k) = gamma(k) / gamma(0), where gamma(k)
49/// is the autocovariance at lag k computed from the demeaned series.
50pub fn acf(data: &[f64], max_lag: usize) -> Vec<f64> {
51    let n = data.len();
52    if n == 0 {
53        return vec![f64::NAN; max_lag + 1];
54    }
55
56    let mean = binned_mean(data);
57
58    // Demeaned series
59    let centered: Vec<f64> = data.iter().map(|&x| x - mean).collect();
60
61    // Lag-0 autocovariance (variance)
62    let sq: Vec<f64> = centered.iter().map(|&x| x * x).collect();
63    let gamma0 = binned_sum(&sq);
64
65    if gamma0 == 0.0 {
66        // Constant series: lag-0 = 1.0, all others = 0.0
67        let mut result = vec![0.0; max_lag + 1];
68        result[0] = 1.0;
69        return result;
70    }
71
72    let mut result = Vec::with_capacity(max_lag + 1);
73    for k in 0..=max_lag {
74        if k >= n {
75            result.push(f64::NAN);
76            continue;
77        }
78        let prods: Vec<f64> = (0..n - k).map(|t| centered[t] * centered[t + k]).collect();
79        let gamma_k = binned_sum(&prods);
80        result.push(gamma_k / gamma0);
81    }
82
83    result
84}
85
86// ---------------------------------------------------------------------------
87// PACF — Partial autocorrelation function (Durbin-Levinson)
88// ---------------------------------------------------------------------------
89
90/// Compute the partial autocorrelation function via the Durbin-Levinson algorithm.
91///
92/// Returns `Vec<f64>` of length `max_lag + 1` where `result[0] = 1.0`.
93pub fn pacf(data: &[f64], max_lag: usize) -> Vec<f64> {
94    let n = data.len();
95    if n == 0 || max_lag == 0 {
96        return vec![1.0];
97    }
98
99    // First compute the ACF values we need
100    let r = acf(data, max_lag);
101
102    let mut result = vec![0.0; max_lag + 1];
103    result[0] = 1.0;
104
105    if max_lag >= n {
106        // Can't compute beyond data length
107        for i in n..=max_lag {
108            result[i] = f64::NAN;
109        }
110    }
111
112    // Durbin-Levinson recursion
113    // phi[m][j] = AR coefficient at order m, index j
114    // We only need two rows: current and previous.
115    let effective_max = max_lag.min(n - 1);
116
117    let mut phi_prev = vec![0.0; effective_max + 1];
118    // Order 1
119    phi_prev[1] = r[1];
120    result[1] = r[1];
121
122    for m in 2..=effective_max {
123        // Compute phi[m][m] using Durbin-Levinson:
124        // phi[m][m] = (r[m] - sum_{j=1}^{m-1} phi[m-1][j] * r[m-j]) / (1 - sum_{j=1}^{m-1} phi[m-1][j] * r[j])
125        let num_terms: Vec<f64> = (1..m).map(|j| phi_prev[j] * r[m - j]).collect();
126        let den_terms: Vec<f64> = (1..m).map(|j| phi_prev[j] * r[j]).collect();
127
128        let num = r[m] - binned_sum(&num_terms);
129        let den = 1.0 - binned_sum(&den_terms);
130
131        if den.abs() < 1e-15 {
132            result[m] = f64::NAN;
133            break;
134        }
135
136        let phi_mm = num / den;
137        result[m] = phi_mm;
138
139        // Update phi coefficients for next iteration
140        let mut phi_new = vec![0.0; effective_max + 1];
141        for j in 1..m {
142            phi_new[j] = phi_prev[j] - phi_mm * phi_prev[m - j];
143        }
144        phi_new[m] = phi_mm;
145        phi_prev = phi_new;
146    }
147
148    result
149}
150
151// ---------------------------------------------------------------------------
152// EWMA — Exponential weighted moving average
153// ---------------------------------------------------------------------------
154
155/// Compute the exponential weighted moving average.
156///
157/// `alpha` is the smoothing factor (0 < alpha <= 1).
158/// Returns `Vec<f64>` of the same length as `data`.
159/// The first value is `data[0]`; subsequent values are `alpha * data[i] + (1 - alpha) * ewma[i-1]`.
160pub fn ewma(data: &[f64], alpha: f64) -> Vec<f64> {
161    if data.is_empty() {
162        return Vec::new();
163    }
164
165    let mut result = Vec::with_capacity(data.len());
166    result.push(data[0]);
167
168    for i in 1..data.len() {
169        let prev = result[i - 1];
170        result.push(alpha * data[i] + (1.0 - alpha) * prev);
171    }
172
173    result
174}
175
176// ---------------------------------------------------------------------------
177// EMA — Exponential moving average (span-based)
178// ---------------------------------------------------------------------------
179
180/// Compute the exponential moving average with span-based smoothing.
181///
182/// `alpha = 2 / (span + 1)`.
183/// Returns `Vec<f64>` of the same length as `data`.
184pub fn ema(data: &[f64], span: usize) -> Vec<f64> {
185    let alpha = 2.0 / (span as f64 + 1.0);
186    ewma(data, alpha)
187}
188
189// ---------------------------------------------------------------------------
190// Seasonal decomposition
191// ---------------------------------------------------------------------------
192
193/// Decompose a time series into trend, seasonal, and residual components.
194///
195/// `period`: the seasonal period (e.g., 12 for monthly data with yearly seasonality).
196/// `model`: `"additive"` or `"multiplicative"`.
197///
198/// Returns `(trend, seasonal, residual)` each as `Vec<f64>` of the same length as `data`.
199/// Boundary values where the centered moving average cannot be computed are set to `f64::NAN`.
200pub fn seasonal_decompose(
201    data: &[f64],
202    period: usize,
203    model: &str,
204) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>), String> {
205    let n = data.len();
206
207    if period < 2 {
208        return Err("seasonal_decompose: period must be >= 2".into());
209    }
210    if n < 2 * period {
211        return Err(format!(
212            "seasonal_decompose: need at least {} observations for period {}, got {}",
213            2 * period,
214            period,
215            n
216        ));
217    }
218    if model != "additive" && model != "multiplicative" {
219        return Err(format!(
220            "seasonal_decompose: model must be \"additive\" or \"multiplicative\", got \"{}\"",
221            model
222        ));
223    }
224
225    let is_mult = model == "multiplicative";
226
227    // Check for zeros/negatives in multiplicative mode
228    if is_mult {
229        for &v in data {
230            if v <= 0.0 {
231                return Err(
232                    "seasonal_decompose: multiplicative model requires all positive data".into(),
233                );
234            }
235        }
236    }
237
238    // Step 1: Centered moving average for trend extraction
239    let trend = centered_moving_average(data, period);
240
241    // Step 2: Detrend
242    let mut detrended = vec![f64::NAN; n];
243    for i in 0..n {
244        if trend[i].is_nan() {
245            continue;
246        }
247        if is_mult {
248            if trend[i] != 0.0 {
249                detrended[i] = data[i] / trend[i];
250            }
251        } else {
252            detrended[i] = data[i] - trend[i];
253        }
254    }
255
256    // Step 3: Average detrended values for each period position
257    let mut seasonal = vec![0.0; n];
258    let mut period_avgs = vec![0.0; period];
259
260    for p in 0..period {
261        let mut vals = Vec::new();
262        let mut idx = p;
263        while idx < n {
264            if !detrended[idx].is_nan() {
265                vals.push(detrended[idx]);
266            }
267            idx += period;
268        }
269        if !vals.is_empty() {
270            period_avgs[p] = binned_mean(&vals);
271        }
272    }
273
274    // Normalize seasonal component so it sums to 0 (additive) or averages to 1 (multiplicative)
275    if is_mult {
276        let avg = binned_mean(&period_avgs);
277        if avg != 0.0 {
278            for v in &mut period_avgs {
279                *v /= avg;
280            }
281        }
282    } else {
283        let avg = binned_mean(&period_avgs);
284        for v in &mut period_avgs {
285            *v -= avg;
286        }
287    }
288
289    // Tile the seasonal pattern
290    for i in 0..n {
291        seasonal[i] = period_avgs[i % period];
292    }
293
294    // Step 4: Residual
295    let mut residual = vec![f64::NAN; n];
296    for i in 0..n {
297        if trend[i].is_nan() {
298            continue;
299        }
300        if is_mult {
301            if seasonal[i] != 0.0 {
302                residual[i] = data[i] / (trend[i] * seasonal[i]);
303            }
304        } else {
305            residual[i] = data[i] - trend[i] - seasonal[i];
306        }
307    }
308
309    Ok((trend, seasonal, residual))
310}
311
312/// Compute the centered moving average of length `period`.
313///
314/// For odd periods, uses a simple symmetric window of `period` elements.
315/// For even periods, applies a two-pass convolution: first a trailing
316/// `period`-length MA, then averages adjacent pairs to center the result.
317/// Boundary positions where the full window cannot be placed are set to
318/// [`f64::NAN`].
319fn centered_moving_average(data: &[f64], period: usize) -> Vec<f64> {
320    let n = data.len();
321    let mut result = vec![f64::NAN; n];
322
323    if period % 2 == 1 {
324        // Odd period: simple centered MA
325        let half = period / 2;
326        for i in half..n.saturating_sub(half) {
327            let window = &data[i - half..=i + half];
328            result[i] = binned_mean(window);
329        }
330    } else {
331        // Even period: first compute period-length MA, then average adjacent pairs
332        let mut ma = vec![f64::NAN; n];
333        let half = period / 2;
334        // First pass: period-length trailing MA starting at index period-1
335        for i in (period - 1)..n {
336            let window = &data[i + 1 - period..=i];
337            ma[i] = binned_mean(window);
338        }
339        // Second pass: center by averaging adjacent MA values
340        for i in half..n.saturating_sub(half) {
341            let left_idx = i + half - 1; // index in ma for trailing MA ending at i+half-1
342            let right_idx = left_idx + 1;
343            if left_idx < n && right_idx < n && !ma[left_idx].is_nan() && !ma[right_idx].is_nan()
344            {
345                result[i] = (ma[left_idx] + ma[right_idx]) / 2.0;
346            }
347        }
348    }
349
350    result
351}
352
353// ---------------------------------------------------------------------------
354// Differencing
355// ---------------------------------------------------------------------------
356
357/// Difference a series: `y[i] = data[i + periods] - data[i]`.
358///
359/// Returns `Vec<f64>` of length `data.len() - periods`.
360pub fn diff(data: &[f64], periods: usize) -> Vec<f64> {
361    if periods >= data.len() {
362        return Vec::new();
363    }
364    (periods..data.len())
365        .map(|i| data[i] - data[i - periods])
366        .collect()
367}
368
369// ---------------------------------------------------------------------------
370// ARIMA primitives
371// ---------------------------------------------------------------------------
372
373/// ARIMA(p,d,q) differencing step.
374///
375/// Applies first-order differencing `d` times to produce a stationary series.
376/// Returns the d-th order differenced series.
377/// After one round: result\[i\] = data\[i+1\] - data\[i\], length n-1.
378/// After d rounds: length n-d.
379pub fn arima_diff(data: &[f64], d: usize) -> Vec<f64> {
380    let mut current = data.to_vec();
381    for _ in 0..d {
382        if current.len() <= 1 {
383            return Vec::new();
384        }
385        current = diff(&current, 1);
386    }
387    current
388}
389
390/// Fit an AR(p) model using the Yule-Walker method.
391///
392/// 1. Compute autocorrelation r\[0..=p\] using `acf`.
393/// 2. Build the p x p Toeplitz matrix R where R\[i,j\] = r\[|i-j|\].
394/// 3. Solve R * phi = r\[1..=p\] using LU decomposition (via `Tensor::solve`).
395///
396/// Returns the AR coefficients phi\[1..p\] as a `Vec<f64>`.
397///
398/// **Determinism:** ACF uses `BinnedAccumulatorF64`; solve uses deterministic LU.
399pub fn ar_fit(data: &[f64], p: usize) -> Result<Vec<f64>, String> {
400    if p == 0 {
401        return Err("ar_fit: p must be > 0".into());
402    }
403    if data.len() <= p {
404        return Err(format!(
405            "ar_fit: need at least {} observations for AR({}), got {}",
406            p + 1,
407            p,
408            data.len()
409        ));
410    }
411
412    let r = acf(data, p);
413
414    // Build Toeplitz matrix R: R[i][j] = r[|i-j|]
415    let mut mat_data = vec![0.0f64; p * p];
416    for i in 0..p {
417        for j in 0..p {
418            let lag = if i >= j { i - j } else { j - i };
419            mat_data[i * p + j] = r[lag];
420        }
421    }
422
423    // RHS: r[1..=p]
424    let rhs: Vec<f64> = (1..=p).map(|k| r[k]).collect();
425
426    use crate::tensor::Tensor;
427    let r_matrix =
428        Tensor::from_vec(mat_data, &[p, p]).map_err(|e| format!("ar_fit: {e}"))?;
429    let r_vec = Tensor::from_vec(rhs, &[p]).map_err(|e| format!("ar_fit: {e}"))?;
430    let phi_tensor = r_matrix.solve(&r_vec).map_err(|e| format!("ar_fit: {e}"))?;
431
432    Ok(phi_tensor.to_vec())
433}
434
435/// AR forecast: given fitted AR coefficients and recent history, predict the
436/// next `steps` values.
437///
438/// `coeffs`: AR coefficients \[phi_1, phi_2, ..., phi_p\] (length p).
439/// `history`: recent observations, at least p values.
440/// `steps`: number of future values to predict.
441///
442/// Each prediction is: y_hat = sum(phi_i * y\[t-i\]) for i=1..p.
443/// Uses Kahan summation for determinism.
444pub fn ar_forecast(coeffs: &[f64], history: &[f64], steps: usize) -> Result<Vec<f64>, String> {
445    let p = coeffs.len();
446    if p == 0 {
447        return Err("ar_forecast: need at least one coefficient".into());
448    }
449    if history.len() < p {
450        return Err(format!(
451            "ar_forecast: need at least {} history values for AR({}), got {}",
452            p,
453            p,
454            history.len()
455        ));
456    }
457
458    // Work buffer: copy the tail of history + space for predictions
459    let mut buf: Vec<f64> = history.to_vec();
460    let mut predictions = Vec::with_capacity(steps);
461
462    for _ in 0..steps {
463        let n = buf.len();
464        let mut acc = BinnedAccumulatorF64::new();
465        for i in 0..p {
466            acc.add(coeffs[i] * buf[n - 1 - i]);
467        }
468        let val = acc.finalize();
469        predictions.push(val);
470        buf.push(val);
471    }
472
473    Ok(predictions)
474}
475
476// ═══════════════════════════════════════════════════════════════
477// Tests
478// ═══════════════════════════════════════════════════════════════
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    // -- ACF tests ----------------------------------------------------------
485
486    #[test]
487    fn test_acf_constant_series() {
488        let data = vec![5.0; 100];
489        let result = acf(&data, 5);
490        assert_eq!(result.len(), 6);
491        assert_eq!(result[0], 1.0);
492        for k in 1..=5 {
493            assert_eq!(result[k], 0.0, "ACF at lag {} should be 0 for constant series", k);
494        }
495    }
496
497    #[test]
498    fn test_acf_lag_zero_is_one() {
499        let data: Vec<f64> = (0..50).map(|i| (i as f64).sin()).collect();
500        let result = acf(&data, 10);
501        assert!((result[0] - 1.0).abs() < 1e-12);
502    }
503
504    #[test]
505    fn test_acf_sinusoidal_periodicity() {
506        // Sine wave with period 20: ACF should show periodic pattern
507        let data: Vec<f64> = (0..200)
508            .map(|i| (2.0 * std::f64::consts::PI * i as f64 / 20.0).sin())
509            .collect();
510        let result = acf(&data, 25);
511        // ACF at lag 20 should be close to 1.0 (same phase)
512        assert!(
513            result[20] > 0.9,
514            "ACF at lag=period should be high, got {}",
515            result[20]
516        );
517        // ACF at lag 10 should be close to -1.0 (half period)
518        assert!(
519            result[10] < -0.9,
520            "ACF at lag=period/2 should be negative, got {}",
521            result[10]
522        );
523    }
524
525    #[test]
526    fn test_acf_empty() {
527        let result = acf(&[], 3);
528        assert_eq!(result.len(), 4);
529        assert!(result[0].is_nan());
530    }
531
532    #[test]
533    fn test_acf_determinism() {
534        let data: Vec<f64> = (0..100).map(|i| (i as f64) * 0.1 + (i as f64).sin()).collect();
535        let r1 = acf(&data, 10);
536        let r2 = acf(&data, 10);
537        for (a, b) in r1.iter().zip(r2.iter()) {
538            assert_eq!(a.to_bits(), b.to_bits());
539        }
540    }
541
542    // -- PACF tests ---------------------------------------------------------
543
544    #[test]
545    fn test_pacf_lag_zero_is_one() {
546        let data: Vec<f64> = (0..100).map(|i| (i as f64) * 0.3).collect();
547        let result = pacf(&data, 5);
548        assert!((result[0] - 1.0).abs() < 1e-12);
549    }
550
551    #[test]
552    fn test_pacf_ar1_process() {
553        // AR(1) process: x[t] = 0.8 * x[t-1] + noise
554        // PACF should have significant value at lag 1, near zero after
555        let mut data = vec![0.0; 500];
556        let mut rng = cjc_repro::Rng::seeded(42);
557        for t in 1..500 {
558            data[t] = 0.8 * data[t - 1] + (rng.next_f64() - 0.5) * 0.1;
559        }
560        let result = pacf(&data, 5);
561        assert!(
562            result[1].abs() > 0.5,
563            "PACF at lag 1 should be significant for AR(1), got {}",
564            result[1]
565        );
566        // Lags 2+ should be much smaller
567        for k in 2..=5 {
568            assert!(
569                result[k].abs() < 0.3,
570                "PACF at lag {} should be small for AR(1), got {}",
571                k,
572                result[k]
573            );
574        }
575    }
576
577    #[test]
578    fn test_pacf_determinism() {
579        let data: Vec<f64> = (0..100).map(|i| (i as f64).cos()).collect();
580        let r1 = pacf(&data, 5);
581        let r2 = pacf(&data, 5);
582        for (a, b) in r1.iter().zip(r2.iter()) {
583            assert_eq!(a.to_bits(), b.to_bits());
584        }
585    }
586
587    // -- EWMA tests ---------------------------------------------------------
588
589    #[test]
590    fn test_ewma_alpha_one_returns_original() {
591        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
592        let result = ewma(&data, 1.0);
593        assert_eq!(result, data);
594    }
595
596    #[test]
597    fn test_ewma_alpha_zero_returns_first() {
598        let data = vec![10.0, 20.0, 30.0, 40.0];
599        let result = ewma(&data, 0.0);
600        // alpha=0 means ewma[i] = ewma[i-1] for all i, so all values = data[0]
601        for &v in &result {
602            assert_eq!(v, 10.0);
603        }
604    }
605
606    #[test]
607    fn test_ewma_length() {
608        let data = vec![1.0, 2.0, 3.0];
609        let result = ewma(&data, 0.5);
610        assert_eq!(result.len(), data.len());
611    }
612
613    #[test]
614    fn test_ewma_empty() {
615        let result = ewma(&[], 0.5);
616        assert!(result.is_empty());
617    }
618
619    #[test]
620    fn test_ewma_smoothing() {
621        // With alpha=0.5: ewma[0]=1, ewma[1]=0.5*3+0.5*1=2, ewma[2]=0.5*5+0.5*2=3.5
622        let data = vec![1.0, 3.0, 5.0];
623        let result = ewma(&data, 0.5);
624        assert_eq!(result[0], 1.0);
625        assert!((result[1] - 2.0).abs() < 1e-12);
626        assert!((result[2] - 3.5).abs() < 1e-12);
627    }
628
629    // -- EMA tests ----------------------------------------------------------
630
631    #[test]
632    fn test_ema_span_relationship() {
633        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
634        let span = 3;
635        let alpha = 2.0 / (span as f64 + 1.0); // 0.5
636        let ema_result = ema(&data, span);
637        let ewma_result = ewma(&data, alpha);
638        for (a, b) in ema_result.iter().zip(ewma_result.iter()) {
639            assert_eq!(a.to_bits(), b.to_bits());
640        }
641    }
642
643    // -- diff tests ---------------------------------------------------------
644
645    #[test]
646    fn test_diff_first_differences() {
647        let data = vec![1.0, 3.0, 6.0, 10.0, 15.0];
648        let result = diff(&data, 1);
649        assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0]);
650    }
651
652    #[test]
653    fn test_diff_periods_two() {
654        let data = vec![1.0, 2.0, 4.0, 7.0, 11.0];
655        let result = diff(&data, 2);
656        // result[0] = data[2] - data[0] = 3, result[1] = data[3] - data[1] = 5, result[2] = data[4] - data[2] = 7
657        assert_eq!(result, vec![3.0, 5.0, 7.0]);
658    }
659
660    #[test]
661    fn test_diff_empty_when_periods_too_large() {
662        let data = vec![1.0, 2.0];
663        let result = diff(&data, 5);
664        assert!(result.is_empty());
665    }
666
667    #[test]
668    fn test_diff_length() {
669        let data = vec![1.0; 10];
670        let result = diff(&data, 3);
671        assert_eq!(result.len(), 7);
672    }
673
674    // -- seasonal_decompose tests -------------------------------------------
675
676    #[test]
677    fn test_seasonal_decompose_additive_reconstruction() {
678        // Create data with known trend + seasonal + noise
679        let period = 4;
680        let n = 40;
681        let mut data = vec![0.0; n];
682        for i in 0..n {
683            let trend = 10.0 + 0.5 * i as f64;
684            let seasonal = [2.0, -1.0, 0.5, -1.5][i % period];
685            data[i] = trend + seasonal;
686        }
687
688        let (trend, seasonal, residual) =
689            seasonal_decompose(&data, period, "additive").unwrap();
690
691        // For non-NaN positions, trend + seasonal + residual ≈ original
692        for i in 0..n {
693            if trend[i].is_nan() || residual[i].is_nan() {
694                continue;
695            }
696            let reconstructed = trend[i] + seasonal[i] + residual[i];
697            assert!(
698                (reconstructed - data[i]).abs() < 1e-10,
699                "Reconstruction failed at i={}: {} vs {}",
700                i,
701                reconstructed,
702                data[i]
703            );
704        }
705    }
706
707    #[test]
708    fn test_seasonal_decompose_multiplicative_reconstruction() {
709        let period = 4;
710        let n = 40;
711        let mut data = vec![0.0; n];
712        for i in 0..n {
713            let trend = 100.0 + 2.0 * i as f64;
714            let seasonal = [1.1, 0.9, 1.05, 0.95][i % period];
715            data[i] = trend * seasonal;
716        }
717
718        let (trend, seasonal, residual) =
719            seasonal_decompose(&data, period, "multiplicative").unwrap();
720
721        for i in 0..n {
722            if trend[i].is_nan() || residual[i].is_nan() {
723                continue;
724            }
725            let reconstructed = trend[i] * seasonal[i] * residual[i];
726            assert!(
727                (reconstructed - data[i]).abs() < 1e-6,
728                "Multiplicative reconstruction failed at i={}: {} vs {}",
729                i,
730                reconstructed,
731                data[i]
732            );
733        }
734    }
735
736    #[test]
737    fn test_seasonal_decompose_invalid_period() {
738        let data = vec![1.0; 20];
739        assert!(seasonal_decompose(&data, 1, "additive").is_err());
740    }
741
742    #[test]
743    fn test_seasonal_decompose_too_short() {
744        let data = vec![1.0; 5];
745        assert!(seasonal_decompose(&data, 4, "additive").is_err());
746    }
747
748    #[test]
749    fn test_seasonal_decompose_invalid_model() {
750        let data = vec![1.0; 20];
751        assert!(seasonal_decompose(&data, 4, "invalid").is_err());
752    }
753
754    #[test]
755    fn test_seasonal_decompose_multiplicative_negative_data() {
756        let data = vec![1.0, -1.0, 2.0, 3.0, 1.0, -1.0, 2.0, 3.0];
757        assert!(seasonal_decompose(&data, 4, "multiplicative").is_err());
758    }
759
760    #[test]
761    fn test_seasonal_decompose_seasonal_component_sums_to_zero() {
762        let period = 4;
763        let n = 40;
764        let mut data = vec![0.0; n];
765        for i in 0..n {
766            data[i] = 10.0 + 0.5 * i as f64 + [2.0, -1.0, 0.5, -1.5][i % period];
767        }
768
769        let (_, seasonal, _) = seasonal_decompose(&data, period, "additive").unwrap();
770
771        // One full period of seasonal component should sum to ~0
772        let one_period: Vec<f64> = (0..period).map(|i| seasonal[i]).collect();
773        let period_sum = binned_sum(&one_period);
774        assert!(
775            period_sum.abs() < 1e-10,
776            "Seasonal component should sum to ~0 over one period, got {}",
777            period_sum
778        );
779    }
780
781    #[test]
782    fn test_seasonal_decompose_determinism() {
783        let period = 4;
784        let n = 40;
785        let data: Vec<f64> = (0..n).map(|i| 10.0 + (i as f64).sin()).collect();
786
787        let (t1, s1, r1) = seasonal_decompose(&data, period, "additive").unwrap();
788        let (t2, s2, r2) = seasonal_decompose(&data, period, "additive").unwrap();
789
790        for i in 0..n {
791            assert_eq!(t1[i].to_bits(), t2[i].to_bits());
792            assert_eq!(s1[i].to_bits(), s2[i].to_bits());
793            assert_eq!(r1[i].to_bits(), r2[i].to_bits());
794        }
795    }
796}