1use crate::iter_maybe_parallel;
2use crate::matrix::FdMatrix;
3#[cfg(feature = "parallel")]
4use rayon::iter::ParallelIterator;
5
6#[derive(Debug, Clone)]
12#[non_exhaustive]
13pub struct StlResult {
14 pub trend: FdMatrix,
16 pub seasonal: FdMatrix,
18 pub remainder: FdMatrix,
20 pub weights: FdMatrix,
22 pub period: usize,
24 pub s_window: usize,
26 pub t_window: usize,
28 pub inner_iterations: usize,
30 pub outer_iterations: usize,
32}
33
34#[derive(Debug, Clone, PartialEq, Default)]
48#[non_exhaustive]
49pub struct StlConfig {
50 pub s_window: Option<usize>,
52 pub t_window: Option<usize>,
54 pub l_window: Option<usize>,
56 pub robust: bool,
58 pub inner_iterations: Option<usize>,
60 pub outer_iterations: Option<usize>,
62}
63
64pub fn stl_decompose_with_config(data: &FdMatrix, period: usize, config: &StlConfig) -> StlResult {
74 stl_decompose(
75 data,
76 period,
77 config.s_window,
78 config.t_window,
79 config.l_window,
80 config.robust,
81 config.inner_iterations,
82 config.outer_iterations,
83 )
84}
85
86pub fn stl_decompose(
119 data: &FdMatrix,
120 period: usize,
121 s_window: Option<usize>,
122 t_window: Option<usize>,
123 l_window: Option<usize>,
124 robust: bool,
125 inner_iterations: Option<usize>,
126 outer_iterations: Option<usize>,
127) -> StlResult {
128 let (n, m) = data.shape();
129 if n == 0 || m < 2 * period || period < 2 {
130 return StlResult {
131 trend: FdMatrix::zeros(n, m),
132 seasonal: FdMatrix::zeros(n, m),
133 remainder: FdMatrix::from_slice(data.as_slice(), n, m)
134 .unwrap_or_else(|_| FdMatrix::zeros(n, m)),
135 weights: FdMatrix::from_column_major(vec![1.0; n * m], n, m)
136 .unwrap_or_else(|_| FdMatrix::zeros(n, m)),
137 period,
138 s_window: 0,
139 t_window: 0,
140 inner_iterations: 0,
141 outer_iterations: 0,
142 };
143 }
144 let s_win = s_window.unwrap_or(7).max(3) | 1;
145 let t_win = t_window.unwrap_or_else(|| {
146 let ratio = 1.5 * period as f64 / (1.0 - 1.5 / s_win as f64);
147 let val = ratio.ceil() as usize;
148 val.max(3) | 1
149 });
150 let l_win = l_window.unwrap_or(period) | 1;
151 let n_inner = inner_iterations.unwrap_or(2);
152 let n_outer = outer_iterations.unwrap_or(if robust { 15 } else { 1 });
153 let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
154 .map(|i| {
155 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
156 stl_single_series(
157 &curve, period, s_win, t_win, l_win, robust, n_inner, n_outer,
158 )
159 })
160 .collect();
161 let mut trend = FdMatrix::zeros(n, m);
162 let mut seasonal = FdMatrix::zeros(n, m);
163 let mut remainder = FdMatrix::zeros(n, m);
164 let mut weights = FdMatrix::from_column_major(vec![1.0; n * m], n, m)
165 .expect("dimension invariant: data.len() == n * m");
166 for (i, (t, s, r, w)) in results.into_iter().enumerate() {
167 for j in 0..m {
168 trend[(i, j)] = t[j];
169 seasonal[(i, j)] = s[j];
170 remainder[(i, j)] = r[j];
171 weights[(i, j)] = w[j];
172 }
173 }
174 StlResult {
175 trend,
176 seasonal,
177 remainder,
178 weights,
179 period,
180 s_window: s_win,
181 t_window: t_win,
182 inner_iterations: n_inner,
183 outer_iterations: n_outer,
184 }
185}
186
187fn stl_single_series(
188 data: &[f64],
189 period: usize,
190 s_window: usize,
191 t_window: usize,
192 l_window: usize,
193 robust: bool,
194 n_inner: usize,
195 n_outer: usize,
196) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
197 let m = data.len();
198 let mut trend = vec![0.0; m];
199 let mut seasonal = vec![0.0; m];
200 let mut weights = vec![1.0; m];
201 for outer in 0..n_outer {
202 for _inner in 0..n_inner {
203 let detrended: Vec<f64> = data
204 .iter()
205 .zip(trend.iter())
206 .map(|(&y, &t)| y - t)
207 .collect();
208 let cycle_smoothed = smooth_cycle_subseries(&detrended, period, s_window, &weights);
209 let low_pass = stl_lowpass_filter(&cycle_smoothed, period, l_window);
210 seasonal = cycle_smoothed
211 .iter()
212 .zip(low_pass.iter())
213 .map(|(&c, &l)| c - l)
214 .collect();
215 let deseasonalized: Vec<f64> = data
216 .iter()
217 .zip(seasonal.iter())
218 .map(|(&y, &s)| y - s)
219 .collect();
220 trend = weighted_loess(&deseasonalized, t_window, &weights);
221 }
222 if robust && outer < n_outer - 1 {
223 let remainder: Vec<f64> = data
224 .iter()
225 .zip(trend.iter())
226 .zip(seasonal.iter())
227 .map(|((&y, &t), &s)| y - t - s)
228 .collect();
229 weights = compute_robustness_weights(&remainder);
230 }
231 }
232 let remainder: Vec<f64> = data
233 .iter()
234 .zip(trend.iter())
235 .zip(seasonal.iter())
236 .map(|((&y, &t), &s)| y - t - s)
237 .collect();
238 (trend, seasonal, remainder, weights)
239}
240
241fn smooth_cycle_subseries(
242 data: &[f64],
243 period: usize,
244 s_window: usize,
245 weights: &[f64],
246) -> Vec<f64> {
247 let m = data.len();
248 let n_cycles = m.div_ceil(period);
249 let mut result = vec![0.0; m];
250 for pos in 0..period {
251 let mut subseries_idx: Vec<usize> = Vec::new();
252 let mut subseries_vals: Vec<f64> = Vec::new();
253 let mut subseries_weights: Vec<f64> = Vec::new();
254 for cycle in 0..n_cycles {
255 let idx = cycle * period + pos;
256 if idx < m {
257 subseries_idx.push(idx);
258 subseries_vals.push(data[idx]);
259 subseries_weights.push(weights[idx]);
260 }
261 }
262 if subseries_vals.is_empty() {
263 continue;
264 }
265 let smoothed = weighted_loess(&subseries_vals, s_window, &subseries_weights);
266 for (i, &idx) in subseries_idx.iter().enumerate() {
267 result[idx] = smoothed[i];
268 }
269 }
270 result
271}
272
273fn stl_lowpass_filter(data: &[f64], period: usize, _l_window: usize) -> Vec<f64> {
274 let ma1 = moving_average(data, period);
275 let ma2 = moving_average(&ma1, period);
276 moving_average(&ma2, 3)
277}
278
279fn moving_average(data: &[f64], window: usize) -> Vec<f64> {
280 let m = data.len();
281 if m == 0 || window == 0 {
282 return data.to_vec();
283 }
284 let half = window / 2;
285 let mut result = vec![0.0; m];
286 for i in 0..m {
287 let start = i.saturating_sub(half);
288 let end = (i + half + 1).min(m);
289 let sum: f64 = data[start..end].iter().sum();
290 let count = (end - start) as f64;
291 result[i] = sum / count;
292 }
293 result
294}
295
296fn weighted_loess(data: &[f64], window: usize, weights: &[f64]) -> Vec<f64> {
297 let m = data.len();
298 if m == 0 {
299 return vec![];
300 }
301 let half = window / 2;
302 let mut result = vec![0.0; m];
303 for i in 0..m {
304 let start = i.saturating_sub(half);
305 let end = (i + half + 1).min(m);
306 let mut sum_w = 0.0;
307 let mut sum_wx = 0.0;
308 let mut sum_wy = 0.0;
309 let mut sum_wxx = 0.0;
310 let mut sum_wxy = 0.0;
311 for j in start..end {
312 let dist = (j as f64 - i as f64).abs() / (half.max(1) as f64);
313 let tricube = if dist < 1.0 {
314 (1.0 - dist.powi(3)).powi(3)
315 } else {
316 0.0
317 };
318 let w = tricube * weights[j];
319 let x = j as f64;
320 let y = data[j];
321 sum_w += w;
322 sum_wx += w * x;
323 sum_wy += w * y;
324 sum_wxx += w * x * x;
325 sum_wxy += w * x * y;
326 }
327 if sum_w > 1e-10 {
328 let denom = sum_w * sum_wxx - sum_wx * sum_wx;
329 if denom.abs() > 1e-10 {
330 let intercept = (sum_wxx * sum_wy - sum_wx * sum_wxy) / denom;
331 let slope = (sum_w * sum_wxy - sum_wx * sum_wy) / denom;
332 result[i] = intercept + slope * i as f64;
333 } else {
334 result[i] = sum_wy / sum_w;
335 }
336 } else {
337 result[i] = data[i];
338 }
339 }
340 result
341}
342
343fn compute_robustness_weights(residuals: &[f64]) -> Vec<f64> {
344 let m = residuals.len();
345 if m == 0 {
346 return vec![];
347 }
348 let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
349 crate::helpers::sort_nan_safe(&mut abs_residuals);
350 let median_idx = m / 2;
351 let mad = if m % 2 == 0 {
352 (abs_residuals[median_idx - 1] + abs_residuals[median_idx]) / 2.0
353 } else {
354 abs_residuals[median_idx]
355 };
356 let h = 6.0 * mad.max(1e-10);
357 residuals
358 .iter()
359 .map(|&r| {
360 let u = r.abs() / h;
361 if u < 1.0 {
362 (1.0 - u * u).powi(2)
363 } else {
364 0.0
365 }
366 })
367 .collect()
368}
369
370pub fn stl_fdata(
372 data: &FdMatrix,
373 _argvals: &[f64],
374 period: usize,
375 s_window: Option<usize>,
376 t_window: Option<usize>,
377 robust: bool,
378) -> StlResult {
379 stl_decompose(data, period, s_window, t_window, None, robust, None, None)
380}