1use nalgebra::DVector;
2use num_complex::Complex;
3use rustfft::{Fft, FftPlanner};
4use std::sync::Arc;
5
6pub struct FdafAec {
11 fft_size: usize,
12 frame_size: usize,
13 fft: Arc<dyn Fft<f32>>,
14 ifft: Arc<dyn Fft<f32>>,
15 weights: DVector<Complex<f32>>,
16 far_end_buffer: DVector<f32>,
17 mu: f32,
18 psd: DVector<f32>,
19 smoothing_factor: f32,
20}
21
22impl FdafAec {
23 pub fn new(fft_size: usize, step_size: f32) -> Self {
35 assert!(fft_size > 0 && fft_size.is_power_of_two(), "fft_size must be a power of two.");
36 let frame_size = fft_size / 2;
37 let mut fft_planner = FftPlanner::new();
38 let fft = fft_planner.plan_fft_forward(fft_size);
39 let ifft = fft_planner.plan_fft_inverse(fft_size);
40
41 Self {
42 fft_size,
43 frame_size,
44 fft,
45 ifft,
46 weights: DVector::from_element(fft_size, Complex::new(0.0, 0.0)),
47 far_end_buffer: DVector::from_element(fft_size, 0.0),
48 mu: step_size,
49 psd: DVector::from_element(fft_size, 1.0), smoothing_factor: 0.98,
51 }
52 }
53
54 pub fn process(&mut self, far_end_frame: &[f32], mic_frame: &[f32]) -> Vec<f32> {
67 assert_eq!(far_end_frame.len(), self.frame_size, "Input far-end frame size must be half of FFT size.");
68 assert_eq!(mic_frame.len(), self.frame_size, "Input mic frame size must be half of FFT size.");
69
70 self.far_end_buffer.as_mut_slice().copy_within(self.frame_size.., 0);
73 self.far_end_buffer
74 .rows_mut(self.frame_size, self.frame_size)
75 .copy_from_slice(far_end_frame);
76
77 let mut x_t_buffer: Vec<Complex<f32>> = self
79 .far_end_buffer
80 .iter()
81 .map(|&x| Complex::new(x, 0.0))
82 .collect();
83 self.fft.process(&mut x_t_buffer);
84 let x_f = DVector::from_vec(x_t_buffer);
85
86 for i in 0..self.fft_size {
88 let power = x_f[i].norm_sqr();
89 self.psd[i] = self.smoothing_factor * self.psd[i] + (1.0 - self.smoothing_factor) * power;
90 }
91
92 let y_f = self.weights.component_mul(&x_f);
94
95 let mut y_t_complex = y_f.as_slice().to_vec();
97 self.ifft.process(&mut y_t_complex);
98
99 let fft_size_f32 = self.fft_size as f32;
101 let y_t: DVector<f32> = DVector::from_iterator(
102 self.fft_size,
103 y_t_complex.iter().map(|c| c.re / fft_size_f32),
104 );
105
106 let estimated_echo = y_t.rows(self.frame_size, self.frame_size);
108
109 let error_signal: Vec<f32> = mic_frame
111 .iter()
112 .zip(estimated_echo.iter())
113 .map(|(mic, echo)| mic - echo)
114 .collect();
115
116 let mut e_t_buffer = vec![Complex::new(0.0, 0.0); self.fft_size];
120 for (i, &sample) in error_signal.iter().enumerate() {
121 e_t_buffer[i + self.frame_size] = Complex::new(sample, 0.0);
122 }
123
124 self.fft.process(&mut e_t_buffer);
125 let e_f = DVector::from_vec(e_t_buffer);
126
127 let mut gradient = x_f.map(|c| c.conj()).component_mul(&e_f);
129 for i in 0..self.fft_size {
130 gradient[i] /= self.psd[i] + 1e-10; }
133 self.weights += &gradient * Complex::new(self.mu, 0.0);
134
135 error_signal
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn new_instance_and_process_frame() {
146 const FFT_SIZE: usize = 512;
147 const FRAME_SIZE: usize = FFT_SIZE / 2;
148 const STEP_SIZE: f32 = 0.5;
149
150 let mut aec = FdafAec::new(FFT_SIZE, STEP_SIZE);
151
152 let far_end_frame = vec![0.0; FRAME_SIZE];
153 let mic_frame = vec![0.1; FRAME_SIZE]; let error_signal = aec.process(&far_end_frame, &mic_frame);
156
157 assert_eq!(error_signal.len(), FRAME_SIZE);
159
160 assert!(error_signal.iter().all(|&x| x.is_finite()), "Output contains NaN or Infinity");
162 }
163
164 #[test]
165 #[should_panic]
166 fn test_new_with_non_power_of_two_fft_size() {
167 FdafAec::new(511, 0.5);
168 }
169
170 #[test]
171 #[should_panic]
172 fn test_process_with_wrong_frame_size() {
173 let mut aec = FdafAec::new(512, 0.5);
174 let far_end_frame = vec![0.0; 128];
175 let mic_frame = vec![0.0; 256];
176 aec.process(&far_end_frame, &mic_frame);
177 }
178}