1use crate::smoothing::local_polynomial;
13use nalgebra::{DMatrix, DVector};
14use rayon::prelude::*;
15
16#[derive(Debug, Clone)]
18pub struct TrendResult {
19 pub trend: Vec<f64>,
21 pub detrended: Vec<f64>,
23 pub method: String,
25 pub coefficients: Option<Vec<f64>>,
28 pub rss: Vec<f64>,
30 pub n_params: usize,
32}
33
34#[derive(Debug, Clone)]
36pub struct DecomposeResult {
37 pub trend: Vec<f64>,
39 pub seasonal: Vec<f64>,
41 pub remainder: Vec<f64>,
43 pub period: f64,
45 pub method: String,
47}
48
49pub fn detrend_linear(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
60 if n == 0 || m < 2 || data.len() != n * m || argvals.len() != m {
61 return TrendResult {
62 trend: vec![0.0; n * m],
63 detrended: data.to_vec(),
64 method: "linear".to_string(),
65 coefficients: None,
66 rss: vec![0.0; n],
67 n_params: 2,
68 };
69 }
70
71 let mean_t: f64 = argvals.iter().sum::<f64>() / m as f64;
73 let ss_t: f64 = argvals.iter().map(|&t| (t - mean_t).powi(2)).sum();
74
75 let results: Vec<(Vec<f64>, Vec<f64>, f64, f64, f64)> = (0..n)
77 .into_par_iter()
78 .map(|i| {
79 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
81 let mean_y: f64 = curve.iter().sum::<f64>() / m as f64;
82
83 let mut sp = 0.0;
85 for j in 0..m {
86 sp += (argvals[j] - mean_t) * (curve[j] - mean_y);
87 }
88 let slope = if ss_t.abs() > 1e-15 { sp / ss_t } else { 0.0 };
89 let intercept = mean_y - slope * mean_t;
90
91 let mut trend = vec![0.0; m];
93 let mut detrended = vec![0.0; m];
94 let mut rss = 0.0;
95 for j in 0..m {
96 trend[j] = intercept + slope * argvals[j];
97 detrended[j] = curve[j] - trend[j];
98 rss += detrended[j].powi(2);
99 }
100
101 (trend, detrended, intercept, slope, rss)
102 })
103 .collect();
104
105 let mut trend = vec![0.0; n * m];
107 let mut detrended = vec![0.0; n * m];
108 let mut coefficients = vec![0.0; n * 2];
109 let mut rss = vec![0.0; n];
110
111 for (i, (t, d, intercept, slope, r)) in results.into_iter().enumerate() {
112 for j in 0..m {
113 trend[i + j * n] = t[j];
114 detrended[i + j * n] = d[j];
115 }
116 coefficients[i * 2] = intercept;
117 coefficients[i * 2 + 1] = slope;
118 rss[i] = r;
119 }
120
121 TrendResult {
122 trend,
123 detrended,
124 method: "linear".to_string(),
125 coefficients: Some(coefficients),
126 rss,
127 n_params: 2,
128 }
129}
130
131pub fn detrend_polynomial(
143 data: &[f64],
144 n: usize,
145 m: usize,
146 argvals: &[f64],
147 degree: usize,
148) -> TrendResult {
149 if n == 0 || m < degree + 1 || data.len() != n * m || argvals.len() != m || degree == 0 {
150 return TrendResult {
152 trend: vec![0.0; n * m],
153 detrended: data.to_vec(),
154 method: format!("polynomial({})", degree),
155 coefficients: None,
156 rss: vec![0.0; n],
157 n_params: degree + 1,
158 };
159 }
160
161 if degree == 1 {
163 let mut result = detrend_linear(data, n, m, argvals);
164 result.method = "polynomial(1)".to_string();
165 return result;
166 }
167
168 let n_coef = degree + 1;
169
170 let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
172 let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
173 let t_range = if (t_max - t_min).abs() > 1e-15 {
174 t_max - t_min
175 } else {
176 1.0
177 };
178 let t_norm: Vec<f64> = argvals.iter().map(|&t| (t - t_min) / t_range).collect();
179
180 let mut design = DMatrix::zeros(m, n_coef);
182 for j in 0..m {
183 let t = t_norm[j];
184 let mut power = 1.0;
185 for k in 0..n_coef {
186 design[(j, k)] = power;
187 power *= t;
188 }
189 }
190
191 let svd = design.clone().svd(true, true);
193
194 let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = (0..n)
196 .into_par_iter()
197 .map(|i| {
198 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
200 let y = DVector::from_row_slice(&curve);
201
202 let beta = svd
204 .solve(&y, 1e-10)
205 .unwrap_or_else(|_| DVector::zeros(n_coef));
206
207 let fitted = &design * β
209 let mut trend = vec![0.0; m];
210 let mut detrended = vec![0.0; m];
211 let mut rss = 0.0;
212 for j in 0..m {
213 trend[j] = fitted[j];
214 detrended[j] = curve[j] - fitted[j];
215 rss += detrended[j].powi(2);
216 }
217
218 let coefs: Vec<f64> = beta.iter().cloned().collect();
220
221 (trend, detrended, coefs, rss)
222 })
223 .collect();
224
225 let mut trend = vec![0.0; n * m];
227 let mut detrended = vec![0.0; n * m];
228 let mut coefficients = vec![0.0; n * n_coef];
229 let mut rss = vec![0.0; n];
230
231 for (i, (t, d, coefs, r)) in results.into_iter().enumerate() {
232 for j in 0..m {
233 trend[i + j * n] = t[j];
234 detrended[i + j * n] = d[j];
235 }
236 for k in 0..n_coef {
237 coefficients[i * n_coef + k] = coefs[k];
238 }
239 rss[i] = r;
240 }
241
242 TrendResult {
243 trend,
244 detrended,
245 method: format!("polynomial({})", degree),
246 coefficients: Some(coefficients),
247 rss,
248 n_params: n_coef,
249 }
250}
251
252pub fn detrend_diff(data: &[f64], n: usize, m: usize, order: usize) -> TrendResult {
267 if n == 0 || m <= order || data.len() != n * m || order == 0 || order > 2 {
268 return TrendResult {
269 trend: vec![0.0; n * m],
270 detrended: data.to_vec(),
271 method: format!("diff{}", order),
272 coefficients: None,
273 rss: vec![0.0; n],
274 n_params: order,
275 };
276 }
277
278 let new_m = m - order;
279
280 let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, f64)> = (0..n)
282 .into_par_iter()
283 .map(|i| {
284 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
286
287 let diff1: Vec<f64> = (0..m - 1).map(|j| curve[j + 1] - curve[j]).collect();
289
290 let detrended = if order == 2 {
292 (0..diff1.len() - 1)
293 .map(|j| diff1[j + 1] - diff1[j])
294 .collect()
295 } else {
296 diff1.clone()
297 };
298
299 let initial_values = if order == 2 {
301 vec![curve[0], curve[1]]
302 } else {
303 vec![curve[0]]
304 };
305
306 let rss: f64 = detrended.iter().map(|&x| x.powi(2)).sum();
308
309 let mut trend = vec![0.0; m];
312 trend[0] = curve[0];
313 if order == 1 {
314 for j in 1..m {
315 trend[j] = curve[j] - if j <= new_m { detrended[j - 1] } else { 0.0 };
316 }
317 } else {
318 trend = curve.clone();
320 }
321
322 let mut det_full = vec![0.0; m];
324 det_full[..new_m].copy_from_slice(&detrended[..new_m]);
325
326 (trend, det_full, initial_values, rss)
327 })
328 .collect();
329
330 let mut trend = vec![0.0; n * m];
332 let mut detrended = vec![0.0; n * m];
333 let mut coefficients = vec![0.0; n * order];
334 let mut rss = vec![0.0; n];
335
336 for (i, (t, d, init, r)) in results.into_iter().enumerate() {
337 for j in 0..m {
338 trend[i + j * n] = t[j];
339 detrended[i + j * n] = d[j];
340 }
341 for k in 0..order {
342 coefficients[i * order + k] = init[k];
343 }
344 rss[i] = r;
345 }
346
347 TrendResult {
348 trend,
349 detrended,
350 method: format!("diff{}", order),
351 coefficients: Some(coefficients),
352 rss,
353 n_params: order,
354 }
355}
356
357pub fn detrend_loess(
370 data: &[f64],
371 n: usize,
372 m: usize,
373 argvals: &[f64],
374 bandwidth: f64,
375 degree: usize,
376) -> TrendResult {
377 if n == 0 || m < 3 || data.len() != n * m || argvals.len() != m || bandwidth <= 0.0 {
378 return TrendResult {
379 trend: vec![0.0; n * m],
380 detrended: data.to_vec(),
381 method: "loess".to_string(),
382 coefficients: None,
383 rss: vec![0.0; n],
384 n_params: (m as f64 * bandwidth).ceil() as usize,
385 };
386 }
387
388 let t_min = argvals.iter().cloned().fold(f64::INFINITY, f64::min);
390 let t_max = argvals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
391 let abs_bandwidth = (t_max - t_min) * bandwidth;
392
393 let results: Vec<(Vec<f64>, Vec<f64>, f64)> = (0..n)
395 .into_par_iter()
396 .map(|i| {
397 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
399
400 let trend =
402 local_polynomial(argvals, &curve, argvals, abs_bandwidth, degree, "gaussian");
403
404 let mut detrended = vec![0.0; m];
406 let mut rss = 0.0;
407 for j in 0..m {
408 detrended[j] = curve[j] - trend[j];
409 rss += detrended[j].powi(2);
410 }
411
412 (trend, detrended, rss)
413 })
414 .collect();
415
416 let mut trend = vec![0.0; n * m];
418 let mut detrended = vec![0.0; n * m];
419 let mut rss = vec![0.0; n];
420
421 for (i, (t, d, r)) in results.into_iter().enumerate() {
422 for j in 0..m {
423 trend[i + j * n] = t[j];
424 detrended[i + j * n] = d[j];
425 }
426 rss[i] = r;
427 }
428
429 let n_params = (m as f64 * bandwidth).ceil() as usize;
431
432 TrendResult {
433 trend,
434 detrended,
435 method: "loess".to_string(),
436 coefficients: None,
437 rss,
438 n_params,
439 }
440}
441
442pub fn auto_detrend(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> TrendResult {
456 if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m {
457 return TrendResult {
458 trend: vec![0.0; n * m],
459 detrended: data.to_vec(),
460 method: "auto(none)".to_string(),
461 coefficients: None,
462 rss: vec![0.0; n],
463 n_params: 0,
464 };
465 }
466
467 let compute_aic = |result: &TrendResult| -> f64 {
470 let mut total_aic = 0.0;
471 for i in 0..n {
472 let rss = result.rss[i];
473 let k = result.n_params as f64;
474 let aic = if rss > 1e-15 {
475 m as f64 * (rss / m as f64).ln() + 2.0 * k
476 } else {
477 f64::NEG_INFINITY };
479 total_aic += aic;
480 }
481 total_aic / n as f64
482 };
483
484 let linear = detrend_linear(data, n, m, argvals);
486 let poly2 = detrend_polynomial(data, n, m, argvals, 2);
487 let poly3 = detrend_polynomial(data, n, m, argvals, 3);
488 let loess = detrend_loess(data, n, m, argvals, 0.3, 2);
489
490 let aic_linear = compute_aic(&linear);
491 let aic_poly2 = compute_aic(&poly2);
492 let aic_poly3 = compute_aic(&poly3);
493 let aic_loess = compute_aic(&loess);
494
495 let methods = [
497 (aic_linear, "linear", linear),
498 (aic_poly2, "polynomial(2)", poly2),
499 (aic_poly3, "polynomial(3)", poly3),
500 (aic_loess, "loess", loess),
501 ];
502
503 let (_, best_name, mut best_result) = methods
504 .into_iter()
505 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
506 .unwrap();
507
508 best_result.method = format!("auto({})", best_name);
509 best_result
510}
511
512pub fn decompose_additive(
530 data: &[f64],
531 n: usize,
532 m: usize,
533 argvals: &[f64],
534 period: f64,
535 trend_method: &str,
536 bandwidth: f64,
537 n_harmonics: usize,
538) -> DecomposeResult {
539 if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
540 return DecomposeResult {
541 trend: vec![0.0; n * m],
542 seasonal: vec![0.0; n * m],
543 remainder: data.to_vec(),
544 period,
545 method: "additive".to_string(),
546 };
547 }
548
549 let trend_result = match trend_method {
551 "spline" => {
552 detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2)
554 }
555 _ => detrend_loess(data, n, m, argvals, bandwidth.max(0.3), 2),
556 };
557
558 let n_harm = n_harmonics.max(1).min(m / 4);
560 let omega = 2.0 * std::f64::consts::PI / period;
561
562 let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = (0..n)
564 .into_par_iter()
565 .map(|i| {
566 let trend_i: Vec<f64> = (0..m).map(|j| trend_result.trend[i + j * n]).collect();
567 let detrended_i: Vec<f64> = (0..m).map(|j| trend_result.detrended[i + j * n]).collect();
568
569 let n_coef = 2 * n_harm;
572 let mut design = DMatrix::zeros(m, n_coef);
573 for j in 0..m {
574 let t = argvals[j];
575 for k in 0..n_harm {
576 let freq = (k + 1) as f64 * omega;
577 design[(j, 2 * k)] = (freq * t).cos();
578 design[(j, 2 * k + 1)] = (freq * t).sin();
579 }
580 }
581
582 let y = DVector::from_row_slice(&detrended_i);
584 let svd = design.clone().svd(true, true);
585 let coef = svd
586 .solve(&y, 1e-10)
587 .unwrap_or_else(|_| DVector::zeros(n_coef));
588
589 let fitted = &design * &coef;
591 let seasonal: Vec<f64> = fitted.iter().cloned().collect();
592
593 let remainder: Vec<f64> = (0..m).map(|j| detrended_i[j] - seasonal[j]).collect();
595
596 (trend_i, seasonal, remainder)
597 })
598 .collect();
599
600 let mut trend = vec![0.0; n * m];
602 let mut seasonal = vec![0.0; n * m];
603 let mut remainder = vec![0.0; n * m];
604
605 for (i, (t, s, r)) in results.into_iter().enumerate() {
606 for j in 0..m {
607 trend[i + j * n] = t[j];
608 seasonal[i + j * n] = s[j];
609 remainder[i + j * n] = r[j];
610 }
611 }
612
613 DecomposeResult {
614 trend,
615 seasonal,
616 remainder,
617 period,
618 method: "additive".to_string(),
619 }
620}
621
622pub fn decompose_multiplicative(
640 data: &[f64],
641 n: usize,
642 m: usize,
643 argvals: &[f64],
644 period: f64,
645 trend_method: &str,
646 bandwidth: f64,
647 n_harmonics: usize,
648) -> DecomposeResult {
649 if n == 0 || m < 4 || data.len() != n * m || argvals.len() != m || period <= 0.0 {
650 return DecomposeResult {
651 trend: vec![0.0; n * m],
652 seasonal: vec![0.0; n * m],
653 remainder: data.to_vec(),
654 period,
655 method: "multiplicative".to_string(),
656 };
657 }
658
659 let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
661 let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
662
663 let log_data: Vec<f64> = data.iter().map(|&x| (x + shift).ln()).collect();
665
666 let additive_result = decompose_additive(
668 &log_data,
669 n,
670 m,
671 argvals,
672 period,
673 trend_method,
674 bandwidth,
675 n_harmonics,
676 );
677
678 let mut trend = vec![0.0; n * m];
684 let mut seasonal = vec![0.0; n * m];
685 let mut remainder = vec![0.0; n * m];
686
687 for idx in 0..n * m {
688 trend[idx] = additive_result.trend[idx].exp() - shift;
690
691 seasonal[idx] = additive_result.seasonal[idx].exp();
694
695 remainder[idx] = additive_result.remainder[idx].exp();
697 }
698
699 DecomposeResult {
700 trend,
701 seasonal,
702 remainder,
703 period,
704 method: "multiplicative".to_string(),
705 }
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711 use std::f64::consts::PI;
712
713 #[test]
714 fn test_detrend_linear_removes_linear_trend() {
715 let m = 100;
716 let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
717
718 let data: Vec<f64> = argvals
720 .iter()
721 .map(|&t| 2.0 + 0.5 * t + (2.0 * PI * t / 2.0).sin())
722 .collect();
723
724 let result = detrend_linear(&data, 1, m, &argvals);
725
726 let expected: Vec<f64> = argvals
728 .iter()
729 .map(|&t| (2.0 * PI * t / 2.0).sin())
730 .collect();
731
732 let mut max_diff = 0.0f64;
733 for j in 0..m {
734 let diff = (result.detrended[j] - expected[j]).abs();
735 max_diff = max_diff.max(diff);
736 }
737 assert!(max_diff < 0.2, "Max difference: {}", max_diff);
738 }
739
740 #[test]
741 fn test_detrend_polynomial_removes_quadratic_trend() {
742 let m = 100;
743 let argvals: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64 * 10.0).collect();
744
745 let data: Vec<f64> = argvals
747 .iter()
748 .map(|&t| 1.0 + 0.5 * t - 0.1 * t * t + (2.0 * PI * t / 2.0).sin())
749 .collect();
750
751 let result = detrend_polynomial(&data, 1, m, &argvals, 2);
752
753 let expected: Vec<f64> = argvals
755 .iter()
756 .map(|&t| (2.0 * PI * t / 2.0).sin())
757 .collect();
758
759 let mean_det: f64 = result.detrended.iter().sum::<f64>() / m as f64;
761 let mean_exp: f64 = expected.iter().sum::<f64>() / m as f64;
762 let mut num = 0.0;
763 let mut den_det = 0.0;
764 let mut den_exp = 0.0;
765 for j in 0..m {
766 num += (result.detrended[j] - mean_det) * (expected[j] - mean_exp);
767 den_det += (result.detrended[j] - mean_det).powi(2);
768 den_exp += (expected[j] - mean_exp).powi(2);
769 }
770 let corr = num / (den_det.sqrt() * den_exp.sqrt());
771 assert!(corr > 0.95, "Correlation: {}", corr);
772 }
773
774 #[test]
775 fn test_detrend_diff1() {
776 let m = 100;
777 let data: Vec<f64> = {
779 let mut v = vec![0.0; m];
780 v[0] = 1.0;
781 for i in 1..m {
782 v[i] = v[i - 1] + 0.1 * (i as f64).sin();
783 }
784 v
785 };
786
787 let result = detrend_diff(&data, 1, m, 1);
788
789 for j in 0..m - 1 {
791 let expected = data[j + 1] - data[j];
792 assert!(
793 (result.detrended[j] - expected).abs() < 1e-10,
794 "Mismatch at {}: {} vs {}",
795 j,
796 result.detrended[j],
797 expected
798 );
799 }
800 }
801
802 #[test]
803 fn test_auto_detrend_selects_linear_for_linear_data() {
804 let m = 100;
805 let argvals: Vec<f64> = (0..m).map(|i| i as f64).collect();
806
807 let data: Vec<f64> = argvals.iter().map(|&t| 2.0 + 0.5 * t).collect();
809
810 let result = auto_detrend(&data, 1, m, &argvals);
811
812 assert!(
814 result.method.contains("linear") || result.method.contains("polynomial"),
815 "Method: {}",
816 result.method
817 );
818 }
819}