fdaf_aec/
lib.rs

1use nalgebra::DVector;
2use num_complex::Complex;
3use rustfft::{Fft, FftPlanner};
4use std::sync::Arc;
5
6/// Implements an Acoustic Echo Canceller using the Frequency Domain Adaptive Filter (FDAF)
7/// algorithm with the Overlap-Save method.
8///
9/// This struct holds the state for the AEC and processes audio in frames.
10pub struct FdafAec {
11    fft_size: usize,
12    frame_size: usize,
13    fft: Arc<dyn Fft<f32>>,
14    ifft: Arc<dyn Fft<f32>>,
15    weights: DVector<Complex<f32>>,
16    far_end_buffer: DVector<f32>,
17    mu: f32,
18    psd: DVector<f32>,
19    smoothing_factor: f32,
20}
21
22impl FdafAec {
23    /// Creates a new `FdafAec` instance.
24    ///
25    /// # Arguments
26    ///
27    /// * `fft_size`: The size of the FFT. This determines the filter length and the trade-off
28    ///   between computational complexity and filter performance. A larger `fft_size` provides
29    ///   a longer filter, which can cancel more delayed echoes, but increases latency and
30    ///   computational cost. Must be a power of two.
31    /// * `step_size`: The learning rate (mu) for the adaptive filter. It controls how fast the
32    ///   filter adapts. A larger value leads to faster convergence but can be less stable.
33    ///   A typical value is between 0.1 and 1.0.
34    pub fn new(fft_size: usize, step_size: f32) -> Self {
35        assert!(fft_size > 0 && fft_size.is_power_of_two(), "fft_size must be a power of two.");
36        let frame_size = fft_size / 2;
37        let mut fft_planner = FftPlanner::new();
38        let fft = fft_planner.plan_fft_forward(fft_size);
39        let ifft = fft_planner.plan_fft_inverse(fft_size);
40
41        Self {
42            fft_size,
43            frame_size,
44            fft,
45            ifft,
46            weights: DVector::from_element(fft_size, Complex::new(0.0, 0.0)),
47            far_end_buffer: DVector::from_element(fft_size, 0.0),
48            mu: step_size,
49            psd: DVector::from_element(fft_size, 1.0), // Initialize with 1 to avoid division by zero
50            smoothing_factor: 0.98,
51        }
52    }
53
54    /// Processes a frame of audio data to remove echo.
55    ///
56    /// # Arguments
57    ///
58    /// * `far_end_frame`: A slice representing the audio frame from the far-end (the reference signal, e.g., loudspeaker).
59    ///   Its length must be `fft_size / 2`.
60    /// * `mic_frame`: A slice representing the audio frame from the near-end microphone, containing both the
61    ///   near-end speaker's voice and the echo from the far-end. Its length must be `fft_size / 2`.
62    ///
63    /// # Returns
64    ///
65    /// A `Vec<f32>` containing the echo-cancelled audio frame. The length of the vector is `fft_size / 2`.
66    pub fn process(&mut self, far_end_frame: &[f32], mic_frame: &[f32]) -> Vec<f32> {
67        assert_eq!(far_end_frame.len(), self.frame_size, "Input far-end frame size must be half of FFT size.");
68        assert_eq!(mic_frame.len(), self.frame_size, "Input mic frame size must be half of FFT size.");
69
70        // 1. Update far-end buffer (shift old data, add new data)
71        // This creates a rolling window of the last `fft_size` samples.
72        self.far_end_buffer.as_mut_slice().copy_within(self.frame_size.., 0);
73        self.far_end_buffer
74            .rows_mut(self.frame_size, self.frame_size)
75            .copy_from_slice(far_end_frame);
76
77        // 2. FFT of the far-end signal block
78        let mut x_t_buffer: Vec<Complex<f32>> = self
79            .far_end_buffer
80            .iter()
81            .map(|&x| Complex::new(x, 0.0))
82            .collect();
83        self.fft.process(&mut x_t_buffer);
84        let x_f = DVector::from_vec(x_t_buffer);
85
86        // 3. Update Power Spectral Density (PSD) of the far-end signal
87        for i in 0..self.fft_size {
88            let power = x_f[i].norm_sqr();
89            self.psd[i] = self.smoothing_factor * self.psd[i] + (1.0 - self.smoothing_factor) * power;
90        }
91
92        // 4. Estimate echo in frequency domain
93        let y_f = self.weights.component_mul(&x_f);
94
95        // 5. Inverse FFT of the estimated echo
96        let mut y_t_complex = y_f.as_slice().to_vec();
97        self.ifft.process(&mut y_t_complex);
98
99        // IFFT normalization and extract real part
100        let fft_size_f32 = self.fft_size as f32;
101        let y_t: DVector<f32> = DVector::from_iterator(
102            self.fft_size,
103            y_t_complex.iter().map(|c| c.re / fft_size_f32),
104        );
105
106        // 6. Extract the valid part of the convolution (Overlap-Save method)
107        let estimated_echo = y_t.rows(self.frame_size, self.frame_size);
108
109        // 7. Calculate the error signal (mic signal - estimated echo)
110        let error_signal: Vec<f32> = mic_frame
111            .iter()
112            .zip(estimated_echo.iter())
113            .map(|(mic, echo)| mic - echo)
114            .collect();
115
116        // 8. FFT of the error signal for weight update
117        // The error signal is placed in the second half of the buffer (the first half
118        // is zero-padded) to ensure correct time alignment for the gradient calculation.
119        let mut e_t_buffer = vec![Complex::new(0.0, 0.0); self.fft_size];
120        for (i, &sample) in error_signal.iter().enumerate() {
121            e_t_buffer[i + self.frame_size] = Complex::new(sample, 0.0);
122        }
123        
124        self.fft.process(&mut e_t_buffer);
125        let e_f = DVector::from_vec(e_t_buffer);
126        
127        // 9. Update filter weights using Normalized LMS algorithm
128        let mut gradient = x_f.map(|c| c.conj()).component_mul(&e_f);
129        for i in 0..self.fft_size {
130            // Normalize by the PSD of the far-end signal
131            gradient[i] /= self.psd[i] + 1e-10; // Add a small epsilon for stability
132        }
133        self.weights += &gradient * Complex::new(self.mu, 0.0);
134
135        // 10. Return the echo-cancelled (error) signal
136        error_signal
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn new_instance_and_process_frame() {
146        const FFT_SIZE: usize = 512;
147        const FRAME_SIZE: usize = FFT_SIZE / 2;
148        const STEP_SIZE: f32 = 0.5;
149
150        let mut aec = FdafAec::new(FFT_SIZE, STEP_SIZE);
151
152        let far_end_frame = vec![0.0; FRAME_SIZE];
153        let mic_frame = vec![0.1; FRAME_SIZE]; // Some non-zero value
154
155        let error_signal = aec.process(&far_end_frame, &mic_frame);
156
157        // Check output length
158        assert_eq!(error_signal.len(), FRAME_SIZE);
159
160        // Check for NaN or Infinity
161        assert!(error_signal.iter().all(|&x| x.is_finite()), "Output contains NaN or Infinity");
162    }
163
164    #[test]
165    #[should_panic]
166    fn test_new_with_non_power_of_two_fft_size() {
167        FdafAec::new(511, 0.5);
168    }
169
170    #[test]
171    #[should_panic]
172    fn test_process_with_wrong_frame_size() {
173        let mut aec = FdafAec::new(512, 0.5);
174        let far_end_frame = vec![0.0; 128];
175        let mic_frame = vec![0.0; 256];
176        aec.process(&far_end_frame, &mic_frame);
177    }
178}