fdars_core/
detrend.rs

1//! Detrending and decomposition functions for non-stationary functional data.
2//!
3//! This module provides methods for removing trends from functional data
4//! to enable more accurate seasonal analysis. It includes:
5//! - Linear detrending (least squares)
6//! - Polynomial detrending (QR decomposition)
7//! - Differencing (first and second order)
8//! - LOESS detrending (local polynomial regression)
9//! - Spline detrending (P-splines)
10//! - Automatic method selection via AIC
11
12use crate::iter_maybe_parallel;
13use crate::smoothing::local_polynomial;
14use nalgebra::{DMatrix, DVector};
15#[cfg(feature = "parallel")]
16use rayon::iter::ParallelIterator;
17
18/// Result of detrending operation.
19#[derive(Debug, Clone)]
20pub struct TrendResult {
21    /// Estimated trend values (n x m column-major)
22    pub trend: Vec<f64>,
23    /// Detrended data (n x m column-major)
24    pub detrended: Vec<f64>,
25    /// Method used for detrending
26    pub method: String,
27    /// Polynomial coefficients (for polynomial methods, per sample)
28    /// For n samples with polynomial degree d: coefficients[i * (d+1) + k] is coefficient k for sample i
29    pub coefficients: Option<Vec<f64>>,
30    /// Residual sum of squares for each sample
31    pub rss: Vec<f64>,
32    /// Number of parameters (for AIC calculation)
33    pub n_params: usize,
34}
35
36/// Result of seasonal decomposition.
37#[derive(Debug, Clone)]
38pub struct DecomposeResult {
39    /// Trend component (n x m column-major)
40    pub trend: Vec<f64>,
41    /// Seasonal component (n x m column-major)
42    pub seasonal: Vec<f64>,
43    /// Remainder/residual component (n x m column-major)
44    pub remainder: Vec<f64>,
45    /// Period used for decomposition
46    pub period: f64,
47    /// Decomposition method ("additive" or "multiplicative")
48    pub method: String,
49}
50
51/// Remove linear trend from functional data using least squares.
52///
53/// # Arguments
54/// * `data` - Column-major matrix (n x m): n samples, m evaluation points
55/// * `n` - Number of samples
56/// * `m` - Number of evaluation points
57/// * `argvals` - Time/argument values of length m
58///
59/// # Returns
60/// TrendResult with trend, detrended data, and coefficients (intercept, slope)
61pub fn detrend_linear(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
62    if n == 0 || m < 2 || data.len() != n * m || argvals.len() != m {
63        return TrendResult {
64            trend: vec![0.0; n * m],
65            detrended: data.to_vec(),
66            method: "linear".to_string(),
67            coefficients: None,
68            rss: vec![0.0; n],
69            n_params: 2,
70        };
71    }
72
73    // Precompute t statistics
74    let mean_t: f64 = argvals.iter().sum::<f64>() / m as f64;
75    let ss_t: f64 = argvals.iter().map(|&t| (t - mean_t).powi(2)).sum();
76
77    // Process each sample in parallel
78    let results: Vec<(Vec<f64>, Vec<f64>, f64, f64, f64)> = iter_maybe_parallel!(0..n)
79        .map(|i| {
80            // Extract curve
81            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
82            let mean_y: f64 = curve.iter().sum::<f64>() / m as f64;
83
84            // Compute slope: sum((t - mean_t) * (y - mean_y)) / sum((t - mean_t)^2)
85            let mut sp = 0.0;
86            for j in 0..m {
87                sp += (argvals[j] - mean_t) * (curve[j] - mean_y);
88            }
89            let slope = if ss_t.abs() > 1e-15 { sp / ss_t } else { 0.0 };
90            let intercept = mean_y - slope * mean_t;
91
92            // Compute trend and detrended
93            let mut trend = vec![0.0; m];
94            let mut detrended = vec![0.0; m];
95            let mut rss = 0.0;
96            for j in 0..m {
97                trend[j] = intercept + slope * argvals[j];
98                detrended[j] = curve[j] - trend[j];
99                rss += detrended[j].powi(2);
100            }
101
102            (trend, detrended, intercept, slope, rss)
103        })
104        .collect();
105
106    // Reassemble into column-major format
107    let mut trend = vec![0.0; n * m];
108    let mut detrended = vec![0.0; n * m];
109    let mut coefficients = vec![0.0; n * 2];
110    let mut rss = vec![0.0; n];
111
112    for (i, (t, d, intercept, slope, r)) in results.into_iter().enumerate() {
113        for j in 0..m {
114            trend[i + j * n] = t[j];
115            detrended[i + j * n] = d[j];
116        }
117        coefficients[i * 2] = intercept;
118        coefficients[i * 2 + 1] = slope;
119        rss[i] = r;
120    }
121
122    TrendResult {
123        trend,
124        detrended,
125        method: "linear".to_string(),
126        coefficients: Some(coefficients),
127        rss,
128        n_params: 2,
129    }
130}
131
132/// Remove polynomial trend from functional data using QR decomposition.
133///
134/// # Arguments
135/// * `data` - Column-major matrix (n x m)
136/// * `n` - Number of samples
137/// * `m` - Number of evaluation points
138/// * `argvals` - Time/argument values of length m
139/// * `degree` - Polynomial degree (1 = linear, 2 = quadratic, etc.)
140///
141/// # Returns
142/// TrendResult with trend, detrended data, and polynomial coefficients
143pub fn detrend_polynomial(
144    data: &[f64],
145    n: usize,
146    m: usize,
147    argvals: &[f64],
148    degree: usize,
149) -> TrendResult {
150    if n == 0 || m < degree + 1 || data.len() != n * m || argvals.len() != m || degree == 0 {
151        // For degree 0 or invalid input, return original data
152        return TrendResult {
153            trend: vec![0.0; n * m],
154            detrended: data.to_vec(),
155            method: format!("polynomial({})", degree),
156            coefficients: None,
157            rss: vec![0.0; n],
158            n_params: degree + 1,
159        };
160    }
161
162    // Special case: degree 1 is linear
163    if degree == 1 {
164        let mut result = detrend_linear(data, n, m, argvals);
165        result.method = "polynomial(1)".to_string();
166        return result;
167    }
168
169    let n_coef = degree + 1;
170
171    // Normalize argvals to avoid numerical issues with high-degree polynomials
172    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
173    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
174    let t_range = if (t_max - t_min).abs() > 1e-15 {
175        t_max - t_min
176    } else {
177        1.0
178    };
179    let t_norm: Vec<f64> = argvals.iter().map(|&t| (t - t_min) / t_range).collect();
180
181    // Build Vandermonde matrix (m x n_coef)
182    let mut design = DMatrix::zeros(m, n_coef);
183    for j in 0..m {
184        let t = t_norm[j];
185        let mut power = 1.0;
186        for k in 0..n_coef {
187            design[(j, k)] = power;
188            power *= t;
189        }
190    }
191
192    // SVD for stable least squares
193    let svd = design.clone().svd(true, true);
194
195    // Process each sample in parallel
196    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
197        .map(|i| {
198            // Extract curve
199            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
200            let y = DVector::from_row_slice(&curve);
201
202            // Solve least squares using SVD
203            let beta = svd
204                .solve(&y, 1e-10)
205                .unwrap_or_else(|_| DVector::zeros(n_coef));
206
207            // Compute fitted values (trend) and residuals
208            let fitted = &design * &beta;
209            let mut trend = vec![0.0; m];
210            let mut detrended = vec![0.0; m];
211            let mut rss = 0.0;
212            for j in 0..m {
213                trend[j] = fitted[j];
214                detrended[j] = curve[j] - fitted[j];
215                rss += detrended[j].powi(2);
216            }
217
218            // Extract coefficients
219            let coefs: Vec<f64> = beta.iter().cloned().collect();
220
221            (trend, detrended, coefs, rss)
222        })
223        .collect();
224
225    // Reassemble into column-major format
226    let mut trend = vec![0.0; n * m];
227    let mut detrended = vec![0.0; n * m];
228    let mut coefficients = vec![0.0; n * n_coef];
229    let mut rss = vec![0.0; n];
230
231    for (i, (t, d, coefs, r)) in results.into_iter().enumerate() {
232        for j in 0..m {
233            trend[i + j * n] = t[j];
234            detrended[i + j * n] = d[j];
235        }
236        for k in 0..n_coef {
237            coefficients[i * n_coef + k] = coefs[k];
238        }
239        rss[i] = r;
240    }
241
242    TrendResult {
243        trend,
244        detrended,
245        method: format!("polynomial({})", degree),
246        coefficients: Some(coefficients),
247        rss,
248        n_params: n_coef,
249    }
250}
251
252/// Remove trend by differencing.
253///
254/// # Arguments
255/// * `data` - Column-major matrix (n x m)
256/// * `n` - Number of samples
257/// * `m` - Number of evaluation points
258/// * `order` - Differencing order (1 or 2)
259///
260/// # Returns
261/// TrendResult with trend (cumulative sum to reverse), detrended (differences),
262/// and original first values as "coefficients"
263///
264/// Note: Differencing reduces the series length by `order` points.
265/// The returned detrended data has m - order points padded with zeros at the end.
266pub fn detrend_diff(data: &[f64], n: usize, m: usize, order: usize) -> TrendResult {
267    if n == 0 || m <= order || data.len() != n * m || order == 0 || order > 2 {
268        return TrendResult {
269            trend: vec![0.0; n * m],
270            detrended: data.to_vec(),
271            method: format!("diff{}", order),
272            coefficients: None,
273            rss: vec![0.0; n],
274            n_params: order,
275        };
276    }
277
278    let new_m = m - order;
279
280    // Process each sample in parallel
281    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
282        .map(|i| {
283            // Extract curve
284            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
285
286            // First difference
287            let diff1: Vec<f64> = (0..m - 1).map(|j| curve[j + 1] - curve[j]).collect();
288
289            // Second difference if order == 2
290            let detrended = if order == 2 {
291                (0..diff1.len() - 1)
292                    .map(|j| diff1[j + 1] - diff1[j])
293                    .collect()
294            } else {
295                diff1.clone()
296            };
297
298            // Store initial values needed for reconstruction
299            let initial_values = if order == 2 {
300                vec![curve[0], curve[1]]
301            } else {
302                vec![curve[0]]
303            };
304
305            // Compute RSS (sum of squared differences as "residuals" - interpretation varies)
306            let rss: f64 = detrended.iter().map(|&x| x.powi(2)).sum();
307
308            // For "trend", we reconstruct as cumsum of differences
309            // This is a rough approximation; true trend would need integration
310            let mut trend = vec![0.0; m];
311            trend[0] = curve[0];
312            if order == 1 {
313                for j in 1..m {
314                    trend[j] = curve[j] - if j <= new_m { detrended[j - 1] } else { 0.0 };
315                }
316            } else {
317                // For order 2, trend is less meaningful
318                trend = curve.clone();
319            }
320
321            // Pad detrended to full length
322            let mut det_full = vec![0.0; m];
323            det_full[..new_m].copy_from_slice(&detrended[..new_m]);
324
325            (trend, det_full, initial_values, rss)
326        })
327        .collect();
328
329    // Reassemble
330    let mut trend = vec![0.0; n * m];
331    let mut detrended = vec![0.0; n * m];
332    let mut coefficients = vec![0.0; n * order];
333    let mut rss = vec![0.0; n];
334
335    for (i, (t, d, init, r)) in results.into_iter().enumerate() {
336        for j in 0..m {
337            trend[i + j * n] = t[j];
338            detrended[i + j * n] = d[j];
339        }
340        for k in 0..order {
341            coefficients[i * order + k] = init[k];
342        }
343        rss[i] = r;
344    }
345
346    TrendResult {
347        trend,
348        detrended,
349        method: format!("diff{}", order),
350        coefficients: Some(coefficients),
351        rss,
352        n_params: order,
353    }
354}
355
356/// Remove trend using LOESS (local polynomial regression).
357///
358/// # Arguments
359/// * `data` - Column-major matrix (n x m)
360/// * `n` - Number of samples
361/// * `m` - Number of evaluation points
362/// * `argvals` - Time/argument values
363/// * `bandwidth` - Bandwidth as fraction of data range (0.1 to 0.5 typical)
364/// * `degree` - Local polynomial degree (1 or 2)
365///
366/// # Returns
367/// TrendResult with LOESS-smoothed trend
368pub fn detrend_loess(
369    data: &[f64],
370    n: usize,
371    m: usize,
372    argvals: &[f64],
373    bandwidth: f64,
374    degree: usize,
375) -> TrendResult {
376    if n == 0 || m < 3 || data.len() != n * m || argvals.len() != m || bandwidth <= 0.0 {
377        return TrendResult {
378            trend: vec![0.0; n * m],
379            detrended: data.to_vec(),
380            method: "loess".to_string(),
381            coefficients: None,
382            rss: vec![0.0; n],
383            n_params: (m as f64 * bandwidth).ceil() as usize,
384        };
385    }
386
387    // Convert bandwidth from fraction to absolute units
388    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
389    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
390    let abs_bandwidth = (t_max - t_min) * bandwidth;
391
392    // Process each sample in parallel
393    let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n)
394        .map(|i| {
395            // Extract curve
396            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
397
398            // Apply local polynomial regression
399            let trend =
400                local_polynomial(argvals, &curve, argvals, abs_bandwidth, degree, "gaussian");
401
402            // Compute detrended and RSS
403            let mut detrended = vec![0.0; m];
404            let mut rss = 0.0;
405            for j in 0..m {
406                detrended[j] = curve[j] - trend[j];
407                rss += detrended[j].powi(2);
408            }
409
410            (trend, detrended, rss)
411        })
412        .collect();
413
414    // Reassemble
415    let mut trend = vec![0.0; n * m];
416    let mut detrended = vec![0.0; n * m];
417    let mut rss = vec![0.0; n];
418
419    for (i, (t, d, r)) in results.into_iter().enumerate() {
420        for j in 0..m {
421            trend[i + j * n] = t[j];
422            detrended[i + j * n] = d[j];
423        }
424        rss[i] = r;
425    }
426
427    // Effective number of parameters for LOESS is approximately n * bandwidth
428    let n_params = (m as f64 * bandwidth).ceil() as usize;
429
430    TrendResult {
431        trend,
432        detrended,
433        method: "loess".to_string(),
434        coefficients: None,
435        rss,
436        n_params,
437    }
438}
439
440/// Automatically select the best detrending method using AIC.
441///
442/// Compares linear, polynomial (degree 2 and 3), and LOESS,
443/// selecting the method with lowest AIC.
444///
445/// # Arguments
446/// * `data` - Column-major matrix (n x m)
447/// * `n` - Number of samples
448/// * `m` - Number of evaluation points
449/// * `argvals` - Time/argument values
450///
451/// # Returns
452/// TrendResult from the best method
453pub fn auto_detrend(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
454    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m {
455        return TrendResult {
456            trend: vec![0.0; n * m],
457            detrended: data.to_vec(),
458            method: "auto(none)".to_string(),
459            coefficients: None,
460            rss: vec![0.0; n],
461            n_params: 0,
462        };
463    }
464
465    // Compute AIC for a result: AIC = n * log(RSS/n) + 2*k
466    // We use mean AIC across all samples
467    let compute_aic = |result: &TrendResult| -> f64 {
468        let mut total_aic = 0.0;
469        for i in 0..n {
470            let rss = result.rss[i];
471            let k = result.n_params as f64;
472            let aic = if rss > 1e-15 {
473                m as f64 * (rss / m as f64).ln() + 2.0 * k
474            } else {
475                f64::NEG_INFINITY // Perfect fit (unlikely)
476            };
477            total_aic += aic;
478        }
479        total_aic / n as f64
480    };
481
482    // Try different methods
483    let linear = detrend_linear(data, n, m, argvals);
484    let poly2 = detrend_polynomial(data, n, m, argvals, 2);
485    let poly3 = detrend_polynomial(data, n, m, argvals, 3);
486    let loess = detrend_loess(data, n, m, argvals, 0.3, 2);
487
488    let aic_linear = compute_aic(&linear);
489    let aic_poly2 = compute_aic(&poly2);
490    let aic_poly3 = compute_aic(&poly3);
491    let aic_loess = compute_aic(&loess);
492
493    // Find minimum AIC
494    let methods = [
495        (aic_linear, "linear", linear),
496        (aic_poly2, "polynomial(2)", poly2),
497        (aic_poly3, "polynomial(3)", poly3),
498        (aic_loess, "loess", loess),
499    ];
500
501    let (_, best_name, mut best_result) = methods
502        .into_iter()
503        .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
504        .unwrap();
505
506    best_result.method = format!("auto({})", best_name);
507    best_result
508}
509
510/// Additive seasonal decomposition: data = trend + seasonal + remainder
511///
512/// Uses LOESS or spline for trend extraction, then averages within-period
513/// residuals to estimate the seasonal component.
514///
515/// # Arguments
516/// * `data` - Column-major matrix (n x m)
517/// * `n` - Number of samples
518/// * `m` - Number of evaluation points
519/// * `argvals` - Time/argument values
520/// * `period` - Seasonal period in same units as argvals
521/// * `trend_method` - "loess" or "spline"
522/// * `bandwidth` - Bandwidth for LOESS (fraction, e.g., 0.3)
523/// * `n_harmonics` - Number of Fourier harmonics for seasonal component
524///
525/// # Returns
526/// DecomposeResult with trend, seasonal, and remainder components
527pub fn decompose_additive(
528    data: &[f64],
529    n: usize,
530    m: usize,
531    argvals: &[f64],
532    period: f64,
533    trend_method: &str,
534    bandwidth: f64,
535    n_harmonics: usize,
536) -> DecomposeResult {
537    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
538        return DecomposeResult {
539            trend: vec![0.0; n * m],
540            seasonal: vec![0.0; n * m],
541            remainder: data.to_vec(),
542            period,
543            method: "additive".to_string(),
544        };
545    }
546
547    // Step 1: Extract trend using LOESS or spline
548    let trend_result = match trend_method {
549        "spline" => {
550            // Use P-spline fitting - use a larger bandwidth for trend
551            detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2)
552        }
553        _ => detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2),
554    };
555
556    // Step 2: Extract seasonal component using Fourier basis on detrended data
557    let n_harm = n_harmonics.max(1).min(m / 4);
558    let omega = 2.0 * std::f64::consts::PI / period;
559
560    // Process each sample
561    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
562        .map(|i| {
563            let trend_i: Vec<f64> = (0..m).map(|j| trend_result.trend[i + j * n]).collect();
564            let detrended_i: Vec<f64> = (0..m).map(|j| trend_result.detrended[i + j * n]).collect();
565
566            // Fit Fourier model to detrended data: sum of sin and cos terms
567            // y = sum_k (a_k * cos(k*omega*t) + b_k * sin(k*omega*t))
568            let n_coef = 2 * n_harm;
569            let mut design = DMatrix::zeros(m, n_coef);
570            for j in 0..m {
571                let t = argvals[j];
572                for k in 0..n_harm {
573                    let freq = (k + 1) as f64 * omega;
574                    design[(j, 2 * k)] = (freq * t).cos();
575                    design[(j, 2 * k + 1)] = (freq * t).sin();
576                }
577            }
578
579            // Solve least squares using SVD
580            let y = DVector::from_row_slice(&detrended_i);
581            let svd = design.clone().svd(true, true);
582            let coef = svd
583                .solve(&y, 1e-10)
584                .unwrap_or_else(|_| DVector::zeros(n_coef));
585
586            // Compute seasonal component
587            let fitted = &design * &coef;
588            let seasonal: Vec<f64> = fitted.iter().cloned().collect();
589
590            // Compute remainder
591            let remainder: Vec<f64> = (0..m).map(|j| detrended_i[j] - seasonal[j]).collect();
592
593            (trend_i, seasonal, remainder)
594        })
595        .collect();
596
597    // Reassemble into column-major format
598    let mut trend = vec![0.0; n * m];
599    let mut seasonal = vec![0.0; n * m];
600    let mut remainder = vec![0.0; n * m];
601
602    for (i, (t, s, r)) in results.into_iter().enumerate() {
603        for j in 0..m {
604            trend[i + j * n] = t[j];
605            seasonal[i + j * n] = s[j];
606            remainder[i + j * n] = r[j];
607        }
608    }
609
610    DecomposeResult {
611        trend,
612        seasonal,
613        remainder,
614        period,
615        method: "additive".to_string(),
616    }
617}
618
619/// Multiplicative seasonal decomposition: data = trend * seasonal * remainder
620///
621/// Applies log transformation, then additive decomposition, then back-transforms.
622/// Handles non-positive values by adding a shift.
623///
624/// # Arguments
625/// * `data` - Column-major matrix (n x m)
626/// * `n` - Number of samples
627/// * `m` - Number of evaluation points
628/// * `argvals` - Time/argument values
629/// * `period` - Seasonal period
630/// * `trend_method` - "loess" or "spline"
631/// * `bandwidth` - Bandwidth for LOESS
632/// * `n_harmonics` - Number of Fourier harmonics
633///
634/// # Returns
635/// DecomposeResult with multiplicative components
636pub fn decompose_multiplicative(
637    data: &[f64],
638    n: usize,
639    m: usize,
640    argvals: &[f64],
641    period: f64,
642    trend_method: &str,
643    bandwidth: f64,
644    n_harmonics: usize,
645) -> DecomposeResult {
646    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
647        return DecomposeResult {
648            trend: vec![0.0; n * m],
649            seasonal: vec![0.0; n * m],
650            remainder: data.to_vec(),
651            period,
652            method: "multiplicative".to_string(),
653        };
654    }
655
656    // Find minimum value and add shift if needed to make all values positive
657    let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
658    let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
659
660    // Log transform
661    let log_data: Vec<f64> = data.iter().map(|&x| (x + shift).ln()).collect();
662
663    // Apply additive decomposition to log data
664    let additive_result = decompose_additive(
665        &log_data,
666        n,
667        m,
668        argvals,
669        period,
670        trend_method,
671        bandwidth,
672        n_harmonics,
673    );
674
675    // Back transform: exp of each component
676    // For multiplicative: data = trend * seasonal * remainder
677    // In log space: log(data) = log(trend) + log(seasonal) + log(remainder)
678    // So: trend_mult = exp(trend_add), seasonal_mult = exp(seasonal_add), etc.
679
680    let mut trend = vec![0.0; n * m];
681    let mut seasonal = vec![0.0; n * m];
682    let mut remainder = vec![0.0; n * m];
683
684    for idx in 0..n * m {
685        // Back-transform trend (subtract shift)
686        trend[idx] = additive_result.trend[idx].exp() - shift;
687
688        // Seasonal is a multiplicative factor (centered around 1)
689        // We interpret the additive seasonal component as log(seasonal factor)
690        seasonal[idx] = additive_result.seasonal[idx].exp();
691
692        // Remainder is also multiplicative
693        remainder[idx] = additive_result.remainder[idx].exp();
694    }
695
696    DecomposeResult {
697        trend,
698        seasonal,
699        remainder,
700        period,
701        method: "multiplicative".to_string(),
702    }
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use std::f64::consts::PI;
709
710    #[test]
711    fn test_detrend_linear_removes_linear_trend() {
712        let m = 100;
713        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
714
715        // y = 2 + 0.5*t + sin(2*pi*t/2)
716        let data: Vec<f64> = argvals
717            .iter()
718            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
719            .collect();
720
721        let result = detrend_linear(&data, 1, m, &argvals);
722
723        // Detrended should be approximately sin wave
724        let expected: Vec<f64> = argvals
725            .iter()
726            .map(|&t| (2.0 * PI * t / 2.0).sin())
727            .collect();
728
729        let mut max_diff = 0.0f64;
730        for j in 0..m {
731            let diff = (result.detrended[j] - expected[j]).abs();
732            max_diff = max_diff.max(diff);
733        }
734        assert!(max_diff < 0.2, "Max difference: {}", max_diff);
735    }
736
737    #[test]
738    fn test_detrend_polynomial_removes_quadratic_trend() {
739        let m = 100;
740        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
741
742        // y = 1 + 0.5*t - 0.1*t^2 + sin(2*pi*t/2)
743        let data: Vec<f64> = argvals
744            .iter()
745            .map(|&t| 1.0 + 0.5 * t - 0.1 * t * t + (2.0 * PI * t / 2.0).sin())
746            .collect();
747
748        let result = detrend_polynomial(&data, 1, m, &argvals, 2);
749
750        // Detrended should be approximately sin wave
751        let expected: Vec<f64> = argvals
752            .iter()
753            .map(|&t| (2.0 * PI * t / 2.0).sin())
754            .collect();
755
756        // Compute correlation
757        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
758        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
759        let mut num = 0.0;
760        let mut den_det = 0.0;
761        let mut den_exp = 0.0;
762        for j in 0..m {
763            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
764            den_det += (result.detrended[j] - mean_det).powi(2);
765            den_exp += (expected[j] - mean_exp).powi(2);
766        }
767        let corr = num / (den_det.sqrt() * den_exp.sqrt());
768        assert!(corr > 0.95, "Correlation: {}", corr);
769    }
770
771    #[test]
772    fn test_detrend_diff1() {
773        let m = 100;
774        // Random walk: cumsum of random values
775        let data: Vec<f64> = {
776            let mut v = vec![0.0; m];
777            v[0] = 1.0;
778            for i in 1..m {
779                v[i] = v[i - 1] + 0.1 * (i as f64).sin();
780            }
781            v
782        };
783
784        let result = detrend_diff(&data, 1, m, 1);
785
786        // First difference should recover the increments
787        for j in 0..m - 1 {
788            let expected = data[j + 1] - data[j];
789            assert!(
790                (result.detrended[j] - expected).abs() < 1e-10,
791                "Mismatch at {}: {} vs {}",
792                j,
793                result.detrended[j],
794                expected
795            );
796        }
797    }
798
799    #[test]
800    fn test_auto_detrend_selects_linear_for_linear_data() {
801        let m = 100;
802        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
803
804        // Pure linear trend with small noise
805        let data: Vec<f64> = argvals.iter().map(|&t| 2.0 + 0.5 * t).collect();
806
807        let result = auto_detrend(&data, 1, m, &argvals);
808
809        // Should select linear (or poly 2/3 with linear being sufficient)
810        assert!(
811            result.method.contains("linear") || result.method.contains("polynomial"),
812            "Method: {}",
813            result.method
814        );
815    }
816
817    // ========================================================================
818    // Tests for detrend_loess
819    // ========================================================================
820
821    #[test]
822    fn test_detrend_loess_removes_linear_trend() {
823        let m = 100;
824        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
825
826        // y = 2 + 0.5*t + sin(2*pi*t/2)
827        let data: Vec<f64> = argvals
828            .iter()
829            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
830            .collect();
831
832        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
833
834        // Detrended should be approximately sin wave
835        let expected: Vec<f64> = argvals
836            .iter()
837            .map(|&t| (2.0 * PI * t / 2.0).sin())
838            .collect();
839
840        // Compute correlation (LOESS may smooth slightly)
841        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
842        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
843        let mut num = 0.0;
844        let mut den_det = 0.0;
845        let mut den_exp = 0.0;
846        for j in 0..m {
847            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
848            den_det += (result.detrended[j] - mean_det).powi(2);
849            den_exp += (expected[j] - mean_exp).powi(2);
850        }
851        let corr = num / (den_det.sqrt() * den_exp.sqrt());
852        assert!(corr > 0.9, "Correlation: {}", corr);
853        assert_eq!(result.method, "loess");
854    }
855
856    #[test]
857    fn test_detrend_loess_removes_quadratic_trend() {
858        let m = 100;
859        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
860
861        // y = 1 + 0.3*t - 0.05*t^2 + sin(2*pi*t/2)
862        let data: Vec<f64> = argvals
863            .iter()
864            .map(|&t| 1.0 + 0.3 * t - 0.05 * t * t + (2.0 * PI * t / 2.0).sin())
865            .collect();
866
867        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 2);
868
869        // Trend should follow the quadratic shape
870        assert_eq!(result.trend.len(), m);
871        assert_eq!(result.detrended.len(), m);
872
873        // Check that RSS is computed
874        assert!(result.rss[0] > 0.0);
875    }
876
877    #[test]
878    fn test_detrend_loess_different_bandwidths() {
879        let m = 100;
880        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
881
882        // Noisy sine wave
883        let data: Vec<f64> = argvals
884            .iter()
885            .enumerate()
886            .map(|(i, &t)| (2.0 * PI * t / 2.0).sin() + 0.1 * ((i * 17) % 100) as f64 / 100.0)
887            .collect();
888
889        // Small bandwidth = more local = rougher trend
890        let result_small = detrend_loess(&data, 1, m, &argvals, 0.1, 1);
891        // Large bandwidth = smoother trend
892        let result_large = detrend_loess(&data, 1, m, &argvals, 0.5, 1);
893
894        // Both should produce valid results
895        assert_eq!(result_small.trend.len(), m);
896        assert_eq!(result_large.trend.len(), m);
897
898        // Large bandwidth should have more parameters
899        assert!(result_large.n_params >= result_small.n_params);
900    }
901
902    #[test]
903    fn test_detrend_loess_short_series() {
904        let m = 10;
905        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
906        let data: Vec<f64> = argvals.iter().map(|&t| t * 2.0).collect();
907
908        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
909
910        // Should still work on short series
911        assert_eq!(result.trend.len(), m);
912        assert_eq!(result.detrended.len(), m);
913    }
914
915    // ========================================================================
916    // Tests for decompose_additive
917    // ========================================================================
918
919    #[test]
920    fn test_decompose_additive_separates_components() {
921        let m = 200;
922        let period = 2.0;
923        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
924
925        // data = trend + seasonal: y = 2 + 0.5*t + sin(2*pi*t/2)
926        let data: Vec<f64> = argvals
927            .iter()
928            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / period).sin())
929            .collect();
930
931        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
932
933        assert_eq!(result.trend.len(), m);
934        assert_eq!(result.seasonal.len(), m);
935        assert_eq!(result.remainder.len(), m);
936        assert_eq!(result.method, "additive");
937        assert_eq!(result.period, period);
938
939        // Check that components approximately sum to original
940        for j in 0..m {
941            let reconstructed = result.trend[j] + result.seasonal[j] + result.remainder[j];
942            assert!(
943                (reconstructed - data[j]).abs() < 0.5,
944                "Reconstruction error at {}: {} vs {}",
945                j,
946                reconstructed,
947                data[j]
948            );
949        }
950    }
951
952    #[test]
953    fn test_decompose_additive_different_harmonics() {
954        let m = 200;
955        let period = 2.0;
956        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
957
958        // Simple seasonal pattern
959        let data: Vec<f64> = argvals
960            .iter()
961            .map(|&t| 1.0 + (2.0 * PI * t / period).sin())
962            .collect();
963
964        // 1 harmonic
965        let result1 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 1);
966        // 5 harmonics
967        let result5 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 5);
968
969        // Both should produce valid results
970        assert_eq!(result1.seasonal.len(), m);
971        assert_eq!(result5.seasonal.len(), m);
972    }
973
974    #[test]
975    fn test_decompose_additive_residual_properties() {
976        let m = 200;
977        let period = 2.0;
978        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
979
980        // Data with trend and seasonal
981        let data: Vec<f64> = argvals
982            .iter()
983            .map(|&t| 2.0 + 0.3 * t + (2.0 * PI * t / period).sin())
984            .collect();
985
986        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
987
988        // Remainder should have mean close to zero
989        let mean_rem: f64 = result.remainder.iter().sum::<f64>() / m as f64;
990        assert!(mean_rem.abs() < 0.5, "Remainder mean: {}", mean_rem);
991
992        // Remainder variance should be smaller than original variance
993        let var_data: f64 = data
994            .iter()
995            .map(|&x| (x - data.iter().sum::<f64>() / m as f64).powi(2))
996            .sum::<f64>()
997            / m as f64;
998        let var_rem: f64 = result
999            .remainder
1000            .iter()
1001            .map(|&x| (x - mean_rem).powi(2))
1002            .sum::<f64>()
1003            / m as f64;
1004        assert!(
1005            var_rem < var_data,
1006            "Remainder variance {} should be < data variance {}",
1007            var_rem,
1008            var_data
1009        );
1010    }
1011
1012    #[test]
1013    fn test_decompose_additive_multi_sample() {
1014        let n = 3;
1015        let m = 100;
1016        let period = 2.0;
1017        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1018
1019        // Create 3 samples with different amplitudes
1020        let mut data = vec![0.0; n * m];
1021        for i in 0..n {
1022            let amp = (i + 1) as f64;
1023            for j in 0..m {
1024                data[i + j * n] =
1025                    1.0 + 0.1 * argvals[j] + amp * (2.0 * PI * argvals[j] / period).sin();
1026            }
1027        }
1028
1029        let result = decompose_additive(&data, n, m, &argvals, period, "loess", 0.3, 2);
1030
1031        assert_eq!(result.trend.len(), n * m);
1032        assert_eq!(result.seasonal.len(), n * m);
1033        assert_eq!(result.remainder.len(), n * m);
1034    }
1035
1036    // ========================================================================
1037    // Tests for decompose_multiplicative
1038    // ========================================================================
1039
1040    #[test]
1041    fn test_decompose_multiplicative_basic() {
1042        let m = 200;
1043        let period = 2.0;
1044        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1045
1046        // Multiplicative: data = trend * seasonal
1047        // trend = 2 + 0.1*t, seasonal = 1 + 0.3*sin(...)
1048        let data: Vec<f64> = argvals
1049            .iter()
1050            .map(|&t| (2.0 + 0.1 * t) * (1.0 + 0.3 * (2.0 * PI * t / period).sin()))
1051            .collect();
1052
1053        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1054
1055        assert_eq!(result.trend.len(), m);
1056        assert_eq!(result.seasonal.len(), m);
1057        assert_eq!(result.remainder.len(), m);
1058        assert_eq!(result.method, "multiplicative");
1059
1060        // Seasonal factors should be centered around 1
1061        let mean_seasonal: f64 = result.seasonal.iter().sum::<f64>() / m as f64;
1062        assert!(
1063            (mean_seasonal - 1.0).abs() < 0.5,
1064            "Mean seasonal factor: {}",
1065            mean_seasonal
1066        );
1067    }
1068
1069    #[test]
1070    fn test_decompose_multiplicative_non_positive_data() {
1071        let m = 100;
1072        let period = 2.0;
1073        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1074
1075        // Data with negative values
1076        let data: Vec<f64> = argvals
1077            .iter()
1078            .map(|&t| -1.0 + (2.0 * PI * t / period).sin())
1079            .collect();
1080
1081        // Should handle negative values by shifting
1082        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 2);
1083
1084        assert_eq!(result.trend.len(), m);
1085        assert_eq!(result.seasonal.len(), m);
1086        // All seasonal values should be positive (multiplicative factors)
1087        for &s in result.seasonal.iter() {
1088            assert!(s.is_finite(), "Seasonal should be finite");
1089        }
1090    }
1091
1092    #[test]
1093    fn test_decompose_multiplicative_vs_additive() {
1094        let m = 200;
1095        let period = 2.0;
1096        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1097
1098        // Simple positive data
1099        let data: Vec<f64> = argvals
1100            .iter()
1101            .map(|&t| 5.0 + (2.0 * PI * t / period).sin())
1102            .collect();
1103
1104        let add_result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1105        let mult_result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1106
1107        // Both should produce valid decompositions
1108        assert_eq!(add_result.seasonal.len(), m);
1109        assert_eq!(mult_result.seasonal.len(), m);
1110
1111        // Additive seasonal oscillates around 0
1112        let add_mean: f64 = add_result.seasonal.iter().sum::<f64>() / m as f64;
1113        // Multiplicative seasonal oscillates around 1
1114        let mult_mean: f64 = mult_result.seasonal.iter().sum::<f64>() / m as f64;
1115
1116        assert!(
1117            add_mean.abs() < mult_mean,
1118            "Additive mean {} vs mult mean {}",
1119            add_mean,
1120            mult_mean
1121        );
1122    }
1123
1124    #[test]
1125    fn test_decompose_multiplicative_edge_cases() {
1126        // Empty data
1127        let result = decompose_multiplicative(&[], 0, 0, &[], 2.0, "loess", 0.3, 2);
1128        assert_eq!(result.trend.len(), 0);
1129
1130        // Very short series
1131        let m = 5;
1132        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1133        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 2.0, 1.0];
1134        let result = decompose_multiplicative(&data, 1, m, &argvals, 2.0, "loess", 0.3, 1);
1135        // Should return original data as remainder for too-short series
1136        assert_eq!(result.remainder.len(), m);
1137    }
1138}