stlplus_rs/
lib.rs

1pub mod error;
2mod loess;
3mod util;
4
5use crate::error::Error;
6use crate::loess::loess_stl;
7use crate::util::{NextOddInt, SumAgg, ValidateNotZero};
8use itertools::izip;
9use num_traits::{AsPrimitive, Float};
10use std::fmt::Debug;
11pub use util::NextOdd;
12
13#[derive(Debug, Eq, PartialEq, Copy, Clone)]
14pub enum Degree {
15    Degree0,
16    Degree1,
17    Degree2,
18}
19
20impl Degree {
21    /// index in the COEF_XX static arrays
22    fn coef_index(&self) -> usize {
23        match self {
24            Degree::Degree0 => 0,
25            Degree::Degree1 => 0,
26            Degree::Degree2 => 1,
27        }
28    }
29}
30
31macro_rules! impl_tryfrom_int_for_degree {
32    ($($t:ty),*) => {
33        $(
34            impl TryFrom<$t> for Degree {
35                type Error = Error;
36
37                fn try_from(value: $t) -> Result<Self, Self::Error> {
38                    match value {
39                        1 => Ok(Self::Degree0),
40                        2 => Ok(Self::Degree1),
41                        3 => Ok(Self::Degree2),
42                        _ => Err(Error::InvalidDegree),
43                    }
44                }
45            }
46        )*
47    };
48}
49impl_tryfrom_int_for_degree!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
50
51///
52/// Comments are still from the original R implementation and Java port.
53#[derive(Clone)]
54pub struct STLOptions {
55    /// The number of observations in each cycle of the seasonal component, n_p
56    pub num_obs_per_period: usize,
57
58    /// s.window either the character string \code{"periodic"} or the span (in lags) of the loess window for seasonal extraction,
59    /// which should be odd.  This has no default.
60    ///
61    /// None is used for periodic
62    pub s_window: Option<usize>,
63
64    /// s.degree degree of locally-fitted polynomial in seasonal extraction.  Should be 0, 1, or 2.
65    pub s_degree: Degree,
66
67    /// t.window the span (in lags) of the loess window for trend extraction, which should be odd.
68    /// If \code{NULL}, the default, \code{nextodd(ceiling((1.5*period) / (1-(1.5/s.window))))}, is taken.
69    pub t_window: Option<usize>,
70
71    /// t.degree degree of locally-fitted polynomial in trend extraction.  Should be 0, 1, or 2.
72    pub t_degree: Degree,
73
74    /// l.window the span (in lags) of the loess window of the low-pass filter used for each subseries.
75    ///
76    /// Defaults to the smallest odd integer greater than or equal to \code{n.p}
77    /// which is recommended since it prevents competition between the trend and seasonal components.
78    /// If not an odd integer its given value is increased to the next odd one.
79    pub l_window: Option<usize>,
80
81    /// l.degree degree of locally-fitted polynomial for the subseries low-pass filter.  Should be 0, 1, or 2.
82    pub l_degree: Degree,
83
84    /// s.jump s.jump,t.jump,l.jump,fc.jump integers at least one to increase speed of the respective smoother.
85    /// Linear interpolation happens between every \code{*.jump}th value.
86    pub s_jump: Option<usize>,
87
88    /// t.jump
89    pub t_jump: Option<usize>,
90
91    /// l.jump
92    pub l_jump: Option<usize>,
93
94    /// critfreq the critical frequency to use for automatic calculation of smoothing windows for the trend and high-pass filter.
95    pub critfreq: f64,
96
97    /// The number of passes through the inner loop, n_i
98    pub number_of_inner_loop_passes: u32,
99
100    /// The number of robustness iterations of the outer loop, n_o
101    pub number_of_robustness_iterations: u32,
102}
103
104impl Default for STLOptions {
105    fn default() -> Self {
106        let num_obs_per_period = 4;
107        Self {
108            num_obs_per_period,
109            s_window: None,
110            s_degree: Degree::Degree1,
111            t_window: None,
112            t_degree: Degree::Degree1,
113            l_window: None,
114            l_degree: Degree::Degree1,
115            s_jump: None,
116            t_jump: None,
117            l_jump: None,
118            critfreq: 0.05,
119            number_of_inner_loop_passes: 2,
120            number_of_robustness_iterations: 1,
121            //number_of_data_points: num_obs_per_period * 2,
122        }
123    }
124}
125
126pub struct STLOutput<VALUE> {
127    pub trend: Vec<VALUE>,
128    pub seasonal: Vec<VALUE>,
129    pub remainder: Vec<VALUE>,
130}
131
132pub fn stl_decompose<VALUE>(
133    values: &[VALUE],
134    options: &STLOptions,
135) -> Result<STLOutput<VALUE>, Error>
136where
137    VALUE: Float + 'static + Copy,
138    usize: AsPrimitive<VALUE>,
139    i64: AsPrimitive<VALUE>,
140{
141    let n = values.len();
142    let times_i64: Vec<_> = (0..n).map(|v| v as i64).collect();
143
144    stl_decompose_with_time(&times_i64, values, options)
145}
146
147pub fn stl_decompose_with_time<TIME, VALUE>(
148    times: &[TIME],
149    values: &[VALUE],
150    options: &STLOptions,
151) -> Result<STLOutput<VALUE>, Error>
152where
153    VALUE: Float + 'static + Copy,
154    usize: AsPrimitive<VALUE>,
155    TIME: AsPrimitive<i64>,
156    i64: AsPrimitive<VALUE>,
157{
158    if values.is_empty() || times.is_empty() {
159        return Err(Error::EmptyInputSlice);
160    }
161    let n = values.len();
162    if times.len() != n {
163        return Err(Error::InputSlicesDifferingLength);
164    }
165    if n < (options.num_obs_per_period * 2) {
166        return Err(Error::InputSliceToFewObs);
167    }
168
169    let times_i64: Vec<i64> = times.iter().map(|t| t.as_()).collect();
170
171    let validated_options = validate_options(options, n)?;
172    let mut trend = vec![VALUE::zero(); n];
173    let mut seasonal = vec![VALUE::zero(); n];
174    let mut deseasonalized = vec![VALUE::zero(); n];
175
176    // cycleSubIndices will keep track of what part of the seasonal each observation belongs to
177    let cycle_sub_indices: Vec<_> = (1..=validated_options.num_obs_per_period)
178        .cycle()
179        .take(values.len())
180        .collect();
181    let weights = vec![VALUE::one(); n];
182    let mut detrend = vec![VALUE::zero(); n];
183    // todo: missing stuff from java impl?
184
185    let mut cycle_sub = Vec::with_capacity(
186        (n as f64 / validated_options.num_obs_per_period as f64).ceil() as usize / 2,
187    );
188    let mut sub_weights = Vec::with_capacity(cycle_sub.capacity());
189
190    let (cs1, cs2) = {
191        let mut cs1 = Vec::with_capacity(validated_options.num_obs_per_period);
192        let mut cs2 = Vec::with_capacity(validated_options.num_obs_per_period);
193        for i in 0..validated_options.num_obs_per_period {
194            cs1.push(cycle_sub_indices[i]);
195            cs2.push(
196                cycle_sub_indices[values
197                    .len()
198                    .saturating_sub(validated_options.num_obs_per_period + i)],
199            );
200        }
201        (cs1, cs2)
202    };
203
204    let l_ev = Ev::new(n, validated_options.l_jump);
205    let t_ev = Ev::new(n, validated_options.t_jump);
206
207    let mut c = vec![VALUE::nan(); n + 2 * validated_options.num_obs_per_period];
208
209    // start and end indices for after adding in extra n.p before and after
210    let c_start_idx = validated_options.num_obs_per_period;
211    let c_end_idx = n - 1 + validated_options.num_obs_per_period;
212
213    for _outer_iteration_i in 0..options.number_of_robustness_iterations {
214        for _inner_iteration_i in 0..options.number_of_inner_loop_passes {
215            // Step 1: detrending
216            izip!(detrend.iter_mut(), values.iter(), trend.iter()).for_each(|(dt, v, t)| {
217                *dt = *v - *t;
218            });
219
220            // step 2: smoothing of cycle-subseries
221            for i in 0..validated_options.num_obs_per_period {
222                cycle_sub.clear();
223                sub_weights.clear();
224                let mut j = i;
225                while j < n {
226                    if cycle_sub_indices[j] == i + 1 {
227                        cycle_sub.push(detrend[j]);
228                        sub_weights.push(weights[j])
229                    }
230                    j += validated_options.num_obs_per_period;
231                }
232
233                let weight_mean_ans = weight_mean(&cycle_sub, &sub_weights)?;
234                j = i;
235                while j < validated_options.num_obs_per_period {
236                    if cs1[j] == i + 1 {
237                        c[j] = weight_mean_ans;
238                    }
239                    j += validated_options.num_obs_per_period;
240                }
241                j = i;
242                while j < n {
243                    if cycle_sub_indices[j] == i + 1 {
244                        c[j + validated_options.num_obs_per_period] = weight_mean_ans;
245                    }
246                    j += validated_options.num_obs_per_period;
247                }
248                for j in 0..validated_options.num_obs_per_period {
249                    if cs2[j] == i + 1 {
250                        c[j + validated_options.num_obs_per_period + n] = weight_mean_ans;
251                    }
252                }
253            }
254
255            // Step 3: Low-pass filtering of collection of all the cycle-subseries
256            // moving averages
257            let ma3 = cycle_subseries_moving_averages(&c, validated_options.num_obs_per_period);
258
259            // Step 4: Detrend smoothed cycle-subseries
260            let l = loess_stl(
261                &times_i64,
262                &ma3,
263                validated_options.l_window,
264                validated_options.l_degree,
265                l_ev.as_slice(),
266                &weights,
267                validated_options.l_jump,
268            )?;
269
270            // Step 5: Deseasonalize
271            izip!(
272                seasonal.iter_mut(),
273                (&c)[c_start_idx..=c_end_idx].iter(),
274                l.iter(),
275                values.iter(),
276                deseasonalized.iter_mut()
277            )
278            .for_each(|(s, c, l, v, d)| {
279                *s = *c - *l;
280                *d = *v - *s;
281            });
282
283            // Step 6: Trend Smoothing
284            trend = loess_stl(
285                &times_i64,
286                &deseasonalized,
287                validated_options.t_window,
288                validated_options.t_degree,
289                t_ev.as_slice(),
290                &weights,
291                validated_options.t_jump,
292            )?;
293        }
294    }
295
296    let remainder: Vec<_> = izip!(values.iter(), trend.iter(), seasonal.iter())
297        .map(|(v, t, s)| *v - *t - *s)
298        .collect();
299
300    Ok(STLOutput {
301        trend,
302        seasonal,
303        remainder,
304    })
305}
306
307struct Ev {
308    n: usize,
309    array_min_len: usize,
310    storage_vec: Vec<usize>,
311}
312
313impl Ev {
314    fn new(n: usize, jump: usize) -> Self {
315        let array_min_len = (n as f64 / jump as f64).ceil() as usize;
316        let mut storage_vec = vec![0usize; array_min_len + 1];
317
318        let mut i = 0;
319        let mut j = 0;
320        while i < array_min_len {
321            storage_vec[i] = j;
322            i += 1;
323            j += jump
324        }
325
326        // always have the last element == (n - 1)
327        storage_vec[array_min_len] = n - 1;
328
329        Self {
330            n,
331            array_min_len,
332            storage_vec,
333        }
334    }
335
336    /// return a slice where the last element == `n`
337    fn as_slice(&self) -> &[usize] {
338        if self.storage_vec[self.array_min_len - 1] != self.n - 1 {
339            &self.storage_vec
340        } else {
341            &self.storage_vec[0..self.array_min_len]
342        }
343    }
344}
345
346struct ValidatedOptions {
347    num_obs_per_period: usize,
348    //number_of_data_points: usize,
349    //s_window: usize,
350    //s_degree: Degree,
351    //s_jump: usize,
352    t_window: usize,
353    t_degree: Degree,
354    t_jump: usize,
355    l_window: usize,
356    l_degree: Degree,
357    l_jump: usize,
358    //periodic: bool,
359}
360
361fn validate_options(options: &STLOptions, num_values: usize) -> Result<ValidatedOptions, Error> {
362    let num_obs_per_period = if options.num_obs_per_period >= 4 {
363        options.num_obs_per_period
364    } else {
365        return Err(Error::InvalidNumObsPerPeriod);
366    };
367
368    /*
369    let number_of_data_points = if options.number_of_data_points > 2 * num_obs_per_period {
370        options.number_of_data_points
371    } else {
372        return Err(Error::InvalidNumberOfDataPoints);
373    };
374
375     */
376
377    let l_degree = options.l_degree;
378    let l_window = options.l_window.unwrap_or(num_obs_per_period).next_odd();
379    let l_jump = options.l_jump.unwrap_or_else(|| window_to_jump(l_window));
380
381    let (s_window, s_degree, _s_jump, _periodic) = if let Some(s_window) = options.s_window {
382        let s_window = validate_window(s_window)?;
383        let s_jump = options.s_jump.unwrap_or_else(|| window_to_jump(s_window));
384        (s_window, options.s_degree, s_jump, false)
385    } else {
386        // periodic
387        let s_window = 10 * num_values + 1;
388        let s_degree = Degree::Degree0;
389        let s_jump = window_to_jump(s_window);
390        (s_window, s_degree, s_jump, true)
391    };
392
393    let t_degree = options.t_degree;
394    let t_window = if let Some(t_window) = options.t_window {
395        validate_window(t_window)?
396    } else {
397        get_t_window(
398            t_degree,
399            s_degree,
400            s_window,
401            num_obs_per_period,
402            options.critfreq,
403        )?
404    };
405    let t_jump = options.t_jump.unwrap_or_else(|| window_to_jump(t_window));
406
407    Ok(ValidatedOptions {
408        num_obs_per_period,
409        //number_of_data_points,
410        //s_window,
411        //s_degree,
412        //s_jump,
413        t_window,
414        t_degree,
415        t_jump,
416        l_window,
417        l_degree,
418        l_jump,
419        //periodic,
420    })
421}
422
423static COEFS_A: [[f64; 2]; 2] = [
424    [0.000103350651767650, 3.81086166990428e-6],
425    [-0.000216653946625270, 0.000708495976681902],
426];
427static COEFS_B: [[f64; 2]; 3] = [
428    [1.42686036792937, 2.24089552678906],
429    [-3.1503819836694, -3.30435316073732],
430    [5.07481807116087, 5.08099438760489],
431];
432static COEFS_C: [[f64; 2]; 3] = [
433    [1.66534145060448, 2.33114333880815],
434    [-3.87719398039131, -1.8314816166323],
435    [6.46952900183769, 1.85431548427732],
436];
437
438fn get_t_window(
439    t_degree: Degree,
440    s_degree: Degree,
441    s_window: usize,
442    num_obs_per_period: usize,
443    critfreq: f64,
444) -> Result<usize, Error> {
445    let s_index = s_degree.coef_index();
446    let t_index = t_degree.coef_index();
447
448    // estimate critical frequency for seasonal
449    let betac0 = COEFS_A[1][s_index].mul_add(critfreq, COEFS_A[0][s_index]);
450    let betac1 = COEFS_B[1][s_index].mul_add(critfreq, COEFS_B[0][s_index])
451        + COEFS_B[2][s_index] * critfreq.powi(2);
452    let betac2 = COEFS_C[1][s_index].mul_add(critfreq, COEFS_C[0][s_index])
453        + COEFS_C[2][s_index] * critfreq.powi(2);
454
455    let f_c = (1.0 - (betac0 + betac1 / s_window as f64 + betac2 / s_window.pow(2) as f64))
456        / num_obs_per_period as f64;
457
458    // choose
459    let betat0 = COEFS_A[1][t_index].mul_add(critfreq, COEFS_A[0][t_index]);
460    let betat1 = COEFS_B[1][t_index].mul_add(critfreq, COEFS_B[0][t_index])
461        + COEFS_B[2][t_index] * critfreq.powi(2);
462    let betat2 = COEFS_C[1][t_index].mul_add(critfreq, COEFS_C[0][t_index])
463        + COEFS_C[2][t_index] * critfreq.powi(2);
464
465    let betat00 = betat0 - f_c;
466
467    Ok(
468        ((-betat1 - (betat1.powi(2) - 4.0 * betat00 * betat2).sqrt()) / (2.0 * betat00)).next_odd()
469            as usize,
470    )
471}
472
473fn validate_window(window: usize) -> Result<usize, Error> {
474    if window < 1 {
475        Err(Error::InvalidWindow)
476    } else {
477        Ok(window)
478    }
479}
480
481fn window_to_jump(window: usize) -> usize {
482    (window as f64 / 10.0).ceil() as usize
483}
484
485/// cycle-subseries moving averages
486///
487/// `num_obs_per_period` is the periodicity `n_p`
488///
489/// This function was called `c_ma` in stlplus
490pub(crate) fn cycle_subseries_moving_averages<F>(x: &[F], num_obs_per_period: usize) -> Vec<F>
491where
492    F: Float + 'static + Copy,
493    usize: AsPrimitive<F>,
494{
495    let nn = x.len().saturating_sub(num_obs_per_period * 2);
496
497    let mut ans = vec![F::zero(); x.len() - 2 * num_obs_per_period];
498    let mut ma = vec![F::zero(); nn + num_obs_per_period + 1];
499    let mut ma2 = vec![F::zero(); nn + 2];
500
501    let mut ma_tmp = x[0..num_obs_per_period].sum_agg();
502    ma[0] = ma_tmp / num_obs_per_period.as_();
503
504    for i in num_obs_per_period..(nn + 2 * num_obs_per_period) {
505        ma_tmp = ma_tmp - x[i - num_obs_per_period] + x[i];
506        ma[i - num_obs_per_period + 1] = ma_tmp / num_obs_per_period.as_();
507    }
508
509    ma_tmp = (ma[0..num_obs_per_period]).sum_agg();
510    ma2[0] = ma_tmp / num_obs_per_period.as_();
511    for i in num_obs_per_period..(nn + num_obs_per_period + 1) {
512        ma_tmp = ma_tmp - ma[i - num_obs_per_period] + ma[i];
513        ma2[i - num_obs_per_period + 1] = ma_tmp / num_obs_per_period.as_();
514    }
515
516    ma_tmp = (ma2[0..3]).sum_agg();
517    ans[0] = ma_tmp / 3usize.as_();
518    for i in 3..(nn + 2) {
519        ma_tmp = ma_tmp - ma2[i - 3] + ma2[i];
520        ans[i - 2] = ma_tmp / 3usize.as_();
521    }
522
523    ans
524}
525
526fn weight_mean<T>(x: &[T], w: &[T]) -> Result<T, Error>
527where
528    T: Float,
529{
530    let (sum, sum_w) = x.iter().zip(w.iter()).fold(
531        (T::zero(), T::zero()),
532        |(sum, sum_w), (x_value, w_value)| {
533            if !x_value.is_nan() {
534                (sum + (*x_value * *w_value), sum_w + *w_value)
535            } else {
536                (sum, sum_w)
537            }
538        },
539    );
540    Ok(sum / sum_w.validate_not_zero()?)
541}
542
543#[cfg(test)]
544mod tests {
545    use crate::{stl_decompose, STLOptions};
546
547    #[test]
548    fn c_ma() {
549        let input = vec![1.0f32, 1.0, 2.0, 2.0, 3.0, 3.0, 2.0, 2.0, 1.0, 1.0];
550        let n_p = 3;
551        let out = super::cycle_subseries_moving_averages(&input, n_p);
552        dbg!(out);
553    }
554
555    /// https://github.com/nmandery/stlplus-rs/issues/1
556    #[test]
557    fn indexing_within_bounds() {
558        let input = vec![0.0f32; 2581];
559        let options = STLOptions {
560            num_obs_per_period: 365,
561            ..Default::default()
562        };
563        // should not cause out-of-bounds panic
564        stl_decompose(&input, &options).unwrap();
565    }
566
567    /// https://github.com/nmandery/stlplus-rs/issues/1
568    #[test]
569    fn input_too_short() {
570        let input = vec![0.0f32; 25];
571        let options = STLOptions {
572            num_obs_per_period: 365,
573            ..Default::default()
574        };
575        // should not cause out-of-bounds panic
576        assert!(stl_decompose(&input, &options).is_err());
577    }
578}