Skip to main content

fdars_core/detrend/
stl.rs

1use crate::iter_maybe_parallel;
2use crate::matrix::FdMatrix;
3#[cfg(feature = "parallel")]
4use rayon::iter::ParallelIterator;
5
6// ============================================================================
7// STL Decomposition (Cleveland et al., 1990)
8// ============================================================================
9
10/// Result of STL decomposition including robustness weights.
11#[derive(Debug, Clone)]
12pub struct StlResult {
13    /// Trend component (n x m)
14    pub trend: FdMatrix,
15    /// Seasonal component (n x m)
16    pub seasonal: FdMatrix,
17    /// Remainder/residual component (n x m)
18    pub remainder: FdMatrix,
19    /// Robustness weights per point (n x m)
20    pub weights: FdMatrix,
21    /// Period used for decomposition
22    pub period: usize,
23    /// Seasonal smoothing window
24    pub s_window: usize,
25    /// Trend smoothing window
26    pub t_window: usize,
27    /// Number of inner loop iterations performed
28    pub inner_iterations: usize,
29    /// Number of outer loop iterations performed
30    pub outer_iterations: usize,
31}
32
33/// Configuration for STL decomposition.
34///
35/// Collects all tuning parameters for [`stl_decompose_with_config`], with sensible
36/// defaults obtained via [`StlConfig::default()`].
37///
38/// # Example
39/// ```no_run
40/// use fdars_core::detrend::stl::StlConfig;
41///
42/// let config = StlConfig {
43///     robust: true,
44///     s_window: Some(13),
45///     ..StlConfig::default()
46/// };
47/// ```
48#[derive(Debug, Clone, Default)]
49pub struct StlConfig {
50    /// Seasonal smoothing window (default: `None` for auto = 7).
51    pub s_window: Option<usize>,
52    /// Trend smoothing window (default: `None` for auto).
53    pub t_window: Option<usize>,
54    /// Low-pass filter window (default: `None` for auto = period).
55    pub l_window: Option<usize>,
56    /// Whether to use robust (bisquare) weights (default: false).
57    pub robust: bool,
58    /// Number of inner loop iterations (default: `None` for auto = 2).
59    pub inner_iterations: Option<usize>,
60    /// Number of outer loop iterations (default: `None` for auto = 1 or 15 if robust).
61    pub outer_iterations: Option<usize>,
62}
63
64/// STL Decomposition using a [`StlConfig`] struct.
65///
66/// This is the config-based alternative to [`stl_decompose`]. It takes data
67/// and period directly, and reads all tuning parameters from the config.
68///
69/// # Arguments
70/// * `data` — Functional data matrix (n x m)
71/// * `period` — Seasonal period length
72/// * `config` — Tuning parameters
73pub fn stl_decompose_with_config(data: &FdMatrix, period: usize, config: &StlConfig) -> StlResult {
74    stl_decompose(
75        data,
76        period,
77        config.s_window,
78        config.t_window,
79        config.l_window,
80        config.robust,
81        config.inner_iterations,
82        config.outer_iterations,
83    )
84}
85
86/// STL Decomposition: Seasonal and Trend decomposition using LOESS
87pub fn stl_decompose(
88    data: &FdMatrix,
89    period: usize,
90    s_window: Option<usize>,
91    t_window: Option<usize>,
92    l_window: Option<usize>,
93    robust: bool,
94    inner_iterations: Option<usize>,
95    outer_iterations: Option<usize>,
96) -> StlResult {
97    let (n, m) = data.shape();
98    if n == 0 || m < 2 * period || period < 2 {
99        return StlResult {
100            trend: FdMatrix::zeros(n, m),
101            seasonal: FdMatrix::zeros(n, m),
102            remainder: FdMatrix::from_slice(data.as_slice(), n, m)
103                .unwrap_or_else(|_| FdMatrix::zeros(n, m)),
104            weights: FdMatrix::from_column_major(vec![1.0; n * m], n, m)
105                .unwrap_or_else(|_| FdMatrix::zeros(n, m)),
106            period,
107            s_window: 0,
108            t_window: 0,
109            inner_iterations: 0,
110            outer_iterations: 0,
111        };
112    }
113    let s_win = s_window.unwrap_or(7).max(3) | 1;
114    let t_win = t_window.unwrap_or_else(|| {
115        let ratio = 1.5 * period as f64 / (1.0 - 1.5 / s_win as f64);
116        let val = ratio.ceil() as usize;
117        val.max(3) | 1
118    });
119    let l_win = l_window.unwrap_or(period) | 1;
120    let n_inner = inner_iterations.unwrap_or(2);
121    let n_outer = outer_iterations.unwrap_or(if robust { 15 } else { 1 });
122    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
123        .map(|i| {
124            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
125            stl_single_series(
126                &curve, period, s_win, t_win, l_win, robust, n_inner, n_outer,
127            )
128        })
129        .collect();
130    let mut trend = FdMatrix::zeros(n, m);
131    let mut seasonal = FdMatrix::zeros(n, m);
132    let mut remainder = FdMatrix::zeros(n, m);
133    let mut weights = FdMatrix::from_column_major(vec![1.0; n * m], n, m)
134        .expect("dimension invariant: data.len() == n * m");
135    for (i, (t, s, r, w)) in results.into_iter().enumerate() {
136        for j in 0..m {
137            trend[(i, j)] = t[j];
138            seasonal[(i, j)] = s[j];
139            remainder[(i, j)] = r[j];
140            weights[(i, j)] = w[j];
141        }
142    }
143    StlResult {
144        trend,
145        seasonal,
146        remainder,
147        weights,
148        period,
149        s_window: s_win,
150        t_window: t_win,
151        inner_iterations: n_inner,
152        outer_iterations: n_outer,
153    }
154}
155
156fn stl_single_series(
157    data: &[f64],
158    period: usize,
159    s_window: usize,
160    t_window: usize,
161    l_window: usize,
162    robust: bool,
163    n_inner: usize,
164    n_outer: usize,
165) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
166    let m = data.len();
167    let mut trend = vec![0.0; m];
168    let mut seasonal = vec![0.0; m];
169    let mut weights = vec![1.0; m];
170    for _outer in 0..n_outer {
171        for _inner in 0..n_inner {
172            let detrended: Vec<f64> = data
173                .iter()
174                .zip(trend.iter())
175                .map(|(&y, &t)| y - t)
176                .collect();
177            let cycle_smoothed = smooth_cycle_subseries(&detrended, period, s_window, &weights);
178            let low_pass = stl_lowpass_filter(&cycle_smoothed, period, l_window);
179            seasonal = cycle_smoothed
180                .iter()
181                .zip(low_pass.iter())
182                .map(|(&c, &l)| c - l)
183                .collect();
184            let deseasonalized: Vec<f64> = data
185                .iter()
186                .zip(seasonal.iter())
187                .map(|(&y, &s)| y - s)
188                .collect();
189            trend = weighted_loess(&deseasonalized, t_window, &weights);
190        }
191        if robust && _outer < n_outer - 1 {
192            let remainder: Vec<f64> = data
193                .iter()
194                .zip(trend.iter())
195                .zip(seasonal.iter())
196                .map(|((&y, &t), &s)| y - t - s)
197                .collect();
198            weights = compute_robustness_weights(&remainder);
199        }
200    }
201    let remainder: Vec<f64> = data
202        .iter()
203        .zip(trend.iter())
204        .zip(seasonal.iter())
205        .map(|((&y, &t), &s)| y - t - s)
206        .collect();
207    (trend, seasonal, remainder, weights)
208}
209
210fn smooth_cycle_subseries(
211    data: &[f64],
212    period: usize,
213    s_window: usize,
214    weights: &[f64],
215) -> Vec<f64> {
216    let m = data.len();
217    let n_cycles = m.div_ceil(period);
218    let mut result = vec![0.0; m];
219    for pos in 0..period {
220        let mut subseries_idx: Vec<usize> = Vec::new();
221        let mut subseries_vals: Vec<f64> = Vec::new();
222        let mut subseries_weights: Vec<f64> = Vec::new();
223        for cycle in 0..n_cycles {
224            let idx = cycle * period + pos;
225            if idx < m {
226                subseries_idx.push(idx);
227                subseries_vals.push(data[idx]);
228                subseries_weights.push(weights[idx]);
229            }
230        }
231        if subseries_vals.is_empty() {
232            continue;
233        }
234        let smoothed = weighted_loess(&subseries_vals, s_window, &subseries_weights);
235        for (i, &idx) in subseries_idx.iter().enumerate() {
236            result[idx] = smoothed[i];
237        }
238    }
239    result
240}
241
242fn stl_lowpass_filter(data: &[f64], period: usize, _l_window: usize) -> Vec<f64> {
243    let ma1 = moving_average(data, period);
244    let ma2 = moving_average(&ma1, period);
245    moving_average(&ma2, 3)
246}
247
248fn moving_average(data: &[f64], window: usize) -> Vec<f64> {
249    let m = data.len();
250    if m == 0 || window == 0 {
251        return data.to_vec();
252    }
253    let half = window / 2;
254    let mut result = vec![0.0; m];
255    for i in 0..m {
256        let start = i.saturating_sub(half);
257        let end = (i + half + 1).min(m);
258        let sum: f64 = data[start..end].iter().sum();
259        let count = (end - start) as f64;
260        result[i] = sum / count;
261    }
262    result
263}
264
265fn weighted_loess(data: &[f64], window: usize, weights: &[f64]) -> Vec<f64> {
266    let m = data.len();
267    if m == 0 {
268        return vec![];
269    }
270    let half = window / 2;
271    let mut result = vec![0.0; m];
272    for i in 0..m {
273        let start = i.saturating_sub(half);
274        let end = (i + half + 1).min(m);
275        let mut sum_w = 0.0;
276        let mut sum_wx = 0.0;
277        let mut sum_wy = 0.0;
278        let mut sum_wxx = 0.0;
279        let mut sum_wxy = 0.0;
280        for j in start..end {
281            let dist = (j as f64 - i as f64).abs() / (half.max(1) as f64);
282            let tricube = if dist < 1.0 {
283                (1.0 - dist.powi(3)).powi(3)
284            } else {
285                0.0
286            };
287            let w = tricube * weights[j];
288            let x = j as f64;
289            let y = data[j];
290            sum_w += w;
291            sum_wx += w * x;
292            sum_wy += w * y;
293            sum_wxx += w * x * x;
294            sum_wxy += w * x * y;
295        }
296        if sum_w > 1e-10 {
297            let denom = sum_w * sum_wxx - sum_wx * sum_wx;
298            if denom.abs() > 1e-10 {
299                let intercept = (sum_wxx * sum_wy - sum_wx * sum_wxy) / denom;
300                let slope = (sum_w * sum_wxy - sum_wx * sum_wy) / denom;
301                result[i] = intercept + slope * i as f64;
302            } else {
303                result[i] = sum_wy / sum_w;
304            }
305        } else {
306            result[i] = data[i];
307        }
308    }
309    result
310}
311
312fn compute_robustness_weights(residuals: &[f64]) -> Vec<f64> {
313    let m = residuals.len();
314    if m == 0 {
315        return vec![];
316    }
317    let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
318    crate::helpers::sort_nan_safe(&mut abs_residuals);
319    let median_idx = m / 2;
320    let mad = if m % 2 == 0 {
321        (abs_residuals[median_idx - 1] + abs_residuals[median_idx]) / 2.0
322    } else {
323        abs_residuals[median_idx]
324    };
325    let h = 6.0 * mad.max(1e-10);
326    residuals
327        .iter()
328        .map(|&r| {
329            let u = r.abs() / h;
330            if u < 1.0 {
331                (1.0 - u * u).powi(2)
332            } else {
333                0.0
334            }
335        })
336        .collect()
337}
338
339/// Wrapper function for functional data STL decomposition.
340pub fn stl_fdata(
341    data: &FdMatrix,
342    _argvals: &[f64],
343    period: usize,
344    s_window: Option<usize>,
345    t_window: Option<usize>,
346    robust: bool,
347) -> StlResult {
348    stl_decompose(data, period, s_window, t_window, None, robust, None, None)
349}