cepstrum-extractor 0.2.1

An easy-to-use crate to compute the cepstrum of a signal.
Documentation
use std::{iter, ptr};
use std::sync::{Arc, Mutex};

use rustfft::{Fft, FftPlanner};
use rustfft::num_complex::Complex;
use rustfft::num_traits::Zero;
use crate::CepFloat;

pub struct CepFft<T: CepFloat> {
    len: usize,
    scratch_len: usize,
    scratches: Mutex<Vec<Vec<Complex<T>>>>,

    fft_instance: Arc<dyn Fft<T>>,
    ifft_instance: Arc<dyn Fft<T>>
}

impl<T: CepFloat> CepFft<T> {
    pub fn new(len: usize) -> CepFft<T> {
        let mut fft_planner = FftPlanner::<T>::new();

        let fft_instance = fft_planner.plan_fft_forward(len);

        CepFft {
            len,
            scratch_len: fft_instance.get_inplace_scratch_len(),
            scratches: Mutex::new(vec![vec![Complex::zero(); fft_instance.get_inplace_scratch_len()]]),

            fft_instance,
            ifft_instance: fft_planner.plan_fft_inverse(len),
        }
    }

    pub fn set_len(&mut self, len: usize) {
        if len == self.len {
            return;
        }

        let mut fft_planner = FftPlanner::<T>::new();

        self.fft_instance = fft_planner.plan_fft_forward(len);
        self.ifft_instance = fft_planner.plan_fft_inverse(len);

        let new_scratches_len = self.fft_instance.get_inplace_scratch_len();

        for scratch in self.scratches.lock().unwrap().iter_mut() {
            if new_scratches_len > scratch.len() {
                scratch.extend(iter::repeat(Complex::zero()).take(new_scratches_len - scratch.len()));
            } else {
                scratch.truncate(new_scratches_len);
            }
        }

        self.scratch_len = new_scratches_len;
        self.len = len;
    }

    pub fn extend_scratches(&self, new_count: usize) {
        let mut s = self.scratches.lock().unwrap();
        let len = s.len();

        if new_count > s.len() {
            s.extend(
                iter::repeat(vec![Complex::zero(); self.scratch_len]).take(new_count - len)
            )
        }
    }

    #[inline(always)]
    fn retrieve_scratch(&self, i: usize) -> &mut [Complex<T>] {
        let mut scratches = self.scratches.lock().unwrap();
        assert!(i < scratches.len(), "Index out of bounds: {}", i);

        unsafe {
            &mut *ptr::slice_from_raw_parts_mut(
                scratches.get_unchecked_mut(i).as_mut_ptr(),
                self.scratch_len
            )
        }
    }

    #[inline]
    pub fn do_fft(&self, input: &mut [Complex<T>], i: usize) {
        self.fft_instance.process_with_scratch(
            input,
            self.retrieve_scratch(i)
        );
    }

    #[inline]
    pub fn do_ifft(&self, input: &mut [Complex<T>], i: usize) {
        self.ifft_instance.process_with_scratch(
            input,
            self.retrieve_scratch(i)
        );
    }
}

#[cfg(test)]
mod tests {
    use super::CepFft;

    const LEN: usize = 10;

    #[test]
    fn check_scratches() {
        let mut inst: CepFft<f32> = CepFft::new(LEN);

        assert_eq!(inst.scratches.lock().unwrap().len(), 1);

        inst.extend_scratches(10);

        assert_eq!(inst.scratches.lock().unwrap().len(), 10);

        inst.extend_scratches(9);

        assert_eq!(inst.scratches.lock().unwrap().len(), 10);

        assert!(inst.scratches.lock().unwrap().iter().all(|s| s.len() == inst.scratch_len && s.len() == LEN));

        inst.set_len(LEN * 2);

        assert!(inst.scratches.lock().unwrap().iter().all(|s| s.len() == inst.scratch_len && s.len() == LEN * 2));
    }
}