Skip to main content

chromaprint/fft/
mod.rs

1pub mod window;
2
3use realfft::RealFftPlanner;
4use realfft::RealToComplex;
5use std::sync::Arc;
6
7/// Round up to the next power of two.
8fn next_power_of_two(n: usize) -> usize {
9    n.next_power_of_two()
10}
11
12/// FFT processor that handles frame slicing, windowing, and FFT computation.
13pub struct FftProcessor {
14    frame_size: usize,
15    increment: usize,
16    window: Vec<f32>,
17    // Ring buffer for incoming samples (power-of-2 sized for fast masking)
18    ring_buffer: Vec<i16>,
19    ring_mask: usize,
20    ring_write: usize,
21    ring_read: usize,
22    ring_count: usize,
23    // FFT state
24    fft: Arc<dyn RealToComplex<f32>>,
25    fft_input: Vec<f32>,
26    fft_scratch: Vec<realfft::num_complex::Complex<f32>>,
27    fft_output: Vec<realfft::num_complex::Complex<f32>>,
28    // Power spectrum output
29    power_spectrum: Vec<f32>,
30}
31
32impl FftProcessor {
33    pub fn new(frame_size: usize, frame_overlap: usize) -> Self {
34        let increment = frame_size - frame_overlap;
35        let window = window::hamming(frame_size);
36
37        let mut planner = RealFftPlanner::<f32>::new();
38        let fft = planner.plan_fft_forward(frame_size);
39
40        let fft_input = vec![0.0f32; frame_size];
41        let fft_output = fft.make_output_vec();
42        let fft_scratch = fft.make_scratch_vec();
43        let power_spectrum = vec![0.0f32; frame_size / 2 + 1];
44
45        // Power-of-2 ring buffer for bitmask indexing
46        let ring_size = next_power_of_two(frame_size * 4);
47
48        Self {
49            frame_size,
50            increment,
51            window,
52            ring_buffer: vec![0i16; ring_size],
53            ring_mask: ring_size - 1,
54            ring_write: 0,
55            ring_read: 0,
56            ring_count: 0,
57            fft,
58            fft_input,
59            fft_scratch,
60            fft_output,
61            power_spectrum,
62        }
63    }
64
65    pub fn reset(&mut self) {
66        self.ring_write = 0;
67        self.ring_read = 0;
68        self.ring_count = 0;
69    }
70
71    /// Feed samples into the FFT processor.
72    /// Calls `callback` for each complete FFT frame's power spectrum.
73    pub fn consume<F>(&mut self, samples: &[i16], mut callback: F)
74    where
75        F: FnMut(&[f32]),
76    {
77        let ring_mask = self.ring_mask;
78        let frame_size = self.frame_size;
79        let increment = self.increment;
80        let mut pos = 0;
81
82        while pos < samples.len() {
83            // Bulk-copy as many samples as possible into the ring buffer
84            let space_before_wrap = (ring_mask + 1) - self.ring_write;
85            let space_before_full = (ring_mask + 1) - self.ring_count;
86            let can_write = space_before_wrap
87                .min(space_before_full)
88                .min(samples.len() - pos);
89            // Limit to what gets us to exactly one frame, so we don't overshoot
90            let need_for_frame = frame_size.saturating_sub(self.ring_count);
91            let to_copy = can_write.min(need_for_frame.max(1));
92
93            self.ring_buffer[self.ring_write..self.ring_write + to_copy]
94                .copy_from_slice(&samples[pos..pos + to_copy]);
95            self.ring_write = (self.ring_write + to_copy) & ring_mask;
96            self.ring_count += to_copy;
97            pos += to_copy;
98
99            // Process frames as soon as we have enough data
100            while self.ring_count >= frame_size {
101                self.process_frame();
102                callback(&self.power_spectrum);
103                self.ring_read = (self.ring_read + increment) & ring_mask;
104                self.ring_count -= increment;
105            }
106        }
107    }
108
109    fn process_frame(&mut self) {
110        let ring_mask = self.ring_mask;
111        let ring_read = self.ring_read;
112
113        // Check if the frame wraps around the ring buffer
114        let end = ring_read + self.frame_size;
115        if end <= ring_mask + 1 {
116            // Contiguous: no wrapping needed
117            for i in 0..self.frame_size {
118                self.fft_input[i] = self.ring_buffer[ring_read + i] as f32 * self.window[i];
119            }
120        } else {
121            // Wrapping: apply mask per element
122            for i in 0..self.frame_size {
123                let ring_idx = (ring_read + i) & ring_mask;
124                self.fft_input[i] = self.ring_buffer[ring_idx] as f32 * self.window[i];
125            }
126        }
127
128        // Compute real FFT
129        self.fft
130            .process_with_scratch(&mut self.fft_input, &mut self.fft_output, &mut self.fft_scratch)
131            .expect("FFT computation failed");
132
133        // Compute power spectrum: |X[k]|^2 = re^2 + im^2
134        for (i, c) in self.fft_output.iter().enumerate() {
135            self.power_spectrum[i] = c.re * c.re + c.im * c.im;
136        }
137    }
138}