Skip to main content

rlx_fft/
denoise.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Learned spectrum denoiser — per-bin affine correction (legacy / teacher paths).
17
18use crate::reference::max_abs_error;
19use anyhow::{Result, ensure};
20
21/// Per-bin affine correction on interleaved complex spectrum `[batch, n_fft, 2]`.
22#[derive(Debug, Clone)]
23pub struct SpectrumDenoiser {
24    pub scale: Vec<f32>,
25    pub bias: Vec<f32>,
26}
27
28impl SpectrumDenoiser {
29    pub fn identity(n_fft: usize) -> Self {
30        Self {
31            scale: vec![1.0; n_fft * 2],
32            bias: vec![0.0; n_fft * 2],
33        }
34    }
35
36    pub fn apply_batch(&self, spectrum: &[f32], batch: usize, n_fft: usize) -> Result<Vec<f32>> {
37        ensure!(spectrum.len() == batch * n_fft * 2);
38        ensure!(self.scale.len() == n_fft * 2 && self.bias.len() == n_fft * 2);
39        let mut out = spectrum.to_vec();
40        for b in 0..batch {
41            for i in 0..n_fft * 2 {
42                let idx = b * n_fft * 2 + i;
43                out[idx] = spectrum[idx] * self.scale[i] + self.bias[i];
44            }
45        }
46        Ok(out)
47    }
48
49    pub fn train_step_affine(
50        &mut self,
51        pred: &[f32],
52        target: &[f32],
53        batch: usize,
54        n_fft: usize,
55        lr: f32,
56    ) -> Result<f32> {
57        ensure!(pred.len() == target.len() && pred.len() == batch * n_fft * 2);
58        let mut mse = 0f32;
59        let n = (batch * n_fft * 2) as f32;
60        for i in 0..n_fft * 2 {
61            let mut ds = 0f32;
62            let mut db = 0f32;
63            for b in 0..batch {
64                let idx = b * n_fft * 2 + i;
65                let p = pred[idx] * self.scale[i] + self.bias[i];
66                let d = p - target[idx];
67                mse += d * d;
68                ds += d * pred[idx];
69                db += d;
70            }
71            self.scale[i] -= lr * 2.0 * ds / n;
72            self.bias[i] -= lr * 2.0 * db / n;
73        }
74        Ok(mse / n)
75    }
76}
77
78pub fn denoised_max_err(pred: &[f32], denoised: &[f32], target: &[f32]) -> f32 {
79    max_abs_error(denoised, target).min(max_abs_error(pred, target))
80}