Skip to main content

audio_engine_core/processor/
spectrum.rs

1//! FFT-based spectrum analyzer for visualization
2
3use rustfft::{num_complex::Complex, FftPlanner};
4use std::sync::Arc;
5
6/// FFT-based spectrum analyzer for visualization
7pub struct SpectrumAnalyzer {
8    fft_size: usize,
9    fft: Arc<dyn rustfft::Fft<f64>>,
10    window: Vec<f64>,
11    num_bins: usize,
12    fft_buffer: Vec<Complex<f64>>,
13    magnitudes: Vec<f64>,
14    result: Vec<f32>,
15    bin_ranges: Vec<(usize, usize)>,
16    bin_sample_rate: Option<u32>,
17}
18
19impl SpectrumAnalyzer {
20    pub fn new(fft_size: usize, num_bins: usize) -> Self {
21        let mut planner = FftPlanner::new();
22        let fft = planner.plan_fft_forward(fft_size);
23        let window: Vec<f64> = (0..fft_size)
24            .map(|i| 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / fft_size as f64).cos()))
25            .collect();
26
27        Self {
28            fft_size,
29            fft,
30            window,
31            num_bins,
32            fft_buffer: vec![Complex::new(0.0, 0.0); fft_size],
33            magnitudes: vec![0.0; fft_size.saturating_div(2).saturating_sub(1)],
34            result: vec![0.0; num_bins],
35            bin_ranges: Vec::with_capacity(num_bins),
36            bin_sample_rate: None,
37        }
38    }
39
40    pub fn analyze(&mut self, samples: &[f64], sample_rate: u32) -> &[f32] {
41        if samples.len() < self.fft_size {
42            self.result.fill(0.0);
43            return &self.result;
44        }
45
46        for ((slot, &sample), &window) in self
47            .fft_buffer
48            .iter_mut()
49            .zip(samples.iter().take(self.fft_size))
50            .zip(&self.window)
51        {
52            *slot = Complex::new(sample * window, 0.0);
53        }
54
55        self.fft.process(&mut self.fft_buffer);
56
57        for (dst, c) in self
58            .magnitudes
59            .iter_mut()
60            .zip(self.fft_buffer[1..self.fft_size / 2].iter())
61        {
62            *dst = c.norm() / self.fft_size as f64;
63        }
64
65        self.ensure_bin_ranges(sample_rate);
66        self.log_bin();
67        &self.result
68    }
69
70    fn ensure_bin_ranges(&mut self, sample_rate: u32) {
71        if self.bin_sample_rate == Some(sample_rate) && self.bin_ranges.len() == self.num_bins {
72            return;
73        }
74
75        let nyquist = sample_rate as f64 / 2.0;
76        let min_freq = 20.0f64;
77        let max_freq = nyquist;
78        let log_min = min_freq.log10();
79        let log_max = max_freq.log10();
80        let freq_per_bin = nyquist / self.magnitudes.len().max(1) as f64;
81
82        self.bin_ranges.clear();
83        for bin_idx in 0..self.num_bins {
84            let freq_low = 10.0_f64
85                .powf(log_min + (log_max - log_min) * bin_idx as f64 / self.num_bins as f64);
86            let freq_high = 10.0_f64
87                .powf(log_min + (log_max - log_min) * (bin_idx + 1) as f64 / self.num_bins as f64);
88            let idx_low = ((freq_low / freq_per_bin) as usize)
89                .clamp(0, self.magnitudes.len().saturating_sub(1));
90            let idx_high =
91                ((freq_high / freq_per_bin) as usize).clamp(idx_low + 1, self.magnitudes.len());
92            self.bin_ranges.push((idx_low, idx_high));
93        }
94        self.bin_sample_rate = Some(sample_rate);
95    }
96
97    fn log_bin(&mut self) {
98        self.result.fill(0.0);
99        for (result_val, &(idx_low, idx_high)) in self.result.iter_mut().zip(&self.bin_ranges) {
100            if idx_high > idx_low {
101                let sum: f64 = self.magnitudes[idx_low..idx_high]
102                    .iter()
103                    .map(|m| m * m)
104                    .sum();
105                let rms = (sum / (idx_high - idx_low) as f64).sqrt();
106                let db = 20.0 * (rms + 1e-9).log10();
107                *result_val = ((db + 90.0) / 90.0).clamp(0.0, 1.0) as f32;
108            }
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use rustfft::FftPlanner;
117
118    #[test]
119    fn short_input_returns_reused_zero_bins() {
120        let mut analyzer = SpectrumAnalyzer::new(16, 4);
121        let first_ptr = analyzer.analyze(&[0.0; 8], 48_000).as_ptr();
122        assert_eq!(analyzer.analyze(&[0.0; 8], 48_000), &[0.0; 4]);
123        assert_eq!(analyzer.analyze(&[0.0; 8], 48_000).as_ptr(), first_ptr);
124    }
125
126    #[test]
127    fn analyze_reuses_result_and_recomputes_ranges_on_sample_rate_change() {
128        let mut analyzer = SpectrumAnalyzer::new(64, 8);
129        let samples: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
130
131        let first_ptr = analyzer.analyze(&samples, 48_000).as_ptr();
132        let first_ranges = analyzer.bin_ranges.clone();
133        assert!(analyzer.analyze(&samples, 48_000).iter().any(|&v| v > 0.0));
134        assert_eq!(analyzer.analyze(&samples, 48_000).as_ptr(), first_ptr);
135        assert_eq!(analyzer.bin_ranges, first_ranges);
136
137        analyzer.analyze(&samples, 96_000);
138        assert_ne!(analyzer.bin_ranges, first_ranges);
139    }
140
141    #[test]
142    fn analyzer_output_matches_legacy_allocation_path() {
143        let mut analyzer = SpectrumAnalyzer::new(128, 16);
144        let samples: Vec<f64> = (0..128)
145            .map(|i| {
146                let t = i as f64 / 48_000.0;
147                (2.0 * std::f64::consts::PI * 997.0 * t).sin() * 0.4
148            })
149            .collect();
150
151        let actual = analyzer.analyze(&samples, 48_000).to_vec();
152        let expected = legacy_analyze(&samples, 128, 16, 48_000);
153
154        for (idx, (actual, expected)) in actual.iter().zip(expected.iter()).enumerate() {
155            assert!(
156                (actual - expected).abs() <= 1e-6,
157                "bin {idx}: actual={actual}, expected={expected}"
158            );
159        }
160    }
161
162    fn legacy_analyze(
163        samples: &[f64],
164        fft_size: usize,
165        num_bins: usize,
166        sample_rate: u32,
167    ) -> Vec<f32> {
168        if samples.len() < fft_size {
169            return vec![0.0; num_bins];
170        }
171
172        let mut planner = FftPlanner::new();
173        let fft = planner.plan_fft_forward(fft_size);
174        let window: Vec<f64> = (0..fft_size)
175            .map(|i| 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / fft_size as f64).cos()))
176            .collect();
177        let mut buffer: Vec<Complex<f64>> = samples[..fft_size]
178            .iter()
179            .zip(&window)
180            .map(|(&s, &w)| Complex::new(s * w, 0.0))
181            .collect();
182
183        fft.process(&mut buffer);
184        let magnitudes: Vec<f64> = buffer[1..fft_size / 2]
185            .iter()
186            .map(|c| c.norm() / fft_size as f64)
187            .collect();
188        legacy_log_bin(&magnitudes, sample_rate, num_bins)
189    }
190
191    fn legacy_log_bin(magnitudes: &[f64], sample_rate: u32, num_bins: usize) -> Vec<f32> {
192        let mut result = vec![0.0f32; num_bins];
193        let nyquist = sample_rate as f64 / 2.0;
194        let min_freq = 20.0f64;
195        let max_freq = nyquist;
196        let log_min = min_freq.log10();
197        let log_max = max_freq.log10();
198
199        for (bin_idx, result_val) in result.iter_mut().enumerate() {
200            let freq_low =
201                10.0_f64.powf(log_min + (log_max - log_min) * bin_idx as f64 / num_bins as f64);
202            let freq_high = 10.0_f64
203                .powf(log_min + (log_max - log_min) * (bin_idx + 1) as f64 / num_bins as f64);
204            let freq_per_bin = nyquist / magnitudes.len() as f64;
205            let idx_low =
206                ((freq_low / freq_per_bin) as usize).clamp(0, magnitudes.len().saturating_sub(1));
207            let idx_high =
208                ((freq_high / freq_per_bin) as usize).clamp(idx_low + 1, magnitudes.len());
209
210            if idx_high > idx_low {
211                let sum: f64 = magnitudes[idx_low..idx_high].iter().map(|m| m * m).sum();
212                let rms = (sum / (idx_high - idx_low) as f64).sqrt();
213                let db = 20.0 * (rms + 1e-9).log10();
214                *result_val = ((db + 90.0) / 90.0).clamp(0.0, 1.0) as f32;
215            }
216        }
217
218        result
219    }
220}