fdars_core/
detrend.rs

1//! Detrending and decomposition functions for non-stationary functional data.
2//!
3//! This module provides methods for removing trends from functional data
4//! to enable more accurate seasonal analysis. It includes:
5//! - Linear detrending (least squares)
6//! - Polynomial detrending (QR decomposition)
7//! - Differencing (first and second order)
8//! - LOESS detrending (local polynomial regression)
9//! - Spline detrending (P-splines)
10//! - Automatic method selection via AIC
11
12use crate::smoothing::local_polynomial;
13use nalgebra::{DMatrix, DVector};
14use rayon::prelude::*;
15
16/// Result of detrending operation.
17#[derive(Debug, Clone)]
18pub struct TrendResult {
19    /// Estimated trend values (n x m column-major)
20    pub trend: Vec<f64>,
21    /// Detrended data (n x m column-major)
22    pub detrended: Vec<f64>,
23    /// Method used for detrending
24    pub method: String,
25    /// Polynomial coefficients (for polynomial methods, per sample)
26    /// For n samples with polynomial degree d: coefficients[i * (d+1) + k] is coefficient k for sample i
27    pub coefficients: Option<Vec<f64>>,
28    /// Residual sum of squares for each sample
29    pub rss: Vec<f64>,
30    /// Number of parameters (for AIC calculation)
31    pub n_params: usize,
32}
33
34/// Result of seasonal decomposition.
35#[derive(Debug, Clone)]
36pub struct DecomposeResult {
37    /// Trend component (n x m column-major)
38    pub trend: Vec<f64>,
39    /// Seasonal component (n x m column-major)
40    pub seasonal: Vec<f64>,
41    /// Remainder/residual component (n x m column-major)
42    pub remainder: Vec<f64>,
43    /// Period used for decomposition
44    pub period: f64,
45    /// Decomposition method ("additive" or "multiplicative")
46    pub method: String,
47}
48
49/// Remove linear trend from functional data using least squares.
50///
51/// # Arguments
52/// * `data` - Column-major matrix (n x m): n samples, m evaluation points
53/// * `n` - Number of samples
54/// * `m` - Number of evaluation points
55/// * `argvals` - Time/argument values of length m
56///
57/// # Returns
58/// TrendResult with trend, detrended data, and coefficients (intercept, slope)
59pub fn detrend_linear(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
60    if n == 0 || m < 2 || data.len() != n * m || argvals.len() != m {
61        return TrendResult {
62            trend: vec![0.0; n * m],
63            detrended: data.to_vec(),
64            method: "linear".to_string(),
65            coefficients: None,
66            rss: vec![0.0; n],
67            n_params: 2,
68        };
69    }
70
71    // Precompute t statistics
72    let mean_t: f64 = argvals.iter().sum::<f64>() / m as f64;
73    let ss_t: f64 = argvals.iter().map(|&t| (t - mean_t).powi(2)).sum();
74
75    // Process each sample in parallel
76    let results: Vec<(Vec<f64>, Vec<f64>, f64, f64, f64)> = (0..n)
77        .into_par_iter()
78        .map(|i| {
79            // Extract curve
80            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
81            let mean_y: f64 = curve.iter().sum::<f64>() / m as f64;
82
83            // Compute slope: sum((t - mean_t) * (y - mean_y)) / sum((t - mean_t)^2)
84            let mut sp = 0.0;
85            for j in 0..m {
86                sp += (argvals[j] - mean_t) * (curve[j] - mean_y);
87            }
88            let slope = if ss_t.abs() > 1e-15 { sp / ss_t } else { 0.0 };
89            let intercept = mean_y - slope * mean_t;
90
91            // Compute trend and detrended
92            let mut trend = vec![0.0; m];
93            let mut detrended = vec![0.0; m];
94            let mut rss = 0.0;
95            for j in 0..m {
96                trend[j] = intercept + slope * argvals[j];
97                detrended[j] = curve[j] - trend[j];
98                rss += detrended[j].powi(2);
99            }
100
101            (trend, detrended, intercept, slope, rss)
102        })
103        .collect();
104
105    // Reassemble into column-major format
106    let mut trend = vec![0.0; n * m];
107    let mut detrended = vec![0.0; n * m];
108    let mut coefficients = vec![0.0; n * 2];
109    let mut rss = vec![0.0; n];
110
111    for (i, (t, d, intercept, slope, r)) in results.into_iter().enumerate() {
112        for j in 0..m {
113            trend[i + j * n] = t[j];
114            detrended[i + j * n] = d[j];
115        }
116        coefficients[i * 2] = intercept;
117        coefficients[i * 2 + 1] = slope;
118        rss[i] = r;
119    }
120
121    TrendResult {
122        trend,
123        detrended,
124        method: "linear".to_string(),
125        coefficients: Some(coefficients),
126        rss,
127        n_params: 2,
128    }
129}
130
131/// Remove polynomial trend from functional data using QR decomposition.
132///
133/// # Arguments
134/// * `data` - Column-major matrix (n x m)
135/// * `n` - Number of samples
136/// * `m` - Number of evaluation points
137/// * `argvals` - Time/argument values of length m
138/// * `degree` - Polynomial degree (1 = linear, 2 = quadratic, etc.)
139///
140/// # Returns
141/// TrendResult with trend, detrended data, and polynomial coefficients
142pub fn detrend_polynomial(
143    data: &[f64],
144    n: usize,
145    m: usize,
146    argvals: &[f64],
147    degree: usize,
148) -> TrendResult {
149    if n == 0 || m < degree + 1 || data.len() != n * m || argvals.len() != m || degree == 0 {
150        // For degree 0 or invalid input, return original data
151        return TrendResult {
152            trend: vec![0.0; n * m],
153            detrended: data.to_vec(),
154            method: format!("polynomial({})", degree),
155            coefficients: None,
156            rss: vec![0.0; n],
157            n_params: degree + 1,
158        };
159    }
160
161    // Special case: degree 1 is linear
162    if degree == 1 {
163        let mut result = detrend_linear(data, n, m, argvals);
164        result.method = "polynomial(1)".to_string();
165        return result;
166    }
167
168    let n_coef = degree + 1;
169
170    // Normalize argvals to avoid numerical issues with high-degree polynomials
171    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
172    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
173    let t_range = if (t_max - t_min).abs() > 1e-15 {
174        t_max - t_min
175    } else {
176        1.0
177    };
178    let t_norm: Vec<f64> = argvals.iter().map(|&t| (t - t_min) / t_range).collect();
179
180    // Build Vandermonde matrix (m x n_coef)
181    let mut design = DMatrix::zeros(m, n_coef);
182    for j in 0..m {
183        let t = t_norm[j];
184        let mut power = 1.0;
185        for k in 0..n_coef {
186            design[(j, k)] = power;
187            power *= t;
188        }
189    }
190
191    // SVD for stable least squares
192    let svd = design.clone().svd(true, true);
193
194    // Process each sample in parallel
195    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = (0..n)
196        .into_par_iter()
197        .map(|i| {
198            // Extract curve
199            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
200            let y = DVector::from_row_slice(&curve);
201
202            // Solve least squares using SVD
203            let beta = svd
204                .solve(&y, 1e-10)
205                .unwrap_or_else(|_| DVector::zeros(n_coef));
206
207            // Compute fitted values (trend) and residuals
208            let fitted = &design * &beta;
209            let mut trend = vec![0.0; m];
210            let mut detrended = vec![0.0; m];
211            let mut rss = 0.0;
212            for j in 0..m {
213                trend[j] = fitted[j];
214                detrended[j] = curve[j] - fitted[j];
215                rss += detrended[j].powi(2);
216            }
217
218            // Extract coefficients
219            let coefs: Vec<f64> = beta.iter().cloned().collect();
220
221            (trend, detrended, coefs, rss)
222        })
223        .collect();
224
225    // Reassemble into column-major format
226    let mut trend = vec![0.0; n * m];
227    let mut detrended = vec![0.0; n * m];
228    let mut coefficients = vec![0.0; n * n_coef];
229    let mut rss = vec![0.0; n];
230
231    for (i, (t, d, coefs, r)) in results.into_iter().enumerate() {
232        for j in 0..m {
233            trend[i + j * n] = t[j];
234            detrended[i + j * n] = d[j];
235        }
236        for k in 0..n_coef {
237            coefficients[i * n_coef + k] = coefs[k];
238        }
239        rss[i] = r;
240    }
241
242    TrendResult {
243        trend,
244        detrended,
245        method: format!("polynomial({})", degree),
246        coefficients: Some(coefficients),
247        rss,
248        n_params: n_coef,
249    }
250}
251
252/// Remove trend by differencing.
253///
254/// # Arguments
255/// * `data` - Column-major matrix (n x m)
256/// * `n` - Number of samples
257/// * `m` - Number of evaluation points
258/// * `order` - Differencing order (1 or 2)
259///
260/// # Returns
261/// TrendResult with trend (cumulative sum to reverse), detrended (differences),
262/// and original first values as "coefficients"
263///
264/// Note: Differencing reduces the series length by `order` points.
265/// The returned detrended data has m - order points padded with zeros at the end.
266pub fn detrend_diff(data: &[f64], n: usize, m: usize, order: usize) -> TrendResult {
267    if n == 0 || m <= order || data.len() != n * m || order == 0 || order > 2 {
268        return TrendResult {
269            trend: vec![0.0; n * m],
270            detrended: data.to_vec(),
271            method: format!("diff{}", order),
272            coefficients: None,
273            rss: vec![0.0; n],
274            n_params: order,
275        };
276    }
277
278    let new_m = m - order;
279
280    // Process each sample in parallel
281    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = (0..n)
282        .into_par_iter()
283        .map(|i| {
284            // Extract curve
285            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
286
287            // First difference
288            let diff1: Vec<f64> = (0..m - 1).map(|j| curve[j + 1] - curve[j]).collect();
289
290            // Second difference if order == 2
291            let detrended = if order == 2 {
292                (0..diff1.len() - 1)
293                    .map(|j| diff1[j + 1] - diff1[j])
294                    .collect()
295            } else {
296                diff1.clone()
297            };
298
299            // Store initial values needed for reconstruction
300            let initial_values = if order == 2 {
301                vec![curve[0], curve[1]]
302            } else {
303                vec![curve[0]]
304            };
305
306            // Compute RSS (sum of squared differences as "residuals" - interpretation varies)
307            let rss: f64 = detrended.iter().map(|&x| x.powi(2)).sum();
308
309            // For "trend", we reconstruct as cumsum of differences
310            // This is a rough approximation; true trend would need integration
311            let mut trend = vec![0.0; m];
312            trend[0] = curve[0];
313            if order == 1 {
314                for j in 1..m {
315                    trend[j] = curve[j] - if j <= new_m { detrended[j - 1] } else { 0.0 };
316                }
317            } else {
318                // For order 2, trend is less meaningful
319                trend = curve.clone();
320            }
321
322            // Pad detrended to full length
323            let mut det_full = vec![0.0; m];
324            det_full[..new_m].copy_from_slice(&detrended[..new_m]);
325
326            (trend, det_full, initial_values, rss)
327        })
328        .collect();
329
330    // Reassemble
331    let mut trend = vec![0.0; n * m];
332    let mut detrended = vec![0.0; n * m];
333    let mut coefficients = vec![0.0; n * order];
334    let mut rss = vec![0.0; n];
335
336    for (i, (t, d, init, r)) in results.into_iter().enumerate() {
337        for j in 0..m {
338            trend[i + j * n] = t[j];
339            detrended[i + j * n] = d[j];
340        }
341        for k in 0..order {
342            coefficients[i * order + k] = init[k];
343        }
344        rss[i] = r;
345    }
346
347    TrendResult {
348        trend,
349        detrended,
350        method: format!("diff{}", order),
351        coefficients: Some(coefficients),
352        rss,
353        n_params: order,
354    }
355}
356
357/// Remove trend using LOESS (local polynomial regression).
358///
359/// # Arguments
360/// * `data` - Column-major matrix (n x m)
361/// * `n` - Number of samples
362/// * `m` - Number of evaluation points
363/// * `argvals` - Time/argument values
364/// * `bandwidth` - Bandwidth as fraction of data range (0.1 to 0.5 typical)
365/// * `degree` - Local polynomial degree (1 or 2)
366///
367/// # Returns
368/// TrendResult with LOESS-smoothed trend
369pub fn detrend_loess(
370    data: &[f64],
371    n: usize,
372    m: usize,
373    argvals: &[f64],
374    bandwidth: f64,
375    degree: usize,
376) -> TrendResult {
377    if n == 0 || m < 3 || data.len() != n * m || argvals.len() != m || bandwidth <= 0.0 {
378        return TrendResult {
379            trend: vec![0.0; n * m],
380            detrended: data.to_vec(),
381            method: "loess".to_string(),
382            coefficients: None,
383            rss: vec![0.0; n],
384            n_params: (m as f64 * bandwidth).ceil() as usize,
385        };
386    }
387
388    // Convert bandwidth from fraction to absolute units
389    let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
390    let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
391    let abs_bandwidth = (t_max - t_min) * bandwidth;
392
393    // Process each sample in parallel
394    let results: Vec<(Vec<f64>, Vec<f64>, f64)> = (0..n)
395        .into_par_iter()
396        .map(|i| {
397            // Extract curve
398            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
399
400            // Apply local polynomial regression
401            let trend =
402                local_polynomial(argvals, &curve, argvals, abs_bandwidth, degree, "gaussian");
403
404            // Compute detrended and RSS
405            let mut detrended = vec![0.0; m];
406            let mut rss = 0.0;
407            for j in 0..m {
408                detrended[j] = curve[j] - trend[j];
409                rss += detrended[j].powi(2);
410            }
411
412            (trend, detrended, rss)
413        })
414        .collect();
415
416    // Reassemble
417    let mut trend = vec![0.0; n * m];
418    let mut detrended = vec![0.0; n * m];
419    let mut rss = vec![0.0; n];
420
421    for (i, (t, d, r)) in results.into_iter().enumerate() {
422        for j in 0..m {
423            trend[i + j * n] = t[j];
424            detrended[i + j * n] = d[j];
425        }
426        rss[i] = r;
427    }
428
429    // Effective number of parameters for LOESS is approximately n * bandwidth
430    let n_params = (m as f64 * bandwidth).ceil() as usize;
431
432    TrendResult {
433        trend,
434        detrended,
435        method: "loess".to_string(),
436        coefficients: None,
437        rss,
438        n_params,
439    }
440}
441
442/// Automatically select the best detrending method using AIC.
443///
444/// Compares linear, polynomial (degree 2 and 3), and LOESS,
445/// selecting the method with lowest AIC.
446///
447/// # Arguments
448/// * `data` - Column-major matrix (n x m)
449/// * `n` - Number of samples
450/// * `m` - Number of evaluation points
451/// * `argvals` - Time/argument values
452///
453/// # Returns
454/// TrendResult from the best method
455pub fn auto_detrend(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
456    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m {
457        return TrendResult {
458            trend: vec![0.0; n * m],
459            detrended: data.to_vec(),
460            method: "auto(none)".to_string(),
461            coefficients: None,
462            rss: vec![0.0; n],
463            n_params: 0,
464        };
465    }
466
467    // Compute AIC for a result: AIC = n * log(RSS/n) + 2*k
468    // We use mean AIC across all samples
469    let compute_aic = |result: &TrendResult| -> f64 {
470        let mut total_aic = 0.0;
471        for i in 0..n {
472            let rss = result.rss[i];
473            let k = result.n_params as f64;
474            let aic = if rss > 1e-15 {
475                m as f64 * (rss / m as f64).ln() + 2.0 * k
476            } else {
477                f64::NEG_INFINITY // Perfect fit (unlikely)
478            };
479            total_aic += aic;
480        }
481        total_aic / n as f64
482    };
483
484    // Try different methods
485    let linear = detrend_linear(data, n, m, argvals);
486    let poly2 = detrend_polynomial(data, n, m, argvals, 2);
487    let poly3 = detrend_polynomial(data, n, m, argvals, 3);
488    let loess = detrend_loess(data, n, m, argvals, 0.3, 2);
489
490    let aic_linear = compute_aic(&linear);
491    let aic_poly2 = compute_aic(&poly2);
492    let aic_poly3 = compute_aic(&poly3);
493    let aic_loess = compute_aic(&loess);
494
495    // Find minimum AIC
496    let methods = [
497        (aic_linear, "linear", linear),
498        (aic_poly2, "polynomial(2)", poly2),
499        (aic_poly3, "polynomial(3)", poly3),
500        (aic_loess, "loess", loess),
501    ];
502
503    let (_, best_name, mut best_result) = methods
504        .into_iter()
505        .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
506        .unwrap();
507
508    best_result.method = format!("auto({})", best_name);
509    best_result
510}
511
512/// Additive seasonal decomposition: data = trend + seasonal + remainder
513///
514/// Uses LOESS or spline for trend extraction, then averages within-period
515/// residuals to estimate the seasonal component.
516///
517/// # Arguments
518/// * `data` - Column-major matrix (n x m)
519/// * `n` - Number of samples
520/// * `m` - Number of evaluation points
521/// * `argvals` - Time/argument values
522/// * `period` - Seasonal period in same units as argvals
523/// * `trend_method` - "loess" or "spline"
524/// * `bandwidth` - Bandwidth for LOESS (fraction, e.g., 0.3)
525/// * `n_harmonics` - Number of Fourier harmonics for seasonal component
526///
527/// # Returns
528/// DecomposeResult with trend, seasonal, and remainder components
529pub fn decompose_additive(
530    data: &[f64],
531    n: usize,
532    m: usize,
533    argvals: &[f64],
534    period: f64,
535    trend_method: &str,
536    bandwidth: f64,
537    n_harmonics: usize,
538) -> DecomposeResult {
539    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
540        return DecomposeResult {
541            trend: vec![0.0; n * m],
542            seasonal: vec![0.0; n * m],
543            remainder: data.to_vec(),
544            period,
545            method: "additive".to_string(),
546        };
547    }
548
549    // Step 1: Extract trend using LOESS or spline
550    let trend_result = match trend_method {
551        "spline" => {
552            // Use P-spline fitting - use a larger bandwidth for trend
553            detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2)
554        }
555        _ => detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2),
556    };
557
558    // Step 2: Extract seasonal component using Fourier basis on detrended data
559    let n_harm = n_harmonics.max(1).min(m / 4);
560    let omega = 2.0 * std::f64::consts::PI / period;
561
562    // Process each sample
563    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = (0..n)
564        .into_par_iter()
565        .map(|i| {
566            let trend_i: Vec<f64> = (0..m).map(|j| trend_result.trend[i + j * n]).collect();
567            let detrended_i: Vec<f64> = (0..m).map(|j| trend_result.detrended[i + j * n]).collect();
568
569            // Fit Fourier model to detrended data: sum of sin and cos terms
570            // y = sum_k (a_k * cos(k*omega*t) + b_k * sin(k*omega*t))
571            let n_coef = 2 * n_harm;
572            let mut design = DMatrix::zeros(m, n_coef);
573            for j in 0..m {
574                let t = argvals[j];
575                for k in 0..n_harm {
576                    let freq = (k + 1) as f64 * omega;
577                    design[(j, 2 * k)] = (freq * t).cos();
578                    design[(j, 2 * k + 1)] = (freq * t).sin();
579                }
580            }
581
582            // Solve least squares using SVD
583            let y = DVector::from_row_slice(&detrended_i);
584            let svd = design.clone().svd(true, true);
585            let coef = svd
586                .solve(&y, 1e-10)
587                .unwrap_or_else(|_| DVector::zeros(n_coef));
588
589            // Compute seasonal component
590            let fitted = &design * &coef;
591            let seasonal: Vec<f64> = fitted.iter().cloned().collect();
592
593            // Compute remainder
594            let remainder: Vec<f64> = (0..m).map(|j| detrended_i[j] - seasonal[j]).collect();
595
596            (trend_i, seasonal, remainder)
597        })
598        .collect();
599
600    // Reassemble into column-major format
601    let mut trend = vec![0.0; n * m];
602    let mut seasonal = vec![0.0; n * m];
603    let mut remainder = vec![0.0; n * m];
604
605    for (i, (t, s, r)) in results.into_iter().enumerate() {
606        for j in 0..m {
607            trend[i + j * n] = t[j];
608            seasonal[i + j * n] = s[j];
609            remainder[i + j * n] = r[j];
610        }
611    }
612
613    DecomposeResult {
614        trend,
615        seasonal,
616        remainder,
617        period,
618        method: "additive".to_string(),
619    }
620}
621
622/// Multiplicative seasonal decomposition: data = trend * seasonal * remainder
623///
624/// Applies log transformation, then additive decomposition, then back-transforms.
625/// Handles non-positive values by adding a shift.
626///
627/// # Arguments
628/// * `data` - Column-major matrix (n x m)
629/// * `n` - Number of samples
630/// * `m` - Number of evaluation points
631/// * `argvals` - Time/argument values
632/// * `period` - Seasonal period
633/// * `trend_method` - "loess" or "spline"
634/// * `bandwidth` - Bandwidth for LOESS
635/// * `n_harmonics` - Number of Fourier harmonics
636///
637/// # Returns
638/// DecomposeResult with multiplicative components
639pub fn decompose_multiplicative(
640    data: &[f64],
641    n: usize,
642    m: usize,
643    argvals: &[f64],
644    period: f64,
645    trend_method: &str,
646    bandwidth: f64,
647    n_harmonics: usize,
648) -> DecomposeResult {
649    if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
650        return DecomposeResult {
651            trend: vec![0.0; n * m],
652            seasonal: vec![0.0; n * m],
653            remainder: data.to_vec(),
654            period,
655            method: "multiplicative".to_string(),
656        };
657    }
658
659    // Find minimum value and add shift if needed to make all values positive
660    let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
661    let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
662
663    // Log transform
664    let log_data: Vec<f64> = data.iter().map(|&x| (x + shift).ln()).collect();
665
666    // Apply additive decomposition to log data
667    let additive_result = decompose_additive(
668        &log_data,
669        n,
670        m,
671        argvals,
672        period,
673        trend_method,
674        bandwidth,
675        n_harmonics,
676    );
677
678    // Back transform: exp of each component
679    // For multiplicative: data = trend * seasonal * remainder
680    // In log space: log(data) = log(trend) + log(seasonal) + log(remainder)
681    // So: trend_mult = exp(trend_add), seasonal_mult = exp(seasonal_add), etc.
682
683    let mut trend = vec![0.0; n * m];
684    let mut seasonal = vec![0.0; n * m];
685    let mut remainder = vec![0.0; n * m];
686
687    for idx in 0..n * m {
688        // Back-transform trend (subtract shift)
689        trend[idx] = additive_result.trend[idx].exp() - shift;
690
691        // Seasonal is a multiplicative factor (centered around 1)
692        // We interpret the additive seasonal component as log(seasonal factor)
693        seasonal[idx] = additive_result.seasonal[idx].exp();
694
695        // Remainder is also multiplicative
696        remainder[idx] = additive_result.remainder[idx].exp();
697    }
698
699    DecomposeResult {
700        trend,
701        seasonal,
702        remainder,
703        period,
704        method: "multiplicative".to_string(),
705    }
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711    use std::f64::consts::PI;
712
713    #[test]
714    fn test_detrend_linear_removes_linear_trend() {
715        let m = 100;
716        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
717
718        // y = 2 + 0.5*t + sin(2*pi*t/2)
719        let data: Vec<f64> = argvals
720            .iter()
721            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
722            .collect();
723
724        let result = detrend_linear(&data, 1, m, &argvals);
725
726        // Detrended should be approximately sin wave
727        let expected: Vec<f64> = argvals
728            .iter()
729            .map(|&t| (2.0 * PI * t / 2.0).sin())
730            .collect();
731
732        let mut max_diff = 0.0f64;
733        for j in 0..m {
734            let diff = (result.detrended[j] - expected[j]).abs();
735            max_diff = max_diff.max(diff);
736        }
737        assert!(max_diff < 0.2, "Max difference: {}", max_diff);
738    }
739
740    #[test]
741    fn test_detrend_polynomial_removes_quadratic_trend() {
742        let m = 100;
743        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
744
745        // y = 1 + 0.5*t - 0.1*t^2 + sin(2*pi*t/2)
746        let data: Vec<f64> = argvals
747            .iter()
748            .map(|&t| 1.0 + 0.5 * t - 0.1 * t * t + (2.0 * PI * t / 2.0).sin())
749            .collect();
750
751        let result = detrend_polynomial(&data, 1, m, &argvals, 2);
752
753        // Detrended should be approximately sin wave
754        let expected: Vec<f64> = argvals
755            .iter()
756            .map(|&t| (2.0 * PI * t / 2.0).sin())
757            .collect();
758
759        // Compute correlation
760        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
761        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
762        let mut num = 0.0;
763        let mut den_det = 0.0;
764        let mut den_exp = 0.0;
765        for j in 0..m {
766            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
767            den_det += (result.detrended[j] - mean_det).powi(2);
768            den_exp += (expected[j] - mean_exp).powi(2);
769        }
770        let corr = num / (den_det.sqrt() * den_exp.sqrt());
771        assert!(corr > 0.95, "Correlation: {}", corr);
772    }
773
774    #[test]
775    fn test_detrend_diff1() {
776        let m = 100;
777        // Random walk: cumsum of random values
778        let data: Vec<f64> = {
779            let mut v = vec![0.0; m];
780            v[0] = 1.0;
781            for i in 1..m {
782                v[i] = v[i - 1] + 0.1 * (i as f64).sin();
783            }
784            v
785        };
786
787        let result = detrend_diff(&data, 1, m, 1);
788
789        // First difference should recover the increments
790        for j in 0..m - 1 {
791            let expected = data[j + 1] - data[j];
792            assert!(
793                (result.detrended[j] - expected).abs() < 1e-10,
794                "Mismatch at {}: {} vs {}",
795                j,
796                result.detrended[j],
797                expected
798            );
799        }
800    }
801
802    #[test]
803    fn test_auto_detrend_selects_linear_for_linear_data() {
804        let m = 100;
805        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
806
807        // Pure linear trend with small noise
808        let data: Vec<f64> = argvals.iter().map(|&t| 2.0 + 0.5 * t).collect();
809
810        let result = auto_detrend(&data, 1, m, &argvals);
811
812        // Should select linear (or poly 2/3 with linear being sufficient)
813        assert!(
814            result.method.contains("linear") || result.method.contains("polynomial"),
815            "Method: {}",
816            result.method
817        );
818    }
819
820    // ========================================================================
821    // Tests for detrend_loess
822    // ========================================================================
823
824    #[test]
825    fn test_detrend_loess_removes_linear_trend() {
826        let m = 100;
827        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
828
829        // y = 2 + 0.5*t + sin(2*pi*t/2)
830        let data: Vec<f64> = argvals
831            .iter()
832            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
833            .collect();
834
835        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
836
837        // Detrended should be approximately sin wave
838        let expected: Vec<f64> = argvals
839            .iter()
840            .map(|&t| (2.0 * PI * t / 2.0).sin())
841            .collect();
842
843        // Compute correlation (LOESS may smooth slightly)
844        let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
845        let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
846        let mut num = 0.0;
847        let mut den_det = 0.0;
848        let mut den_exp = 0.0;
849        for j in 0..m {
850            num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
851            den_det += (result.detrended[j] - mean_det).powi(2);
852            den_exp += (expected[j] - mean_exp).powi(2);
853        }
854        let corr = num / (den_det.sqrt() * den_exp.sqrt());
855        assert!(corr > 0.9, "Correlation: {}", corr);
856        assert_eq!(result.method, "loess");
857    }
858
859    #[test]
860    fn test_detrend_loess_removes_quadratic_trend() {
861        let m = 100;
862        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
863
864        // y = 1 + 0.3*t - 0.05*t^2 + sin(2*pi*t/2)
865        let data: Vec<f64> = argvals
866            .iter()
867            .map(|&t| 1.0 + 0.3 * t - 0.05 * t * t + (2.0 * PI * t / 2.0).sin())
868            .collect();
869
870        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 2);
871
872        // Trend should follow the quadratic shape
873        assert_eq!(result.trend.len(), m);
874        assert_eq!(result.detrended.len(), m);
875
876        // Check that RSS is computed
877        assert!(result.rss[0] > 0.0);
878    }
879
880    #[test]
881    fn test_detrend_loess_different_bandwidths() {
882        let m = 100;
883        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
884
885        // Noisy sine wave
886        let data: Vec<f64> = argvals
887            .iter()
888            .enumerate()
889            .map(|(i, &t)| (2.0 * PI * t / 2.0).sin() + 0.1 * ((i * 17) % 100) as f64 / 100.0)
890            .collect();
891
892        // Small bandwidth = more local = rougher trend
893        let result_small = detrend_loess(&data, 1, m, &argvals, 0.1, 1);
894        // Large bandwidth = smoother trend
895        let result_large = detrend_loess(&data, 1, m, &argvals, 0.5, 1);
896
897        // Both should produce valid results
898        assert_eq!(result_small.trend.len(), m);
899        assert_eq!(result_large.trend.len(), m);
900
901        // Large bandwidth should have more parameters
902        assert!(result_large.n_params >= result_small.n_params);
903    }
904
905    #[test]
906    fn test_detrend_loess_short_series() {
907        let m = 10;
908        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
909        let data: Vec<f64> = argvals.iter().map(|&t| t * 2.0).collect();
910
911        let result = detrend_loess(&data, 1, m, &argvals, 0.3, 1);
912
913        // Should still work on short series
914        assert_eq!(result.trend.len(), m);
915        assert_eq!(result.detrended.len(), m);
916    }
917
918    // ========================================================================
919    // Tests for decompose_additive
920    // ========================================================================
921
922    #[test]
923    fn test_decompose_additive_separates_components() {
924        let m = 200;
925        let period = 2.0;
926        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
927
928        // data = trend + seasonal: y = 2 + 0.5*t + sin(2*pi*t/2)
929        let data: Vec<f64> = argvals
930            .iter()
931            .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / period).sin())
932            .collect();
933
934        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
935
936        assert_eq!(result.trend.len(), m);
937        assert_eq!(result.seasonal.len(), m);
938        assert_eq!(result.remainder.len(), m);
939        assert_eq!(result.method, "additive");
940        assert_eq!(result.period, period);
941
942        // Check that components approximately sum to original
943        for j in 0..m {
944            let reconstructed = result.trend[j] + result.seasonal[j] + result.remainder[j];
945            assert!(
946                (reconstructed - data[j]).abs() < 0.5,
947                "Reconstruction error at {}: {} vs {}",
948                j,
949                reconstructed,
950                data[j]
951            );
952        }
953    }
954
955    #[test]
956    fn test_decompose_additive_different_harmonics() {
957        let m = 200;
958        let period = 2.0;
959        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
960
961        // Simple seasonal pattern
962        let data: Vec<f64> = argvals
963            .iter()
964            .map(|&t| 1.0 + (2.0 * PI * t / period).sin())
965            .collect();
966
967        // 1 harmonic
968        let result1 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 1);
969        // 5 harmonics
970        let result5 = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 5);
971
972        // Both should produce valid results
973        assert_eq!(result1.seasonal.len(), m);
974        assert_eq!(result5.seasonal.len(), m);
975    }
976
977    #[test]
978    fn test_decompose_additive_residual_properties() {
979        let m = 200;
980        let period = 2.0;
981        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
982
983        // Data with trend and seasonal
984        let data: Vec<f64> = argvals
985            .iter()
986            .map(|&t| 2.0 + 0.3 * t + (2.0 * PI * t / period).sin())
987            .collect();
988
989        let result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
990
991        // Remainder should have mean close to zero
992        let mean_rem: f64 = result.remainder.iter().sum::<f64>() / m as f64;
993        assert!(mean_rem.abs() < 0.5, "Remainder mean: {}", mean_rem);
994
995        // Remainder variance should be smaller than original variance
996        let var_data: f64 = data
997            .iter()
998            .map(|&x| (x - data.iter().sum::<f64>() / m as f64).powi(2))
999            .sum::<f64>()
1000            / m as f64;
1001        let var_rem: f64 = result
1002            .remainder
1003            .iter()
1004            .map(|&x| (x - mean_rem).powi(2))
1005            .sum::<f64>()
1006            / m as f64;
1007        assert!(
1008            var_rem < var_data,
1009            "Remainder variance {} should be < data variance {}",
1010            var_rem,
1011            var_data
1012        );
1013    }
1014
1015    #[test]
1016    fn test_decompose_additive_multi_sample() {
1017        let n = 3;
1018        let m = 100;
1019        let period = 2.0;
1020        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1021
1022        // Create 3 samples with different amplitudes
1023        let mut data = vec![0.0; n * m];
1024        for i in 0..n {
1025            let amp = (i + 1) as f64;
1026            for j in 0..m {
1027                data[i + j * n] =
1028                    1.0 + 0.1 * argvals[j] + amp * (2.0 * PI * argvals[j] / period).sin();
1029            }
1030        }
1031
1032        let result = decompose_additive(&data, n, m, &argvals, period, "loess", 0.3, 2);
1033
1034        assert_eq!(result.trend.len(), n * m);
1035        assert_eq!(result.seasonal.len(), n * m);
1036        assert_eq!(result.remainder.len(), n * m);
1037    }
1038
1039    // ========================================================================
1040    // Tests for decompose_multiplicative
1041    // ========================================================================
1042
1043    #[test]
1044    fn test_decompose_multiplicative_basic() {
1045        let m = 200;
1046        let period = 2.0;
1047        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1048
1049        // Multiplicative: data = trend * seasonal
1050        // trend = 2 + 0.1*t, seasonal = 1 + 0.3*sin(...)
1051        let data: Vec<f64> = argvals
1052            .iter()
1053            .map(|&t| (2.0 + 0.1 * t) * (1.0 + 0.3 * (2.0 * PI * t / period).sin()))
1054            .collect();
1055
1056        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1057
1058        assert_eq!(result.trend.len(), m);
1059        assert_eq!(result.seasonal.len(), m);
1060        assert_eq!(result.remainder.len(), m);
1061        assert_eq!(result.method, "multiplicative");
1062
1063        // Seasonal factors should be centered around 1
1064        let mean_seasonal: f64 = result.seasonal.iter().sum::<f64>() / m as f64;
1065        assert!(
1066            (mean_seasonal - 1.0).abs() < 0.5,
1067            "Mean seasonal factor: {}",
1068            mean_seasonal
1069        );
1070    }
1071
1072    #[test]
1073    fn test_decompose_multiplicative_non_positive_data() {
1074        let m = 100;
1075        let period = 2.0;
1076        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1077
1078        // Data with negative values
1079        let data: Vec<f64> = argvals
1080            .iter()
1081            .map(|&t| -1.0 + (2.0 * PI * t / period).sin())
1082            .collect();
1083
1084        // Should handle negative values by shifting
1085        let result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 2);
1086
1087        assert_eq!(result.trend.len(), m);
1088        assert_eq!(result.seasonal.len(), m);
1089        // All seasonal values should be positive (multiplicative factors)
1090        for &s in result.seasonal.iter() {
1091            assert!(s.is_finite(), "Seasonal should be finite");
1092        }
1093    }
1094
1095    #[test]
1096    fn test_decompose_multiplicative_vs_additive() {
1097        let m = 200;
1098        let period = 2.0;
1099        let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
1100
1101        // Simple positive data
1102        let data: Vec<f64> = argvals
1103            .iter()
1104            .map(|&t| 5.0 + (2.0 * PI * t / period).sin())
1105            .collect();
1106
1107        let add_result = decompose_additive(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1108        let mult_result = decompose_multiplicative(&data, 1, m, &argvals, period, "loess", 0.3, 3);
1109
1110        // Both should produce valid decompositions
1111        assert_eq!(add_result.seasonal.len(), m);
1112        assert_eq!(mult_result.seasonal.len(), m);
1113
1114        // Additive seasonal oscillates around 0
1115        let add_mean: f64 = add_result.seasonal.iter().sum::<f64>() / m as f64;
1116        // Multiplicative seasonal oscillates around 1
1117        let mult_mean: f64 = mult_result.seasonal.iter().sum::<f64>() / m as f64;
1118
1119        assert!(
1120            add_mean.abs() < mult_mean,
1121            "Additive mean {} vs mult mean {}",
1122            add_mean,
1123            mult_mean
1124        );
1125    }
1126
1127    #[test]
1128    fn test_decompose_multiplicative_edge_cases() {
1129        // Empty data
1130        let result = decompose_multiplicative(&[], 0, 0, &[], 2.0, "loess", 0.3, 2);
1131        assert_eq!(result.trend.len(), 0);
1132
1133        // Very short series
1134        let m = 5;
1135        let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
1136        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 2.0, 1.0];
1137        let result = decompose_multiplicative(&data, 1, m, &argvals, 2.0, "loess", 0.3, 1);
1138        // Should return original data as remainder for too-short series
1139        assert_eq!(result.remainder.len(), m);
1140    }
1141}