Skip to main content

fdars_core/detrend/
stl.rs

1use crate::iter_maybe_parallel;
2use crate::matrix::FdMatrix;
3#[cfg(feature = "parallel")]
4use rayon::iter::ParallelIterator;
5
6// ============================================================================
7// STL Decomposition (Cleveland et al., 1990)
8// ============================================================================
9
10/// Result of STL decomposition including robustness weights.
11#[derive(Debug, Clone)]
12#[non_exhaustive]
13pub struct StlResult {
14    /// Trend component (n x m)
15    pub trend: FdMatrix,
16    /// Seasonal component (n x m)
17    pub seasonal: FdMatrix,
18    /// Remainder/residual component (n x m)
19    pub remainder: FdMatrix,
20    /// Robustness weights per point (n x m)
21    pub weights: FdMatrix,
22    /// Period used for decomposition
23    pub period: usize,
24    /// Seasonal smoothing window
25    pub s_window: usize,
26    /// Trend smoothing window
27    pub t_window: usize,
28    /// Number of inner loop iterations performed
29    pub inner_iterations: usize,
30    /// Number of outer loop iterations performed
31    pub outer_iterations: usize,
32}
33
34/// Configuration for STL decomposition.
35///
36/// Collects all tuning parameters for [`stl_decompose_with_config`], with sensible
37/// defaults obtained via [`StlConfig::default()`].
38///
39/// # Example
40/// ```no_run
41/// use fdars_core::detrend::stl::StlConfig;
42///
43/// let mut config = StlConfig::default();
44/// config.robust = true;
45/// config.s_window = Some(13);
46/// ```
47#[derive(Debug, Clone, PartialEq, Default)]
48#[non_exhaustive]
49pub struct StlConfig {
50    /// Seasonal smoothing window (default: `None` for auto = 7).
51    pub s_window: Option<usize>,
52    /// Trend smoothing window (default: `None` for auto).
53    pub t_window: Option<usize>,
54    /// Low-pass filter window (default: `None` for auto = period).
55    pub l_window: Option<usize>,
56    /// Whether to use robust (bisquare) weights (default: false).
57    pub robust: bool,
58    /// Number of inner loop iterations (default: `None` for auto = 2).
59    pub inner_iterations: Option<usize>,
60    /// Number of outer loop iterations (default: `None` for auto = 1 or 15 if robust).
61    pub outer_iterations: Option<usize>,
62}
63
64/// STL Decomposition using a [`StlConfig`] struct.
65///
66/// This is the config-based alternative to [`stl_decompose`]. It takes data
67/// and period directly, and reads all tuning parameters from the config.
68///
69/// # Arguments
70/// * `data` — Functional data matrix (n x m)
71/// * `period` — Seasonal period length
72/// * `config` — Tuning parameters
73pub fn stl_decompose_with_config(data: &FdMatrix, period: usize, config: &StlConfig) -> StlResult {
74    stl_decompose(
75        data,
76        period,
77        config.s_window,
78        config.t_window,
79        config.l_window,
80        config.robust,
81        config.inner_iterations,
82        config.outer_iterations,
83    )
84}
85
86/// STL Decomposition: Seasonal and Trend decomposition using LOESS.
87///
88/// # Arguments
89/// * `data` - Functional data matrix (n x m)
90/// * `period` - Seasonal period length
91/// * `s_window` - Seasonal smoothing window (None for auto)
92/// * `t_window` - Trend smoothing window (None for auto)
93/// * `l_window` - Low-pass filter window (None for auto)
94/// * `robust` - Whether to use robust weights
95/// * `inner_iterations` - Number of inner loop iterations (None for auto)
96/// * `outer_iterations` - Number of outer loop iterations (None for auto)
97///
98/// # Examples
99///
100/// ```
101/// use fdars_core::matrix::FdMatrix;
102/// use fdars_core::detrend::stl::stl_decompose;
103///
104/// let n = 3;
105/// let m = 40; // must be >= 2 * period
106/// let data = FdMatrix::from_column_major(
107///     (0..n * m).map(|i| {
108///         let t = (i % m) as f64;
109///         (t * std::f64::consts::PI / 5.0).sin() + t * 0.01
110///     }).collect(),
111///     n, m,
112/// ).unwrap();
113/// let result = stl_decompose(&data, 10, None, None, None, false, None, None);
114/// assert_eq!(result.trend.shape(), (n, m));
115/// assert_eq!(result.seasonal.shape(), (n, m));
116/// assert_eq!(result.remainder.shape(), (n, m));
117/// ```
118pub fn stl_decompose(
119    data: &FdMatrix,
120    period: usize,
121    s_window: Option<usize>,
122    t_window: Option<usize>,
123    l_window: Option<usize>,
124    robust: bool,
125    inner_iterations: Option<usize>,
126    outer_iterations: Option<usize>,
127) -> StlResult {
128    let (n, m) = data.shape();
129    if n == 0 || m < 2 * period || period < 2 {
130        return StlResult {
131            trend: FdMatrix::zeros(n, m),
132            seasonal: FdMatrix::zeros(n, m),
133            remainder: FdMatrix::from_slice(data.as_slice(), n, m)
134                .unwrap_or_else(|_| FdMatrix::zeros(n, m)),
135            weights: FdMatrix::from_column_major(vec![1.0; n * m], n, m)
136                .unwrap_or_else(|_| FdMatrix::zeros(n, m)),
137            period,
138            s_window: 0,
139            t_window: 0,
140            inner_iterations: 0,
141            outer_iterations: 0,
142        };
143    }
144    let s_win = s_window.unwrap_or(7).max(3) | 1;
145    let t_win = t_window.unwrap_or_else(|| {
146        let ratio = 1.5 * period as f64 / (1.0 - 1.5 / s_win as f64);
147        let val = ratio.ceil() as usize;
148        val.max(3) | 1
149    });
150    let l_win = l_window.unwrap_or(period) | 1;
151    let n_inner = inner_iterations.unwrap_or(2);
152    let n_outer = outer_iterations.unwrap_or(if robust { 15 } else { 1 });
153    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
154        .map(|i| {
155            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
156            stl_single_series(
157                &curve, period, s_win, t_win, l_win, robust, n_inner, n_outer,
158            )
159        })
160        .collect();
161    let mut trend = FdMatrix::zeros(n, m);
162    let mut seasonal = FdMatrix::zeros(n, m);
163    let mut remainder = FdMatrix::zeros(n, m);
164    let mut weights = FdMatrix::from_column_major(vec![1.0; n * m], n, m)
165        .expect("dimension invariant: data.len() == n * m");
166    for (i, (t, s, r, w)) in results.into_iter().enumerate() {
167        for j in 0..m {
168            trend[(i, j)] = t[j];
169            seasonal[(i, j)] = s[j];
170            remainder[(i, j)] = r[j];
171            weights[(i, j)] = w[j];
172        }
173    }
174    StlResult {
175        trend,
176        seasonal,
177        remainder,
178        weights,
179        period,
180        s_window: s_win,
181        t_window: t_win,
182        inner_iterations: n_inner,
183        outer_iterations: n_outer,
184    }
185}
186
187fn stl_single_series(
188    data: &[f64],
189    period: usize,
190    s_window: usize,
191    t_window: usize,
192    l_window: usize,
193    robust: bool,
194    n_inner: usize,
195    n_outer: usize,
196) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
197    let m = data.len();
198    let mut trend = vec![0.0; m];
199    let mut seasonal = vec![0.0; m];
200    let mut weights = vec![1.0; m];
201    for outer in 0..n_outer {
202        for _inner in 0..n_inner {
203            let detrended: Vec<f64> = data
204                .iter()
205                .zip(trend.iter())
206                .map(|(&y, &t)| y - t)
207                .collect();
208            let cycle_smoothed = smooth_cycle_subseries(&detrended, period, s_window, &weights);
209            let low_pass = stl_lowpass_filter(&cycle_smoothed, period, l_window);
210            seasonal = cycle_smoothed
211                .iter()
212                .zip(low_pass.iter())
213                .map(|(&c, &l)| c - l)
214                .collect();
215            let deseasonalized: Vec<f64> = data
216                .iter()
217                .zip(seasonal.iter())
218                .map(|(&y, &s)| y - s)
219                .collect();
220            trend = weighted_loess(&deseasonalized, t_window, &weights);
221        }
222        if robust && outer < n_outer - 1 {
223            let remainder: Vec<f64> = data
224                .iter()
225                .zip(trend.iter())
226                .zip(seasonal.iter())
227                .map(|((&y, &t), &s)| y - t - s)
228                .collect();
229            weights = compute_robustness_weights(&remainder);
230        }
231    }
232    let remainder: Vec<f64> = data
233        .iter()
234        .zip(trend.iter())
235        .zip(seasonal.iter())
236        .map(|((&y, &t), &s)| y - t - s)
237        .collect();
238    (trend, seasonal, remainder, weights)
239}
240
241fn smooth_cycle_subseries(
242    data: &[f64],
243    period: usize,
244    s_window: usize,
245    weights: &[f64],
246) -> Vec<f64> {
247    let m = data.len();
248    let n_cycles = m.div_ceil(period);
249    let mut result = vec![0.0; m];
250    for pos in 0..period {
251        let mut subseries_idx: Vec<usize> = Vec::new();
252        let mut subseries_vals: Vec<f64> = Vec::new();
253        let mut subseries_weights: Vec<f64> = Vec::new();
254        for cycle in 0..n_cycles {
255            let idx = cycle * period + pos;
256            if idx < m {
257                subseries_idx.push(idx);
258                subseries_vals.push(data[idx]);
259                subseries_weights.push(weights[idx]);
260            }
261        }
262        if subseries_vals.is_empty() {
263            continue;
264        }
265        let smoothed = weighted_loess(&subseries_vals, s_window, &subseries_weights);
266        for (i, &idx) in subseries_idx.iter().enumerate() {
267            result[idx] = smoothed[i];
268        }
269    }
270    result
271}
272
273fn stl_lowpass_filter(data: &[f64], period: usize, _l_window: usize) -> Vec<f64> {
274    let ma1 = moving_average(data, period);
275    let ma2 = moving_average(&ma1, period);
276    moving_average(&ma2, 3)
277}
278
279fn moving_average(data: &[f64], window: usize) -> Vec<f64> {
280    let m = data.len();
281    if m == 0 || window == 0 {
282        return data.to_vec();
283    }
284    let half = window / 2;
285    let mut result = vec![0.0; m];
286    for i in 0..m {
287        let start = i.saturating_sub(half);
288        let end = (i + half + 1).min(m);
289        let sum: f64 = data[start..end].iter().sum();
290        let count = (end - start) as f64;
291        result[i] = sum / count;
292    }
293    result
294}
295
296fn weighted_loess(data: &[f64], window: usize, weights: &[f64]) -> Vec<f64> {
297    let m = data.len();
298    if m == 0 {
299        return vec![];
300    }
301    let half = window / 2;
302    let mut result = vec![0.0; m];
303    for i in 0..m {
304        let start = i.saturating_sub(half);
305        let end = (i + half + 1).min(m);
306        let mut sum_w = 0.0;
307        let mut sum_wx = 0.0;
308        let mut sum_wy = 0.0;
309        let mut sum_wxx = 0.0;
310        let mut sum_wxy = 0.0;
311        for j in start..end {
312            let dist = (j as f64 - i as f64).abs() / (half.max(1) as f64);
313            let tricube = if dist < 1.0 {
314                (1.0 - dist.powi(3)).powi(3)
315            } else {
316                0.0
317            };
318            let w = tricube * weights[j];
319            let x = j as f64;
320            let y = data[j];
321            sum_w += w;
322            sum_wx += w * x;
323            sum_wy += w * y;
324            sum_wxx += w * x * x;
325            sum_wxy += w * x * y;
326        }
327        if sum_w > 1e-10 {
328            let denom = sum_w * sum_wxx - sum_wx * sum_wx;
329            if denom.abs() > 1e-10 {
330                let intercept = (sum_wxx * sum_wy - sum_wx * sum_wxy) / denom;
331                let slope = (sum_w * sum_wxy - sum_wx * sum_wy) / denom;
332                result[i] = intercept + slope * i as f64;
333            } else {
334                result[i] = sum_wy / sum_w;
335            }
336        } else {
337            result[i] = data[i];
338        }
339    }
340    result
341}
342
343fn compute_robustness_weights(residuals: &[f64]) -> Vec<f64> {
344    let m = residuals.len();
345    if m == 0 {
346        return vec![];
347    }
348    let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
349    crate::helpers::sort_nan_safe(&mut abs_residuals);
350    let median_idx = m / 2;
351    let mad = if m % 2 == 0 {
352        (abs_residuals[median_idx - 1] + abs_residuals[median_idx]) / 2.0
353    } else {
354        abs_residuals[median_idx]
355    };
356    let h = 6.0 * mad.max(1e-10);
357    residuals
358        .iter()
359        .map(|&r| {
360            let u = r.abs() / h;
361            if u < 1.0 {
362                (1.0 - u * u).powi(2)
363            } else {
364                0.0
365            }
366        })
367        .collect()
368}
369
370/// Wrapper function for functional data STL decomposition.
371pub fn stl_fdata(
372    data: &FdMatrix,
373    _argvals: &[f64],
374    period: usize,
375    s_window: Option<usize>,
376    t_window: Option<usize>,
377    robust: bool,
378) -> StlResult {
379    stl_decompose(data, period, s_window, t_window, None, robust, None, None)
380}