Skip to main content

fdars_core/
detrend.rs

1//! Detrending and decomposition functions for non-stationary functional data.
2//!
3//! This module provides methods for removing trends from functional data
4//! to enable more accurate seasonal analysis. It includes:
5//! - Linear detrending (least squares)
6//! - Polynomial detrending (QR decomposition)
7//! - Differencing (first and second order)
8//! - LOESS detrending (local polynomial regression)
9//! - Spline detrending (P-splines)
10//! - Automatic method selection via AIC
11
12use crate::iter_maybe_parallel;
13use crate::smoothing::local_polynomial;
14use nalgebra::{DMatrix, DVector, Dyn, SVD};
15#[cfg(feature = "parallel")]
16use rayon::iter::ParallelIterator;
17
18/// Result of detrending operation.
19#[derive(Debug, Clone)]
20pub struct TrendResult {
21    /// Estimated trend values (n x m column-major)
22    pub trend: Vec<f64>,
23    /// Detrended data (n x m column-major)
24    pub detrended: Vec<f64>,
25    /// Method used for detrending
26    pub method: String,
27    /// Polynomial coefficients (for polynomial methods, per sample)
28    /// For n samples with polynomial degree d: coefficients[i * (d+1) + k] is coefficient k for sample i
29    pub coefficients: Option<Vec<f64>>,
30    /// Residual sum of squares for each sample
31    pub rss: Vec<f64>,
32    /// Number of parameters (for AIC calculation)
33    pub n_params: usize,
34}
35
36/// Result of seasonal decomposition.
37#[derive(Debug, Clone)]
38pub struct DecomposeResult {
39    /// Trend component (n x m column-major)
40    pub trend: Vec<f64>,
41    /// Seasonal component (n x m column-major)
42    pub seasonal: Vec<f64>,
43    /// Remainder/residual component (n x m column-major)
44    pub remainder: Vec<f64>,
45    /// Period used for decomposition
46    pub period: f64,
47    /// Decomposition method ("additive" or "multiplicative")
48    pub method: String,
49}
50
51/// Remove linear trend from functional data using least squares.
52///
53/// # Arguments
54/// * `data` - Column-major matrix (n x m): n samples, m evaluation points
55/// * `n` - Number of samples
56/// * `m` - Number of evaluation points
57/// * `argvals` - Time/argument values of length m
58///
59/// # Returns
60/// TrendResult with trend, detrended data, and coefficients (intercept, slope)
61pub fn detrend_linear(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
62    if n == 0 || m < 2 || data.len() != n * m || argvals.len() != m {
63        return TrendResult {
64            trend: vec![0.0; n * m],
65            detrended: data.to_vec(),
66            method: "linear".to_string(),
67            coefficients: None,
68            rss: vec![0.0; n],
69            n_params: 2,
70        };
71    }
72
73    // Precompute t statistics
74    let mean_t: f64 = argvals.iter().sum::<f64>() / m as f64;
75    let ss_t: f64 = argvals.iter().map(|&t| (t - mean_t).powi(2)).sum();
76
77    // Process each sample in parallel
78    let results: Vec<(Vec<f64>, Vec<f64>, f64, f64, f64)> = iter_maybe_parallel!(0..n)
79        .map(|i| {
80            // Extract curve
81            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
82            let mean_y: f64 = curve.iter().sum::<f64>() / m as f64;
83
84            // Compute slope: sum((t - mean_t) * (y - mean_y)) / sum((t - mean_t)^2)
85            let mut sp = 0.0;
86            for j in 0..m {
87                sp += (argvals[j] - mean_t) * (curve[j] - mean_y);
88            }
89            let slope = if ss_t.abs() > 1e-15 { sp / ss_t } else { 0.0 };
90            let intercept = mean_y - slope * mean_t;
91
92            // Compute trend and detrended
93            let mut trend = vec![0.0; m];
94            let mut detrended = vec![0.0; m];
95            let mut rss = 0.0;
96            for j in 0..m {
97                trend[j] = intercept + slope * argvals[j];
98                detrended[j] = curve[j] - trend[j];
99                rss += detrended[j].powi(2);
100            }
101
102            (trend, detrended, intercept, slope, rss)
103        })
104        .collect();
105
106    // Reassemble into column-major format
107    let mut trend = vec![0.0; n * m];
108    let mut detrended = vec![0.0; n * m];
109    let mut coefficients = vec![0.0; n * 2];
110    let mut rss = vec![0.0; n];
111
112    for (i, (t, d, intercept, slope, r)) in results.into_iter().enumerate() {
113        for j in 0..m {
114            trend[i + j * n] = t[j];
115            detrended[i + j * n] = d[j];
116        }
117        coefficients[i * 2] = intercept;
118        coefficients[i * 2 + 1] = slope;
119        rss[i] = r;
120    }
121
122    TrendResult {
123        trend,
124        detrended,
125        method: "linear".to_string(),
126        coefficients: Some(coefficients),
127        rss,
128        n_params: 2,
129    }
130}
131
132/// Build a Vandermonde matrix (m x n_coef) from normalized argument values.
133fn build_vandermonde_matrix(t_norm: &[f64], m: usize, n_coef: usize) -> DMatrix<f64> {
134    let mut design = DMatrix::zeros(m, n_coef);
135    for j in 0..m {
136        let t = t_norm[j];
137        let mut power = 1.0;
138        for k in 0..n_coef {
139            design[(j, k)] = power;
140            power *= t;
141        }
142    }
143    design
144}
145
146/// Fit a polynomial to a single curve using a pre-computed SVD, returning (trend, detrended, coefs, rss).
147fn fit_polynomial_single_curve(
148    curve: &[f64],
149    svd: &SVD<f64, Dyn, Dyn>,
150    design: &DMatrix<f64>,
151    n_coef: usize,
152    m: usize,
153) -> (Vec<f64>, Vec<f64>, Vec<f64>, f64) {
154    let y = DVector::from_row_slice(curve);
155
156    let beta = svd
157        .solve(&y, 1e-10)
158        .unwrap_or_else(|_| DVector::zeros(n_coef));
159
160    let fitted = design * &beta;
161    let mut trend = vec![0.0; m];
162    let mut detrended = vec![0.0; m];
163    let mut rss = 0.0;
164    for j in 0..m {
165        trend[j] = fitted[j];
166        detrended[j] = curve[j] - fitted[j];
167        rss += detrended[j].powi(2);
168    }
169
170    let coefs: Vec<f64> = beta.iter().cloned().collect();
171    (trend, detrended, coefs, rss)
172}
173
174/// Compute differences and reconstruct trend for a single curve.
175fn diff_single_curve(curve: &[f64], m: usize, order: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, f64) {
176    // First difference
177    let diff1: Vec<f64> = (0..m - 1).map(|j| curve[j + 1] - curve[j]).collect();
178
179    // Second difference if order == 2
180    let detrended = if order == 2 {
181        (0..diff1.len() - 1)
182            .map(|j| diff1[j + 1] - diff1[j])
183            .collect()
184    } else {
185        diff1.clone()
186    };
187
188    let initial_values = if order == 2 {
189        vec![curve[0], curve[1]]
190    } else {
191        vec![curve[0]]
192    };
193
194    let rss: f64 = detrended.iter().map(|&x| x.powi(2)).sum();
195
196    let new_m = m - order;
197    let mut trend = vec![0.0; m];
198    trend[0] = curve[0];
199    if order == 1 {
200        for j in 1..m {
201            trend[j] = curve[j] - if j <= new_m { detrended[j - 1] } else { 0.0 };
202        }
203    } else {
204        trend = curve.to_vec();
205    }
206
207    let mut det_full = vec![0.0; m];
208    det_full[..new_m].copy_from_slice(&detrended[..new_m]);
209
210    (trend, det_full, initial_values, rss)
211}
212
213/// Reassemble per-curve polynomial results into column-major output arrays.
214fn reassemble_polynomial_results(
215    results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)>,
216    n: usize,
217    m: usize,
218    n_coef: usize,
219) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
220    let mut trend = vec![0.0; n * m];
221    let mut detrended = vec![0.0; n * m];
222    let mut coefficients = vec![0.0; n * n_coef];
223    let mut rss = vec![0.0; n];
224    for (i, (t, d, coefs, r)) in results.into_iter().enumerate() {
225        for j in 0..m {
226            trend[i + j * n] = t[j];
227            detrended[i + j * n] = d[j];
228        }
229        for k in 0..n_coef {
230            coefficients[i * n_coef + k] = coefs[k];
231        }
232        rss[i] = r;
233    }
234    (trend, detrended, coefficients, rss)
235}
236
237/// Remove polynomial trend from functional data using QR decomposition.
238///
239/// # Arguments
240/// * `data` - Column-major matrix (n x m)
241/// * `n` - Number of samples
242/// * `m` - Number of evaluation points
243/// * `argvals` - Time/argument values of length m
244/// * `degree` - Polynomial degree (1 = linear, 2 = quadratic, etc.)
245///
246/// # Returns
247/// TrendResult with trend, detrended data, and polynomial coefficients
248pub fn detrend_polynomial(
249    data: &[f64],
250    n: usize,
251    m: usize,
252    argvals: &[f64],
253    degree: usize,
254) -> TrendResult {
255    if n == 0 || m < degree + 1 || data.len() != n * m || argvals.len() != m || degree == 0 {
256        // For degree 0 or invalid input, return original data
257        return TrendResult {
258            trend: vec![0.0; n * m],
259            detrended: data.to_vec(),
260            method: format!("polynomial({})", degree),
261            coefficients: None,
262            rss: vec![0.0; n],
263            n_params: degree + 1,
264        };
265    }
266
267    // Special case: degree 1 is linear
268    if degree == 1 {
269        let mut result = detrend_linear(data, n, m, argvals);
270        result.method = "polynomial(1)".to_string();
271        return result;
272    }
273
274    let n_coef = degree + 1;
275
276    // Normalize argvals to avoid numerical issues with high-degree polynomials
277    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
278    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
279    let t_range = if (t_max - t_min).abs() > 1e-15 {
280        t_max - t_min
281    } else {
282        1.0
283    };
284    let t_norm: Vec<f64> = argvals.iter().map(|&t| (t - t_min) / t_range).collect();
285
286    // Build Vandermonde matrix (m x n_coef)
287    let design = build_vandermonde_matrix(&t_norm, m, n_coef);
288
289    // SVD for stable least squares
290    let svd = design.clone().svd(true, true);
291
292    // Process each sample in parallel
293    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
294        .map(|i| {
295            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
296            fit_polynomial_single_curve(&curve, &svd, &design, n_coef, m)
297        })
298        .collect();
299
300    let (trend, detrended, coefficients, rss) =
301        reassemble_polynomial_results(results, n, m, n_coef);
302
303    TrendResult {
304        trend,
305        detrended,
306        method: format!("polynomial({})", degree),
307        coefficients: Some(coefficients),
308        rss,
309        n_params: n_coef,
310    }
311}
312
313/// Remove trend by differencing.
314///
315/// # Arguments
316/// * `data` - Column-major matrix (n x m)
317/// * `n` - Number of samples
318/// * `m` - Number of evaluation points
319/// * `order` - Differencing order (1 or 2)
320///
321/// # Returns
322/// TrendResult with trend (cumulative sum to reverse), detrended (differences),
323/// and original first values as "coefficients"
324///
325/// Note: Differencing reduces the series length by `order` points.
326/// The returned detrended data has m - order points padded with zeros at the end.
327pub fn detrend_diff(data: &[f64], n: usize, m: usize, order: usize) -> TrendResult {
328    if n == 0 || m <= order || data.len() != n * m || order == 0 || order > 2 {
329        return TrendResult {
330            trend: vec![0.0; n * m],
331            detrended: data.to_vec(),
332            method: format!("diff{}", order),
333            coefficients: None,
334            rss: vec![0.0; n],
335            n_params: order,
336        };
337    }
338
339    // Process each sample in parallel
340    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
341        .map(|i| {
342            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
343            diff_single_curve(&curve, m, order)
344        })
345        .collect();
346
347    // Reassemble
348    let mut trend = vec![0.0; n * m];
349    let mut detrended = vec![0.0; n * m];
350    let mut coefficients = vec![0.0; n * order];
351    let mut rss = vec![0.0; n];
352
353    for (i, (t, d, init, r)) in results.into_iter().enumerate() {
354        for j in 0..m {
355            trend[i + j * n] = t[j];
356            detrended[i + j * n] = d[j];
357        }
358        for k in 0..order {
359            coefficients[i * order + k] = init[k];
360        }
361        rss[i] = r;
362    }
363
364    TrendResult {
365        trend,
366        detrended,
367        method: format!("diff{}", order),
368        coefficients: Some(coefficients),
369        rss,
370        n_params: order,
371    }
372}
373
374/// Remove trend using LOESS (local polynomial regression).
375///
376/// # Arguments
377/// * `data` - Column-major matrix (n x m)
378/// * `n` - Number of samples
379/// * `m` - Number of evaluation points
380/// * `argvals` - Time/argument values
381/// * `bandwidth` - Bandwidth as fraction of data range (0.1 to 0.5 typical)
382/// * `degree` - Local polynomial degree (1 or 2)
383///
384/// # Returns
385/// TrendResult with LOESS-smoothed trend
386pub fn detrend_loess(
387    data: &[f64],
388    n: usize,
389    m: usize,
390    argvals: &[f64],
391    bandwidth: f64,
392    degree: usize,
393) -> TrendResult {
394    if n == 0 || m < 3 || data.len() != n * m || argvals.len() != m || bandwidth <= 0.0 {
395        return TrendResult {
396            trend: vec![0.0; n * m],
397            detrended: data.to_vec(),
398            method: "loess".to_string(),
399            coefficients: None,
400            rss: vec![0.0; n],
401            n_params: (m as f64 * bandwidth).ceil() as usize,
402        };
403    }
404
405    // Convert bandwidth from fraction to absolute units
406    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
407    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
408    let abs_bandwidth = (t_max - t_min) * bandwidth;
409
410    // Process each sample in parallel
411    let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
412        .map(|i| {
413            // Extract curve
414            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
415
416            // Apply local polynomial regression
417            let trend =
418                local_polynomial(argvals, &curve, argvals, abs_bandwidth, degree, "gaussian");
419
420            // Compute detrended and RSS
421            let mut detrended = vec![0.0; m];
422            let mut rss = 0.0;
423            for j in 0..m {
424                detrended[j] = curve[j] - trend[j];
425                rss += detrended[j].powi(2);
426            }
427
428            (trend, detrended, rss)
429        })
430        .collect();
431
432    // Reassemble
433    let mut trend = vec![0.0; n * m];
434    let mut detrended = vec![0.0; n * m];
435    let mut rss = vec![0.0; n];
436
437    for (i, (t, d, r)) in results.into_iter().enumerate() {
438        for j in 0..m {
439            trend[i + j * n] = t[j];
440            detrended[i + j * n] = d[j];
441        }
442        rss[i] = r;
443    }
444
445    // Effective number of parameters for LOESS is approximately n * bandwidth
446    let n_params = (m as f64 * bandwidth).ceil() as usize;
447
448    TrendResult {
449        trend,
450        detrended,
451        method: "loess".to_string(),
452        coefficients: None,
453        rss,
454        n_params,
455    }
456}
457
458/// Automatically select the best detrending method using AIC.
459///
460/// Compares linear, polynomial (degree 2 and 3), and LOESS,
461/// selecting the method with lowest AIC.
462///
463/// # Arguments
464/// * `data` - Column-major matrix (n x m)
465/// * `n` - Number of samples
466/// * `m` - Number of evaluation points
467/// * `argvals` - Time/argument values
468///
469/// # Returns
470/// TrendResult from the best method
471pub fn auto_detrend(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
472    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m {
473        return TrendResult {
474            trend: vec![0.0; n * m],
475            detrended: data.to_vec(),
476            method: "auto(none)".to_string(),
477            coefficients: None,
478            rss: vec![0.0; n],
479            n_params: 0,
480        };
481    }
482
483    // Compute AIC for a result: AIC = n * log(RSS/n) + 2*k
484    // We use mean AIC across all samples
485    let compute_aic = |result: &TrendResult| -> f64 {
486        let mut total_aic = 0.0;
487        for i in 0..n {
488            let rss = result.rss[i];
489            let k = result.n_params as f64;
490            let aic = if rss > 1e-15 {
491                m as f64 * (rss / m as f64).ln() + 2.0 * k
492            } else {
493                f64::NEG_INFINITY // Perfect fit (unlikely)
494            };
495            total_aic += aic;
496        }
497        total_aic / n as f64
498    };
499
500    // Try different methods
501    let linear = detrend_linear(data, n, m, argvals);
502    let poly2 = detrend_polynomial(data, n, m, argvals, 2);
503    let poly3 = detrend_polynomial(data, n, m, argvals, 3);
504    let loess = detrend_loess(data, n, m, argvals, 0.3, 2);
505
506    let aic_linear = compute_aic(&linear);
507    let aic_poly2 = compute_aic(&poly2);
508    let aic_poly3 = compute_aic(&poly3);
509    let aic_loess = compute_aic(&loess);
510
511    // Find minimum AIC
512    let methods = [
513        (aic_linear, "linear", linear),
514        (aic_poly2, "polynomial(2)", poly2),
515        (aic_poly3, "polynomial(3)", poly3),
516        (aic_loess, "loess", loess),
517    ];
518
519    let (_, best_name, mut best_result) = methods
520        .into_iter()
521        .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
522        .unwrap();
523
524    best_result.method = format!("auto({})", best_name);
525    best_result
526}
527
528/// Fit Fourier harmonics to detrended data, returning (seasonal, remainder).
529fn fit_fourier_seasonal(
530    detrended_i: &[f64],
531    argvals: &[f64],
532    omega: f64,
533    n_harm: usize,
534    m: usize,
535) -> (Vec<f64>, Vec<f64>) {
536    let n_coef = 2 * n_harm;
537    let mut design = DMatrix::zeros(m, n_coef);
538    for j in 0..m {
539        let t = argvals[j];
540        for k in 0..n_harm {
541            let freq = (k + 1) as f64 * omega;
542            design[(j, 2 * k)] = (freq * t).cos();
543            design[(j, 2 * k + 1)] = (freq * t).sin();
544        }
545    }
546    let y = DVector::from_row_slice(detrended_i);
547    let svd = design.clone().svd(true, true);
548    let coef = svd
549        .solve(&y, 1e-10)
550        .unwrap_or_else(|_| DVector::zeros(n_coef));
551    let fitted = &design * &coef;
552    let seasonal: Vec<f64> = fitted.iter().cloned().collect();
553    let remainder: Vec<f64> = (0..m).map(|j| detrended_i[j] - seasonal[j]).collect();
554    (seasonal, remainder)
555}
556
557/// Additive seasonal decomposition: data = trend + seasonal + remainder
558///
559/// Uses LOESS or spline for trend extraction, then averages within-period
560/// residuals to estimate the seasonal component.
561///
562/// # Arguments
563/// * `data` - Column-major matrix (n x m)
564/// * `n` - Number of samples
565/// * `m` - Number of evaluation points
566/// * `argvals` - Time/argument values
567/// * `period` - Seasonal period in same units as argvals
568/// * `trend_method` - "loess" or "spline"
569/// * `bandwidth` - Bandwidth for LOESS (fraction, e.g., 0.3)
570/// * `n_harmonics` - Number of Fourier harmonics for seasonal component
571///
572/// # Returns
573/// DecomposeResult with trend, seasonal, and remainder components
574pub fn decompose_additive(
575    data: &[f64],
576    n: usize,
577    m: usize,
578    argvals: &[f64],
579    period: f64,
580    trend_method: &str,
581    bandwidth: f64,
582    n_harmonics: usize,
583) -> DecomposeResult {
584    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
585        return DecomposeResult {
586            trend: vec![0.0; n * m],
587            seasonal: vec![0.0; n * m],
588            remainder: data.to_vec(),
589            period,
590            method: "additive".to_string(),
591        };
592    }
593
594    // Step 1: Extract trend using LOESS (trend_method parameter preserved for API compatibility)
595    let _ = trend_method;
596    let trend_result = detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2);
597
598    // Step 2: Extract seasonal component using Fourier basis on detrended data
599    let n_harm = n_harmonics.max(1).min(m / 4);
600    let omega = 2.0 * std::f64::consts::PI / period;
601
602    // Process each sample
603    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
604        .map(|i| {
605            let trend_i: Vec<f64> = (0..m).map(|j| trend_result.trend[i + j * n]).collect();
606            let detrended_i: Vec<f64> = (0..m).map(|j| trend_result.detrended[i + j * n]).collect();
607            let (seasonal, remainder) =
608                fit_fourier_seasonal(&detrended_i, argvals, omega, n_harm, m);
609            (trend_i, seasonal, remainder)
610        })
611        .collect();
612
613    // Reassemble into column-major format
614    let mut trend = vec![0.0; n * m];
615    let mut seasonal = vec![0.0; n * m];
616    let mut remainder = vec![0.0; n * m];
617
618    for (i, (t, s, r)) in results.into_iter().enumerate() {
619        for j in 0..m {
620            trend[i + j * n] = t[j];
621            seasonal[i + j * n] = s[j];
622            remainder[i + j * n] = r[j];
623        }
624    }
625
626    DecomposeResult {
627        trend,
628        seasonal,
629        remainder,
630        period,
631        method: "additive".to_string(),
632    }
633}
634
635/// Multiplicative seasonal decomposition: data = trend * seasonal * remainder
636///
637/// Applies log transformation, then additive decomposition, then back-transforms.
638/// Handles non-positive values by adding a shift.
639///
640/// # Arguments
641/// * `data` - Column-major matrix (n x m)
642/// * `n` - Number of samples
643/// * `m` - Number of evaluation points
644/// * `argvals` - Time/argument values
645/// * `period` - Seasonal period
646/// * `trend_method` - "loess" or "spline"
647/// * `bandwidth` - Bandwidth for LOESS
648/// * `n_harmonics` - Number of Fourier harmonics
649///
650/// # Returns
651/// DecomposeResult with multiplicative components
652pub fn decompose_multiplicative(
653    data: &[f64],
654    n: usize,
655    m: usize,
656    argvals: &[f64],
657    period: f64,
658    trend_method: &str,
659    bandwidth: f64,
660    n_harmonics: usize,
661) -> DecomposeResult {
662    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
663        return DecomposeResult {
664            trend: vec![0.0; n * m],
665            seasonal: vec![0.0; n * m],
666            remainder: data.to_vec(),
667            period,
668            method: "multiplicative".to_string(),
669        };
670    }
671
672    // Find minimum value and add shift if needed to make all values positive
673    let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
674    let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
675
676    // Log transform
677    let log_data: Vec<f64> = data.iter().map(|&x| (x + shift).ln()).collect();
678
679    // Apply additive decomposition to log data
680    let additive_result = decompose_additive(
681        &log_data,
682        n,
683        m,
684        argvals,
685        period,
686        trend_method,
687        bandwidth,
688        n_harmonics,
689    );
690
691    // Back transform: exp of each component
692    // For multiplicative: data = trend * seasonal * remainder
693    // In log space: log(data) = log(trend) + log(seasonal) + log(remainder)
694    // So: trend_mult = exp(trend_add), seasonal_mult = exp(seasonal_add), etc.
695
696    let mut trend = vec![0.0; n * m];
697    let mut seasonal = vec![0.0; n * m];
698    let mut remainder = vec![0.0; n * m];
699
700    for idx in 0..n * m {
701        // Back-transform trend (subtract shift)
702        trend[idx] = additive_result.trend[idx].exp() - shift;
703
704        // Seasonal is a multiplicative factor (centered around 1)
705        // We interpret the additive seasonal component as log(seasonal factor)
706        seasonal[idx] = additive_result.seasonal[idx].exp();
707
708        // Remainder is also multiplicative
709        remainder[idx] = additive_result.remainder[idx].exp();
710    }
711
712    DecomposeResult {
713        trend,
714        seasonal,
715        remainder,
716        period,
717        method: "multiplicative".to_string(),
718    }
719}
720
721// ============================================================================
722// STL Decomposition (Cleveland et al., 1990)
723// ============================================================================
724
725/// Result of STL decomposition including robustness weights.
726#[derive(Debug, Clone)]
727pub struct StlResult {
728    /// Trend component (n x m column-major)
729    pub trend: Vec<f64>,
730    /// Seasonal component (n x m column-major)
731    pub seasonal: Vec<f64>,
732    /// Remainder/residual component (n x m column-major)
733    pub remainder: Vec<f64>,
734    /// Robustness weights per point (n x m column-major)
735    pub weights: Vec<f64>,
736    /// Period used for decomposition
737    pub period: usize,
738    /// Seasonal smoothing window
739    pub s_window: usize,
740    /// Trend smoothing window
741    pub t_window: usize,
742    /// Number of inner loop iterations performed
743    pub inner_iterations: usize,
744    /// Number of outer loop iterations performed
745    pub outer_iterations: usize,
746}
747
748/// STL Decomposition: Seasonal and Trend decomposition using LOESS
749///
750/// Implements the Cleveland et al. (1990) algorithm for robust iterative
751/// decomposition of time series into trend, seasonal, and remainder components.
752///
753/// # Algorithm Overview
754/// - **Inner Loop**: Extracts seasonal and trend components using LOESS smoothing
755/// - **Outer Loop**: Computes robustness weights to downweight outliers
756///
757/// # Arguments
758/// * `data` - Column-major matrix (n x m): n samples, m evaluation points
759/// * `n` - Number of samples
760/// * `m` - Number of evaluation points
761/// * `period` - Seasonal period (number of observations per cycle)
762/// * `s_window` - Seasonal smoothing window (must be odd, ≥7 recommended)
763/// * `t_window` - Trend smoothing window. If None, uses default formula
764/// * `l_window` - Low-pass filter window. If None, uses period
765/// * `robust` - Whether to perform robustness iterations
766/// * `inner_iterations` - Number of inner loop iterations. Default: 2
767/// * `outer_iterations` - Number of outer loop iterations. Default: 1 (or 15 if robust)
768///
769/// # Returns
770/// `StlResult` with trend, seasonal, remainder, and robustness weights
771///
772/// # References
773/// Cleveland, R. B., Cleveland, W. S., McRae, J. E., & Terpenning, I. (1990).
774/// STL: A Seasonal-Trend Decomposition Procedure Based on Loess.
775/// Journal of Official Statistics, 6(1), 3-73.
776pub fn stl_decompose(
777    data: &[f64],
778    n: usize,
779    m: usize,
780    period: usize,
781    s_window: Option<usize>,
782    t_window: Option<usize>,
783    l_window: Option<usize>,
784    robust: bool,
785    inner_iterations: Option<usize>,
786    outer_iterations: Option<usize>,
787) -> StlResult {
788    // Validate inputs
789    if n == 0 || m < 2 * period || data.len() != n * m || period < 2 {
790        return StlResult {
791            trend: vec![0.0; n * m],
792            seasonal: vec![0.0; n * m],
793            remainder: data.to_vec(),
794            weights: vec![1.0; n * m],
795            period,
796            s_window: 0,
797            t_window: 0,
798            inner_iterations: 0,
799            outer_iterations: 0,
800        };
801    }
802
803    // Set default parameters following Cleveland et al. recommendations
804    let s_win = s_window.unwrap_or(7).max(3) | 1; // Ensure odd
805
806    // Default t_window: smallest odd integer >= (1.5 * period) / (1 - 1.5/s_window)
807    let t_win = t_window.unwrap_or_else(|| {
808        let ratio = 1.5 * period as f64 / (1.0 - 1.5 / s_win as f64);
809        let val = ratio.ceil() as usize;
810        val.max(3) | 1 // Ensure odd
811    });
812
813    // Low-pass filter window: smallest odd integer >= period
814    let l_win = l_window.unwrap_or(period) | 1;
815
816    let n_inner = inner_iterations.unwrap_or(2);
817    let n_outer = outer_iterations.unwrap_or(if robust { 15 } else { 1 });
818
819    // Process each sample in parallel
820    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
821        .map(|i| {
822            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
823            stl_single_series(
824                &curve, period, s_win, t_win, l_win, robust, n_inner, n_outer,
825            )
826        })
827        .collect();
828
829    // Reassemble into column-major format
830    let mut trend = vec![0.0; n * m];
831    let mut seasonal = vec![0.0; n * m];
832    let mut remainder = vec![0.0; n * m];
833    let mut weights = vec![1.0; n * m];
834
835    for (i, (t, s, r, w)) in results.into_iter().enumerate() {
836        for j in 0..m {
837            trend[i + j * n] = t[j];
838            seasonal[i + j * n] = s[j];
839            remainder[i + j * n] = r[j];
840            weights[i + j * n] = w[j];
841        }
842    }
843
844    StlResult {
845        trend,
846        seasonal,
847        remainder,
848        weights,
849        period,
850        s_window: s_win,
851        t_window: t_win,
852        inner_iterations: n_inner,
853        outer_iterations: n_outer,
854    }
855}
856
857/// STL decomposition for a single time series.
858fn stl_single_series(
859    data: &[f64],
860    period: usize,
861    s_window: usize,
862    t_window: usize,
863    l_window: usize,
864    robust: bool,
865    n_inner: usize,
866    n_outer: usize,
867) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
868    let m = data.len();
869
870    // Initialize components
871    let mut trend = vec![0.0; m];
872    let mut seasonal = vec![0.0; m];
873    let mut weights = vec![1.0; m];
874
875    // Outer loop for robustness
876    for _outer in 0..n_outer {
877        // Inner loop
878        for _inner in 0..n_inner {
879            // Step 1: Detrending
880            let detrended: Vec<f64> = data
881                .iter()
882                .zip(trend.iter())
883                .map(|(&y, &t)| y - t)
884                .collect();
885
886            // Step 2: Cycle-subseries smoothing
887            let cycle_smoothed = smooth_cycle_subseries(&detrended, period, s_window, &weights);
888
889            // Step 3: Low-pass filtering of smoothed cycle-subseries
890            let low_pass = stl_lowpass_filter(&cycle_smoothed, period, l_window);
891
892            // Step 4: Detrending the smoothed cycle-subseries
893            seasonal = cycle_smoothed
894                .iter()
895                .zip(low_pass.iter())
896                .map(|(&c, &l)| c - l)
897                .collect();
898
899            // Step 5: Deseasonalizing
900            let deseasonalized: Vec<f64> = data
901                .iter()
902                .zip(seasonal.iter())
903                .map(|(&y, &s)| y - s)
904                .collect();
905
906            // Step 6: Trend smoothing (weighted LOESS)
907            trend = weighted_loess(&deseasonalized, t_window, &weights);
908        }
909
910        // After inner loop: compute residuals and robustness weights
911        if robust && _outer < n_outer - 1 {
912            let remainder: Vec<f64> = data
913                .iter()
914                .zip(trend.iter())
915                .zip(seasonal.iter())
916                .map(|((&y, &t), &s)| y - t - s)
917                .collect();
918
919            weights = compute_robustness_weights(&remainder);
920        }
921    }
922
923    // Final remainder
924    let remainder: Vec<f64> = data
925        .iter()
926        .zip(trend.iter())
927        .zip(seasonal.iter())
928        .map(|((&y, &t), &s)| y - t - s)
929        .collect();
930
931    (trend, seasonal, remainder, weights)
932}
933
934/// Smooth cycle-subseries: for each seasonal position, smooth across cycles.
935fn smooth_cycle_subseries(
936    data: &[f64],
937    period: usize,
938    s_window: usize,
939    weights: &[f64],
940) -> Vec<f64> {
941    let m = data.len();
942    let n_cycles = m.div_ceil(period);
943    let mut result = vec![0.0; m];
944
945    // For each position in the cycle (0, 1, ..., period-1)
946    for pos in 0..period {
947        // Extract subseries at this position
948        let mut subseries_idx: Vec<usize> = Vec::new();
949        let mut subseries_vals: Vec<f64> = Vec::new();
950        let mut subseries_weights: Vec<f64> = Vec::new();
951
952        for cycle in 0..n_cycles {
953            let idx = cycle * period + pos;
954            if idx < m {
955                subseries_idx.push(idx);
956                subseries_vals.push(data[idx]);
957                subseries_weights.push(weights[idx]);
958            }
959        }
960
961        if subseries_vals.is_empty() {
962            continue;
963        }
964
965        // Smooth this subseries using weighted LOESS
966        let smoothed = weighted_loess(&subseries_vals, s_window, &subseries_weights);
967
968        // Put smoothed values back
969        for (i, &idx) in subseries_idx.iter().enumerate() {
970            result[idx] = smoothed[i];
971        }
972    }
973
974    result
975}
976
977/// Low-pass filter for STL (combination of moving averages).
978/// Applies: MA(period) -> MA(period) -> MA(3)
979fn stl_lowpass_filter(data: &[f64], period: usize, _l_window: usize) -> Vec<f64> {
980    // First MA with period
981    let ma1 = moving_average(data, period);
982    // Second MA with period
983    let ma2 = moving_average(&ma1, period);
984    // Third MA with 3
985    moving_average(&ma2, 3)
986}
987
988/// Simple moving average with window size.
989fn moving_average(data: &[f64], window: usize) -> Vec<f64> {
990    let m = data.len();
991    if m == 0 || window == 0 {
992        return data.to_vec();
993    }
994
995    let half = window / 2;
996    let mut result = vec![0.0; m];
997
998    for i in 0..m {
999        let start = i.saturating_sub(half);
1000        let end = (i + half + 1).min(m);
1001        let sum: f64 = data[start..end].iter().sum();
1002        let count = (end - start) as f64;
1003        result[i] = sum / count;
1004    }
1005
1006    result
1007}
1008
1009/// Weighted LOESS smoothing.
1010fn weighted_loess(data: &[f64], window: usize, weights: &[f64]) -> Vec<f64> {
1011    let m = data.len();
1012    if m == 0 {
1013        return vec![];
1014    }
1015
1016    let half = window / 2;
1017    let mut result = vec![0.0; m];
1018
1019    for i in 0..m {
1020        let start = i.saturating_sub(half);
1021        let end = (i + half + 1).min(m);
1022
1023        // Compute weighted local linear regression
1024        let mut sum_w = 0.0;
1025        let mut sum_wx = 0.0;
1026        let mut sum_wy = 0.0;
1027        let mut sum_wxx = 0.0;
1028        let mut sum_wxy = 0.0;
1029
1030        for j in start..end {
1031            // Tricube weight based on distance
1032            let dist = (j as f64 - i as f64).abs() / (half.max(1) as f64);
1033            let tricube = if dist < 1.0 {
1034                (1.0 - dist.powi(3)).powi(3)
1035            } else {
1036                0.0
1037            };
1038
1039            let w = tricube * weights[j];
1040            let x = j as f64;
1041            let y = data[j];
1042
1043            sum_w += w;
1044            sum_wx += w * x;
1045            sum_wy += w * y;
1046            sum_wxx += w * x * x;
1047            sum_wxy += w * x * y;
1048        }
1049
1050        // Solve weighted least squares
1051        if sum_w > 1e-10 {
1052            let denom = sum_w * sum_wxx - sum_wx * sum_wx;
1053            if denom.abs() > 1e-10 {
1054                let intercept = (sum_wxx * sum_wy - sum_wx * sum_wxy) / denom;
1055                let slope = (sum_w * sum_wxy - sum_wx * sum_wy) / denom;
1056                result[i] = intercept + slope * i as f64;
1057            } else {
1058                result[i] = sum_wy / sum_w;
1059            }
1060        } else {
1061            result[i] = data[i];
1062        }
1063    }
1064
1065    result
1066}
1067
1068/// Compute robustness weights using bisquare function.
1069fn compute_robustness_weights(residuals: &[f64]) -> Vec<f64> {
1070    let m = residuals.len();
1071    if m == 0 {
1072        return vec![];
1073    }
1074
1075    // Compute median absolute deviation (MAD)
1076    let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
1077    abs_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1078
1079    let median_idx = m / 2;
1080    let mad = if m % 2 == 0 {
1081        (abs_residuals[median_idx - 1] + abs_residuals[median_idx]) / 2.0
1082    } else {
1083        abs_residuals[median_idx]
1084    };
1085
1086    // Scale factor: 6 * MAD (Cleveland et al. use 6 MAD)
1087    let h = 6.0 * mad.max(1e-10);
1088
1089    // Bisquare weight function
1090    residuals
1091        .iter()
1092        .map(|&r| {
1093            let u = r.abs() / h;
1094            if u < 1.0 {
1095                (1.0 - u * u).powi(2)
1096            } else {
1097                0.0
1098            }
1099        })
1100        .collect()
1101}
1102
1103/// Wrapper function for functional data STL decomposition.
1104///
1105/// Computes STL decomposition for each curve in the functional data object
1106/// and returns aggregated results.
1107///
1108/// # Arguments
1109/// * `data` - Column-major matrix (n x m) of functional data
1110/// * `n` - Number of samples (rows)
1111/// * `m` - Number of evaluation points (columns)
1112/// * `argvals` - Time points of length m (used to infer period if needed)
1113/// * `period` - Seasonal period (in number of observations)
1114/// * `s_window` - Seasonal smoothing window
1115/// * `t_window` - Trend smoothing window (0 for auto)
1116/// * `robust` - Whether to use robustness iterations
1117///
1118/// # Returns
1119/// `StlResult` with decomposed components.
1120pub fn stl_fdata(
1121    data: &[f64],
1122    n: usize,
1123    m: usize,
1124    _argvals: &[f64],
1125    period: usize,
1126    s_window: Option<usize>,
1127    t_window: Option<usize>,
1128    robust: bool,
1129) -> StlResult {
1130    stl_decompose(
1131        data, n, m, period, s_window, t_window, None, robust, None, None,
1132    )
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137    use super::*;
1138    use std::f64::consts::PI;
1139
1140    #[test]
1141    fn test_detrend_linear_removes_linear_trend() {
1142        let m = 100;
1143        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1144
1145        // y = 2 + 0.5*t + sin(2*pi*t/2)
1146        let data: Vec<f64> = argvals
1147            .iter()
1148            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
1149            .collect();
1150
1151        let result = detrend_linear(&data, 1, m, &argvals);
1152
1153        // Detrended should be approximately sin wave
1154        let expected: Vec<f64> = argvals
1155            .iter()
1156            .map(|&t| (2.0 * PI * t / 2.0).sin())
1157            .collect();
1158
1159        let mut max_diff = 0.0f64;
1160        for j in 0..m {
1161            let diff = (result.detrended[j] - expected[j]).abs();
1162            max_diff = max_diff.max(diff);
1163        }
1164        assert!(max_diff < 0.2, "Max difference: {}", max_diff);
1165    }
1166
1167    #[test]
1168    fn test_detrend_polynomial_removes_quadratic_trend() {
1169        let m = 100;
1170        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1171
1172        // y = 1 + 0.5*t - 0.1*t^2 + sin(2*pi*t/2)
1173        let data: Vec<f64> = argvals
1174            .iter()
1175            .map(|&t| 1.0 + 0.5 * t - 0.1 * t * t + (2.0 * PI * t / 2.0).sin())
1176            .collect();
1177
1178        let result = detrend_polynomial(&data, 1, m, &argvals, 2);
1179
1180        // Detrended should be approximately sin wave
1181        let expected: Vec<f64> = argvals
1182            .iter()
1183            .map(|&t| (2.0 * PI * t / 2.0).sin())
1184            .collect();
1185
1186        // Compute correlation
1187        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
1188        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
1189        let mut num = 0.0;
1190        let mut den_det = 0.0;
1191        let mut den_exp = 0.0;
1192        for j in 0..m {
1193            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
1194            den_det += (result.detrended[j] - mean_det).powi(2);
1195            den_exp += (expected[j] - mean_exp).powi(2);
1196        }
1197        let corr = num / (den_det.sqrt() * den_exp.sqrt());
1198        assert!(corr > 0.95, "Correlation: {}", corr);
1199    }
1200
1201    #[test]
1202    fn test_detrend_diff1() {
1203        let m = 100;
1204        // Random walk: cumsum of random values
1205        let data: Vec<f64> = {
1206            let mut v = vec![0.0; m];
1207            v[0] = 1.0;
1208            for i in 1..m {
1209                v[i] = v[i - 1] + 0.1 * (i as f64).sin();
1210            }
1211            v
1212        };
1213
1214        let result = detrend_diff(&data, 1, m, 1);
1215
1216        // First difference should recover the increments
1217        for j in 0..m - 1 {
1218            let expected = data[j + 1] - data[j];
1219            assert!(
1220                (result.detrended[j] - expected).abs() < 1e-10,
1221                "Mismatch at {}: {} vs {}",
1222                j,
1223                result.detrended[j],
1224                expected
1225            );
1226        }
1227    }
1228
1229    #[test]
1230    fn test_auto_detrend_selects_linear_for_linear_data() {
1231        let m = 100;
1232        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1233
1234        // Pure linear trend with small noise
1235        let data: Vec<f64> = argvals.iter().map(|&t| 2.0 + 0.5 * t).collect();
1236
1237        let result = auto_detrend(&data, 1, m, &argvals);
1238
1239        // Should select linear (or poly 2/3 with linear being sufficient)
1240        assert!(
1241            result.method.contains("linear") || result.method.contains("polynomial"),
1242            "Method: {}",
1243            result.method
1244        );
1245    }
1246
1247    // ========================================================================
1248    // Tests for detrend_loess
1249    // ========================================================================
1250
1251    #[test]
1252    fn test_detrend_loess_removes_linear_trend() {
1253        let m = 100;
1254        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1255
1256        // y = 2 + 0.5*t + sin(2*pi*t/2)
1257        let data: Vec<f64> = argvals
1258            .iter()
1259            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
1260            .collect();
1261
1262        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
1263
1264        // Detrended should be approximately sin wave
1265        let expected: Vec<f64> = argvals
1266            .iter()
1267            .map(|&t| (2.0 * PI * t / 2.0).sin())
1268            .collect();
1269
1270        // Compute correlation (LOESS may smooth slightly)
1271        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
1272        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
1273        let mut num = 0.0;
1274        let mut den_det = 0.0;
1275        let mut den_exp = 0.0;
1276        for j in 0..m {
1277            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
1278            den_det += (result.detrended[j] - mean_det).powi(2);
1279            den_exp += (expected[j] - mean_exp).powi(2);
1280        }
1281        let corr = num / (den_det.sqrt() * den_exp.sqrt());
1282        assert!(corr > 0.9, "Correlation: {}", corr);
1283        assert_eq!(result.method, "loess");
1284    }
1285
1286    #[test]
1287    fn test_detrend_loess_removes_quadratic_trend() {
1288        let m = 100;
1289        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1290
1291        // y = 1 + 0.3*t - 0.05*t^2 + sin(2*pi*t/2)
1292        let data: Vec<f64> = argvals
1293            .iter()
1294            .map(|&t| 1.0 + 0.3 * t - 0.05 * t * t + (2.0 * PI * t / 2.0).sin())
1295            .collect();
1296
1297        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 2);
1298
1299        // Trend should follow the quadratic shape
1300        assert_eq!(result.trend.len(), m);
1301        assert_eq!(result.detrended.len(), m);
1302
1303        // Check that RSS is computed
1304        assert!(result.rss[0] > 0.0);
1305    }
1306
1307    #[test]
1308    fn test_detrend_loess_different_bandwidths() {
1309        let m = 100;
1310        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1311
1312        // Noisy sine wave
1313        let data: Vec<f64> = argvals
1314            .iter()
1315            .enumerate()
1316            .map(|(i, &t)| (2.0 * PI * t / 2.0).sin() + 0.1 * ((i * 17) % 100) as f64 / 100.0)
1317            .collect();
1318
1319        // Small bandwidth = more local = rougher trend
1320        let result_small = detrend_loess(&data, 1, m, &argvals, 0.1, 1);
1321        // Large bandwidth = smoother trend
1322        let result_large = detrend_loess(&data, 1, m, &argvals, 0.5, 1);
1323
1324        // Both should produce valid results
1325        assert_eq!(result_small.trend.len(), m);
1326        assert_eq!(result_large.trend.len(), m);
1327
1328        // Large bandwidth should have more parameters
1329        assert!(result_large.n_params >= result_small.n_params);
1330    }
1331
1332    #[test]
1333    fn test_detrend_loess_short_series() {
1334        let m = 10;
1335        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1336        let data: Vec<f64> = argvals.iter().map(|&t| t * 2.0).collect();
1337
1338        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
1339
1340        // Should still work on short series
1341        assert_eq!(result.trend.len(), m);
1342        assert_eq!(result.detrended.len(), m);
1343    }
1344
1345    // ========================================================================
1346    // Tests for decompose_additive
1347    // ========================================================================
1348
1349    #[test]
1350    fn test_decompose_additive_separates_components() {
1351        let m = 200;
1352        let period = 2.0;
1353        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1354
1355        // data = trend + seasonal: y = 2 + 0.5*t + sin(2*pi*t/2)
1356        let data: Vec<f64> = argvals
1357            .iter()
1358            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / period).sin())
1359            .collect();
1360
1361        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1362
1363        assert_eq!(result.trend.len(), m);
1364        assert_eq!(result.seasonal.len(), m);
1365        assert_eq!(result.remainder.len(), m);
1366        assert_eq!(result.method, "additive");
1367        assert_eq!(result.period, period);
1368
1369        // Check that components approximately sum to original
1370        for j in 0..m {
1371            let reconstructed = result.trend[j] + result.seasonal[j] + result.remainder[j];
1372            assert!(
1373                (reconstructed - data[j]).abs() < 0.5,
1374                "Reconstruction error at {}: {} vs {}",
1375                j,
1376                reconstructed,
1377                data[j]
1378            );
1379        }
1380    }
1381
1382    #[test]
1383    fn test_decompose_additive_different_harmonics() {
1384        let m = 200;
1385        let period = 2.0;
1386        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1387
1388        // Simple seasonal pattern
1389        let data: Vec<f64> = argvals
1390            .iter()
1391            .map(|&t| 1.0 + (2.0 * PI * t / period).sin())
1392            .collect();
1393
1394        // 1 harmonic
1395        let result1 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 1);
1396        // 5 harmonics
1397        let result5 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 5);
1398
1399        // Both should produce valid results
1400        assert_eq!(result1.seasonal.len(), m);
1401        assert_eq!(result5.seasonal.len(), m);
1402    }
1403
1404    #[test]
1405    fn test_decompose_additive_residual_properties() {
1406        let m = 200;
1407        let period = 2.0;
1408        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1409
1410        // Data with trend and seasonal
1411        let data: Vec<f64> = argvals
1412            .iter()
1413            .map(|&t| 2.0 + 0.3 * t + (2.0 * PI * t / period).sin())
1414            .collect();
1415
1416        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1417
1418        // Remainder should have mean close to zero
1419        let mean_rem: f64 = result.remainder.iter().sum::<f64>() / m as f64;
1420        assert!(mean_rem.abs() < 0.5, "Remainder mean: {}", mean_rem);
1421
1422        // Remainder variance should be smaller than original variance
1423        let var_data: f64 = data
1424            .iter()
1425            .map(|&x| (x - data.iter().sum::<f64>() / m as f64).powi(2))
1426            .sum::<f64>()
1427            / m as f64;
1428        let var_rem: f64 = result
1429            .remainder
1430            .iter()
1431            .map(|&x| (x - mean_rem).powi(2))
1432            .sum::<f64>()
1433            / m as f64;
1434        assert!(
1435            var_rem < var_data,
1436            "Remainder variance {} should be < data variance {}",
1437            var_rem,
1438            var_data
1439        );
1440    }
1441
1442    #[test]
1443    fn test_decompose_additive_multi_sample() {
1444        let n = 3;
1445        let m = 100;
1446        let period = 2.0;
1447        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1448
1449        // Create 3 samples with different amplitudes
1450        let mut data = vec![0.0; n * m];
1451        for i in 0..n {
1452            let amp = (i + 1) as f64;
1453            for j in 0..m {
1454                data[i + j * n] =
1455                    1.0 + 0.1 * argvals[j] + amp * (2.0 * PI * argvals[j] / period).sin();
1456            }
1457        }
1458
1459        let result = decompose_additive(&data, n, m, &argvals, period, "loess", 0.3, 2);
1460
1461        assert_eq!(result.trend.len(), n * m);
1462        assert_eq!(result.seasonal.len(), n * m);
1463        assert_eq!(result.remainder.len(), n * m);
1464    }
1465
1466    // ========================================================================
1467    // Tests for decompose_multiplicative
1468    // ========================================================================
1469
1470    #[test]
1471    fn test_decompose_multiplicative_basic() {
1472        let m = 200;
1473        let period = 2.0;
1474        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1475
1476        // Multiplicative: data = trend * seasonal
1477        // trend = 2 + 0.1*t, seasonal = 1 + 0.3*sin(...)
1478        let data: Vec<f64> = argvals
1479            .iter()
1480            .map(|&t| (2.0 + 0.1 * t) * (1.0 + 0.3 * (2.0 * PI * t / period).sin()))
1481            .collect();
1482
1483        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1484
1485        assert_eq!(result.trend.len(), m);
1486        assert_eq!(result.seasonal.len(), m);
1487        assert_eq!(result.remainder.len(), m);
1488        assert_eq!(result.method, "multiplicative");
1489
1490        // Seasonal factors should be centered around 1
1491        let mean_seasonal: f64 = result.seasonal.iter().sum::<f64>() / m as f64;
1492        assert!(
1493            (mean_seasonal - 1.0).abs() < 0.5,
1494            "Mean seasonal factor: {}",
1495            mean_seasonal
1496        );
1497    }
1498
1499    #[test]
1500    fn test_decompose_multiplicative_non_positive_data() {
1501        let m = 100;
1502        let period = 2.0;
1503        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1504
1505        // Data with negative values
1506        let data: Vec<f64> = argvals
1507            .iter()
1508            .map(|&t| -1.0 + (2.0 * PI * t / period).sin())
1509            .collect();
1510
1511        // Should handle negative values by shifting
1512        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 2);
1513
1514        assert_eq!(result.trend.len(), m);
1515        assert_eq!(result.seasonal.len(), m);
1516        // All seasonal values should be positive (multiplicative factors)
1517        for &s in result.seasonal.iter() {
1518            assert!(s.is_finite(), "Seasonal should be finite");
1519        }
1520    }
1521
1522    #[test]
1523    fn test_decompose_multiplicative_vs_additive() {
1524        let m = 200;
1525        let period = 2.0;
1526        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1527
1528        // Simple positive data
1529        let data: Vec<f64> = argvals
1530            .iter()
1531            .map(|&t| 5.0 + (2.0 * PI * t / period).sin())
1532            .collect();
1533
1534        let add_result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1535        let mult_result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1536
1537        // Both should produce valid decompositions
1538        assert_eq!(add_result.seasonal.len(), m);
1539        assert_eq!(mult_result.seasonal.len(), m);
1540
1541        // Additive seasonal oscillates around 0
1542        let add_mean: f64 = add_result.seasonal.iter().sum::<f64>() / m as f64;
1543        // Multiplicative seasonal oscillates around 1
1544        let mult_mean: f64 = mult_result.seasonal.iter().sum::<f64>() / m as f64;
1545
1546        assert!(
1547            add_mean.abs() < mult_mean,
1548            "Additive mean {} vs mult mean {}",
1549            add_mean,
1550            mult_mean
1551        );
1552    }
1553
1554    #[test]
1555    fn test_decompose_multiplicative_edge_cases() {
1556        // Empty data
1557        let result = decompose_multiplicative(&[], 0, 0, &[], 2.0, "loess", 0.3, 2);
1558        assert_eq!(result.trend.len(), 0);
1559
1560        // Very short series
1561        let m = 5;
1562        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1563        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 2.0, 1.0];
1564        let result = decompose_multiplicative(&data, 1, m, &argvals, 2.0, "loess", 0.3, 1);
1565        // Should return original data as remainder for too-short series
1566        assert_eq!(result.remainder.len(), m);
1567    }
1568
1569    // ========================================================================
1570    // STL Decomposition Tests
1571    // ========================================================================
1572
1573    #[test]
1574    fn test_stl_decompose_basic() {
1575        let period = 12;
1576        let n_cycles = 10;
1577        let m = period * n_cycles;
1578        let data: Vec<f64> = (0..m)
1579            .map(|i| {
1580                let t = i as f64;
1581                // trend + seasonal + small noise
1582                0.01 * t + (2.0 * PI * t / period as f64).sin()
1583            })
1584            .collect();
1585
1586        let result = stl_decompose(&data, 1, m, period, None, None, None, false, None, None);
1587
1588        assert_eq!(result.trend.len(), m);
1589        assert_eq!(result.seasonal.len(), m);
1590        assert_eq!(result.remainder.len(), m);
1591        assert_eq!(result.period, period);
1592
1593        // trend + seasonal + remainder ≈ original
1594        for j in 0..m {
1595            let reconstructed = result.trend[j] + result.seasonal[j] + result.remainder[j];
1596            assert!(
1597                (reconstructed - data[j]).abs() < 1e-8,
1598                "Reconstruction error at {}: {} vs {}",
1599                j,
1600                reconstructed,
1601                data[j]
1602            );
1603        }
1604    }
1605
1606    #[test]
1607    fn test_stl_decompose_robust() {
1608        let period = 12;
1609        let n_cycles = 10;
1610        let m = period * n_cycles;
1611        let mut data: Vec<f64> = (0..m)
1612            .map(|i| {
1613                let t = i as f64;
1614                0.01 * t + (2.0 * PI * t / period as f64).sin()
1615            })
1616            .collect();
1617
1618        // Add outlier spikes
1619        data[30] += 10.0;
1620        data[60] += 10.0;
1621
1622        let result = stl_decompose(&data, 1, m, period, None, None, None, true, None, Some(5));
1623
1624        // Weights should be < 1.0 near outliers
1625        assert!(
1626            result.weights[30] < 1.0,
1627            "Weight at outlier should be < 1.0: {}",
1628            result.weights[30]
1629        );
1630        assert!(
1631            result.weights[60] < 1.0,
1632            "Weight at outlier should be < 1.0: {}",
1633            result.weights[60]
1634        );
1635
1636        // Non-outlier points should have higher weights
1637        let non_outlier_weight = result.weights[15];
1638        assert!(
1639            non_outlier_weight > result.weights[30],
1640            "Non-outlier weight {} should be > outlier weight {}",
1641            non_outlier_weight,
1642            result.weights[30]
1643        );
1644    }
1645
1646    #[test]
1647    fn test_stl_decompose_default_params() {
1648        let period = 10;
1649        let m = period * 8;
1650        let data: Vec<f64> = (0..m)
1651            .map(|i| (2.0 * PI * i as f64 / period as f64).sin())
1652            .collect();
1653
1654        // All None options
1655        let result = stl_decompose(&data, 1, m, period, None, None, None, false, None, None);
1656
1657        assert_eq!(result.trend.len(), m);
1658        assert_eq!(result.seasonal.len(), m);
1659        assert!(result.s_window >= 3);
1660        assert!(result.t_window >= 3);
1661        assert_eq!(result.inner_iterations, 2);
1662        assert_eq!(result.outer_iterations, 1);
1663    }
1664
1665    #[test]
1666    fn test_stl_decompose_invalid() {
1667        // period < 2
1668        let result = stl_decompose(&[1.0, 2.0], 1, 2, 1, None, None, None, false, None, None);
1669        assert_eq!(result.s_window, 0);
1670
1671        // m < 2*period
1672        let result = stl_decompose(
1673            &[1.0, 2.0, 3.0],
1674            1,
1675            3,
1676            5,
1677            None,
1678            None,
1679            None,
1680            false,
1681            None,
1682            None,
1683        );
1684        assert_eq!(result.s_window, 0);
1685
1686        // empty data
1687        let result = stl_decompose(&[], 0, 0, 10, None, None, None, false, None, None);
1688        assert_eq!(result.trend.len(), 0);
1689    }
1690
1691    #[test]
1692    fn test_stl_fdata() {
1693        let n = 3;
1694        let period = 10;
1695        let m = period * 5;
1696        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1697        let mut data = vec![0.0; n * m];
1698        for i in 0..n {
1699            let amp = (i + 1) as f64;
1700            for j in 0..m {
1701                data[i + j * n] = amp * (2.0 * PI * argvals[j] / period as f64).sin();
1702            }
1703        }
1704
1705        let result = stl_fdata(&data, n, m, &argvals, period, None, None, false);
1706
1707        assert_eq!(result.trend.len(), n * m);
1708        assert_eq!(result.seasonal.len(), n * m);
1709        assert_eq!(result.remainder.len(), n * m);
1710
1711        // Verify reconstruction for each sample
1712        for i in 0..n {
1713            for j in 0..m {
1714                let idx = i + j * n;
1715                let reconstructed =
1716                    result.trend[idx] + result.seasonal[idx] + result.remainder[idx];
1717                assert!(
1718                    (reconstructed - data[idx]).abs() < 1e-8,
1719                    "Reconstruction error for sample {} at {}: {} vs {}",
1720                    i,
1721                    j,
1722                    reconstructed,
1723                    data[idx]
1724                );
1725            }
1726        }
1727    }
1728
1729    #[test]
1730    fn test_stl_decompose_multi_sample() {
1731        let n = 5;
1732        let period = 10;
1733        let m = period * 6;
1734        let mut data = vec![0.0; n * m];
1735        for i in 0..n {
1736            let offset = i as f64 * 0.5;
1737            for j in 0..m {
1738                data[i + j * n] =
1739                    offset + 0.01 * j as f64 + (2.0 * PI * j as f64 / period as f64).sin();
1740            }
1741        }
1742
1743        let result = stl_decompose(&data, n, m, period, None, None, None, false, None, None);
1744
1745        assert_eq!(result.trend.len(), n * m);
1746        assert_eq!(result.seasonal.len(), n * m);
1747        assert_eq!(result.remainder.len(), n * m);
1748        assert_eq!(result.weights.len(), n * m);
1749    }
1750
1751    // ========================================================================
1752    // Additional edge case tests
1753    // ========================================================================
1754
1755    #[test]
1756    fn test_detrend_diff_order2() {
1757        let m = 50;
1758        // Quadratic data: y = t^2
1759        let data: Vec<f64> = (0..m).map(|i| (i as f64).powi(2)).collect();
1760
1761        let result = detrend_diff(&data, 1, m, 2);
1762
1763        // Second differences of t^2 should be constant = 2
1764        for j in 0..m - 2 {
1765            assert!(
1766                (result.detrended[j] - 2.0).abs() < 1e-10,
1767                "Second diff at {}: expected 2.0, got {}",
1768                j,
1769                result.detrended[j]
1770            );
1771        }
1772    }
1773
1774    #[test]
1775    fn test_detrend_polynomial_degree3() {
1776        let m = 100;
1777        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 5.0).collect();
1778
1779        // Cubic data: y = 1 + 2*t - 0.5*t^2 + 0.1*t^3
1780        let data: Vec<f64> = argvals
1781            .iter()
1782            .map(|&t| 1.0 + 2.0 * t - 0.5 * t * t + 0.1 * t * t * t)
1783            .collect();
1784
1785        let result = detrend_polynomial(&data, 1, m, &argvals, 3);
1786
1787        assert_eq!(result.method, "polynomial(3)");
1788        assert!(result.coefficients.is_some());
1789
1790        // Detrended should be close to zero since data is a pure cubic
1791        let max_detrend: f64 = result.detrended.iter().map(|x| x.abs()).fold(0.0, f64::max);
1792        assert!(
1793            max_detrend < 0.1,
1794            "Pure cubic should be nearly zero after degree-3 detrend: {}",
1795            max_detrend
1796        );
1797    }
1798
1799    #[test]
1800    fn test_detrend_loess_invalid() {
1801        // bandwidth <= 0
1802        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1803        let argvals = vec![0.0, 1.0, 2.0, 3.0, 4.0];
1804        let result = detrend_loess(&data, 1, 5, &argvals, -0.1, 1);
1805        assert_eq!(result.detrended, data);
1806
1807        // m = 2 (< 3)
1808        let result = detrend_loess(&[1.0, 2.0], 1, 2, &[0.0, 1.0], 0.3, 1);
1809        assert_eq!(result.detrended, vec![1.0, 2.0]);
1810    }
1811}