vmd_rs/
vmd.rs

1use ndarray::{concatenate, prelude::*};
2use ndarray_rand::{rand_distr::Uniform, RandomExt};
3use ndarray_slice::Slice1Ext;
4use num_complex::{Complex, ComplexFloat};
5use rustfft::FftPlanner;
6use std::cell::RefCell;
7
8use crate::errors::VmdError;
9use crate::utils::array::{fftshift1d, ifftshift1d, Flip};
10
11thread_local! {
12  static FFT_PLANNER: RefCell<FftPlanner<f64>> = RefCell::new(FftPlanner::new());
13}
14
15#[allow(non_snake_case, clippy::type_complexity)]
16/// Description
17/// ---
18/// u,u_hat,omega = VMD(input, alpha, tau, K, DC, init, tol) <br>
19/// Variational mode decomposition <br>
20/// Based on Python implementation by @vrcarfa <br>
21/// Original paper: <br>
22/// Dragomiretskiy, K. and Zosso, D. (2014) ‘Variational Mode Decomposition’,  <br>
23/// IEEE Transactions on Signal Processing, 62(3), pp. 531–544. doi: 10.1109/TSP.2013.2288675. <br>
24///
25/// Input and Parameters:
26/// ---------------------
27/// input   - the time domain signal (1D) to be decomposed <br>
28/// alpha   - the balancing parameter of the data-fidelity constraint <br>
29/// tau     - time-step of the dual ascent ( pick 0 for noise-slack ) <br>
30/// K       - the number of modes to be recovered <br>
31/// DC      - true if the first mode is put and kept at DC (0-freq) <br>
32/// init    - 0 = all omegas start at 0 <br>
33///           1 = all omegas start uniformly distributed <br>
34///           2 = all omegas initialized randomly <br>
35/// tol     - tolerance of convergence criterion; typically around 1e-6 <br>
36///
37/// Output:
38/// -------
39/// u       - the collection of decomposed modes <br>
40/// u_hat   - spectra of the modes <br>
41/// omega   - estimated mode center-frequencies <br>
42///
43pub fn vmd(
44    input: &[f64],
45    alpha: f64,
46    tau: f64,
47    K: usize,
48    DC: i32,
49    init: i32,
50    tol: f64,
51) -> Result<(Array2<f64>, Array2<Complex<f64>>, Array2<f64>), VmdError> {
52    // Output of python code did not work for odd number
53
54    // Period and sampling of input frequency
55    let fs = 1.0 / input.len() as f64;
56
57    let T = input.len();
58    let midpoint = (input.len() as f64 / 2.0).ceil() as usize;
59
60    let mut f_mirr = {
61        let input = ArrayView1::from_shape(T, input)?;
62        let first_half = input.slice(s![..midpoint]);
63        let second_half = input.slice(s![midpoint..]);
64        concatenate(Axis(0), &[first_half.flip(), input, second_half.flip()])?
65            // .unwrap()
66            .map(|&f| Complex::new(f, 0.))
67    };
68
69    let T = f_mirr.len() as f64;
70    let t = Array::range(1., T + 1., 1.) / T;
71    let t_len = t.len();
72    let freqs = t - 0.5 - (1. / T);
73    const N_ITER: usize = 500;
74
75    // Construct and center f_hat
76    let fft_fhat = {
77        FFT_PLANNER.with(|planner| {
78            let fft = planner.borrow_mut().plan_fft_forward(T as usize);
79            fft.process(f_mirr.as_slice_mut().unwrap());
80            f_mirr
81            // # Safety
82            // The output buffer is immediately filled in the .process() call below
83            // let mut output_buf = unsafe {
84            // Real-to-Complex FFT skips redundant calculations, effectively returning N/2 + 1 values
85            //     // let arr = Array1::uninit((T as usize / 2) + 1);
86            //     let arr = Array1::uninit(T as usize);
87            //     let arr = arr.assume_init();
88            //     arr
89            // };
90            // return match fft.process(
91            //     f_mirr.as_slice_mut().unwrap(),
92            //     output_buf.as_slice_mut().unwrap(),
93            // ) {
94            //     Ok(_) => output_buf,
95            //     Err(e) => panic!("{}", e),
96            // };
97        })
98    };
99
100    let f_hat = fftshift1d(fft_fhat.view());
101    let mut f_hat_plus = f_hat;
102    f_hat_plus
103        .slice_mut(s![..T as usize / 2])
104        .map_inplace(|v| *v = Complex::new(0., 0.));
105
106    // Initialization of omega k
107    let mut omega_plus = Array::from_shape_fn((N_ITER, K), |(_, _)| 0.);
108    match init {
109        1 => {
110            for i in 0..K {
111                omega_plus[[0, i]] = (0.5 / K as f64) * i as f64
112            }
113        }
114        // PY => omega_plus[0,:] = np.sort(np.exp(np.log(fs) + (np.log(0.5)-np.log(fs))*np.random.rand(1,K)))
115        2 => {
116            // TODO: reduce allocs
117            let rexpr = fs.log(std::f64::consts::E);
118            let random = Array::random([1, K], Uniform::new(0., 1.));
119            // let random = ndarray::Array2::from_shape_vec([1, 4], vec![1., 2., 3., 4.]).unwrap();
120            let rexpr2 = (0.5_f64.log(std::f64::consts::E) - rexpr) * random;
121            let mut expr = rexpr + rexpr2;
122
123            expr.map_inplace(|f| *f = f.exp());
124            let mut axis_sort = expr.slice_axis_mut(Axis(0), ndarray::Slice::new(0, None, 1));
125            axis_sort
126                .row_mut(0)
127                .sort_unstable_by(|f1, f2| f1.partial_cmp(f2).unwrap());
128            expr.row_mut(0).assign_to(
129                omega_plus
130                    .slice_axis_mut(Axis(0), ndarray::Slice::new(0, None, 1))
131                    .row_mut(0),
132            );
133        }
134        _ => {
135            omega_plus.slice_mut(s![.., ..]).map_inplace(|f| *f = 0.);
136        }
137    };
138    if DC != 0 {
139        omega_plus[[0, 0]] = 0.;
140    }
141
142    // Huge allocs here
143    // start with empty dual variables
144
145    // optimization: only need 2 rows, but we use 3 because its simpler to write
146    const ROWS: usize = 3;
147    let mut lambda_hat: Array2<Complex<f64>> = Array::zeros((ROWS, freqs.len()));
148
149    // Huge allocs!
150    // matrix keeping track of every iterant // could be discarded for mem
151    // optimization: use only 3 rows
152    let mut u_hat_plus: Array3<Complex<f64>> = Array::zeros((ROWS, freqs.len(), K));
153    let mut udiff = tol + f64::EPSILON;
154    let mut n = 0;
155    let mut sum_uk: Array1<Complex<f64>> = Array::zeros(freqs.len());
156
157    let mut cur: usize = 0; // n % ROWS
158    let mut next: usize = 1; // (n+1) % ROWS
159    let mut prev: usize;
160
161    // For future generalizations: individual alpha for each mode
162    let alpha: Array1<f64> = Array::ones(K) * alpha;
163
164    // Main loop for iterative updates
165    while udiff > tol && n < N_ITER - 1 {
166        let T = T as usize;
167        // Not converged and below iteration limit
168
169        // update first mode accumulator
170        let k = 0;
171        let s1 = u_hat_plus.slice(s![cur, .., K - 1]);
172        let s2 = u_hat_plus.slice(s![cur, .., 0]);
173        sum_uk += &s1;
174        sum_uk -= &s2;
175
176        // Update spectrum of first mode through Wiener filter of residuals
177        let lambda_hat_slice = &lambda_hat.slice(s![cur, ..]) / Complex::new(2., 0.);
178        let lexpr = &f_hat_plus - &sum_uk - &lambda_hat_slice;
179        let rexpr = 1. + alpha[k] * (&freqs - omega_plus[[n, k]]).map_mut(|f| f.powi(2));
180        (lexpr / rexpr).move_into(u_hat_plus.slice_mut(s![next, .., k]));
181
182        if DC == 0 {
183            let expr1 = freqs.slice(s![T / 2..T]);
184            let subexpr2 = u_hat_plus.slice(s![next, T / 2..T, k]);
185            let expr2 = subexpr2.map(|f| ComplexFloat::abs(*f).powi(2));
186            let expr1: f64 = expr1.dot(&expr2);
187            let expr2 = expr2.sum();
188            omega_plus[[n + 1, k]] = expr1 / expr2;
189        }
190
191        // update of any other node
192        for k in 1..K {
193            // accumulator
194            sum_uk += &u_hat_plus.slice(s![next, .., k - 1]);
195            sum_uk -= &u_hat_plus.slice(s![cur, .., k]);
196
197            // mode spectrum
198            // let lexpr = &lambda_hat.slice(s![cur, ..]) / Complex::new(2., 0.);
199            let lexpr = &f_hat_plus - &sum_uk - &lambda_hat_slice;
200            let rexpr = 1. + alpha[k] * (&freqs - omega_plus[[n, k]]).map(|v| v.powi(2));
201            (lexpr / rexpr).move_into(u_hat_plus.slice_mut(s![next, .., k]));
202
203            // center frequencies
204            let expr1 = freqs.slice(s![T / 2..T]);
205            let subexpr2 = u_hat_plus.slice(s![next, T / 2..T, k]);
206            let expr2 = subexpr2.map(|f| ComplexFloat::abs(*f).powi(2));
207            let expr1: f64 = expr1.dot(&expr2);
208            let expr2 = expr2.sum();
209            omega_plus[[n + 1, k]] = expr1 / expr2;
210        }
211
212        // dual ascent
213        let expr1 = (&u_hat_plus
214            .slice(s![next, .., ..])
215            .sum_axis(ndarray::Axis(1))
216            - &f_hat_plus)
217            * tau;
218        let expr1 = &lambda_hat.slice(s![cur, ..]) + expr1;
219        expr1.move_into(lambda_hat.slice_mut(s![next, ..]));
220
221        // loop counters
222        n += 1;
223        cur = n % ROWS;
224        next = (n + 1) % ROWS;
225        prev = (n - 1) % ROWS;
226
227        let mut udiff_ = Complex::new(f64::EPSILON, 0.);
228        for i in 0..K {
229            let expr1 = &u_hat_plus.slice(s![cur, .., i]) - &u_hat_plus.slice(s![prev, .., i]);
230            let expr2 = expr1.map(|f| f.conj());
231            let expr = expr1.dot(&expr2) * (1. / T as f64);
232
233            udiff_ += expr;
234        }
235        udiff = ComplexFloat::abs(udiff_);
236    }
237    // Postprocessing and cleanup
238    // discard empty space if converged early
239    let n_iter = std::cmp::min(n, N_ITER);
240    let omega = omega_plus.slice(s![..n_iter, ..]);
241
242    // signal reconstruction (slight optimization)
243
244    // let idxs = np_flip(&ndarray::Array::range(1., T as f64/2.+1., 1.).view());
245    // let idxs = ndarray::Array::range(T as f64/2., 0., -1.);
246    // let idxs = T/2..0;
247    let T = T as usize;
248    let mut u_hat = Array::from_elem([T, K], Complex::new(0.0, 0.0));
249    u_hat
250        .slice_mut(s![T / 2..T, ..])
251        .assign(&u_hat_plus.slice(s![(n_iter - 1) % ROWS, T / 2..T, ..]));
252    // idxs = 1..T/2+1;-1
253    u_hat_plus
254        .slice(s![(n_iter - 1) % ROWS, T / 2..T, ..])
255        .map(|f| f.conj())
256        .move_into(u_hat.slice_mut(s![1..T/2+1;-1,..]));
257    u_hat
258        .slice(s![-1, ..])
259        .map(|f| f.conj())
260        .move_into(u_hat.slice_mut(s![0, ..]));
261
262    let mut u: Array2<f64> = ndarray::Array::zeros([K, t_len]);
263    FFT_PLANNER.with(|planner| {
264        let ffti = planner
265            .borrow_mut()
266            .plan_fft_inverse(u_hat.slice(s![.., 0]).len());
267        for k in 0..K {
268            let subexpr = u_hat.slice(s![.., k]);
269            let mut ishifted = ifftshift1d(subexpr);
270            ffti.process(ishifted.as_slice_mut().unwrap());
271            // rustfft does not normalize, normalize ourselves
272            // https://numpy.org/doc/stable/reference/routines.fft.html#normalization
273            let len = ishifted.len() as f64;
274            // ishifted = ishifted / len;
275            // println!("{:?}", &ishifted);
276            (ishifted / len)
277                .map(|f| f.re())
278                .move_into(u.slice_mut(s![k, ..]));
279        }
280    });
281
282    // Remove mirror part
283    let u = u.slice_mut(s![.., T / 4..3 * T / 4]);
284
285    // Recompute spectrum
286    let mut u_hat: Array2<Complex<f64>> = Array::zeros([u.shape()[1], K]);
287    FFT_PLANNER.with(|planner| {
288        for k in 0..K {
289            let mut u_ = u.slice(s![k, ..]).map(|f| Complex::new(*f, 0.));
290            let fft = planner.borrow_mut().plan_fft_forward(u_.len());
291            fft.process(u_.as_slice_mut().unwrap());
292            fftshift1d(u_.view()).move_into(u_hat.slice_mut(s![.., k]));
293        }
294    });
295
296    Ok((u.to_owned(), u_hat, omega.to_owned()))
297}