1#[cfg(not(feature = "std"))]
19extern crate alloc;
20#[cfg(not(feature = "std"))]
21use alloc::vec;
22#[cfg(not(feature = "std"))]
23use alloc::vec::Vec;
24use core::{cmp, f64::consts::PI};
25use num::{Complex, Float};
26#[cfg(feature = "std")]
27use rustfft::{Fft, FftDirection};
28
29pub trait Window<T> {
35    fn apply(&self, chunk: &[T]) -> impl Iterator<Item = T>;
46    fn power(&self) -> T;
54    fn sum(&self) -> T;
59}
60
61impl<T: Float + Copy + core::fmt::Debug + core::iter::Sum> Window<T> for Hann<T> {
62    fn apply(&self, chunk: &[T]) -> impl Iterator<Item = T> {
63        if chunk.len() != self.weights.len() {
64            panic!("Signal and Window should have the same size");
65        }
66        let chunk_mean = chunk.iter().copied().sum::<T>() / T::from(chunk.len()).unwrap();
67        chunk
68            .iter()
69            .zip(self.weights.iter())
70            .map(move |(&a, &b)| (a - chunk_mean) * b)
71    }
72
73    fn power(&self) -> T {
74        self.weights.iter().map(|&w| w * w).sum()
75    }
76
77    fn sum(&self) -> T {
78        self.weights.iter().copied().sum()
79    }
80}
81
82pub struct Hann<T> {
87    weights: Vec<T>,
88}
89
90pub struct HannBuilder {
94    n: usize,
95}
96
97impl HannBuilder {
98    pub fn new(n: usize) -> Self {
106        HannBuilder { n }
107    }
108
109    pub fn build<T: Float + Copy + core::fmt::Debug + core::iter::Sum>(&self) -> Hann<T> {
117        let weights = (0..self.n)
118            .map(|i| {
119                T::from(0.5 * (1.0 - ((2.0 * PI * i as f64) / (self.n as f64 - 1.0)).cos()))
120                    .unwrap()
121            })
122            .collect::<Vec<T>>();
123        Hann { weights }
124    }
125}
126
127#[derive(Debug, PartialEq, Clone, Copy)]
132pub enum Normalization<T> {
133    Density,
139
140    Spectrum,
145
146    Custom(T),
151}
152
153pub trait Periodogram<T> {
159    fn periodogram(&self) -> impl Iterator<Item = T> + '_;
167
168    fn frequencies(&self) -> impl Iterator<Item = T> + '_;
176}
177
178impl<T: Float + Copy + core::fmt::Debug + core::iter::Sum> Periodogram<T> for Welch<T> {
179    fn periodogram(&self) -> impl Iterator<Item = T> + '_ {
180        self.periodogram.iter().copied()
181    }
182
183    fn frequencies(&self) -> impl Iterator<Item = T> + '_ {
184        self.frequencies.iter().copied()
185    }
186}
187
188pub struct Welch<T> {
196    periodogram: Vec<T>,
197    frequencies: Vec<T>,
198}
199
200pub struct WelchBuilder<T> {
208    normalization: Normalization<T>,
209    segment_size: usize,
210    dft_size: usize,
211    overlap_size: usize,
212    fs: T,
213    signal: Vec<T>,
214}
215
216impl<
217    T: Float
218        + Copy
219        + core::fmt::Debug
220        + core::marker::Sync
221        + core::marker::Send
222        + core::iter::Sum
223        + core::ops::AddAssign
224        + num::Signed
225        + num::FromPrimitive
226        + 'static,
227> WelchBuilder<T>
228{
229    pub fn new(signal: Vec<T>) -> Self {
237        WelchBuilder {
238            normalization: Normalization::Density,
239            signal,
240            segment_size: 256,
241            dft_size: 4096,
242            overlap_size: 128,
243            fs: T::from(4).unwrap(),
244        }
245    }
246
247    pub fn with_segment_size(mut self, n: usize) -> Self {
249        self.segment_size = n;
250        self
251    }
252
253    #[cfg(feature = "std")]
255    pub fn with_dft_size(mut self, n: usize) -> Self {
256        self.dft_size = n;
257        self
258    }
259
260    pub fn with_overlap_size(mut self, n: usize) -> Self {
262        self.overlap_size = n;
263        self
264    }
265
266    pub fn with_fs(mut self, n: T) -> Self {
268        self.fs = n;
269        self
270    }
271
272    pub fn with_normalization(mut self, norma: Normalization<T>) -> Self {
274        self.normalization = norma;
275        self
276    }
277
278    pub fn build(&self) -> Welch<T> {
283        let window = HannBuilder::new(self.segment_size).build::<T>();
284        let mut periodogram: Vec<T> = vec![T::from(0).unwrap(); self.dft_size / 2];
285        let mut i = 0;
286        let mut n_segments = 0;
287        while i + self.segment_size < self.signal.len() {
288            let chunk = &self.signal[i..cmp::min(i + self.segment_size, self.signal.len())];
289            let chunk = window.apply(chunk);
290
291            let mut buffer = vec![
292                Complex {
293                    re: T::from(0).unwrap(),
294                    im: T::from(0).unwrap(),
295                };
296                self.dft_size
297            ];
298            chunk.enumerate().for_each(|(i, j)| {
299                buffer[i].re = j;
300            });
301
302            #[cfg(feature = "std")]
303            {
304                let fft = rustfft::algorithm::Radix4::new(self.dft_size, FftDirection::Forward);
305                fft.process(&mut buffer);
306            }
307
308            #[cfg(not(feature = "std"))]
309            {
310                naive_fft(&mut buffer);
311            }
312
313            let pdg: Vec<T> = buffer
314                .into_iter()
315                .take(self.dft_size / 2)
316                .map(|i| i.norm_sqr())
317                .collect();
318            periodogram
319                .iter_mut()
320                .zip(pdg.into_iter())
321                .for_each(|(a, b)| *a += b);
322
323            i += self.segment_size - self.overlap_size;
324            n_segments += 1;
325        }
326        let frequencies: Vec<T> = (0..self.dft_size / 2)
327            .map(|i| T::from(i).unwrap() * self.fs / T::from(self.dft_size).unwrap())
328            .collect();
329        let norma = match self.normalization {
330            Normalization::Density => {
331                (window.power() * T::from(n_segments).unwrap() * self.fs).recip()
332            }
333            Normalization::Spectrum => {
334                (window.sum() * window.sum() * T::from(n_segments).unwrap()).recip()
335            }
336            Normalization::Custom(e) => e * T::from(n_segments).unwrap(),
337        };
338        let periodogram = periodogram.into_iter().map(|p| p * norma).collect();
339
340        Welch {
341            periodogram,
342            frequencies,
343        }
344    }
345}
346
347#[cfg(not(feature = "std"))]
348fn naive_fft<T: Float + Copy + core::fmt::Debug + core::iter::Sum>(input: &mut [Complex<T>]) {
349    let n = input.len();
350    if n <= 1 {
351        return;
352    }
353
354    let mut even: Vec<Complex<T>> = input.iter().copied().step_by(2).collect();
355    let mut odd: Vec<Complex<T>> = input.iter().copied().skip(1).step_by(2).collect();
356
357    naive_fft(&mut even);
358    naive_fft(&mut odd);
359
360    for k in 0..n / 2 {
361        let twiddle = Complex::from_polar(
362            T::one(),
363            -T::from(2.0 * PI).unwrap() * T::from(k).unwrap() / T::from(n).unwrap(),
364        ) * odd[k];
365        input[k] = even[k] + twiddle;
366        input[k + n / 2] = even[k] - twiddle;
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use approx::assert_relative_eq;
374    use rand::rng;
375    use rand_distr::{Distribution, Normal};
376
377    #[test]
378    fn test_normal_density() {
379        let normal = Normal::new(0., 1.).unwrap();
380        let mut rng = rng();
381        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
382
383        let welch = WelchBuilder::new(samples)
384            .with_fs(4.)
385            .with_overlap_size(128)
386            .with_segment_size(256)
387            .with_normalization(Normalization::Density)
388            .build();
389        assert_relative_eq!(
390            welch.periodogram().sum::<f64>()
391                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
392            0.249,
393            epsilon = 1e-2
394        );
395    }
396
397    #[test]
398    fn test_normal_spectrum() {
399        let normal = Normal::new(0., 1.).unwrap();
400        let mut rng = rng();
401        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
402
403        let welch = WelchBuilder::new(samples)
404            .with_fs(4.)
405            .with_overlap_size(128)
406            .with_segment_size(256)
407            .with_normalization(Normalization::Spectrum)
408            .build();
409        assert_relative_eq!(
410            welch.periodogram().sum::<f64>()
411                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
412            0.00585,
413            epsilon = 1e-4
414        );
415    }
416
417    #[test]
418    fn test_frequence() {
419        let normal = Normal::new(0., 1.).unwrap();
420        let mut rng = rng();
421        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
422
423        let welch = WelchBuilder::new(samples)
424            .with_fs(400.)
425            .with_overlap_size(128)
426            .with_segment_size(256)
427            .with_normalization(Normalization::Density)
428            .build();
429        assert_relative_eq!(
430            welch.periodogram().sum::<f64>()
431                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
432            0.0024846134053086084,
433            epsilon = 1e-4
434        );
435    }
436
437    #[cfg(feature = "std")]
438    #[test]
439    fn test_dtf() {
440        let normal = Normal::new(0., 1.).unwrap();
441        let mut rng = rng();
442        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
443
444        let welch = WelchBuilder::new(samples)
445            .with_fs(400.)
446            .with_dft_size(512)
447            .with_overlap_size(128)
448            .with_segment_size(256)
449            .with_normalization(Normalization::Density)
450            .build();
451        assert_relative_eq!(
452            welch.periodogram().sum::<f64>()
453                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
454            0.0025083335651038905,
455            epsilon = 1e-4
456        );
457    }
458
459    #[cfg(feature = "std")]
460    #[test]
461    fn test_short() {
462        let normal = Normal::new(0., 1.).unwrap();
463        let mut rng = rng();
464        let samples: Vec<f64> = (0..100_000).map(|_| normal.sample(&mut rng)).collect();
465
466        let welch = WelchBuilder::new(samples)
467            .with_fs(4.)
468            .with_dft_size(128)
469            .with_overlap_size(1)
470            .with_segment_size(8)
471            .with_normalization(Normalization::Density)
472            .build();
473        assert_relative_eq!(
474            welch.periodogram().sum::<f64>()
475                / welch.periodogram().collect::<Vec<f64>>().len() as f64,
476            0.21985243050737127,
477            epsilon = 1e-2
478        );
479    }
480}