1use crate::config::FftLearnConfig;
19use crate::denoise::SpectrumDenoiser;
20use crate::mel::{hann_window, log_mel_from_spectrum_batch, mel_filterbank, ref_log_mel_batch};
21use crate::peak::{WelchPeakParams, welch_peaks_from_segment_spectrum};
22use crate::pruned::{init_gates, pruned_forward_real_batch};
23use crate::q8::Q8Twiddles;
24use crate::reference::{fft_real_batch, max_abs_error};
25use crate::twiddle::exact_twiddles;
26use crate::welch::{WelchParams, average_welch_psd, welch_rustfft, welch_windowed_segments};
27use anyhow::{Result, ensure};
28
29#[derive(Debug, Clone)]
31pub struct FastLearnedFftModel {
32 pub n_fft: usize,
33 pub n_mels: usize,
34 pub sample_rate: f32,
35 pub twiddles: Vec<f32>,
36 pub gates: Vec<f32>,
37 pub freq_mask: Vec<f32>,
39 pub denoiser: SpectrumDenoiser,
40 pub use_q8: bool,
41 q8: Option<Q8Twiddles>,
42 mel_filters: Vec<f32>,
43 pub hard_gate_threshold: Option<f32>,
45}
46
47impl FastLearnedFftModel {
48 pub fn new(cfg: &FftLearnConfig, n_mels: usize, sample_rate: f32) -> Self {
49 let n_fft = cfg.n_fft;
50 Self {
51 n_fft,
52 n_mels,
53 sample_rate,
54 twiddles: exact_twiddles(cfg),
55 gates: init_gates(n_fft),
56 freq_mask: vec![1.0; n_fft * 2],
57 denoiser: SpectrumDenoiser::identity(n_fft),
58 use_q8: false,
59 q8: None,
60 mel_filters: mel_filterbank(n_fft, n_mels, sample_rate),
61 hard_gate_threshold: None,
62 }
63 }
64
65 pub fn with_hard_gates(mut self, threshold: f32) -> Self {
66 self.hard_gate_threshold = Some(threshold);
67 self
68 }
69
70 pub fn mel_filters(&self) -> &[f32] {
71 &self.mel_filters
72 }
73
74 fn gates_for_inference(&self) -> Vec<f32> {
75 match self.hard_gate_threshold {
76 Some(t) => crate::pruned::hard_gates(&self.gates, t),
77 None => self.gates.clone(),
78 }
79 }
80
81 fn forward_spectrum(
82 &self,
83 signal: &[f32],
84 batch: usize,
85 apply_denoiser: bool,
86 ) -> Result<Vec<f32>> {
87 ensure!(signal.len() == batch * self.n_fft);
88 let tw = self.effective_twiddles();
89 let gates = self.gates_for_inference();
90 let mut spec = pruned_forward_real_batch(signal, &tw, &gates, batch, self.n_fft)?;
91 for b in 0..batch {
92 for i in 0..self.n_fft * 2 {
93 let idx = b * self.n_fft * 2 + i;
94 spec[idx] *= self.freq_mask[i];
95 }
96 }
97 if apply_denoiser {
98 self.denoiser.apply_batch(&spec, batch, self.n_fft)
99 } else {
100 Ok(spec)
101 }
102 }
103
104 pub fn with_q8(mut self) -> Self {
105 self.use_q8 = true;
106 self.q8 = Some(Q8Twiddles::from_f32(&self.twiddles));
107 self
108 }
109
110 pub fn sync_q8(&mut self) {
111 if self.use_q8 {
112 self.q8 = Some(Q8Twiddles::from_f32(&self.twiddles));
113 }
114 }
115
116 pub fn twiddles_for_forward(&self) -> Vec<f32> {
117 self.effective_twiddles()
118 }
119
120 fn effective_twiddles(&self) -> Vec<f32> {
121 if self.use_q8 {
122 self.q8.as_ref().expect("q8").dequant()
123 } else {
124 self.twiddles.clone()
125 }
126 }
127
128 pub fn spectrum_batch_raw(&self, signal: &[f32], batch: usize) -> Result<Vec<f32>> {
129 self.forward_spectrum(signal, batch, false)
130 }
131
132 pub fn spectrum_batch(&self, signal: &[f32], batch: usize) -> Result<Vec<f32>> {
133 self.forward_spectrum(signal, batch, true)
134 }
135
136 pub fn log_mel_batch(&self, signal: &[f32], batch: usize) -> Result<Vec<f32>> {
137 let window = hann_window(self.n_fft);
138 let mut windowed = signal.to_vec();
139 for b in 0..batch {
140 for i in 0..self.n_fft {
141 windowed[b * self.n_fft + i] *= window[i];
142 }
143 }
144 let spec = self.spectrum_batch(&windowed, batch)?;
145 log_mel_from_spectrum_batch(&spec, &self.mel_filters, batch, self.n_fft, self.n_mels)
146 }
147
148 pub fn welch_psd_batch(
149 &self,
150 signal: &[f32],
151 batch: usize,
152 params: WelchParams,
153 ) -> Result<Vec<f32>> {
154 ensure!(params.n_fft == self.n_fft);
155 let window = crate::welch::hann_window(self.n_fft);
156 let segs = welch_windowed_segments(signal, batch, params, &window)?;
157 let tw = self.effective_twiddles();
158 let gates = self.gates_for_inference();
159 let mut spec =
160 pruned_forward_real_batch(&segs, &tw, &gates, batch * params.n_segments, self.n_fft)?;
161 for seg in 0..(batch * params.n_segments) {
162 for i in 0..self.n_fft * 2 {
163 let idx = seg * self.n_fft * 2 + i;
164 spec[idx] *= self.freq_mask[i];
165 }
166 }
167 let spec = self
168 .denoiser
169 .apply_batch(&spec, batch * params.n_segments, self.n_fft)?;
170 Ok(average_welch_psd(
171 &spec,
172 batch,
173 params.n_segments,
174 self.n_fft,
175 ))
176 }
177
178 pub fn welch_peaks_batch(
180 &self,
181 signal: &[f32],
182 batch: usize,
183 params: WelchPeakParams,
184 ) -> Result<Vec<f32>> {
185 ensure!(params.welch.n_fft == self.n_fft);
186 let window = crate::welch::hann_window(self.n_fft);
187 let segs = welch_windowed_segments(signal, batch, params.welch, &window)?;
188 let tw = self.effective_twiddles();
189 let gates = self.gates_for_inference();
190 let mut spec = pruned_forward_real_batch(
191 &segs,
192 &tw,
193 &gates,
194 batch * params.welch.n_segments,
195 self.n_fft,
196 )?;
197 for seg in 0..(batch * params.welch.n_segments) {
198 for i in 0..self.n_fft * 2 {
199 let idx = seg * self.n_fft * 2 + i;
200 spec[idx] *= self.freq_mask[i];
201 }
202 }
203 let spec = self
204 .denoiser
205 .apply_batch(&spec, batch * params.welch.n_segments, self.n_fft)?;
206 Ok(welch_peaks_from_segment_spectrum(&spec, batch, params))
207 }
208
209 pub fn mean_gate(&self) -> f32 {
210 crate::pruned::mean_gate(&self.gates)
211 }
212
213 pub fn active_gates(&self, threshold: f32) -> usize {
214 crate::pruned::active_gate_count(&self.gates, threshold)
215 }
216}
217
218pub fn ref_spectrum_batch(signal: &[f32], batch: usize, n_fft: usize) -> Result<Vec<f32>> {
220 fft_real_batch(signal, batch, n_fft)
221}
222
223pub fn ref_log_mel(
224 signal: &[f32],
225 batch: usize,
226 n_fft: usize,
227 n_mels: usize,
228 sr: f32,
229) -> Result<Vec<f32>> {
230 ref_log_mel_batch(signal, batch, n_fft, n_mels, sr)
231}
232
233pub fn ref_welch(signal: &[f32], batch: usize, params: WelchParams) -> Result<Vec<f32>> {
234 welch_rustfft(signal, batch, params)
235}
236
237pub fn pipeline_max_err(pred: &[f32], target: &[f32]) -> f32 {
238 max_abs_error(pred, target)
239}