savgol_rs/
lib.rs

1//! Savitzky-Golay filter implementation.
2use std::fmt::Debug;
3
4use lstsq::lstsq;
5use na::{DMatrix, DVector};
6use nalgebra as na;
7use polyfit_rs::polyfit_rs::polyfit;
8
9fn factorial(num: usize) -> usize {
10    match num {
11        0 => 1,
12        1 => 1,
13        _ => factorial(num - 1) * num,
14    }
15}
16
17fn savgol_coeffs(
18    window_length: usize,
19    poly_order: usize,
20    deriv: usize,
21) -> Result<DVector<f64>, String> {
22    if poly_order >= window_length {
23        return Err("poly_order must be less than window_length.".to_string());
24    }
25
26    let (halflen, rem) = (window_length / 2, window_length % 2);
27
28    let pos = match rem {
29        0 => halflen as f64 - 0.5,
30        _ => halflen as f64,
31    };
32
33    if deriv > poly_order {
34        return Ok(DVector::from_element(window_length, 0.0));
35    }
36
37    let x = DVector::from_fn(window_length, |i, _| pos - i as f64);
38    let order = DVector::from_fn(poly_order + 1, |i, _| i);
39    let mat_a = DMatrix::from_fn(poly_order + 1, window_length, |i, j| {
40        x[j].powf(order[i] as f64)
41    });
42
43    let mut y = DVector::from_element(poly_order + 1, 0.0);
44    y[deriv] = factorial(deriv) as f64;
45
46    let epsilon = 1e-14;
47    let results = lstsq(&mat_a, &y, epsilon)?;
48    let solution = results.solution;
49
50    return Ok(solution);
51}
52
53fn poly_derivative(coeffs: &[f64]) -> Vec<f64> {
54    coeffs[1..]
55        .iter()
56        .enumerate()
57        .map(|(i, c)| c * (i + 1) as f64)
58        .collect()
59}
60
61fn polyval(poly: &[f64], values: &[f64]) -> Vec<f64> {
62    return values
63        .iter()
64        .map(|v| {
65            poly.iter()
66                .enumerate()
67                .fold(0.0, |y, (i, c)| y + c * v.powf(i as f64))
68        })
69        .collect();
70}
71
72fn fit_edge(
73    x: &DVector<f64>,
74    window_start: usize,
75    window_stop: usize,
76    interp_start: usize,
77    interp_stop: usize,
78    poly_order: usize,
79    deriv: usize,
80    y: &mut Vec<f64>,
81) -> Result<(), String> {
82    let x_edge: Vec<f64> = x.as_slice()[window_start..window_stop].to_vec();
83    let y_edge: Vec<f64> = (0..window_stop - window_start).map(|i| i as f64).collect();
84    let mut poly_coeffs = polyfit(&y_edge, &x_edge, poly_order)?;
85
86    let mut deriv = deriv;
87    while deriv > 0 {
88        poly_coeffs = poly_derivative(&poly_coeffs);
89        deriv -= 1;
90    }
91
92    let i: Vec<f64> = (0..interp_stop - interp_start)
93        .map(|i| (interp_start - window_start + i) as f64)
94        .collect();
95    let values = polyval(&poly_coeffs, &i);
96    y.splice(interp_start..interp_stop, values);
97    Ok(())
98}
99
100fn fit_edges_polyfit(
101    x: &DVector<f64>,
102    window_length: usize,
103    poly_order: usize,
104    deriv: usize,
105    y: &mut Vec<f64>,
106) -> Result<(), String> {
107    let halflen = window_length / 2;
108    fit_edge(x, 0, window_length, 0, halflen, poly_order, deriv, y)?;
109    let n = x.len();
110    fit_edge(
111        x,
112        n - window_length,
113        n,
114        n - halflen,
115        n,
116        poly_order,
117        deriv,
118        y,
119    )?;
120
121    Ok(())
122}
123
124#[derive(Clone, Debug)]
125pub struct SavGolInput<'a, T> {
126    pub data: &'a [T],
127    pub window_length: usize,
128    pub poly_order: usize,
129    pub derivative: usize,
130}
131
132pub fn savgol_filter<T>(input: &SavGolInput<T>) -> Result<Vec<f64>, String>
133where
134    T: Clone + TryInto<f64>,
135    <T as TryInto<f64>>::Error: Debug,
136{
137    if input.window_length > input.data.len() {
138        return Err(
139            "window_length must be less than or equal to the size of the input data".to_string(),
140        );
141    }
142
143    if input.window_length % 2 == 0 {
144        // TODO: figure out how scipy implementation handles the convolution
145        // in this case
146        return Err("window_length must be odd".to_string());
147    }
148
149    let coeffs = savgol_coeffs(input.window_length, input.poly_order, input.derivative)?;
150
151    let x = match input.data.iter().cloned().map(|i| i.try_into()).collect() {
152        Err(error) => return Err(format!("{:?}", error)),
153        Ok(x) => DVector::from_vec(x)
154    };
155
156    let y = x.convolve_full(coeffs);
157
158    // trim extra length gained during convolution to mimic scipy convolve1d
159    // with mode="constant"
160    let padding = y.len() - x.len();
161    let padding = padding / 2;
162    let y = y.as_slice();
163    let mut y = y[padding..y.len().saturating_sub(padding)].to_vec();
164
165    fit_edges_polyfit(
166        &x,
167        input.window_length,
168        input.poly_order,
169        input.derivative,
170        &mut y,
171    )?;
172    return Ok(y);
173}