Skip to main content

fdars_core/
detrend.rs

1//! Detrending and decomposition functions for non-stationary functional data.
2//!
3//! This module provides methods for removing trends from functional data
4//! to enable more accurate seasonal analysis. It includes:
5//! - Linear detrending (least squares)
6//! - Polynomial detrending (QR decomposition)
7//! - Differencing (first and second order)
8//! - LOESS detrending (local polynomial regression)
9//! - Spline detrending (P-splines)
10//! - Automatic method selection via AIC
11
12use crate::iter_maybe_parallel;
13use crate::smoothing::local_polynomial;
14use nalgebra::{DMatrix, DVector};
15#[cfg(feature = "parallel")]
16use rayon::iter::ParallelIterator;
17
18/// Result of detrending operation.
19#[derive(Debug, Clone)]
20pub struct TrendResult {
21    /// Estimated trend values (n x m column-major)
22    pub trend: Vec<f64>,
23    /// Detrended data (n x m column-major)
24    pub detrended: Vec<f64>,
25    /// Method used for detrending
26    pub method: String,
27    /// Polynomial coefficients (for polynomial methods, per sample)
28    /// For n samples with polynomial degree d: coefficients[i * (d+1) + k] is coefficient k for sample i
29    pub coefficients: Option<Vec<f64>>,
30    /// Residual sum of squares for each sample
31    pub rss: Vec<f64>,
32    /// Number of parameters (for AIC calculation)
33    pub n_params: usize,
34}
35
36/// Result of seasonal decomposition.
37#[derive(Debug, Clone)]
38pub struct DecomposeResult {
39    /// Trend component (n x m column-major)
40    pub trend: Vec<f64>,
41    /// Seasonal component (n x m column-major)
42    pub seasonal: Vec<f64>,
43    /// Remainder/residual component (n x m column-major)
44    pub remainder: Vec<f64>,
45    /// Period used for decomposition
46    pub period: f64,
47    /// Decomposition method ("additive" or "multiplicative")
48    pub method: String,
49}
50
51/// Remove linear trend from functional data using least squares.
52///
53/// # Arguments
54/// * `data` - Column-major matrix (n x m): n samples, m evaluation points
55/// * `n` - Number of samples
56/// * `m` - Number of evaluation points
57/// * `argvals` - Time/argument values of length m
58///
59/// # Returns
60/// TrendResult with trend, detrended data, and coefficients (intercept, slope)
61pub fn detrend_linear(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
62    if n == 0 || m < 2 || data.len() != n * m || argvals.len() != m {
63        return TrendResult {
64            trend: vec![0.0; n * m],
65            detrended: data.to_vec(),
66            method: "linear".to_string(),
67            coefficients: None,
68            rss: vec![0.0; n],
69            n_params: 2,
70        };
71    }
72
73    // Precompute t statistics
74    let mean_t: f64 = argvals.iter().sum::<f64>() / m as f64;
75    let ss_t: f64 = argvals.iter().map(|&t| (t - mean_t).powi(2)).sum();
76
77    // Process each sample in parallel
78    let results: Vec<(Vec<f64>, Vec<f64>, f64, f64, f64)> = iter_maybe_parallel!(0..n)
79        .map(|i| {
80            // Extract curve
81            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
82            let mean_y: f64 = curve.iter().sum::<f64>() / m as f64;
83
84            // Compute slope: sum((t - mean_t) * (y - mean_y)) / sum((t - mean_t)^2)
85            let mut sp = 0.0;
86            for j in 0..m {
87                sp += (argvals[j] - mean_t) * (curve[j] - mean_y);
88            }
89            let slope = if ss_t.abs() > 1e-15 { sp / ss_t } else { 0.0 };
90            let intercept = mean_y - slope * mean_t;
91
92            // Compute trend and detrended
93            let mut trend = vec![0.0; m];
94            let mut detrended = vec![0.0; m];
95            let mut rss = 0.0;
96            for j in 0..m {
97                trend[j] = intercept + slope * argvals[j];
98                detrended[j] = curve[j] - trend[j];
99                rss += detrended[j].powi(2);
100            }
101
102            (trend, detrended, intercept, slope, rss)
103        })
104        .collect();
105
106    // Reassemble into column-major format
107    let mut trend = vec![0.0; n * m];
108    let mut detrended = vec![0.0; n * m];
109    let mut coefficients = vec![0.0; n * 2];
110    let mut rss = vec![0.0; n];
111
112    for (i, (t, d, intercept, slope, r)) in results.into_iter().enumerate() {
113        for j in 0..m {
114            trend[i + j * n] = t[j];
115            detrended[i + j * n] = d[j];
116        }
117        coefficients[i * 2] = intercept;
118        coefficients[i * 2 + 1] = slope;
119        rss[i] = r;
120    }
121
122    TrendResult {
123        trend,
124        detrended,
125        method: "linear".to_string(),
126        coefficients: Some(coefficients),
127        rss,
128        n_params: 2,
129    }
130}
131
132/// Remove polynomial trend from functional data using QR decomposition.
133///
134/// # Arguments
135/// * `data` - Column-major matrix (n x m)
136/// * `n` - Number of samples
137/// * `m` - Number of evaluation points
138/// * `argvals` - Time/argument values of length m
139/// * `degree` - Polynomial degree (1 = linear, 2 = quadratic, etc.)
140///
141/// # Returns
142/// TrendResult with trend, detrended data, and polynomial coefficients
143pub fn detrend_polynomial(
144    data: &[f64],
145    n: usize,
146    m: usize,
147    argvals: &[f64],
148    degree: usize,
149) -> TrendResult {
150    if n == 0 || m < degree + 1 || data.len() != n * m || argvals.len() != m || degree == 0 {
151        // For degree 0 or invalid input, return original data
152        return TrendResult {
153            trend: vec![0.0; n * m],
154            detrended: data.to_vec(),
155            method: format!("polynomial({})", degree),
156            coefficients: None,
157            rss: vec![0.0; n],
158            n_params: degree + 1,
159        };
160    }
161
162    // Special case: degree 1 is linear
163    if degree == 1 {
164        let mut result = detrend_linear(data, n, m, argvals);
165        result.method = "polynomial(1)".to_string();
166        return result;
167    }
168
169    let n_coef = degree + 1;
170
171    // Normalize argvals to avoid numerical issues with high-degree polynomials
172    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
173    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
174    let t_range = if (t_max - t_min).abs() > 1e-15 {
175        t_max - t_min
176    } else {
177        1.0
178    };
179    let t_norm: Vec<f64> = argvals.iter().map(|&t| (t - t_min) / t_range).collect();
180
181    // Build Vandermonde matrix (m x n_coef)
182    let mut design = DMatrix::zeros(m, n_coef);
183    for j in 0..m {
184        let t = t_norm[j];
185        let mut power = 1.0;
186        for k in 0..n_coef {
187            design[(j, k)] = power;
188            power *= t;
189        }
190    }
191
192    // SVD for stable least squares
193    let svd = design.clone().svd(true, true);
194
195    // Process each sample in parallel
196    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
197        .map(|i| {
198            // Extract curve
199            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
200            let y = DVector::from_row_slice(&curve);
201
202            // Solve least squares using SVD
203            let beta = svd
204                .solve(&y, 1e-10)
205                .unwrap_or_else(|_| DVector::zeros(n_coef));
206
207            // Compute fitted values (trend) and residuals
208            let fitted = &design * &beta;
209            let mut trend = vec![0.0; m];
210            let mut detrended = vec![0.0; m];
211            let mut rss = 0.0;
212            for j in 0..m {
213                trend[j] = fitted[j];
214                detrended[j] = curve[j] - fitted[j];
215                rss += detrended[j].powi(2);
216            }
217
218            // Extract coefficients
219            let coefs: Vec<f64> = beta.iter().cloned().collect();
220
221            (trend, detrended, coefs, rss)
222        })
223        .collect();
224
225    // Reassemble into column-major format
226    let mut trend = vec![0.0; n * m];
227    let mut detrended = vec![0.0; n * m];
228    let mut coefficients = vec![0.0; n * n_coef];
229    let mut rss = vec![0.0; n];
230
231    for (i, (t, d, coefs, r)) in results.into_iter().enumerate() {
232        for j in 0..m {
233            trend[i + j * n] = t[j];
234            detrended[i + j * n] = d[j];
235        }
236        for k in 0..n_coef {
237            coefficients[i * n_coef + k] = coefs[k];
238        }
239        rss[i] = r;
240    }
241
242    TrendResult {
243        trend,
244        detrended,
245        method: format!("polynomial({})", degree),
246        coefficients: Some(coefficients),
247        rss,
248        n_params: n_coef,
249    }
250}
251
252/// Remove trend by differencing.
253///
254/// # Arguments
255/// * `data` - Column-major matrix (n x m)
256/// * `n` - Number of samples
257/// * `m` - Number of evaluation points
258/// * `order` - Differencing order (1 or 2)
259///
260/// # Returns
261/// TrendResult with trend (cumulative sum to reverse), detrended (differences),
262/// and original first values as "coefficients"
263///
264/// Note: Differencing reduces the series length by `order` points.
265/// The returned detrended data has m - order points padded with zeros at the end.
266pub fn detrend_diff(data: &[f64], n: usize, m: usize, order: usize) -> TrendResult {
267    if n == 0 || m <= order || data.len() != n * m || order == 0 || order > 2 {
268        return TrendResult {
269            trend: vec![0.0; n * m],
270            detrended: data.to_vec(),
271            method: format!("diff{}", order),
272            coefficients: None,
273            rss: vec![0.0; n],
274            n_params: order,
275        };
276    }
277
278    let new_m = m - order;
279
280    // Process each sample in parallel
281    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
282        .map(|i| {
283            // Extract curve
284            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
285
286            // First difference
287            let diff1: Vec<f64> = (0..m - 1).map(|j| curve[j + 1] - curve[j]).collect();
288
289            // Second difference if order == 2
290            let detrended = if order == 2 {
291                (0..diff1.len() - 1)
292                    .map(|j| diff1[j + 1] - diff1[j])
293                    .collect()
294            } else {
295                diff1.clone()
296            };
297
298            // Store initial values needed for reconstruction
299            let initial_values = if order == 2 {
300                vec![curve[0], curve[1]]
301            } else {
302                vec![curve[0]]
303            };
304
305            // Compute RSS (sum of squared differences as "residuals" - interpretation varies)
306            let rss: f64 = detrended.iter().map(|&x| x.powi(2)).sum();
307
308            // For "trend", we reconstruct as cumsum of differences
309            // This is a rough approximation; true trend would need integration
310            let mut trend = vec![0.0; m];
311            trend[0] = curve[0];
312            if order == 1 {
313                for j in 1..m {
314                    trend[j] = curve[j] - if j <= new_m { detrended[j - 1] } else { 0.0 };
315                }
316            } else {
317                // For order 2, trend is less meaningful
318                trend = curve.clone();
319            }
320
321            // Pad detrended to full length
322            let mut det_full = vec![0.0; m];
323            det_full[..new_m].copy_from_slice(&detrended[..new_m]);
324
325            (trend, det_full, initial_values, rss)
326        })
327        .collect();
328
329    // Reassemble
330    let mut trend = vec![0.0; n * m];
331    let mut detrended = vec![0.0; n * m];
332    let mut coefficients = vec![0.0; n * order];
333    let mut rss = vec![0.0; n];
334
335    for (i, (t, d, init, r)) in results.into_iter().enumerate() {
336        for j in 0..m {
337            trend[i + j * n] = t[j];
338            detrended[i + j * n] = d[j];
339        }
340        for k in 0..order {
341            coefficients[i * order + k] = init[k];
342        }
343        rss[i] = r;
344    }
345
346    TrendResult {
347        trend,
348        detrended,
349        method: format!("diff{}", order),
350        coefficients: Some(coefficients),
351        rss,
352        n_params: order,
353    }
354}
355
356/// Remove trend using LOESS (local polynomial regression).
357///
358/// # Arguments
359/// * `data` - Column-major matrix (n x m)
360/// * `n` - Number of samples
361/// * `m` - Number of evaluation points
362/// * `argvals` - Time/argument values
363/// * `bandwidth` - Bandwidth as fraction of data range (0.1 to 0.5 typical)
364/// * `degree` - Local polynomial degree (1 or 2)
365///
366/// # Returns
367/// TrendResult with LOESS-smoothed trend
368pub fn detrend_loess(
369    data: &[f64],
370    n: usize,
371    m: usize,
372    argvals: &[f64],
373    bandwidth: f64,
374    degree: usize,
375) -> TrendResult {
376    if n == 0 || m < 3 || data.len() != n * m || argvals.len() != m || bandwidth <= 0.0 {
377        return TrendResult {
378            trend: vec![0.0; n * m],
379            detrended: data.to_vec(),
380            method: "loess".to_string(),
381            coefficients: None,
382            rss: vec![0.0; n],
383            n_params: (m as f64 * bandwidth).ceil() as usize,
384        };
385    }
386
387    // Convert bandwidth from fraction to absolute units
388    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
389    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
390    let abs_bandwidth = (t_max - t_min) * bandwidth;
391
392    // Process each sample in parallel
393    let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
394        .map(|i| {
395            // Extract curve
396            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
397
398            // Apply local polynomial regression
399            let trend =
400                local_polynomial(argvals, &curve, argvals, abs_bandwidth, degree, "gaussian");
401
402            // Compute detrended and RSS
403            let mut detrended = vec![0.0; m];
404            let mut rss = 0.0;
405            for j in 0..m {
406                detrended[j] = curve[j] - trend[j];
407                rss += detrended[j].powi(2);
408            }
409
410            (trend, detrended, rss)
411        })
412        .collect();
413
414    // Reassemble
415    let mut trend = vec![0.0; n * m];
416    let mut detrended = vec![0.0; n * m];
417    let mut rss = vec![0.0; n];
418
419    for (i, (t, d, r)) in results.into_iter().enumerate() {
420        for j in 0..m {
421            trend[i + j * n] = t[j];
422            detrended[i + j * n] = d[j];
423        }
424        rss[i] = r;
425    }
426
427    // Effective number of parameters for LOESS is approximately n * bandwidth
428    let n_params = (m as f64 * bandwidth).ceil() as usize;
429
430    TrendResult {
431        trend,
432        detrended,
433        method: "loess".to_string(),
434        coefficients: None,
435        rss,
436        n_params,
437    }
438}
439
440/// Automatically select the best detrending method using AIC.
441///
442/// Compares linear, polynomial (degree 2 and 3), and LOESS,
443/// selecting the method with lowest AIC.
444///
445/// # Arguments
446/// * `data` - Column-major matrix (n x m)
447/// * `n` - Number of samples
448/// * `m` - Number of evaluation points
449/// * `argvals` - Time/argument values
450///
451/// # Returns
452/// TrendResult from the best method
453pub fn auto_detrend(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
454    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m {
455        return TrendResult {
456            trend: vec![0.0; n * m],
457            detrended: data.to_vec(),
458            method: "auto(none)".to_string(),
459            coefficients: None,
460            rss: vec![0.0; n],
461            n_params: 0,
462        };
463    }
464
465    // Compute AIC for a result: AIC = n * log(RSS/n) + 2*k
466    // We use mean AIC across all samples
467    let compute_aic = |result: &TrendResult| -> f64 {
468        let mut total_aic = 0.0;
469        for i in 0..n {
470            let rss = result.rss[i];
471            let k = result.n_params as f64;
472            let aic = if rss > 1e-15 {
473                m as f64 * (rss / m as f64).ln() + 2.0 * k
474            } else {
475                f64::NEG_INFINITY // Perfect fit (unlikely)
476            };
477            total_aic += aic;
478        }
479        total_aic / n as f64
480    };
481
482    // Try different methods
483    let linear = detrend_linear(data, n, m, argvals);
484    let poly2 = detrend_polynomial(data, n, m, argvals, 2);
485    let poly3 = detrend_polynomial(data, n, m, argvals, 3);
486    let loess = detrend_loess(data, n, m, argvals, 0.3, 2);
487
488    let aic_linear = compute_aic(&linear);
489    let aic_poly2 = compute_aic(&poly2);
490    let aic_poly3 = compute_aic(&poly3);
491    let aic_loess = compute_aic(&loess);
492
493    // Find minimum AIC
494    let methods = [
495        (aic_linear, "linear", linear),
496        (aic_poly2, "polynomial(2)", poly2),
497        (aic_poly3, "polynomial(3)", poly3),
498        (aic_loess, "loess", loess),
499    ];
500
501    let (_, best_name, mut best_result) = methods
502        .into_iter()
503        .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
504        .unwrap();
505
506    best_result.method = format!("auto({})", best_name);
507    best_result
508}
509
510/// Additive seasonal decomposition: data = trend + seasonal + remainder
511///
512/// Uses LOESS or spline for trend extraction, then averages within-period
513/// residuals to estimate the seasonal component.
514///
515/// # Arguments
516/// * `data` - Column-major matrix (n x m)
517/// * `n` - Number of samples
518/// * `m` - Number of evaluation points
519/// * `argvals` - Time/argument values
520/// * `period` - Seasonal period in same units as argvals
521/// * `trend_method` - "loess" or "spline"
522/// * `bandwidth` - Bandwidth for LOESS (fraction, e.g., 0.3)
523/// * `n_harmonics` - Number of Fourier harmonics for seasonal component
524///
525/// # Returns
526/// DecomposeResult with trend, seasonal, and remainder components
527pub fn decompose_additive(
528    data: &[f64],
529    n: usize,
530    m: usize,
531    argvals: &[f64],
532    period: f64,
533    trend_method: &str,
534    bandwidth: f64,
535    n_harmonics: usize,
536) -> DecomposeResult {
537    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
538        return DecomposeResult {
539            trend: vec![0.0; n * m],
540            seasonal: vec![0.0; n * m],
541            remainder: data.to_vec(),
542            period,
543            method: "additive".to_string(),
544        };
545    }
546
547    // Step 1: Extract trend using LOESS or spline
548    let trend_result = match trend_method {
549        "spline" => {
550            // Use P-spline fitting - use a larger bandwidth for trend
551            detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2)
552        }
553        _ => detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2),
554    };
555
556    // Step 2: Extract seasonal component using Fourier basis on detrended data
557    let n_harm = n_harmonics.max(1).min(m / 4);
558    let omega = 2.0 * std::f64::consts::PI / period;
559
560    // Process each sample
561    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
562        .map(|i| {
563            let trend_i: Vec<f64> = (0..m).map(|j| trend_result.trend[i + j * n]).collect();
564            let detrended_i: Vec<f64> = (0..m).map(|j| trend_result.detrended[i + j * n]).collect();
565
566            // Fit Fourier model to detrended data: sum of sin and cos terms
567            // y = sum_k (a_k * cos(k*omega*t) + b_k * sin(k*omega*t))
568            let n_coef = 2 * n_harm;
569            let mut design = DMatrix::zeros(m, n_coef);
570            for j in 0..m {
571                let t = argvals[j];
572                for k in 0..n_harm {
573                    let freq = (k + 1) as f64 * omega;
574                    design[(j, 2 * k)] = (freq * t).cos();
575                    design[(j, 2 * k + 1)] = (freq * t).sin();
576                }
577            }
578
579            // Solve least squares using SVD
580            let y = DVector::from_row_slice(&detrended_i);
581            let svd = design.clone().svd(true, true);
582            let coef = svd
583                .solve(&y, 1e-10)
584                .unwrap_or_else(|_| DVector::zeros(n_coef));
585
586            // Compute seasonal component
587            let fitted = &design * &coef;
588            let seasonal: Vec<f64> = fitted.iter().cloned().collect();
589
590            // Compute remainder
591            let remainder: Vec<f64> = (0..m).map(|j| detrended_i[j] - seasonal[j]).collect();
592
593            (trend_i, seasonal, remainder)
594        })
595        .collect();
596
597    // Reassemble into column-major format
598    let mut trend = vec![0.0; n * m];
599    let mut seasonal = vec![0.0; n * m];
600    let mut remainder = vec![0.0; n * m];
601
602    for (i, (t, s, r)) in results.into_iter().enumerate() {
603        for j in 0..m {
604            trend[i + j * n] = t[j];
605            seasonal[i + j * n] = s[j];
606            remainder[i + j * n] = r[j];
607        }
608    }
609
610    DecomposeResult {
611        trend,
612        seasonal,
613        remainder,
614        period,
615        method: "additive".to_string(),
616    }
617}
618
619/// Multiplicative seasonal decomposition: data = trend * seasonal * remainder
620///
621/// Applies log transformation, then additive decomposition, then back-transforms.
622/// Handles non-positive values by adding a shift.
623///
624/// # Arguments
625/// * `data` - Column-major matrix (n x m)
626/// * `n` - Number of samples
627/// * `m` - Number of evaluation points
628/// * `argvals` - Time/argument values
629/// * `period` - Seasonal period
630/// * `trend_method` - "loess" or "spline"
631/// * `bandwidth` - Bandwidth for LOESS
632/// * `n_harmonics` - Number of Fourier harmonics
633///
634/// # Returns
635/// DecomposeResult with multiplicative components
636pub fn decompose_multiplicative(
637    data: &[f64],
638    n: usize,
639    m: usize,
640    argvals: &[f64],
641    period: f64,
642    trend_method: &str,
643    bandwidth: f64,
644    n_harmonics: usize,
645) -> DecomposeResult {
646    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
647        return DecomposeResult {
648            trend: vec![0.0; n * m],
649            seasonal: vec![0.0; n * m],
650            remainder: data.to_vec(),
651            period,
652            method: "multiplicative".to_string(),
653        };
654    }
655
656    // Find minimum value and add shift if needed to make all values positive
657    let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
658    let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
659
660    // Log transform
661    let log_data: Vec<f64> = data.iter().map(|&x| (x + shift).ln()).collect();
662
663    // Apply additive decomposition to log data
664    let additive_result = decompose_additive(
665        &log_data,
666        n,
667        m,
668        argvals,
669        period,
670        trend_method,
671        bandwidth,
672        n_harmonics,
673    );
674
675    // Back transform: exp of each component
676    // For multiplicative: data = trend * seasonal * remainder
677    // In log space: log(data) = log(trend) + log(seasonal) + log(remainder)
678    // So: trend_mult = exp(trend_add), seasonal_mult = exp(seasonal_add), etc.
679
680    let mut trend = vec![0.0; n * m];
681    let mut seasonal = vec![0.0; n * m];
682    let mut remainder = vec![0.0; n * m];
683
684    for idx in 0..n * m {
685        // Back-transform trend (subtract shift)
686        trend[idx] = additive_result.trend[idx].exp() - shift;
687
688        // Seasonal is a multiplicative factor (centered around 1)
689        // We interpret the additive seasonal component as log(seasonal factor)
690        seasonal[idx] = additive_result.seasonal[idx].exp();
691
692        // Remainder is also multiplicative
693        remainder[idx] = additive_result.remainder[idx].exp();
694    }
695
696    DecomposeResult {
697        trend,
698        seasonal,
699        remainder,
700        period,
701        method: "multiplicative".to_string(),
702    }
703}
704
705// ============================================================================
706// STL Decomposition (Cleveland et al., 1990)
707// ============================================================================
708
709/// Result of STL decomposition including robustness weights.
710#[derive(Debug, Clone)]
711pub struct StlResult {
712    /// Trend component (n x m column-major)
713    pub trend: Vec<f64>,
714    /// Seasonal component (n x m column-major)
715    pub seasonal: Vec<f64>,
716    /// Remainder/residual component (n x m column-major)
717    pub remainder: Vec<f64>,
718    /// Robustness weights per point (n x m column-major)
719    pub weights: Vec<f64>,
720    /// Period used for decomposition
721    pub period: usize,
722    /// Seasonal smoothing window
723    pub s_window: usize,
724    /// Trend smoothing window
725    pub t_window: usize,
726    /// Number of inner loop iterations performed
727    pub inner_iterations: usize,
728    /// Number of outer loop iterations performed
729    pub outer_iterations: usize,
730}
731
732/// STL Decomposition: Seasonal and Trend decomposition using LOESS
733///
734/// Implements the Cleveland et al. (1990) algorithm for robust iterative
735/// decomposition of time series into trend, seasonal, and remainder components.
736///
737/// # Algorithm Overview
738/// - **Inner Loop**: Extracts seasonal and trend components using LOESS smoothing
739/// - **Outer Loop**: Computes robustness weights to downweight outliers
740///
741/// # Arguments
742/// * `data` - Column-major matrix (n x m): n samples, m evaluation points
743/// * `n` - Number of samples
744/// * `m` - Number of evaluation points
745/// * `period` - Seasonal period (number of observations per cycle)
746/// * `s_window` - Seasonal smoothing window (must be odd, ≥7 recommended)
747/// * `t_window` - Trend smoothing window. If None, uses default formula
748/// * `l_window` - Low-pass filter window. If None, uses period
749/// * `robust` - Whether to perform robustness iterations
750/// * `inner_iterations` - Number of inner loop iterations. Default: 2
751/// * `outer_iterations` - Number of outer loop iterations. Default: 1 (or 15 if robust)
752///
753/// # Returns
754/// `StlResult` with trend, seasonal, remainder, and robustness weights
755///
756/// # References
757/// Cleveland, R. B., Cleveland, W. S., McRae, J. E., & Terpenning, I. (1990).
758/// STL: A Seasonal-Trend Decomposition Procedure Based on Loess.
759/// Journal of Official Statistics, 6(1), 3-73.
760pub fn stl_decompose(
761    data: &[f64],
762    n: usize,
763    m: usize,
764    period: usize,
765    s_window: Option<usize>,
766    t_window: Option<usize>,
767    l_window: Option<usize>,
768    robust: bool,
769    inner_iterations: Option<usize>,
770    outer_iterations: Option<usize>,
771) -> StlResult {
772    // Validate inputs
773    if n == 0 || m < 2 * period || data.len() != n * m || period < 2 {
774        return StlResult {
775            trend: vec![0.0; n * m],
776            seasonal: vec![0.0; n * m],
777            remainder: data.to_vec(),
778            weights: vec![1.0; n * m],
779            period,
780            s_window: 0,
781            t_window: 0,
782            inner_iterations: 0,
783            outer_iterations: 0,
784        };
785    }
786
787    // Set default parameters following Cleveland et al. recommendations
788    let s_win = s_window.unwrap_or(7).max(3) | 1; // Ensure odd
789
790    // Default t_window: smallest odd integer >= (1.5 * period) / (1 - 1.5/s_window)
791    let t_win = t_window.unwrap_or_else(|| {
792        let ratio = 1.5 * period as f64 / (1.0 - 1.5 / s_win as f64);
793        let val = ratio.ceil() as usize;
794        val.max(3) | 1 // Ensure odd
795    });
796
797    // Low-pass filter window: smallest odd integer >= period
798    let l_win = l_window.unwrap_or(period) | 1;
799
800    let n_inner = inner_iterations.unwrap_or(2);
801    let n_outer = outer_iterations.unwrap_or(if robust { 15 } else { 1 });
802
803    // Process each sample in parallel
804    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
805        .map(|i| {
806            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
807            stl_single_series(
808                &curve, period, s_win, t_win, l_win, robust, n_inner, n_outer,
809            )
810        })
811        .collect();
812
813    // Reassemble into column-major format
814    let mut trend = vec![0.0; n * m];
815    let mut seasonal = vec![0.0; n * m];
816    let mut remainder = vec![0.0; n * m];
817    let mut weights = vec![1.0; n * m];
818
819    for (i, (t, s, r, w)) in results.into_iter().enumerate() {
820        for j in 0..m {
821            trend[i + j * n] = t[j];
822            seasonal[i + j * n] = s[j];
823            remainder[i + j * n] = r[j];
824            weights[i + j * n] = w[j];
825        }
826    }
827
828    StlResult {
829        trend,
830        seasonal,
831        remainder,
832        weights,
833        period,
834        s_window: s_win,
835        t_window: t_win,
836        inner_iterations: n_inner,
837        outer_iterations: n_outer,
838    }
839}
840
841/// STL decomposition for a single time series.
842fn stl_single_series(
843    data: &[f64],
844    period: usize,
845    s_window: usize,
846    t_window: usize,
847    l_window: usize,
848    robust: bool,
849    n_inner: usize,
850    n_outer: usize,
851) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
852    let m = data.len();
853
854    // Initialize components
855    let mut trend = vec![0.0; m];
856    let mut seasonal = vec![0.0; m];
857    let mut weights = vec![1.0; m];
858
859    // Outer loop for robustness
860    for _outer in 0..n_outer {
861        // Inner loop
862        for _inner in 0..n_inner {
863            // Step 1: Detrending
864            let detrended: Vec<f64> = data
865                .iter()
866                .zip(trend.iter())
867                .map(|(&y, &t)| y - t)
868                .collect();
869
870            // Step 2: Cycle-subseries smoothing
871            let cycle_smoothed = smooth_cycle_subseries(&detrended, period, s_window, &weights);
872
873            // Step 3: Low-pass filtering of smoothed cycle-subseries
874            let low_pass = stl_lowpass_filter(&cycle_smoothed, period, l_window);
875
876            // Step 4: Detrending the smoothed cycle-subseries
877            seasonal = cycle_smoothed
878                .iter()
879                .zip(low_pass.iter())
880                .map(|(&c, &l)| c - l)
881                .collect();
882
883            // Step 5: Deseasonalizing
884            let deseasonalized: Vec<f64> = data
885                .iter()
886                .zip(seasonal.iter())
887                .map(|(&y, &s)| y - s)
888                .collect();
889
890            // Step 6: Trend smoothing (weighted LOESS)
891            trend = weighted_loess(&deseasonalized, t_window, &weights);
892        }
893
894        // After inner loop: compute residuals and robustness weights
895        if robust && _outer < n_outer - 1 {
896            let remainder: Vec<f64> = data
897                .iter()
898                .zip(trend.iter())
899                .zip(seasonal.iter())
900                .map(|((&y, &t), &s)| y - t - s)
901                .collect();
902
903            weights = compute_robustness_weights(&remainder);
904        }
905    }
906
907    // Final remainder
908    let remainder: Vec<f64> = data
909        .iter()
910        .zip(trend.iter())
911        .zip(seasonal.iter())
912        .map(|((&y, &t), &s)| y - t - s)
913        .collect();
914
915    (trend, seasonal, remainder, weights)
916}
917
918/// Smooth cycle-subseries: for each seasonal position, smooth across cycles.
919fn smooth_cycle_subseries(
920    data: &[f64],
921    period: usize,
922    s_window: usize,
923    weights: &[f64],
924) -> Vec<f64> {
925    let m = data.len();
926    let n_cycles = (m + period - 1) / period;
927    let mut result = vec![0.0; m];
928
929    // For each position in the cycle (0, 1, ..., period-1)
930    for pos in 0..period {
931        // Extract subseries at this position
932        let mut subseries_idx: Vec<usize> = Vec::new();
933        let mut subseries_vals: Vec<f64> = Vec::new();
934        let mut subseries_weights: Vec<f64> = Vec::new();
935
936        for cycle in 0..n_cycles {
937            let idx = cycle * period + pos;
938            if idx < m {
939                subseries_idx.push(idx);
940                subseries_vals.push(data[idx]);
941                subseries_weights.push(weights[idx]);
942            }
943        }
944
945        if subseries_vals.is_empty() {
946            continue;
947        }
948
949        // Smooth this subseries using weighted LOESS
950        let smoothed = weighted_loess(&subseries_vals, s_window, &subseries_weights);
951
952        // Put smoothed values back
953        for (i, &idx) in subseries_idx.iter().enumerate() {
954            result[idx] = smoothed[i];
955        }
956    }
957
958    result
959}
960
961/// Low-pass filter for STL (combination of moving averages).
962/// Applies: MA(period) -> MA(period) -> MA(3)
963fn stl_lowpass_filter(data: &[f64], period: usize, _l_window: usize) -> Vec<f64> {
964    // First MA with period
965    let ma1 = moving_average(data, period);
966    // Second MA with period
967    let ma2 = moving_average(&ma1, period);
968    // Third MA with 3
969    moving_average(&ma2, 3)
970}
971
972/// Simple moving average with window size.
973fn moving_average(data: &[f64], window: usize) -> Vec<f64> {
974    let m = data.len();
975    if m == 0 || window == 0 {
976        return data.to_vec();
977    }
978
979    let half = window / 2;
980    let mut result = vec![0.0; m];
981
982    for i in 0..m {
983        let start = i.saturating_sub(half);
984        let end = (i + half + 1).min(m);
985        let sum: f64 = data[start..end].iter().sum();
986        let count = (end - start) as f64;
987        result[i] = sum / count;
988    }
989
990    result
991}
992
993/// Weighted LOESS smoothing.
994fn weighted_loess(data: &[f64], window: usize, weights: &[f64]) -> Vec<f64> {
995    let m = data.len();
996    if m == 0 {
997        return vec![];
998    }
999
1000    let half = window / 2;
1001    let mut result = vec![0.0; m];
1002
1003    for i in 0..m {
1004        let start = i.saturating_sub(half);
1005        let end = (i + half + 1).min(m);
1006
1007        // Compute weighted local linear regression
1008        let mut sum_w = 0.0;
1009        let mut sum_wx = 0.0;
1010        let mut sum_wy = 0.0;
1011        let mut sum_wxx = 0.0;
1012        let mut sum_wxy = 0.0;
1013
1014        for j in start..end {
1015            // Tricube weight based on distance
1016            let dist = (j as f64 - i as f64).abs() / (half.max(1) as f64);
1017            let tricube = if dist < 1.0 {
1018                (1.0 - dist.powi(3)).powi(3)
1019            } else {
1020                0.0
1021            };
1022
1023            let w = tricube * weights[j];
1024            let x = j as f64;
1025            let y = data[j];
1026
1027            sum_w += w;
1028            sum_wx += w * x;
1029            sum_wy += w * y;
1030            sum_wxx += w * x * x;
1031            sum_wxy += w * x * y;
1032        }
1033
1034        // Solve weighted least squares
1035        if sum_w > 1e-10 {
1036            let denom = sum_w * sum_wxx - sum_wx * sum_wx;
1037            if denom.abs() > 1e-10 {
1038                let intercept = (sum_wxx * sum_wy - sum_wx * sum_wxy) / denom;
1039                let slope = (sum_w * sum_wxy - sum_wx * sum_wy) / denom;
1040                result[i] = intercept + slope * i as f64;
1041            } else {
1042                result[i] = sum_wy / sum_w;
1043            }
1044        } else {
1045            result[i] = data[i];
1046        }
1047    }
1048
1049    result
1050}
1051
1052/// Compute robustness weights using bisquare function.
1053fn compute_robustness_weights(residuals: &[f64]) -> Vec<f64> {
1054    let m = residuals.len();
1055    if m == 0 {
1056        return vec![];
1057    }
1058
1059    // Compute median absolute deviation (MAD)
1060    let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
1061    abs_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1062
1063    let median_idx = m / 2;
1064    let mad = if m % 2 == 0 {
1065        (abs_residuals[median_idx - 1] + abs_residuals[median_idx]) / 2.0
1066    } else {
1067        abs_residuals[median_idx]
1068    };
1069
1070    // Scale factor: 6 * MAD (Cleveland et al. use 6 MAD)
1071    let h = 6.0 * mad.max(1e-10);
1072
1073    // Bisquare weight function
1074    residuals
1075        .iter()
1076        .map(|&r| {
1077            let u = r.abs() / h;
1078            if u < 1.0 {
1079                (1.0 - u * u).powi(2)
1080            } else {
1081                0.0
1082            }
1083        })
1084        .collect()
1085}
1086
1087/// Wrapper function for functional data STL decomposition.
1088///
1089/// Computes STL decomposition for each curve in the functional data object
1090/// and returns aggregated results.
1091///
1092/// # Arguments
1093/// * `data` - Column-major matrix (n x m) of functional data
1094/// * `n` - Number of samples (rows)
1095/// * `m` - Number of evaluation points (columns)
1096/// * `argvals` - Time points of length m (used to infer period if needed)
1097/// * `period` - Seasonal period (in number of observations)
1098/// * `s_window` - Seasonal smoothing window
1099/// * `t_window` - Trend smoothing window (0 for auto)
1100/// * `robust` - Whether to use robustness iterations
1101///
1102/// # Returns
1103/// `StlResult` with decomposed components.
1104pub fn stl_fdata(
1105    data: &[f64],
1106    n: usize,
1107    m: usize,
1108    _argvals: &[f64],
1109    period: usize,
1110    s_window: Option<usize>,
1111    t_window: Option<usize>,
1112    robust: bool,
1113) -> StlResult {
1114    stl_decompose(
1115        data, n, m, period, s_window, t_window, None, robust, None, None,
1116    )
1117}
1118
1119#[cfg(test)]
1120mod tests {
1121    use super::*;
1122    use std::f64::consts::PI;
1123
1124    #[test]
1125    fn test_detrend_linear_removes_linear_trend() {
1126        let m = 100;
1127        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1128
1129        // y = 2 + 0.5*t + sin(2*pi*t/2)
1130        let data: Vec<f64> = argvals
1131            .iter()
1132            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
1133            .collect();
1134
1135        let result = detrend_linear(&data, 1, m, &argvals);
1136
1137        // Detrended should be approximately sin wave
1138        let expected: Vec<f64> = argvals
1139            .iter()
1140            .map(|&t| (2.0 * PI * t / 2.0).sin())
1141            .collect();
1142
1143        let mut max_diff = 0.0f64;
1144        for j in 0..m {
1145            let diff = (result.detrended[j] - expected[j]).abs();
1146            max_diff = max_diff.max(diff);
1147        }
1148        assert!(max_diff < 0.2, "Max difference: {}", max_diff);
1149    }
1150
1151    #[test]
1152    fn test_detrend_polynomial_removes_quadratic_trend() {
1153        let m = 100;
1154        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1155
1156        // y = 1 + 0.5*t - 0.1*t^2 + sin(2*pi*t/2)
1157        let data: Vec<f64> = argvals
1158            .iter()
1159            .map(|&t| 1.0 + 0.5 * t - 0.1 * t * t + (2.0 * PI * t / 2.0).sin())
1160            .collect();
1161
1162        let result = detrend_polynomial(&data, 1, m, &argvals, 2);
1163
1164        // Detrended should be approximately sin wave
1165        let expected: Vec<f64> = argvals
1166            .iter()
1167            .map(|&t| (2.0 * PI * t / 2.0).sin())
1168            .collect();
1169
1170        // Compute correlation
1171        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
1172        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
1173        let mut num = 0.0;
1174        let mut den_det = 0.0;
1175        let mut den_exp = 0.0;
1176        for j in 0..m {
1177            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
1178            den_det += (result.detrended[j] - mean_det).powi(2);
1179            den_exp += (expected[j] - mean_exp).powi(2);
1180        }
1181        let corr = num / (den_det.sqrt() * den_exp.sqrt());
1182        assert!(corr > 0.95, "Correlation: {}", corr);
1183    }
1184
1185    #[test]
1186    fn test_detrend_diff1() {
1187        let m = 100;
1188        // Random walk: cumsum of random values
1189        let data: Vec<f64> = {
1190            let mut v = vec![0.0; m];
1191            v[0] = 1.0;
1192            for i in 1..m {
1193                v[i] = v[i - 1] + 0.1 * (i as f64).sin();
1194            }
1195            v
1196        };
1197
1198        let result = detrend_diff(&data, 1, m, 1);
1199
1200        // First difference should recover the increments
1201        for j in 0..m - 1 {
1202            let expected = data[j + 1] - data[j];
1203            assert!(
1204                (result.detrended[j] - expected).abs() < 1e-10,
1205                "Mismatch at {}: {} vs {}",
1206                j,
1207                result.detrended[j],
1208                expected
1209            );
1210        }
1211    }
1212
1213    #[test]
1214    fn test_auto_detrend_selects_linear_for_linear_data() {
1215        let m = 100;
1216        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1217
1218        // Pure linear trend with small noise
1219        let data: Vec<f64> = argvals.iter().map(|&t| 2.0 + 0.5 * t).collect();
1220
1221        let result = auto_detrend(&data, 1, m, &argvals);
1222
1223        // Should select linear (or poly 2/3 with linear being sufficient)
1224        assert!(
1225            result.method.contains("linear") || result.method.contains("polynomial"),
1226            "Method: {}",
1227            result.method
1228        );
1229    }
1230
1231    // ========================================================================
1232    // Tests for detrend_loess
1233    // ========================================================================
1234
1235    #[test]
1236    fn test_detrend_loess_removes_linear_trend() {
1237        let m = 100;
1238        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1239
1240        // y = 2 + 0.5*t + sin(2*pi*t/2)
1241        let data: Vec<f64> = argvals
1242            .iter()
1243            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
1244            .collect();
1245
1246        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
1247
1248        // Detrended should be approximately sin wave
1249        let expected: Vec<f64> = argvals
1250            .iter()
1251            .map(|&t| (2.0 * PI * t / 2.0).sin())
1252            .collect();
1253
1254        // Compute correlation (LOESS may smooth slightly)
1255        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
1256        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
1257        let mut num = 0.0;
1258        let mut den_det = 0.0;
1259        let mut den_exp = 0.0;
1260        for j in 0..m {
1261            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
1262            den_det += (result.detrended[j] - mean_det).powi(2);
1263            den_exp += (expected[j] - mean_exp).powi(2);
1264        }
1265        let corr = num / (den_det.sqrt() * den_exp.sqrt());
1266        assert!(corr > 0.9, "Correlation: {}", corr);
1267        assert_eq!(result.method, "loess");
1268    }
1269
1270    #[test]
1271    fn test_detrend_loess_removes_quadratic_trend() {
1272        let m = 100;
1273        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1274
1275        // y = 1 + 0.3*t - 0.05*t^2 + sin(2*pi*t/2)
1276        let data: Vec<f64> = argvals
1277            .iter()
1278            .map(|&t| 1.0 + 0.3 * t - 0.05 * t * t + (2.0 * PI * t / 2.0).sin())
1279            .collect();
1280
1281        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 2);
1282
1283        // Trend should follow the quadratic shape
1284        assert_eq!(result.trend.len(), m);
1285        assert_eq!(result.detrended.len(), m);
1286
1287        // Check that RSS is computed
1288        assert!(result.rss[0] > 0.0);
1289    }
1290
1291    #[test]
1292    fn test_detrend_loess_different_bandwidths() {
1293        let m = 100;
1294        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1295
1296        // Noisy sine wave
1297        let data: Vec<f64> = argvals
1298            .iter()
1299            .enumerate()
1300            .map(|(i, &t)| (2.0 * PI * t / 2.0).sin() + 0.1 * ((i * 17) % 100) as f64 / 100.0)
1301            .collect();
1302
1303        // Small bandwidth = more local = rougher trend
1304        let result_small = detrend_loess(&data, 1, m, &argvals, 0.1, 1);
1305        // Large bandwidth = smoother trend
1306        let result_large = detrend_loess(&data, 1, m, &argvals, 0.5, 1);
1307
1308        // Both should produce valid results
1309        assert_eq!(result_small.trend.len(), m);
1310        assert_eq!(result_large.trend.len(), m);
1311
1312        // Large bandwidth should have more parameters
1313        assert!(result_large.n_params >= result_small.n_params);
1314    }
1315
1316    #[test]
1317    fn test_detrend_loess_short_series() {
1318        let m = 10;
1319        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1320        let data: Vec<f64> = argvals.iter().map(|&t| t * 2.0).collect();
1321
1322        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
1323
1324        // Should still work on short series
1325        assert_eq!(result.trend.len(), m);
1326        assert_eq!(result.detrended.len(), m);
1327    }
1328
1329    // ========================================================================
1330    // Tests for decompose_additive
1331    // ========================================================================
1332
1333    #[test]
1334    fn test_decompose_additive_separates_components() {
1335        let m = 200;
1336        let period = 2.0;
1337        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1338
1339        // data = trend + seasonal: y = 2 + 0.5*t + sin(2*pi*t/2)
1340        let data: Vec<f64> = argvals
1341            .iter()
1342            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / period).sin())
1343            .collect();
1344
1345        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1346
1347        assert_eq!(result.trend.len(), m);
1348        assert_eq!(result.seasonal.len(), m);
1349        assert_eq!(result.remainder.len(), m);
1350        assert_eq!(result.method, "additive");
1351        assert_eq!(result.period, period);
1352
1353        // Check that components approximately sum to original
1354        for j in 0..m {
1355            let reconstructed = result.trend[j] + result.seasonal[j] + result.remainder[j];
1356            assert!(
1357                (reconstructed - data[j]).abs() < 0.5,
1358                "Reconstruction error at {}: {} vs {}",
1359                j,
1360                reconstructed,
1361                data[j]
1362            );
1363        }
1364    }
1365
1366    #[test]
1367    fn test_decompose_additive_different_harmonics() {
1368        let m = 200;
1369        let period = 2.0;
1370        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1371
1372        // Simple seasonal pattern
1373        let data: Vec<f64> = argvals
1374            .iter()
1375            .map(|&t| 1.0 + (2.0 * PI * t / period).sin())
1376            .collect();
1377
1378        // 1 harmonic
1379        let result1 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 1);
1380        // 5 harmonics
1381        let result5 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 5);
1382
1383        // Both should produce valid results
1384        assert_eq!(result1.seasonal.len(), m);
1385        assert_eq!(result5.seasonal.len(), m);
1386    }
1387
1388    #[test]
1389    fn test_decompose_additive_residual_properties() {
1390        let m = 200;
1391        let period = 2.0;
1392        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1393
1394        // Data with trend and seasonal
1395        let data: Vec<f64> = argvals
1396            .iter()
1397            .map(|&t| 2.0 + 0.3 * t + (2.0 * PI * t / period).sin())
1398            .collect();
1399
1400        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1401
1402        // Remainder should have mean close to zero
1403        let mean_rem: f64 = result.remainder.iter().sum::<f64>() / m as f64;
1404        assert!(mean_rem.abs() < 0.5, "Remainder mean: {}", mean_rem);
1405
1406        // Remainder variance should be smaller than original variance
1407        let var_data: f64 = data
1408            .iter()
1409            .map(|&x| (x - data.iter().sum::<f64>() / m as f64).powi(2))
1410            .sum::<f64>()
1411            / m as f64;
1412        let var_rem: f64 = result
1413            .remainder
1414            .iter()
1415            .map(|&x| (x - mean_rem).powi(2))
1416            .sum::<f64>()
1417            / m as f64;
1418        assert!(
1419            var_rem < var_data,
1420            "Remainder variance {} should be < data variance {}",
1421            var_rem,
1422            var_data
1423        );
1424    }
1425
1426    #[test]
1427    fn test_decompose_additive_multi_sample() {
1428        let n = 3;
1429        let m = 100;
1430        let period = 2.0;
1431        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1432
1433        // Create 3 samples with different amplitudes
1434        let mut data = vec![0.0; n * m];
1435        for i in 0..n {
1436            let amp = (i + 1) as f64;
1437            for j in 0..m {
1438                data[i + j * n] =
1439                    1.0 + 0.1 * argvals[j] + amp * (2.0 * PI * argvals[j] / period).sin();
1440            }
1441        }
1442
1443        let result = decompose_additive(&data, n, m, &argvals, period, "loess", 0.3, 2);
1444
1445        assert_eq!(result.trend.len(), n * m);
1446        assert_eq!(result.seasonal.len(), n * m);
1447        assert_eq!(result.remainder.len(), n * m);
1448    }
1449
1450    // ========================================================================
1451    // Tests for decompose_multiplicative
1452    // ========================================================================
1453
1454    #[test]
1455    fn test_decompose_multiplicative_basic() {
1456        let m = 200;
1457        let period = 2.0;
1458        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1459
1460        // Multiplicative: data = trend * seasonal
1461        // trend = 2 + 0.1*t, seasonal = 1 + 0.3*sin(...)
1462        let data: Vec<f64> = argvals
1463            .iter()
1464            .map(|&t| (2.0 + 0.1 * t) * (1.0 + 0.3 * (2.0 * PI * t / period).sin()))
1465            .collect();
1466
1467        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1468
1469        assert_eq!(result.trend.len(), m);
1470        assert_eq!(result.seasonal.len(), m);
1471        assert_eq!(result.remainder.len(), m);
1472        assert_eq!(result.method, "multiplicative");
1473
1474        // Seasonal factors should be centered around 1
1475        let mean_seasonal: f64 = result.seasonal.iter().sum::<f64>() / m as f64;
1476        assert!(
1477            (mean_seasonal - 1.0).abs() < 0.5,
1478            "Mean seasonal factor: {}",
1479            mean_seasonal
1480        );
1481    }
1482
1483    #[test]
1484    fn test_decompose_multiplicative_non_positive_data() {
1485        let m = 100;
1486        let period = 2.0;
1487        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1488
1489        // Data with negative values
1490        let data: Vec<f64> = argvals
1491            .iter()
1492            .map(|&t| -1.0 + (2.0 * PI * t / period).sin())
1493            .collect();
1494
1495        // Should handle negative values by shifting
1496        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 2);
1497
1498        assert_eq!(result.trend.len(), m);
1499        assert_eq!(result.seasonal.len(), m);
1500        // All seasonal values should be positive (multiplicative factors)
1501        for &s in result.seasonal.iter() {
1502            assert!(s.is_finite(), "Seasonal should be finite");
1503        }
1504    }
1505
1506    #[test]
1507    fn test_decompose_multiplicative_vs_additive() {
1508        let m = 200;
1509        let period = 2.0;
1510        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1511
1512        // Simple positive data
1513        let data: Vec<f64> = argvals
1514            .iter()
1515            .map(|&t| 5.0 + (2.0 * PI * t / period).sin())
1516            .collect();
1517
1518        let add_result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1519        let mult_result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1520
1521        // Both should produce valid decompositions
1522        assert_eq!(add_result.seasonal.len(), m);
1523        assert_eq!(mult_result.seasonal.len(), m);
1524
1525        // Additive seasonal oscillates around 0
1526        let add_mean: f64 = add_result.seasonal.iter().sum::<f64>() / m as f64;
1527        // Multiplicative seasonal oscillates around 1
1528        let mult_mean: f64 = mult_result.seasonal.iter().sum::<f64>() / m as f64;
1529
1530        assert!(
1531            add_mean.abs() < mult_mean,
1532            "Additive mean {} vs mult mean {}",
1533            add_mean,
1534            mult_mean
1535        );
1536    }
1537
1538    #[test]
1539    fn test_decompose_multiplicative_edge_cases() {
1540        // Empty data
1541        let result = decompose_multiplicative(&[], 0, 0, &[], 2.0, "loess", 0.3, 2);
1542        assert_eq!(result.trend.len(), 0);
1543
1544        // Very short series
1545        let m = 5;
1546        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1547        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 2.0, 1.0];
1548        let result = decompose_multiplicative(&data, 1, m, &argvals, 2.0, "loess", 0.3, 1);
1549        // Should return original data as remainder for too-short series
1550        assert_eq!(result.remainder.len(), m);
1551    }
1552}