1use std::fmt::Debug;
3
4use lstsq::lstsq;
5use na::{DMatrix, DVector};
6use nalgebra as na;
7use polyfit_rs::polyfit_rs::polyfit;
8
9fn factorial(num: usize) -> usize {
10 match num {
11 0 => 1,
12 1 => 1,
13 _ => factorial(num - 1) * num,
14 }
15}
16
17fn savgol_coeffs(
18 window_length: usize,
19 poly_order: usize,
20 deriv: usize,
21) -> Result<DVector<f64>, String> {
22 if poly_order >= window_length {
23 return Err("poly_order must be less than window_length.".to_string());
24 }
25
26 let (halflen, rem) = (window_length / 2, window_length % 2);
27
28 let pos = match rem {
29 0 => halflen as f64 - 0.5,
30 _ => halflen as f64,
31 };
32
33 if deriv > poly_order {
34 return Ok(DVector::from_element(window_length, 0.0));
35 }
36
37 let x = DVector::from_fn(window_length, |i, _| pos - i as f64);
38 let order = DVector::from_fn(poly_order + 1, |i, _| i);
39 let mat_a = DMatrix::from_fn(poly_order + 1, window_length, |i, j| {
40 x[j].powf(order[i] as f64)
41 });
42
43 let mut y = DVector::from_element(poly_order + 1, 0.0);
44 y[deriv] = factorial(deriv) as f64;
45
46 let epsilon = 1e-14;
47 let results = lstsq(&mat_a, &y, epsilon)?;
48 let solution = results.solution;
49
50 return Ok(solution);
51}
52
53fn poly_derivative(coeffs: &[f64]) -> Vec<f64> {
54 coeffs[1..]
55 .iter()
56 .enumerate()
57 .map(|(i, c)| c * (i + 1) as f64)
58 .collect()
59}
60
61fn polyval(poly: &[f64], values: &[f64]) -> Vec<f64> {
62 return values
63 .iter()
64 .map(|v| {
65 poly.iter()
66 .enumerate()
67 .fold(0.0, |y, (i, c)| y + c * v.powf(i as f64))
68 })
69 .collect();
70}
71
72fn fit_edge(
73 x: &DVector<f64>,
74 window_start: usize,
75 window_stop: usize,
76 interp_start: usize,
77 interp_stop: usize,
78 poly_order: usize,
79 deriv: usize,
80 y: &mut Vec<f64>,
81) -> Result<(), String> {
82 let x_edge: Vec<f64> = x.as_slice()[window_start..window_stop].to_vec();
83 let y_edge: Vec<f64> = (0..window_stop - window_start).map(|i| i as f64).collect();
84 let mut poly_coeffs = polyfit(&y_edge, &x_edge, poly_order)?;
85
86 let mut deriv = deriv;
87 while deriv > 0 {
88 poly_coeffs = poly_derivative(&poly_coeffs);
89 deriv -= 1;
90 }
91
92 let i: Vec<f64> = (0..interp_stop - interp_start)
93 .map(|i| (interp_start - window_start + i) as f64)
94 .collect();
95 let values = polyval(&poly_coeffs, &i);
96 y.splice(interp_start..interp_stop, values);
97 Ok(())
98}
99
100fn fit_edges_polyfit(
101 x: &DVector<f64>,
102 window_length: usize,
103 poly_order: usize,
104 deriv: usize,
105 y: &mut Vec<f64>,
106) -> Result<(), String> {
107 let halflen = window_length / 2;
108 fit_edge(x, 0, window_length, 0, halflen, poly_order, deriv, y)?;
109 let n = x.len();
110 fit_edge(
111 x,
112 n - window_length,
113 n,
114 n - halflen,
115 n,
116 poly_order,
117 deriv,
118 y,
119 )?;
120
121 Ok(())
122}
123
124#[derive(Clone, Debug)]
125pub struct SavGolInput<'a, T> {
126 pub data: &'a [T],
127 pub window_length: usize,
128 pub poly_order: usize,
129 pub derivative: usize,
130}
131
132pub fn savgol_filter<T>(input: &SavGolInput<T>) -> Result<Vec<f64>, String>
133where
134 T: Clone + TryInto<f64>,
135 <T as TryInto<f64>>::Error: Debug,
136{
137 if input.window_length > input.data.len() {
138 return Err(
139 "window_length must be less than or equal to the size of the input data".to_string(),
140 );
141 }
142
143 if input.window_length % 2 == 0 {
144 return Err("window_length must be odd".to_string());
147 }
148
149 let coeffs = savgol_coeffs(input.window_length, input.poly_order, input.derivative)?;
150
151 let x = match input.data.iter().cloned().map(|i| i.try_into()).collect() {
152 Err(error) => return Err(format!("{:?}", error)),
153 Ok(x) => DVector::from_vec(x)
154 };
155
156 let y = x.convolve_full(coeffs);
157
158 let padding = y.len() - x.len();
161 let padding = padding / 2;
162 let y = y.as_slice();
163 let mut y = y[padding..y.len().saturating_sub(padding)].to_vec();
164
165 fit_edges_polyfit(
166 &x,
167 input.window_length,
168 input.poly_order,
169 input.derivative,
170 &mut y,
171 )?;
172 return Ok(y);
173}