cardio_rs/utils/
welch.rs

1//! This module provides an implementation of spectral analysis tools,
2//! including the Hann window and Welch's method for power spectral density estimation.
3//!
4//! Unlike existing Rust implementations, this approach closely follows SciPy’s methodology,
5//! ensuring compatibility and familiar behavior for users migrating from Python.
6//!
7//! This implementation should (and will) eventually be extracted into its own crate,
8//! with additional configuration options for greater flexibility and usability.
9//!
10//! The module includes:
11//! - `Hann`: A Hann window function for reducing spectral leakage.
12//! - `Welch`: Welch's method for estimating the power spectral density.
13//! - `HannBuilder` and `WelchBuilder` for constructing instances with configurable parameters.
14//!
15//! Future improvements should focus on extensibility, performance optimizations,
16//! and additional options to match SciPy's functionality.
17
18#[cfg(not(feature = "std"))]
19extern crate alloc;
20#[cfg(not(feature = "std"))]
21use alloc::vec;
22#[cfg(not(feature = "std"))]
23use alloc::vec::Vec;
24use core::{cmp, f64::consts::PI};
25use num::{Complex, Float};
26#[cfg(feature = "std")]
27use rustfft::{Fft, FftDirection};
28
29/// A trait representing a windowing function that can be applied to a signal.
30///
31/// Window functions are commonly used in signal processing to mitigate spectral leakage
32/// when performing Fourier transforms. A window function modifies a signal by applying
33/// a predefined weighting curve before analysis.
34pub trait Window<T> {
35    /// Applies the window function to the given signal.
36    ///
37    /// # Parameters
38    /// * `signal` - A vector representing the input signal.
39    ///
40    /// # Returns
41    /// A new iterator where each sample has been multiplied by the corresponding window weight.
42    ///
43    /// # Panics
44    /// Implementations may panic if `signal.len()` does not match the expected window size.
45    fn apply(&self, chunk: &[T]) -> impl Iterator<Item = T>;
46    /// Computes the power of the window function.
47    ///
48    /// The power is typically the sum of squared window coefficients, which can be used
49    /// for normalization purposes in signal processing.
50    ///
51    /// # Returns
52    /// The computed power of the window function.
53    fn power(&self) -> T;
54    /// Computes the sum of the window function.
55    ///
56    /// # Returns
57    /// The computed sum of the window function.
58    fn sum(&self) -> T;
59}
60
61impl<T: Float + Copy + core::fmt::Debug + core::iter::Sum> Window<T> for Hann<T> {
62    fn apply(&self, chunk: &[T]) -> impl Iterator<Item = T> {
63        if chunk.len() != self.weights.len() {
64            panic!("Signal and Window should have the same size");
65        }
66        let chunk_mean = chunk.iter().copied().sum::<T>() / T::from(chunk.len()).unwrap();
67        chunk
68            .iter()
69            .zip(self.weights.iter())
70            .map(move |(&a, &b)| (a - chunk_mean) * b)
71    }
72
73    fn power(&self) -> T {
74        self.weights.iter().map(|&w| w * w).sum()
75    }
76
77    fn sum(&self) -> T {
78        self.weights.iter().copied().sum()
79    }
80}
81
82/// Represents a Hann window function.
83///
84/// The Hann window is commonly used in spectral analysis and signal processing
85/// to reduce spectral leakage by applying a tapering function to the signal.
86pub struct Hann<T> {
87    weights: Vec<T>,
88}
89
90/// Builder for creating a Hann window.
91///
92/// This struct allows for flexible construction of a Hann window with a given size.
93pub struct HannBuilder {
94    n: usize,
95}
96
97impl HannBuilder {
98    /// Creates a new HannBuilder with the given window size.
99    ///
100    /// # Arguments
101    /// * `n` - The number of points in the window.
102    ///
103    /// # Returns
104    /// A new `HannBuilder` instance.
105    pub fn new(n: usize) -> Self {
106        HannBuilder { n }
107    }
108
109    /// Constructs a Hann window of the specified type.
110    ///
111    /// # Type Parameters
112    /// * `T` - A floating-point type that supports FFT operations.
113    ///
114    /// # Returns
115    /// A `Hann<T>` instance containing the computed window weights.
116    pub fn build<T: Float + Copy + core::fmt::Debug + core::iter::Sum>(&self) -> Hann<T> {
117        let weights = (0..self.n)
118            .map(|i| {
119                T::from(0.5 * (1.0 - ((2.0 * PI * i as f64) / (self.n as f64 - 1.0)).cos()))
120                    .unwrap()
121            })
122            .collect::<Vec<T>>();
123        Hann { weights }
124    }
125}
126
127/// Represents different normalization methods for spectral analysis.
128///
129/// This enum is used to specify how the periodogram should be normalized when computing
130/// the power spectral density (PSD) or the power spectrum.
131#[derive(Debug, PartialEq, Clone, Copy)]
132pub enum Normalization<T> {
133    /// Normalization by sampling frequency, producing a power spectral density (PSD).
134    ///
135    /// When used, the resulting spectrum has units of `V²/Hz` if the input signal is in V
136    /// and the sampling frequency is in Hz. The total power of the signal is obtained
137    /// by integrating over all frequencies.
138    Density,
139
140    /// Normalization by the sum of squared window coefficients, producing a power spectrum.
141    ///
142    /// When used, the resulting spectrum has units of `V²` if the input signal is in volts.
143    /// This represents the total power distributed across frequency bins.
144    Spectrum,
145
146    /// Custom normalization factor.
147    ///
148    /// Allows the user to specify a custom normalization value, useful for advanced applications
149    /// where neither `Density` nor `Spectrum` provide the desired scaling.
150    Custom(T),
151}
152
153/// A trait for computing the periodogram and corresponding frequencies of a signal.
154///
155/// The periodogram estimates the power spectral density (PSD) of a signal.
156/// Implementations of this trait should provide methods to compute both the periodogram
157/// and its associated frequency values.
158pub trait Periodogram<T> {
159    /// Returns an iterator over the periodogram values.
160    ///
161    /// The periodogram represents the power spectral density (PSD) of a signal,
162    /// computed using a spectral estimation method (e.g., Welch's method).
163    ///
164    /// # Returns
165    /// An iterator that yields the power spectral density values as `T`.
166    fn periodogram(&self) -> impl Iterator<Item = T> + '_;
167
168    /// Returns an iterator over the frequency values corresponding to the periodogram.
169    ///
170    /// The frequencies are typically derived based on the sampling rate and windowing
171    /// used in the spectral estimation.
172    ///
173    /// # Returns
174    /// An iterator that yields the frequency values as `T`.
175    fn frequencies(&self) -> impl Iterator<Item = T> + '_;
176}
177
178impl<T: Float + Copy + core::fmt::Debug + core::iter::Sum> Periodogram<T> for Welch<T> {
179    fn periodogram(&self) -> impl Iterator<Item = T> + '_ {
180        self.periodogram.iter().copied()
181    }
182
183    fn frequencies(&self) -> impl Iterator<Item = T> + '_ {
184        self.frequencies.iter().copied()
185    }
186}
187
188/// Represents the result of Welch's method for power spectral density estimation.
189///
190/// Welch's method reduces noise in power spectral estimates by averaging
191/// multiple periodograms computed from overlapping segments.
192///
193/// # Type Parameters
194/// * `T` - A numeric type that supports floating-point operations.
195pub struct Welch<T> {
196    periodogram: Vec<T>,
197    frequencies: Vec<T>,
198}
199
200/// Builder for constructing a Welch power spectral density estimator.
201///
202/// This struct provides a flexible way to configure the parameters for Welch's method,
203/// such as segment size, overlap, and FFT size.
204///
205/// # Type Parameters
206/// * `T` - A floating-point type that supports FFT operations.
207pub struct WelchBuilder<T> {
208    normalization: Normalization<T>,
209    segment_size: usize,
210    dft_size: usize,
211    overlap_size: usize,
212    fs: T,
213    signal: Vec<T>,
214}
215
216impl<
217    T: Float
218        + Copy
219        + core::fmt::Debug
220        + core::marker::Sync
221        + core::marker::Send
222        + core::iter::Sum
223        + core::ops::AddAssign
224        + num::Signed
225        + num::FromPrimitive
226        + 'static,
227> WelchBuilder<T>
228{
229    /// Creates a new WelchBuilder with a given input signal.
230    ///
231    /// # Arguments
232    /// * `signal` - The input signal as a vector of type `T`.
233    ///
234    /// # Returns
235    /// A new `WelchBuilder` instance with default parameters.
236    pub fn new(signal: Vec<T>) -> Self {
237        WelchBuilder {
238            normalization: Normalization::Density,
239            signal,
240            segment_size: 256,
241            dft_size: 4096,
242            overlap_size: 128,
243            fs: T::from(4).unwrap(),
244        }
245    }
246
247    /// Sets the segment size for Welch's method.
248    pub fn with_segment_size(mut self, n: usize) -> Self {
249        self.segment_size = n;
250        self
251    }
252
253    /// Sets the FFT size.
254    #[cfg(feature = "std")]
255    pub fn with_dft_size(mut self, n: usize) -> Self {
256        self.dft_size = n;
257        self
258    }
259
260    /// Sets the overlap size between segments.
261    pub fn with_overlap_size(mut self, n: usize) -> Self {
262        self.overlap_size = n;
263        self
264    }
265
266    /// Sets the sampling frequency.
267    pub fn with_fs(mut self, n: T) -> Self {
268        self.fs = n;
269        self
270    }
271
272    /// Sets the normalization.
273    pub fn with_normalization(mut self, norma: Normalization<T>) -> Self {
274        self.normalization = norma;
275        self
276    }
277
278    /// Constructs the Welch power spectral density estimator.
279    ///
280    /// # Returns
281    /// A `Welch<T>` instance containing the computed periodogram and frequencies.
282    pub fn build(&self) -> Welch<T> {
283        let window = HannBuilder::new(self.segment_size).build::<T>();
284        let mut periodogram: Vec<T> = vec![T::from(0).unwrap(); self.dft_size / 2];
285        let mut i = 0;
286        let mut n_segments = 0;
287        while i + self.segment_size < self.signal.len() {
288            let chunk = &self.signal[i..cmp::min(i + self.segment_size, self.signal.len())];
289            let chunk = window.apply(chunk);
290
291            let mut buffer = vec![
292                Complex {
293                    re: T::from(0).unwrap(),
294                    im: T::from(0).unwrap(),
295                };
296                self.dft_size
297            ];
298            chunk.enumerate().for_each(|(i, j)| {
299                buffer[i].re = j;
300            });
301
302            #[cfg(feature = "std")]
303            {
304                let fft = rustfft::algorithm::Radix4::new(self.dft_size, FftDirection::Forward);
305                fft.process(&mut buffer);
306            }
307
308            #[cfg(not(feature = "std"))]
309            {
310                naive_fft(&mut buffer);
311            }
312
313            let pdg: Vec<T> = buffer
314                .into_iter()
315                .take(self.dft_size / 2)
316                .map(|i| i.norm_sqr())
317                .collect();
318            periodogram
319                .iter_mut()
320                .zip(pdg.into_iter())
321                .for_each(|(a, b)| *a += b);
322
323            i += self.segment_size - self.overlap_size;
324            n_segments += 1;
325        }
326        let frequencies: Vec<T> = (0..self.dft_size / 2)
327            .map(|i| T::from(i).unwrap() * self.fs / T::from(self.dft_size).unwrap())
328            .collect();
329        let norma = match self.normalization {
330            Normalization::Density => {
331                (window.power() * T::from(n_segments).unwrap() * self.fs).recip()
332            }
333            Normalization::Spectrum => {
334                (window.sum() * window.sum() * T::from(n_segments).unwrap()).recip()
335            }
336            Normalization::Custom(e) => e * T::from(n_segments).unwrap(),
337        };
338        let periodogram = periodogram.into_iter().map(|p| p * norma).collect();
339
340        Welch {
341            periodogram,
342            frequencies,
343        }
344    }
345}
346
347#[cfg(not(feature = "std"))]
348fn naive_fft<T: Float + Copy + core::fmt::Debug + core::iter::Sum>(input: &mut [Complex<T>]) {
349    let n = input.len();
350    if n <= 1 {
351        return;
352    }
353
354    let mut even: Vec<Complex<T>> = input.iter().copied().step_by(2).collect();
355    let mut odd: Vec<Complex<T>> = input.iter().copied().skip(1).step_by(2).collect();
356
357    naive_fft(&mut even);
358    naive_fft(&mut odd);
359
360    for k in 0..n / 2 {
361        let twiddle = Complex::from_polar(
362            T::one(),
363            -T::from(2.0 * PI).unwrap() * T::from(k).unwrap() / T::from(n).unwrap(),
364        ) * odd[k];
365        input[k] = even[k] + twiddle;
366        input[k + n / 2] = even[k] - twiddle;
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use approx::assert_relative_eq;
374    use rand::rng;
375    use rand_distr::{Distribution, Normal};
376
377    #[test]
378    fn test_normal_density() {
379        let normal = Normal::new(0., 1.).unwrap();
380        let mut rng = rng();
381        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
382
383        let welch = WelchBuilder::new(samples)
384            .with_fs(4.)
385            .with_overlap_size(128)
386            .with_segment_size(256)
387            .with_normalization(Normalization::Density)
388            .build();
389        assert_relative_eq!(
390            welch.periodogram().sum::<f64>()
391                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
392            0.249,
393            epsilon = 1e-2
394        );
395    }
396
397    #[test]
398    fn test_normal_spectrum() {
399        let normal = Normal::new(0., 1.).unwrap();
400        let mut rng = rng();
401        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
402
403        let welch = WelchBuilder::new(samples)
404            .with_fs(4.)
405            .with_overlap_size(128)
406            .with_segment_size(256)
407            .with_normalization(Normalization::Spectrum)
408            .build();
409        assert_relative_eq!(
410            welch.periodogram().sum::<f64>()
411                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
412            0.00585,
413            epsilon = 1e-4
414        );
415    }
416
417    #[test]
418    fn test_frequence() {
419        let normal = Normal::new(0., 1.).unwrap();
420        let mut rng = rng();
421        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
422
423        let welch = WelchBuilder::new(samples)
424            .with_fs(400.)
425            .with_overlap_size(128)
426            .with_segment_size(256)
427            .with_normalization(Normalization::Density)
428            .build();
429        assert_relative_eq!(
430            welch.periodogram().sum::<f64>()
431                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
432            0.0024846134053086084,
433            epsilon = 1e-4
434        );
435    }
436
437    #[cfg(feature = "std")]
438    #[test]
439    fn test_dtf() {
440        let normal = Normal::new(0., 1.).unwrap();
441        let mut rng = rng();
442        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
443
444        let welch = WelchBuilder::new(samples)
445            .with_fs(400.)
446            .with_dft_size(512)
447            .with_overlap_size(128)
448            .with_segment_size(256)
449            .with_normalization(Normalization::Density)
450            .build();
451        assert_relative_eq!(
452            welch.periodogram().sum::<f64>()
453                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
454            0.0025083335651038905,
455            epsilon = 1e-4
456        );
457    }
458
459    #[cfg(feature = "std")]
460    #[test]
461    fn test_short() {
462        let normal = Normal::new(0., 1.).unwrap();
463        let mut rng = rng();
464        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
465
466        let welch = WelchBuilder::new(samples)
467            .with_fs(4.)
468            .with_dft_size(128)
469            .with_overlap_size(1)
470            .with_segment_size(8)
471            .with_normalization(Normalization::Density)
472            .build();
473        assert_relative_eq!(
474            welch.periodogram().sum::<f64>()
475                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
476            0.21985243050737127,
477            epsilon = 1e-2
478        );
479    }
480}