Skip to main content

fdars_core/
detrend.rs

1//! Detrending and decomposition functions for non-stationary functional data.
2//!
3//! This module provides methods for removing trends from functional data
4//! to enable more accurate seasonal analysis. It includes:
5//! - Linear detrending (least squares)
6//! - Polynomial detrending (QR decomposition)
7//! - Differencing (first and second order)
8//! - LOESS detrending (local polynomial regression)
9//! - Spline detrending (P-splines)
10//! - Automatic method selection via AIC
11
12use crate::iter_maybe_parallel;
13use crate::matrix::FdMatrix;
14use crate::smoothing::local_polynomial;
15use nalgebra::{DMatrix, DVector, Dyn, SVD};
16#[cfg(feature = "parallel")]
17use rayon::iter::ParallelIterator;
18
19/// Result of detrending operation.
20#[derive(Debug, Clone)]
21pub struct TrendResult {
22    /// Estimated trend values (n x m)
23    pub trend: FdMatrix,
24    /// Detrended data (n x m)
25    pub detrended: FdMatrix,
26    /// Method used for detrending
27    pub method: String,
28    /// Polynomial coefficients (for polynomial methods, per sample)
29    /// For n samples with polynomial degree d: n x (d+1)
30    pub coefficients: Option<FdMatrix>,
31    /// Residual sum of squares for each sample
32    pub rss: Vec<f64>,
33    /// Number of parameters (for AIC calculation)
34    pub n_params: usize,
35}
36
37/// Result of seasonal decomposition.
38#[derive(Debug, Clone)]
39pub struct DecomposeResult {
40    /// Trend component (n x m)
41    pub trend: FdMatrix,
42    /// Seasonal component (n x m)
43    pub seasonal: FdMatrix,
44    /// Remainder/residual component (n x m)
45    pub remainder: FdMatrix,
46    /// Period used for decomposition
47    pub period: f64,
48    /// Decomposition method ("additive" or "multiplicative")
49    pub method: String,
50}
51
52/// Remove linear trend from functional data using least squares.
53///
54/// # Arguments
55/// * `data` - Matrix (n x m): n samples, m evaluation points
56/// * `argvals` - Time/argument values of length m
57///
58/// # Returns
59/// TrendResult with trend, detrended data, and coefficients (intercept, slope)
60pub fn detrend_linear(data: &FdMatrix, argvals: &[f64]) -> TrendResult {
61    let (n, m) = data.shape();
62    if n == 0 || m < 2 || argvals.len() != m {
63        return TrendResult {
64            trend: FdMatrix::zeros(n, m),
65            detrended: FdMatrix::from_slice(data.as_slice(), n, m)
66                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
67            method: "linear".to_string(),
68            coefficients: None,
69            rss: vec![0.0; n],
70            n_params: 2,
71        };
72    }
73
74    let mean_t: f64 = argvals.iter().sum::<f64>() / m as f64;
75    let ss_t: f64 = argvals.iter().map(|&t| (t - mean_t).powi(2)).sum();
76
77    let results: Vec<(Vec<f64>, Vec<f64>, f64, f64, f64)> = iter_maybe_parallel!(0..n)
78        .map(|i| {
79            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
80            let mean_y: f64 = curve.iter().sum::<f64>() / m as f64;
81            let mut sp = 0.0;
82            for j in 0..m {
83                sp += (argvals[j] - mean_t) * (curve[j] - mean_y);
84            }
85            let slope = if ss_t.abs() > 1e-15 { sp / ss_t } else { 0.0 };
86            let intercept = mean_y - slope * mean_t;
87            let mut trend = vec![0.0; m];
88            let mut detrended = vec![0.0; m];
89            let mut rss = 0.0;
90            for j in 0..m {
91                trend[j] = intercept + slope * argvals[j];
92                detrended[j] = curve[j] - trend[j];
93                rss += detrended[j].powi(2);
94            }
95            (trend, detrended, intercept, slope, rss)
96        })
97        .collect();
98
99    let mut trend = FdMatrix::zeros(n, m);
100    let mut detrended = FdMatrix::zeros(n, m);
101    let mut coefficients = FdMatrix::zeros(n, 2);
102    let mut rss = vec![0.0; n];
103
104    for (i, (t, d, intercept, slope, r)) in results.into_iter().enumerate() {
105        for j in 0..m {
106            trend[(i, j)] = t[j];
107            detrended[(i, j)] = d[j];
108        }
109        coefficients[(i, 0)] = intercept;
110        coefficients[(i, 1)] = slope;
111        rss[i] = r;
112    }
113
114    TrendResult {
115        trend,
116        detrended,
117        method: "linear".to_string(),
118        coefficients: Some(coefficients),
119        rss,
120        n_params: 2,
121    }
122}
123
124fn build_vandermonde_matrix(t_norm: &[f64], m: usize, n_coef: usize) -> DMatrix<f64> {
125    let mut design = DMatrix::zeros(m, n_coef);
126    for j in 0..m {
127        let t = t_norm[j];
128        let mut power = 1.0;
129        for k in 0..n_coef {
130            design[(j, k)] = power;
131            power *= t;
132        }
133    }
134    design
135}
136
137fn fit_polynomial_single_curve(
138    curve: &[f64],
139    svd: &SVD<f64, Dyn, Dyn>,
140    design: &DMatrix<f64>,
141    n_coef: usize,
142    m: usize,
143) -> (Vec<f64>, Vec<f64>, Vec<f64>, f64) {
144    let y = DVector::from_row_slice(curve);
145    let beta = svd
146        .solve(&y, 1e-10)
147        .unwrap_or_else(|_| DVector::zeros(n_coef));
148    let fitted = design * &beta;
149    let mut trend = vec![0.0; m];
150    let mut detrended = vec![0.0; m];
151    let mut rss = 0.0;
152    for j in 0..m {
153        trend[j] = fitted[j];
154        detrended[j] = curve[j] - fitted[j];
155        rss += detrended[j].powi(2);
156    }
157    let coefs: Vec<f64> = beta.iter().cloned().collect();
158    (trend, detrended, coefs, rss)
159}
160
161fn diff_single_curve(curve: &[f64], m: usize, order: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, f64) {
162    let diff1: Vec<f64> = (0..m - 1).map(|j| curve[j + 1] - curve[j]).collect();
163    let detrended = if order == 2 {
164        (0..diff1.len() - 1)
165            .map(|j| diff1[j + 1] - diff1[j])
166            .collect()
167    } else {
168        diff1.clone()
169    };
170    let initial_values = if order == 2 {
171        vec![curve[0], curve[1]]
172    } else {
173        vec![curve[0]]
174    };
175    let rss: f64 = detrended.iter().map(|&x| x.powi(2)).sum();
176    let new_m = m - order;
177    let mut trend = vec![0.0; m];
178    trend[0] = curve[0];
179    if order == 1 {
180        for j in 1..m {
181            trend[j] = curve[j] - if j <= new_m { detrended[j - 1] } else { 0.0 };
182        }
183    } else {
184        trend = curve.to_vec();
185    }
186    let mut det_full = vec![0.0; m];
187    det_full[..new_m].copy_from_slice(&detrended[..new_m]);
188    (trend, det_full, initial_values, rss)
189}
190
191fn reassemble_polynomial_results(
192    results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)>,
193    n: usize,
194    m: usize,
195    n_coef: usize,
196) -> (FdMatrix, FdMatrix, FdMatrix, Vec<f64>) {
197    let mut trend = FdMatrix::zeros(n, m);
198    let mut detrended = FdMatrix::zeros(n, m);
199    let mut coefficients = FdMatrix::zeros(n, n_coef);
200    let mut rss = vec![0.0; n];
201    for (i, (t, d, coefs, r)) in results.into_iter().enumerate() {
202        for j in 0..m {
203            trend[(i, j)] = t[j];
204            detrended[(i, j)] = d[j];
205        }
206        for k in 0..n_coef {
207            coefficients[(i, k)] = coefs[k];
208        }
209        rss[i] = r;
210    }
211    (trend, detrended, coefficients, rss)
212}
213
214/// Remove polynomial trend from functional data using QR decomposition.
215pub fn detrend_polynomial(data: &FdMatrix, argvals: &[f64], degree: usize) -> TrendResult {
216    let (n, m) = data.shape();
217    if n == 0 || m < degree + 1 || argvals.len() != m || degree == 0 {
218        return TrendResult {
219            trend: FdMatrix::zeros(n, m),
220            detrended: FdMatrix::from_slice(data.as_slice(), n, m)
221                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
222            method: format!("polynomial({})", degree),
223            coefficients: None,
224            rss: vec![0.0; n],
225            n_params: degree + 1,
226        };
227    }
228    if degree == 1 {
229        let mut result = detrend_linear(data, argvals);
230        result.method = "polynomial(1)".to_string();
231        return result;
232    }
233    let n_coef = degree + 1;
234    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
235    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
236    let t_range = if (t_max - t_min).abs() > 1e-15 {
237        t_max - t_min
238    } else {
239        1.0
240    };
241    let t_norm: Vec<f64> = argvals.iter().map(|&t| (t - t_min) / t_range).collect();
242    let design = build_vandermonde_matrix(&t_norm, m, n_coef);
243    let svd = design.clone().svd(true, true);
244    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
245        .map(|i| {
246            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
247            fit_polynomial_single_curve(&curve, &svd, &design, n_coef, m)
248        })
249        .collect();
250    let (trend, detrended, coefficients, rss) =
251        reassemble_polynomial_results(results, n, m, n_coef);
252    TrendResult {
253        trend,
254        detrended,
255        method: format!("polynomial({})", degree),
256        coefficients: Some(coefficients),
257        rss,
258        n_params: n_coef,
259    }
260}
261
262/// Remove trend by differencing.
263pub fn detrend_diff(data: &FdMatrix, order: usize) -> TrendResult {
264    let (n, m) = data.shape();
265    if n == 0 || m <= order || order == 0 || order > 2 {
266        return TrendResult {
267            trend: FdMatrix::zeros(n, m),
268            detrended: FdMatrix::from_slice(data.as_slice(), n, m)
269                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
270            method: format!("diff{}", order),
271            coefficients: None,
272            rss: vec![0.0; n],
273            n_params: order,
274        };
275    }
276    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
277        .map(|i| {
278            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
279            diff_single_curve(&curve, m, order)
280        })
281        .collect();
282    let mut trend = FdMatrix::zeros(n, m);
283    let mut detrended = FdMatrix::zeros(n, m);
284    let mut coefficients = FdMatrix::zeros(n, order);
285    let mut rss = vec![0.0; n];
286    for (i, (t, d, init, r)) in results.into_iter().enumerate() {
287        for j in 0..m {
288            trend[(i, j)] = t[j];
289            detrended[(i, j)] = d[j];
290        }
291        for k in 0..order {
292            coefficients[(i, k)] = init[k];
293        }
294        rss[i] = r;
295    }
296    TrendResult {
297        trend,
298        detrended,
299        method: format!("diff{}", order),
300        coefficients: Some(coefficients),
301        rss,
302        n_params: order,
303    }
304}
305
306/// Remove trend using LOESS (local polynomial regression).
307pub fn detrend_loess(
308    data: &FdMatrix,
309    argvals: &[f64],
310    bandwidth: f64,
311    degree: usize,
312) -> TrendResult {
313    let (n, m) = data.shape();
314    if n == 0 || m < 3 || argvals.len() != m || bandwidth <= 0.0 {
315        return TrendResult {
316            trend: FdMatrix::zeros(n, m),
317            detrended: FdMatrix::from_slice(data.as_slice(), n, m)
318                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
319            method: "loess".to_string(),
320            coefficients: None,
321            rss: vec![0.0; n],
322            n_params: (m as f64 * bandwidth).ceil() as usize,
323        };
324    }
325    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
326    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
327    let abs_bandwidth = (t_max - t_min) * bandwidth;
328    let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
329        .map(|i| {
330            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
331            let trend =
332                local_polynomial(argvals, &curve, argvals, abs_bandwidth, degree, "gaussian");
333            let mut detrended = vec![0.0; m];
334            let mut rss = 0.0;
335            for j in 0..m {
336                detrended[j] = curve[j] - trend[j];
337                rss += detrended[j].powi(2);
338            }
339            (trend, detrended, rss)
340        })
341        .collect();
342    let mut trend = FdMatrix::zeros(n, m);
343    let mut detrended = FdMatrix::zeros(n, m);
344    let mut rss = vec![0.0; n];
345    for (i, (t, d, r)) in results.into_iter().enumerate() {
346        for j in 0..m {
347            trend[(i, j)] = t[j];
348            detrended[(i, j)] = d[j];
349        }
350        rss[i] = r;
351    }
352    let n_params = (m as f64 * bandwidth).ceil() as usize;
353    TrendResult {
354        trend,
355        detrended,
356        method: "loess".to_string(),
357        coefficients: None,
358        rss,
359        n_params,
360    }
361}
362
363/// Automatically select the best detrending method using AIC.
364pub fn auto_detrend(data: &FdMatrix, argvals: &[f64]) -> TrendResult {
365    let (n, m) = data.shape();
366    if n == 0 || m < 4 || argvals.len() != m {
367        return TrendResult {
368            trend: FdMatrix::zeros(n, m),
369            detrended: FdMatrix::from_slice(data.as_slice(), n, m)
370                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
371            method: "auto(none)".to_string(),
372            coefficients: None,
373            rss: vec![0.0; n],
374            n_params: 0,
375        };
376    }
377    let compute_aic = |result: &TrendResult| -> f64 {
378        let mut total_aic = 0.0;
379        for i in 0..n {
380            let rss = result.rss[i];
381            let k = result.n_params as f64;
382            let aic = if rss > 1e-15 {
383                m as f64 * (rss / m as f64).ln() + 2.0 * k
384            } else {
385                f64::NEG_INFINITY
386            };
387            total_aic += aic;
388        }
389        total_aic / n as f64
390    };
391    let linear = detrend_linear(data, argvals);
392    let poly2 = detrend_polynomial(data, argvals, 2);
393    let poly3 = detrend_polynomial(data, argvals, 3);
394    let loess = detrend_loess(data, argvals, 0.3, 2);
395    let aic_linear = compute_aic(&linear);
396    let aic_poly2 = compute_aic(&poly2);
397    let aic_poly3 = compute_aic(&poly3);
398    let aic_loess = compute_aic(&loess);
399    let methods = [
400        (aic_linear, "linear", linear),
401        (aic_poly2, "polynomial(2)", poly2),
402        (aic_poly3, "polynomial(3)", poly3),
403        (aic_loess, "loess", loess),
404    ];
405    let (_, best_name, mut best_result) = methods
406        .into_iter()
407        .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
408        .unwrap();
409    best_result.method = format!("auto({})", best_name);
410    best_result
411}
412
413fn fit_fourier_seasonal(
414    detrended_i: &[f64],
415    argvals: &[f64],
416    omega: f64,
417    n_harm: usize,
418    m: usize,
419) -> (Vec<f64>, Vec<f64>) {
420    let n_coef = 2 * n_harm;
421    let mut design = DMatrix::zeros(m, n_coef);
422    for j in 0..m {
423        let t = argvals[j];
424        for k in 0..n_harm {
425            let freq = (k + 1) as f64 * omega;
426            design[(j, 2 * k)] = (freq * t).cos();
427            design[(j, 2 * k + 1)] = (freq * t).sin();
428        }
429    }
430    let y = DVector::from_row_slice(detrended_i);
431    let svd = design.clone().svd(true, true);
432    let coef = svd
433        .solve(&y, 1e-10)
434        .unwrap_or_else(|_| DVector::zeros(n_coef));
435    let fitted = &design * &coef;
436    let seasonal: Vec<f64> = fitted.iter().cloned().collect();
437    let remainder: Vec<f64> = (0..m).map(|j| detrended_i[j] - seasonal[j]).collect();
438    (seasonal, remainder)
439}
440
441/// Additive seasonal decomposition: data = trend + seasonal + remainder
442pub fn decompose_additive(
443    data: &FdMatrix,
444    argvals: &[f64],
445    period: f64,
446    trend_method: &str,
447    bandwidth: f64,
448    n_harmonics: usize,
449) -> DecomposeResult {
450    let (n, m) = data.shape();
451    if n == 0 || m < 4 || argvals.len() != m || period <= 0.0 {
452        return DecomposeResult {
453            trend: FdMatrix::zeros(n, m),
454            seasonal: FdMatrix::zeros(n, m),
455            remainder: FdMatrix::from_slice(data.as_slice(), n, m)
456                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
457            period,
458            method: "additive".to_string(),
459        };
460    }
461    let _ = trend_method;
462    let trend_result = detrend_loess(data, argvals, bandwidth.max(0.3), 2);
463    let n_harm = n_harmonics.max(1).min(m / 4);
464    let omega = 2.0 * std::f64::consts::PI / period;
465    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
466        .map(|i| {
467            let trend_i: Vec<f64> = (0..m).map(|j| trend_result.trend[(i, j)]).collect();
468            let detrended_i: Vec<f64> = (0..m).map(|j| trend_result.detrended[(i, j)]).collect();
469            let (seasonal, remainder) =
470                fit_fourier_seasonal(&detrended_i, argvals, omega, n_harm, m);
471            (trend_i, seasonal, remainder)
472        })
473        .collect();
474    let mut trend = FdMatrix::zeros(n, m);
475    let mut seasonal = FdMatrix::zeros(n, m);
476    let mut remainder = FdMatrix::zeros(n, m);
477    for (i, (t, s, r)) in results.into_iter().enumerate() {
478        for j in 0..m {
479            trend[(i, j)] = t[j];
480            seasonal[(i, j)] = s[j];
481            remainder[(i, j)] = r[j];
482        }
483    }
484    DecomposeResult {
485        trend,
486        seasonal,
487        remainder,
488        period,
489        method: "additive".to_string(),
490    }
491}
492
493/// Multiplicative seasonal decomposition: data = trend * seasonal * remainder
494pub fn decompose_multiplicative(
495    data: &FdMatrix,
496    argvals: &[f64],
497    period: f64,
498    trend_method: &str,
499    bandwidth: f64,
500    n_harmonics: usize,
501) -> DecomposeResult {
502    let (n, m) = data.shape();
503    if n == 0 || m < 4 || argvals.len() != m || period <= 0.0 {
504        return DecomposeResult {
505            trend: FdMatrix::zeros(n, m),
506            seasonal: FdMatrix::zeros(n, m),
507            remainder: FdMatrix::from_slice(data.as_slice(), n, m)
508                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
509            period,
510            method: "multiplicative".to_string(),
511        };
512    }
513    let min_val = data
514        .as_slice()
515        .iter()
516        .cloned()
517        .fold(f64::INFINITY, f64::min);
518    let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
519    let log_data_vec: Vec<f64> = data.as_slice().iter().map(|&x| (x + shift).ln()).collect();
520    let log_data = FdMatrix::from_column_major(log_data_vec, n, m).unwrap();
521    let additive_result = decompose_additive(
522        &log_data,
523        argvals,
524        period,
525        trend_method,
526        bandwidth,
527        n_harmonics,
528    );
529    let mut trend = FdMatrix::zeros(n, m);
530    let mut seasonal = FdMatrix::zeros(n, m);
531    let mut remainder = FdMatrix::zeros(n, m);
532    for i in 0..n {
533        for j in 0..m {
534            trend[(i, j)] = additive_result.trend[(i, j)].exp() - shift;
535            seasonal[(i, j)] = additive_result.seasonal[(i, j)].exp();
536            remainder[(i, j)] = additive_result.remainder[(i, j)].exp();
537        }
538    }
539    DecomposeResult {
540        trend,
541        seasonal,
542        remainder,
543        period,
544        method: "multiplicative".to_string(),
545    }
546}
547
548// ============================================================================
549// STL Decomposition (Cleveland et al., 1990)
550// ============================================================================
551
552/// Result of STL decomposition including robustness weights.
553#[derive(Debug, Clone)]
554pub struct StlResult {
555    /// Trend component (n x m)
556    pub trend: FdMatrix,
557    /// Seasonal component (n x m)
558    pub seasonal: FdMatrix,
559    /// Remainder/residual component (n x m)
560    pub remainder: FdMatrix,
561    /// Robustness weights per point (n x m)
562    pub weights: FdMatrix,
563    /// Period used for decomposition
564    pub period: usize,
565    /// Seasonal smoothing window
566    pub s_window: usize,
567    /// Trend smoothing window
568    pub t_window: usize,
569    /// Number of inner loop iterations performed
570    pub inner_iterations: usize,
571    /// Number of outer loop iterations performed
572    pub outer_iterations: usize,
573}
574
575/// STL Decomposition: Seasonal and Trend decomposition using LOESS
576pub fn stl_decompose(
577    data: &FdMatrix,
578    period: usize,
579    s_window: Option<usize>,
580    t_window: Option<usize>,
581    l_window: Option<usize>,
582    robust: bool,
583    inner_iterations: Option<usize>,
584    outer_iterations: Option<usize>,
585) -> StlResult {
586    let (n, m) = data.shape();
587    if n == 0 || m < 2 * period || period < 2 {
588        return StlResult {
589            trend: FdMatrix::zeros(n, m),
590            seasonal: FdMatrix::zeros(n, m),
591            remainder: FdMatrix::from_slice(data.as_slice(), n, m)
592                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
593            weights: FdMatrix::from_column_major(vec![1.0; n * m], n, m)
594                .unwrap_or_else(|| FdMatrix::zeros(n, m)),
595            period,
596            s_window: 0,
597            t_window: 0,
598            inner_iterations: 0,
599            outer_iterations: 0,
600        };
601    }
602    let s_win = s_window.unwrap_or(7).max(3) | 1;
603    let t_win = t_window.unwrap_or_else(|| {
604        let ratio = 1.5 * period as f64 / (1.0 - 1.5 / s_win as f64);
605        let val = ratio.ceil() as usize;
606        val.max(3) | 1
607    });
608    let l_win = l_window.unwrap_or(period) | 1;
609    let n_inner = inner_iterations.unwrap_or(2);
610    let n_outer = outer_iterations.unwrap_or(if robust { 15 } else { 1 });
611    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
612        .map(|i| {
613            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
614            stl_single_series(
615                &curve, period, s_win, t_win, l_win, robust, n_inner, n_outer,
616            )
617        })
618        .collect();
619    let mut trend = FdMatrix::zeros(n, m);
620    let mut seasonal = FdMatrix::zeros(n, m);
621    let mut remainder = FdMatrix::zeros(n, m);
622    let mut weights = FdMatrix::from_column_major(vec![1.0; n * m], n, m).unwrap();
623    for (i, (t, s, r, w)) in results.into_iter().enumerate() {
624        for j in 0..m {
625            trend[(i, j)] = t[j];
626            seasonal[(i, j)] = s[j];
627            remainder[(i, j)] = r[j];
628            weights[(i, j)] = w[j];
629        }
630    }
631    StlResult {
632        trend,
633        seasonal,
634        remainder,
635        weights,
636        period,
637        s_window: s_win,
638        t_window: t_win,
639        inner_iterations: n_inner,
640        outer_iterations: n_outer,
641    }
642}
643
644fn stl_single_series(
645    data: &[f64],
646    period: usize,
647    s_window: usize,
648    t_window: usize,
649    l_window: usize,
650    robust: bool,
651    n_inner: usize,
652    n_outer: usize,
653) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
654    let m = data.len();
655    let mut trend = vec![0.0; m];
656    let mut seasonal = vec![0.0; m];
657    let mut weights = vec![1.0; m];
658    for _outer in 0..n_outer {
659        for _inner in 0..n_inner {
660            let detrended: Vec<f64> = data
661                .iter()
662                .zip(trend.iter())
663                .map(|(&y, &t)| y - t)
664                .collect();
665            let cycle_smoothed = smooth_cycle_subseries(&detrended, period, s_window, &weights);
666            let low_pass = stl_lowpass_filter(&cycle_smoothed, period, l_window);
667            seasonal = cycle_smoothed
668                .iter()
669                .zip(low_pass.iter())
670                .map(|(&c, &l)| c - l)
671                .collect();
672            let deseasonalized: Vec<f64> = data
673                .iter()
674                .zip(seasonal.iter())
675                .map(|(&y, &s)| y - s)
676                .collect();
677            trend = weighted_loess(&deseasonalized, t_window, &weights);
678        }
679        if robust && _outer < n_outer - 1 {
680            let remainder: Vec<f64> = data
681                .iter()
682                .zip(trend.iter())
683                .zip(seasonal.iter())
684                .map(|((&y, &t), &s)| y - t - s)
685                .collect();
686            weights = compute_robustness_weights(&remainder);
687        }
688    }
689    let remainder: Vec<f64> = data
690        .iter()
691        .zip(trend.iter())
692        .zip(seasonal.iter())
693        .map(|((&y, &t), &s)| y - t - s)
694        .collect();
695    (trend, seasonal, remainder, weights)
696}
697
698fn smooth_cycle_subseries(
699    data: &[f64],
700    period: usize,
701    s_window: usize,
702    weights: &[f64],
703) -> Vec<f64> {
704    let m = data.len();
705    let n_cycles = m.div_ceil(period);
706    let mut result = vec![0.0; m];
707    for pos in 0..period {
708        let mut subseries_idx: Vec<usize> = Vec::new();
709        let mut subseries_vals: Vec<f64> = Vec::new();
710        let mut subseries_weights: Vec<f64> = Vec::new();
711        for cycle in 0..n_cycles {
712            let idx = cycle * period + pos;
713            if idx < m {
714                subseries_idx.push(idx);
715                subseries_vals.push(data[idx]);
716                subseries_weights.push(weights[idx]);
717            }
718        }
719        if subseries_vals.is_empty() {
720            continue;
721        }
722        let smoothed = weighted_loess(&subseries_vals, s_window, &subseries_weights);
723        for (i, &idx) in subseries_idx.iter().enumerate() {
724            result[idx] = smoothed[i];
725        }
726    }
727    result
728}
729
730fn stl_lowpass_filter(data: &[f64], period: usize, _l_window: usize) -> Vec<f64> {
731    let ma1 = moving_average(data, period);
732    let ma2 = moving_average(&ma1, period);
733    moving_average(&ma2, 3)
734}
735
736fn moving_average(data: &[f64], window: usize) -> Vec<f64> {
737    let m = data.len();
738    if m == 0 || window == 0 {
739        return data.to_vec();
740    }
741    let half = window / 2;
742    let mut result = vec![0.0; m];
743    for i in 0..m {
744        let start = i.saturating_sub(half);
745        let end = (i + half + 1).min(m);
746        let sum: f64 = data[start..end].iter().sum();
747        let count = (end - start) as f64;
748        result[i] = sum / count;
749    }
750    result
751}
752
753fn weighted_loess(data: &[f64], window: usize, weights: &[f64]) -> Vec<f64> {
754    let m = data.len();
755    if m == 0 {
756        return vec![];
757    }
758    let half = window / 2;
759    let mut result = vec![0.0; m];
760    for i in 0..m {
761        let start = i.saturating_sub(half);
762        let end = (i + half + 1).min(m);
763        let mut sum_w = 0.0;
764        let mut sum_wx = 0.0;
765        let mut sum_wy = 0.0;
766        let mut sum_wxx = 0.0;
767        let mut sum_wxy = 0.0;
768        for j in start..end {
769            let dist = (j as f64 - i as f64).abs() / (half.max(1) as f64);
770            let tricube = if dist < 1.0 {
771                (1.0 - dist.powi(3)).powi(3)
772            } else {
773                0.0
774            };
775            let w = tricube * weights[j];
776            let x = j as f64;
777            let y = data[j];
778            sum_w += w;
779            sum_wx += w * x;
780            sum_wy += w * y;
781            sum_wxx += w * x * x;
782            sum_wxy += w * x * y;
783        }
784        if sum_w > 1e-10 {
785            let denom = sum_w * sum_wxx - sum_wx * sum_wx;
786            if denom.abs() > 1e-10 {
787                let intercept = (sum_wxx * sum_wy - sum_wx * sum_wxy) / denom;
788                let slope = (sum_w * sum_wxy - sum_wx * sum_wy) / denom;
789                result[i] = intercept + slope * i as f64;
790            } else {
791                result[i] = sum_wy / sum_w;
792            }
793        } else {
794            result[i] = data[i];
795        }
796    }
797    result
798}
799
800fn compute_robustness_weights(residuals: &[f64]) -> Vec<f64> {
801    let m = residuals.len();
802    if m == 0 {
803        return vec![];
804    }
805    let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
806    abs_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
807    let median_idx = m / 2;
808    let mad = if m % 2 == 0 {
809        (abs_residuals[median_idx - 1] + abs_residuals[median_idx]) / 2.0
810    } else {
811        abs_residuals[median_idx]
812    };
813    let h = 6.0 * mad.max(1e-10);
814    residuals
815        .iter()
816        .map(|&r| {
817            let u = r.abs() / h;
818            if u < 1.0 {
819                (1.0 - u * u).powi(2)
820            } else {
821                0.0
822            }
823        })
824        .collect()
825}
826
827/// Wrapper function for functional data STL decomposition.
828pub fn stl_fdata(
829    data: &FdMatrix,
830    _argvals: &[f64],
831    period: usize,
832    s_window: Option<usize>,
833    t_window: Option<usize>,
834    robust: bool,
835) -> StlResult {
836    stl_decompose(data, period, s_window, t_window, None, robust, None, None)
837}
838
839#[cfg(test)]
840mod tests {
841    use super::*;
842    use std::f64::consts::PI;
843
844    #[test]
845    fn test_detrend_linear_removes_linear_trend() {
846        let m = 100;
847        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
848        let data_vec: Vec<f64> = argvals
849            .iter()
850            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
851            .collect();
852        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
853        let result = detrend_linear(&data, &argvals);
854        let expected: Vec<f64> = argvals
855            .iter()
856            .map(|&t| (2.0 * PI * t / 2.0).sin())
857            .collect();
858        let mut max_diff = 0.0f64;
859        for j in 0..m {
860            let diff = (result.detrended[(0, j)] - expected[j]).abs();
861            max_diff = max_diff.max(diff);
862        }
863        assert!(max_diff < 0.2, "Max difference: {}", max_diff);
864    }
865
866    #[test]
867    fn test_detrend_polynomial_removes_quadratic_trend() {
868        let m = 100;
869        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
870        let data_vec: Vec<f64> = argvals
871            .iter()
872            .map(|&t| 1.0 + 0.5 * t - 0.1 * t * t + (2.0 * PI * t / 2.0).sin())
873            .collect();
874        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
875        let result = detrend_polynomial(&data, &argvals, 2);
876        let expected: Vec<f64> = argvals
877            .iter()
878            .map(|&t| (2.0 * PI * t / 2.0).sin())
879            .collect();
880        let detrended_vec: Vec<f64> = (0..m).map(|j| result.detrended[(0, j)]).collect();
881        let mean_det: f64 = detrended_vec.iter().sum::<f64>() / m as f64;
882        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
883        let mut num = 0.0;
884        let mut den_det = 0.0;
885        let mut den_exp = 0.0;
886        for j in 0..m {
887            num += (detrended_vec[j] - mean_det) * (expected[j] - mean_exp);
888            den_det += (detrended_vec[j] - mean_det).powi(2);
889            den_exp += (expected[j] - mean_exp).powi(2);
890        }
891        let corr = num / (den_det.sqrt() * den_exp.sqrt());
892        assert!(corr > 0.95, "Correlation: {}", corr);
893    }
894
895    #[test]
896    fn test_detrend_diff1() {
897        let m = 100;
898        let data_vec: Vec<f64> = {
899            let mut v = vec![0.0; m];
900            v[0] = 1.0;
901            for i in 1..m {
902                v[i] = v[i - 1] + 0.1 * (i as f64).sin();
903            }
904            v
905        };
906        let data = FdMatrix::from_column_major(data_vec.clone(), 1, m).unwrap();
907        let result = detrend_diff(&data, 1);
908        for j in 0..m - 1 {
909            let expected = data_vec[j + 1] - data_vec[j];
910            assert!(
911                (result.detrended[(0, j)] - expected).abs() < 1e-10,
912                "Mismatch at {}: {} vs {}",
913                j,
914                result.detrended[(0, j)],
915                expected
916            );
917        }
918    }
919
920    #[test]
921    fn test_auto_detrend_selects_linear_for_linear_data() {
922        let m = 100;
923        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
924        let data_vec: Vec<f64> = argvals.iter().map(|&t| 2.0 + 0.5 * t).collect();
925        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
926        let result = auto_detrend(&data, &argvals);
927        assert!(
928            result.method.contains("linear") || result.method.contains("polynomial"),
929            "Method: {}",
930            result.method
931        );
932    }
933
934    #[test]
935    fn test_detrend_loess_removes_linear_trend() {
936        let m = 100;
937        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
938        let data_vec: Vec<f64> = argvals
939            .iter()
940            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
941            .collect();
942        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
943        let result = detrend_loess(&data, &argvals, 0.3, 1);
944        let expected: Vec<f64> = argvals
945            .iter()
946            .map(|&t| (2.0 * PI * t / 2.0).sin())
947            .collect();
948        let detrended_vec: Vec<f64> = (0..m).map(|j| result.detrended[(0, j)]).collect();
949        let mean_det: f64 = detrended_vec.iter().sum::<f64>() / m as f64;
950        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
951        let mut num = 0.0;
952        let mut den_det = 0.0;
953        let mut den_exp = 0.0;
954        for j in 0..m {
955            num += (detrended_vec[j] - mean_det) * (expected[j] - mean_exp);
956            den_det += (detrended_vec[j] - mean_det).powi(2);
957            den_exp += (expected[j] - mean_exp).powi(2);
958        }
959        let corr = num / (den_det.sqrt() * den_exp.sqrt());
960        assert!(corr > 0.9, "Correlation: {}", corr);
961        assert_eq!(result.method, "loess");
962    }
963
964    #[test]
965    fn test_detrend_loess_removes_quadratic_trend() {
966        let m = 100;
967        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
968        let data_vec: Vec<f64> = argvals
969            .iter()
970            .map(|&t| 1.0 + 0.3 * t - 0.05 * t * t + (2.0 * PI * t / 2.0).sin())
971            .collect();
972        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
973        let result = detrend_loess(&data, &argvals, 0.3, 2);
974        assert_eq!(result.trend.ncols(), m);
975        assert_eq!(result.detrended.ncols(), m);
976        assert!(result.rss[0] > 0.0);
977    }
978
979    #[test]
980    fn test_detrend_loess_different_bandwidths() {
981        let m = 100;
982        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
983        let data_vec: Vec<f64> = argvals
984            .iter()
985            .enumerate()
986            .map(|(i, &t)| (2.0 * PI * t / 2.0).sin() + 0.1 * ((i * 17) % 100) as f64 / 100.0)
987            .collect();
988        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
989        let result_small = detrend_loess(&data, &argvals, 0.1, 1);
990        let result_large = detrend_loess(&data, &argvals, 0.5, 1);
991        assert_eq!(result_small.trend.ncols(), m);
992        assert_eq!(result_large.trend.ncols(), m);
993        assert!(result_large.n_params >= result_small.n_params);
994    }
995
996    #[test]
997    fn test_detrend_loess_short_series() {
998        let m = 10;
999        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1000        let data_vec: Vec<f64> = argvals.iter().map(|&t| t * 2.0).collect();
1001        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1002        let result = detrend_loess(&data, &argvals, 0.3, 1);
1003        assert_eq!(result.trend.ncols(), m);
1004        assert_eq!(result.detrended.ncols(), m);
1005    }
1006
1007    #[test]
1008    fn test_decompose_additive_separates_components() {
1009        let m = 200;
1010        let period = 2.0;
1011        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1012        let data_vec: Vec<f64> = argvals
1013            .iter()
1014            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / period).sin())
1015            .collect();
1016        let data = FdMatrix::from_column_major(data_vec.clone(), 1, m).unwrap();
1017        let result = decompose_additive(&data, &argvals, period, "loess", 0.3, 3);
1018        assert_eq!(result.trend.ncols(), m);
1019        assert_eq!(result.seasonal.ncols(), m);
1020        assert_eq!(result.remainder.ncols(), m);
1021        assert_eq!(result.method, "additive");
1022        assert_eq!(result.period, period);
1023        for j in 0..m {
1024            let reconstructed =
1025                result.trend[(0, j)] + result.seasonal[(0, j)] + result.remainder[(0, j)];
1026            assert!(
1027                (reconstructed - data_vec[j]).abs() < 0.5,
1028                "Reconstruction error at {}: {} vs {}",
1029                j,
1030                reconstructed,
1031                data_vec[j]
1032            );
1033        }
1034    }
1035
1036    #[test]
1037    fn test_decompose_additive_different_harmonics() {
1038        let m = 200;
1039        let period = 2.0;
1040        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1041        let data_vec: Vec<f64> = argvals
1042            .iter()
1043            .map(|&t| 1.0 + (2.0 * PI * t / period).sin())
1044            .collect();
1045        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1046        let result1 = decompose_additive(&data, &argvals, period, "loess", 0.3, 1);
1047        let result5 = decompose_additive(&data, &argvals, period, "loess", 0.3, 5);
1048        assert_eq!(result1.seasonal.ncols(), m);
1049        assert_eq!(result5.seasonal.ncols(), m);
1050    }
1051
1052    #[test]
1053    fn test_decompose_additive_residual_properties() {
1054        let m = 200;
1055        let period = 2.0;
1056        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1057        let data_vec: Vec<f64> = argvals
1058            .iter()
1059            .map(|&t| 2.0 + 0.3 * t + (2.0 * PI * t / period).sin())
1060            .collect();
1061        let data = FdMatrix::from_column_major(data_vec.clone(), 1, m).unwrap();
1062        let result = decompose_additive(&data, &argvals, period, "loess", 0.3, 3);
1063        let remainder_vec: Vec<f64> = (0..m).map(|j| result.remainder[(0, j)]).collect();
1064        let mean_rem: f64 = remainder_vec.iter().sum::<f64>() / m as f64;
1065        assert!(mean_rem.abs() < 0.5, "Remainder mean: {}", mean_rem);
1066        let data_mean: f64 = data_vec.iter().sum::<f64>() / m as f64;
1067        let var_data: f64 = data_vec
1068            .iter()
1069            .map(|&x| (x - data_mean).powi(2))
1070            .sum::<f64>()
1071            / m as f64;
1072        let var_rem: f64 = remainder_vec
1073            .iter()
1074            .map(|&x| (x - mean_rem).powi(2))
1075            .sum::<f64>()
1076            / m as f64;
1077        assert!(
1078            var_rem < var_data,
1079            "Remainder variance {} should be < data variance {}",
1080            var_rem,
1081            var_data
1082        );
1083    }
1084
1085    #[test]
1086    fn test_decompose_additive_multi_sample() {
1087        let n = 3;
1088        let m = 100;
1089        let period = 2.0;
1090        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1091        let mut data = FdMatrix::zeros(n, m);
1092        for i in 0..n {
1093            let amp = (i + 1) as f64;
1094            for j in 0..m {
1095                data[(i, j)] =
1096                    1.0 + 0.1 * argvals[j] + amp * (2.0 * PI * argvals[j] / period).sin();
1097            }
1098        }
1099        let result = decompose_additive(&data, &argvals, period, "loess", 0.3, 2);
1100        assert_eq!(result.trend.shape(), (n, m));
1101        assert_eq!(result.seasonal.shape(), (n, m));
1102        assert_eq!(result.remainder.shape(), (n, m));
1103    }
1104
1105    #[test]
1106    fn test_decompose_multiplicative_basic() {
1107        let m = 200;
1108        let period = 2.0;
1109        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1110        let data_vec: Vec<f64> = argvals
1111            .iter()
1112            .map(|&t| (2.0 + 0.1 * t) * (1.0 + 0.3 * (2.0 * PI * t / period).sin()))
1113            .collect();
1114        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1115        let result = decompose_multiplicative(&data, &argvals, period, "loess", 0.3, 3);
1116        assert_eq!(result.trend.ncols(), m);
1117        assert_eq!(result.seasonal.ncols(), m);
1118        assert_eq!(result.remainder.ncols(), m);
1119        assert_eq!(result.method, "multiplicative");
1120        let seasonal_vec: Vec<f64> = (0..m).map(|j| result.seasonal[(0, j)]).collect();
1121        let mean_seasonal: f64 = seasonal_vec.iter().sum::<f64>() / m as f64;
1122        assert!(
1123            (mean_seasonal - 1.0).abs() < 0.5,
1124            "Mean seasonal factor: {}",
1125            mean_seasonal
1126        );
1127    }
1128
1129    #[test]
1130    fn test_decompose_multiplicative_non_positive_data() {
1131        let m = 100;
1132        let period = 2.0;
1133        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1134        let data_vec: Vec<f64> = argvals
1135            .iter()
1136            .map(|&t| -1.0 + (2.0 * PI * t / period).sin())
1137            .collect();
1138        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1139        let result = decompose_multiplicative(&data, &argvals, period, "loess", 0.3, 2);
1140        assert_eq!(result.trend.ncols(), m);
1141        assert_eq!(result.seasonal.ncols(), m);
1142        for j in 0..m {
1143            let s = result.seasonal[(0, j)];
1144            assert!(s.is_finite(), "Seasonal should be finite");
1145        }
1146    }
1147
1148    #[test]
1149    fn test_decompose_multiplicative_vs_additive() {
1150        let m = 200;
1151        let period = 2.0;
1152        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1153        let data_vec: Vec<f64> = argvals
1154            .iter()
1155            .map(|&t| 5.0 + (2.0 * PI * t / period).sin())
1156            .collect();
1157        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1158        let add_result = decompose_additive(&data, &argvals, period, "loess", 0.3, 3);
1159        let mult_result = decompose_multiplicative(&data, &argvals, period, "loess", 0.3, 3);
1160        assert_eq!(add_result.seasonal.ncols(), m);
1161        assert_eq!(mult_result.seasonal.ncols(), m);
1162        let add_seasonal_vec: Vec<f64> = (0..m).map(|j| add_result.seasonal[(0, j)]).collect();
1163        let add_mean: f64 = add_seasonal_vec.iter().sum::<f64>() / m as f64;
1164        let mult_seasonal_vec: Vec<f64> = (0..m).map(|j| mult_result.seasonal[(0, j)]).collect();
1165        let mult_mean: f64 = mult_seasonal_vec.iter().sum::<f64>() / m as f64;
1166        assert!(
1167            add_mean.abs() < mult_mean,
1168            "Additive mean {} vs mult mean {}",
1169            add_mean,
1170            mult_mean
1171        );
1172    }
1173
1174    #[test]
1175    fn test_decompose_multiplicative_edge_cases() {
1176        let empty = FdMatrix::zeros(0, 0);
1177        let result = decompose_multiplicative(&empty, &[], 2.0, "loess", 0.3, 2);
1178        assert_eq!(result.trend.len(), 0);
1179        let m = 5;
1180        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1181        let data = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 2.0, 1.0], 1, m).unwrap();
1182        let result = decompose_multiplicative(&data, &argvals, 2.0, "loess", 0.3, 1);
1183        assert_eq!(result.remainder.ncols(), m);
1184    }
1185
1186    #[test]
1187    fn test_stl_decompose_basic() {
1188        let period = 12;
1189        let n_cycles = 10;
1190        let m = period * n_cycles;
1191        let data_vec: Vec<f64> = (0..m)
1192            .map(|i| {
1193                let t = i as f64;
1194                0.01 * t + (2.0 * PI * t / period as f64).sin()
1195            })
1196            .collect();
1197        let data = FdMatrix::from_column_major(data_vec.clone(), 1, m).unwrap();
1198        let result = stl_decompose(&data, period, None, None, None, false, None, None);
1199        assert_eq!(result.trend.ncols(), m);
1200        assert_eq!(result.seasonal.ncols(), m);
1201        assert_eq!(result.remainder.ncols(), m);
1202        assert_eq!(result.period, period);
1203        for j in 0..m {
1204            let reconstructed =
1205                result.trend[(0, j)] + result.seasonal[(0, j)] + result.remainder[(0, j)];
1206            assert!(
1207                (reconstructed - data_vec[j]).abs() < 1e-8,
1208                "Reconstruction error at {}: {} vs {}",
1209                j,
1210                reconstructed,
1211                data_vec[j]
1212            );
1213        }
1214    }
1215
1216    #[test]
1217    fn test_stl_decompose_robust() {
1218        let period = 12;
1219        let n_cycles = 10;
1220        let m = period * n_cycles;
1221        let mut data_vec: Vec<f64> = (0..m)
1222            .map(|i| {
1223                let t = i as f64;
1224                0.01 * t + (2.0 * PI * t / period as f64).sin()
1225            })
1226            .collect();
1227        data_vec[30] += 10.0;
1228        data_vec[60] += 10.0;
1229        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1230        let result = stl_decompose(&data, period, None, None, None, true, None, Some(5));
1231        assert!(
1232            result.weights[(0, 30)] < 1.0,
1233            "Weight at outlier should be < 1.0: {}",
1234            result.weights[(0, 30)]
1235        );
1236        assert!(
1237            result.weights[(0, 60)] < 1.0,
1238            "Weight at outlier should be < 1.0: {}",
1239            result.weights[(0, 60)]
1240        );
1241        let non_outlier_weight = result.weights[(0, 15)];
1242        assert!(
1243            non_outlier_weight > result.weights[(0, 30)],
1244            "Non-outlier weight {} should be > outlier weight {}",
1245            non_outlier_weight,
1246            result.weights[(0, 30)]
1247        );
1248    }
1249
1250    #[test]
1251    fn test_stl_decompose_default_params() {
1252        let period = 10;
1253        let m = period * 8;
1254        let data_vec: Vec<f64> = (0..m)
1255            .map(|i| (2.0 * PI * i as f64 / period as f64).sin())
1256            .collect();
1257        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1258        let result = stl_decompose(&data, period, None, None, None, false, None, None);
1259        assert_eq!(result.trend.ncols(), m);
1260        assert_eq!(result.seasonal.ncols(), m);
1261        assert!(result.s_window >= 3);
1262        assert!(result.t_window >= 3);
1263        assert_eq!(result.inner_iterations, 2);
1264        assert_eq!(result.outer_iterations, 1);
1265    }
1266
1267    #[test]
1268    fn test_stl_decompose_invalid() {
1269        let data = FdMatrix::from_column_major(vec![1.0, 2.0], 1, 2).unwrap();
1270        let result = stl_decompose(&data, 1, None, None, None, false, None, None);
1271        assert_eq!(result.s_window, 0);
1272        let data = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 1, 3).unwrap();
1273        let result = stl_decompose(&data, 5, None, None, None, false, None, None);
1274        assert_eq!(result.s_window, 0);
1275        let data = FdMatrix::zeros(0, 0);
1276        let result = stl_decompose(&data, 10, None, None, None, false, None, None);
1277        assert_eq!(result.trend.len(), 0);
1278    }
1279
1280    #[test]
1281    fn test_stl_fdata() {
1282        let n = 3;
1283        let period = 10;
1284        let m = period * 5;
1285        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1286        let mut data = FdMatrix::zeros(n, m);
1287        for i in 0..n {
1288            let amp = (i + 1) as f64;
1289            for j in 0..m {
1290                data[(i, j)] = amp * (2.0 * PI * argvals[j] / period as f64).sin();
1291            }
1292        }
1293        let result = stl_fdata(&data, &argvals, period, None, None, false);
1294        assert_eq!(result.trend.shape(), (n, m));
1295        assert_eq!(result.seasonal.shape(), (n, m));
1296        assert_eq!(result.remainder.shape(), (n, m));
1297        for i in 0..n {
1298            for j in 0..m {
1299                let reconstructed =
1300                    result.trend[(i, j)] + result.seasonal[(i, j)] + result.remainder[(i, j)];
1301                assert!(
1302                    (reconstructed - data[(i, j)]).abs() < 1e-8,
1303                    "Reconstruction error for sample {} at {}: {} vs {}",
1304                    i,
1305                    j,
1306                    reconstructed,
1307                    data[(i, j)]
1308                );
1309            }
1310        }
1311    }
1312
1313    #[test]
1314    fn test_stl_decompose_multi_sample() {
1315        let n = 5;
1316        let period = 10;
1317        let m = period * 6;
1318        let mut data = FdMatrix::zeros(n, m);
1319        for i in 0..n {
1320            let offset = i as f64 * 0.5;
1321            for j in 0..m {
1322                data[(i, j)] =
1323                    offset + 0.01 * j as f64 + (2.0 * PI * j as f64 / period as f64).sin();
1324            }
1325        }
1326        let result = stl_decompose(&data, period, None, None, None, false, None, None);
1327        assert_eq!(result.trend.shape(), (n, m));
1328        assert_eq!(result.seasonal.shape(), (n, m));
1329        assert_eq!(result.remainder.shape(), (n, m));
1330        assert_eq!(result.weights.shape(), (n, m));
1331    }
1332
1333    #[test]
1334    fn test_detrend_diff_order2() {
1335        let m = 50;
1336        let data_vec: Vec<f64> = (0..m).map(|i| (i as f64).powi(2)).collect();
1337        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1338        let result = detrend_diff(&data, 2);
1339        for j in 0..m - 2 {
1340            assert!(
1341                (result.detrended[(0, j)] - 2.0).abs() < 1e-10,
1342                "Second diff at {}: expected 2.0, got {}",
1343                j,
1344                result.detrended[(0, j)]
1345            );
1346        }
1347    }
1348
1349    #[test]
1350    fn test_detrend_polynomial_degree3() {
1351        let m = 100;
1352        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 5.0).collect();
1353        let data_vec: Vec<f64> = argvals
1354            .iter()
1355            .map(|&t| 1.0 + 2.0 * t - 0.5 * t * t + 0.1 * t * t * t)
1356            .collect();
1357        let data = FdMatrix::from_column_major(data_vec, 1, m).unwrap();
1358        let result = detrend_polynomial(&data, &argvals, 3);
1359        assert_eq!(result.method, "polynomial(3)");
1360        assert!(result.coefficients.is_some());
1361        let max_detrend: f64 = (0..m)
1362            .map(|j| result.detrended[(0, j)].abs())
1363            .fold(0.0, f64::max);
1364        assert!(
1365            max_detrend < 0.1,
1366            "Pure cubic should be nearly zero after degree-3 detrend: {}",
1367            max_detrend
1368        );
1369    }
1370
1371    #[test]
1372    fn test_detrend_loess_invalid() {
1373        let data = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5).unwrap();
1374        let argvals = vec![0.0, 1.0, 2.0, 3.0, 4.0];
1375        let result = detrend_loess(&data, &argvals, -0.1, 1);
1376        assert_eq!(result.detrended.as_slice(), data.as_slice());
1377        let data2 = FdMatrix::from_column_major(vec![1.0, 2.0], 1, 2).unwrap();
1378        let result = detrend_loess(&data2, &[0.0, 1.0], 0.3, 1);
1379        assert_eq!(result.detrended.as_slice(), &[1.0, 2.0]);
1380    }
1381}