Skip to main content

rlx_fft/
learned_model.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fast learned FFT model — pruned butterfly + optional Q8 + denoiser + freq mask (Tier D).
17
18use crate::config::FftLearnConfig;
19use crate::denoise::SpectrumDenoiser;
20use crate::mel::{hann_window, log_mel_from_spectrum_batch, mel_filterbank, ref_log_mel_batch};
21use crate::peak::{WelchPeakParams, welch_peaks_from_segment_spectrum};
22use crate::pruned::{init_gates, pruned_forward_real_batch};
23use crate::q8::Q8Twiddles;
24use crate::reference::{fft_real_batch, max_abs_error};
25use crate::twiddle::exact_twiddles;
26use crate::welch::{WelchParams, average_welch_psd, welch_rustfft, welch_windowed_segments};
27use anyhow::{Result, ensure};
28
29/// End-to-end learned spectral model for fast inference paths.
30#[derive(Debug, Clone)]
31pub struct FastLearnedFftModel {
32    pub n_fft: usize,
33    pub n_mels: usize,
34    pub sample_rate: f32,
35    pub twiddles: Vec<f32>,
36    pub gates: Vec<f32>,
37    /// Per complex-bin mask (learnable), init 1.
38    pub freq_mask: Vec<f32>,
39    pub denoiser: SpectrumDenoiser,
40    pub use_q8: bool,
41    q8: Option<Q8Twiddles>,
42    mel_filters: Vec<f32>,
43    /// When set, inference uses hard gates (0/1) at this threshold.
44    pub hard_gate_threshold: Option<f32>,
45}
46
47impl FastLearnedFftModel {
48    pub fn new(cfg: &FftLearnConfig, n_mels: usize, sample_rate: f32) -> Self {
49        let n_fft = cfg.n_fft;
50        Self {
51            n_fft,
52            n_mels,
53            sample_rate,
54            twiddles: exact_twiddles(cfg),
55            gates: init_gates(n_fft),
56            freq_mask: vec![1.0; n_fft * 2],
57            denoiser: SpectrumDenoiser::identity(n_fft),
58            use_q8: false,
59            q8: None,
60            mel_filters: mel_filterbank(n_fft, n_mels, sample_rate),
61            hard_gate_threshold: None,
62        }
63    }
64
65    pub fn with_hard_gates(mut self, threshold: f32) -> Self {
66        self.hard_gate_threshold = Some(threshold);
67        self
68    }
69
70    pub fn mel_filters(&self) -> &[f32] {
71        &self.mel_filters
72    }
73
74    fn gates_for_inference(&self) -> Vec<f32> {
75        match self.hard_gate_threshold {
76            Some(t) => crate::pruned::hard_gates(&self.gates, t),
77            None => self.gates.clone(),
78        }
79    }
80
81    fn forward_spectrum(
82        &self,
83        signal: &[f32],
84        batch: usize,
85        apply_denoiser: bool,
86    ) -> Result<Vec<f32>> {
87        ensure!(signal.len() == batch * self.n_fft);
88        let tw = self.effective_twiddles();
89        let gates = self.gates_for_inference();
90        let mut spec = pruned_forward_real_batch(signal, &tw, &gates, batch, self.n_fft)?;
91        for b in 0..batch {
92            for i in 0..self.n_fft * 2 {
93                let idx = b * self.n_fft * 2 + i;
94                spec[idx] *= self.freq_mask[i];
95            }
96        }
97        if apply_denoiser {
98            self.denoiser.apply_batch(&spec, batch, self.n_fft)
99        } else {
100            Ok(spec)
101        }
102    }
103
104    pub fn with_q8(mut self) -> Self {
105        self.use_q8 = true;
106        self.q8 = Some(Q8Twiddles::from_f32(&self.twiddles));
107        self
108    }
109
110    pub fn sync_q8(&mut self) {
111        if self.use_q8 {
112            self.q8 = Some(Q8Twiddles::from_f32(&self.twiddles));
113        }
114    }
115
116    pub fn twiddles_for_forward(&self) -> Vec<f32> {
117        self.effective_twiddles()
118    }
119
120    fn effective_twiddles(&self) -> Vec<f32> {
121        if self.use_q8 {
122            self.q8.as_ref().expect("q8").dequant()
123        } else {
124            self.twiddles.clone()
125        }
126    }
127
128    pub fn spectrum_batch_raw(&self, signal: &[f32], batch: usize) -> Result<Vec<f32>> {
129        self.forward_spectrum(signal, batch, false)
130    }
131
132    pub fn spectrum_batch(&self, signal: &[f32], batch: usize) -> Result<Vec<f32>> {
133        self.forward_spectrum(signal, batch, true)
134    }
135
136    pub fn log_mel_batch(&self, signal: &[f32], batch: usize) -> Result<Vec<f32>> {
137        let window = hann_window(self.n_fft);
138        let mut windowed = signal.to_vec();
139        for b in 0..batch {
140            for i in 0..self.n_fft {
141                windowed[b * self.n_fft + i] *= window[i];
142            }
143        }
144        let spec = self.spectrum_batch(&windowed, batch)?;
145        log_mel_from_spectrum_batch(&spec, &self.mel_filters, batch, self.n_fft, self.n_mels)
146    }
147
148    pub fn welch_psd_batch(
149        &self,
150        signal: &[f32],
151        batch: usize,
152        params: WelchParams,
153    ) -> Result<Vec<f32>> {
154        ensure!(params.n_fft == self.n_fft);
155        let window = crate::welch::hann_window(self.n_fft);
156        let segs = welch_windowed_segments(signal, batch, params, &window)?;
157        let tw = self.effective_twiddles();
158        let gates = self.gates_for_inference();
159        let mut spec =
160            pruned_forward_real_batch(&segs, &tw, &gates, batch * params.n_segments, self.n_fft)?;
161        for seg in 0..(batch * params.n_segments) {
162            for i in 0..self.n_fft * 2 {
163                let idx = seg * self.n_fft * 2 + i;
164                spec[idx] *= self.freq_mask[i];
165            }
166        }
167        let spec = self
168            .denoiser
169            .apply_batch(&spec, batch * params.n_segments, self.n_fft)?;
170        Ok(average_welch_psd(
171            &spec,
172            batch,
173            params.n_segments,
174            self.n_fft,
175        ))
176    }
177
178    /// Fast Welch path (few segments) → top-K `(bin, power)` spikes only.
179    pub fn welch_peaks_batch(
180        &self,
181        signal: &[f32],
182        batch: usize,
183        params: WelchPeakParams,
184    ) -> Result<Vec<f32>> {
185        ensure!(params.welch.n_fft == self.n_fft);
186        let window = crate::welch::hann_window(self.n_fft);
187        let segs = welch_windowed_segments(signal, batch, params.welch, &window)?;
188        let tw = self.effective_twiddles();
189        let gates = self.gates_for_inference();
190        let mut spec = pruned_forward_real_batch(
191            &segs,
192            &tw,
193            &gates,
194            batch * params.welch.n_segments,
195            self.n_fft,
196        )?;
197        for seg in 0..(batch * params.welch.n_segments) {
198            for i in 0..self.n_fft * 2 {
199                let idx = seg * self.n_fft * 2 + i;
200                spec[idx] *= self.freq_mask[i];
201            }
202        }
203        let spec = self
204            .denoiser
205            .apply_batch(&spec, batch * params.welch.n_segments, self.n_fft)?;
206        Ok(welch_peaks_from_segment_spectrum(&spec, batch, params))
207    }
208
209    pub fn mean_gate(&self) -> f32 {
210        crate::pruned::mean_gate(&self.gates)
211    }
212
213    pub fn active_gates(&self, threshold: f32) -> usize {
214        crate::pruned::active_gate_count(&self.gates, threshold)
215    }
216}
217
218/// Reference pipelines for validation.
219pub fn ref_spectrum_batch(signal: &[f32], batch: usize, n_fft: usize) -> Result<Vec<f32>> {
220    fft_real_batch(signal, batch, n_fft)
221}
222
223pub fn ref_log_mel(
224    signal: &[f32],
225    batch: usize,
226    n_fft: usize,
227    n_mels: usize,
228    sr: f32,
229) -> Result<Vec<f32>> {
230    ref_log_mel_batch(signal, batch, n_fft, n_mels, sr)
231}
232
233pub fn ref_welch(signal: &[f32], batch: usize, params: WelchParams) -> Result<Vec<f32>> {
234    welch_rustfft(signal, batch, params)
235}
236
237pub fn pipeline_max_err(pred: &[f32], target: &[f32]) -> f32 {
238    max_abs_error(pred, target)
239}