1use crate::reference::max_abs_error;
19use anyhow::{Result, ensure};
20
21#[derive(Debug, Clone)]
23pub struct SpectrumDenoiser {
24 pub scale: Vec<f32>,
25 pub bias: Vec<f32>,
26}
27
28impl SpectrumDenoiser {
29 pub fn identity(n_fft: usize) -> Self {
30 Self {
31 scale: vec![1.0; n_fft * 2],
32 bias: vec![0.0; n_fft * 2],
33 }
34 }
35
36 pub fn apply_batch(&self, spectrum: &[f32], batch: usize, n_fft: usize) -> Result<Vec<f32>> {
37 ensure!(spectrum.len() == batch * n_fft * 2);
38 ensure!(self.scale.len() == n_fft * 2 && self.bias.len() == n_fft * 2);
39 let mut out = spectrum.to_vec();
40 for b in 0..batch {
41 for i in 0..n_fft * 2 {
42 let idx = b * n_fft * 2 + i;
43 out[idx] = spectrum[idx] * self.scale[i] + self.bias[i];
44 }
45 }
46 Ok(out)
47 }
48
49 pub fn train_step_affine(
50 &mut self,
51 pred: &[f32],
52 target: &[f32],
53 batch: usize,
54 n_fft: usize,
55 lr: f32,
56 ) -> Result<f32> {
57 ensure!(pred.len() == target.len() && pred.len() == batch * n_fft * 2);
58 let mut mse = 0f32;
59 let n = (batch * n_fft * 2) as f32;
60 for i in 0..n_fft * 2 {
61 let mut ds = 0f32;
62 let mut db = 0f32;
63 for b in 0..batch {
64 let idx = b * n_fft * 2 + i;
65 let p = pred[idx] * self.scale[i] + self.bias[i];
66 let d = p - target[idx];
67 mse += d * d;
68 ds += d * pred[idx];
69 db += d;
70 }
71 self.scale[i] -= lr * 2.0 * ds / n;
72 self.bias[i] -= lr * 2.0 * db / n;
73 }
74 Ok(mse / n)
75 }
76}
77
78pub fn denoised_max_err(pred: &[f32], denoised: &[f32], target: &[f32]) -> f32 {
79 max_abs_error(denoised, target).min(max_abs_error(pred, target))
80}