Skip to main content

proof_engine/dsp/
fft.rs

1//! FFT library — Cooley-Tukey radix-2, real FFT, spectral analysis, STFT,
2//! CQT, autocorrelation, YIN pitch, Mel filterbank, MFCC, and Chroma.
3
4use std::collections::HashMap;
5use std::f32::consts::PI;
6use super::{WindowFunction, next_power_of_two};
7
8// ---------------------------------------------------------------------------
9// Complex32
10// ---------------------------------------------------------------------------
11
12/// 32-bit complex number with full arithmetic.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Complex32 {
15    pub re: f32,
16    pub im: f32,
17}
18
19impl Complex32 {
20    #[inline] pub fn new(re: f32, im: f32) -> Self { Self { re, im } }
21    #[inline] pub fn zero() -> Self { Self { re: 0.0, im: 0.0 } }
22    #[inline] pub fn one()  -> Self { Self { re: 1.0, im: 0.0 } }
23    #[inline] pub fn i()    -> Self { Self { re: 0.0, im: 1.0 } }
24
25    /// Modulus (magnitude).
26    #[inline] pub fn norm(&self) -> f32 { (self.re * self.re + self.im * self.im).sqrt() }
27    /// Squared modulus.
28    #[inline] pub fn norm_sq(&self) -> f32 { self.re * self.re + self.im * self.im }
29    /// Argument (phase angle) in radians.
30    #[inline] pub fn arg(&self) -> f32 { self.im.atan2(self.re) }
31    /// Complex conjugate.
32    #[inline] pub fn conj(&self) -> Self { Self { re: self.re, im: -self.im } }
33
34    /// Euler's formula: e^(iθ) = cos θ + i sin θ.
35    #[inline] pub fn from_polar(r: f32, theta: f32) -> Self {
36        Self { re: r * theta.cos(), im: r * theta.sin() }
37    }
38
39    /// Natural logarithm of a complex number.
40    pub fn ln(&self) -> Self {
41        Self { re: self.norm().ln(), im: self.arg() }
42    }
43
44    /// Complex exponentiation.
45    pub fn exp(&self) -> Self {
46        let e_re = self.re.exp();
47        Self { re: e_re * self.im.cos(), im: e_re * self.im.sin() }
48    }
49
50    /// Complex power z^n (integer exponent).
51    pub fn powi(&self, n: i32) -> Self {
52        if n == 0 { return Self::one(); }
53        let r = self.norm().powi(n);
54        let theta = self.arg() * n as f32;
55        Self::from_polar(r, theta)
56    }
57}
58
59impl std::ops::Add for Complex32 {
60    type Output = Self;
61    #[inline] fn add(self, rhs: Self) -> Self { Self { re: self.re + rhs.re, im: self.im + rhs.im } }
62}
63impl std::ops::AddAssign for Complex32 {
64    #[inline] fn add_assign(&mut self, rhs: Self) { self.re += rhs.re; self.im += rhs.im; }
65}
66impl std::ops::Sub for Complex32 {
67    type Output = Self;
68    #[inline] fn sub(self, rhs: Self) -> Self { Self { re: self.re - rhs.re, im: self.im - rhs.im } }
69}
70impl std::ops::SubAssign for Complex32 {
71    #[inline] fn sub_assign(&mut self, rhs: Self) { self.re -= rhs.re; self.im -= rhs.im; }
72}
73impl std::ops::Mul for Complex32 {
74    type Output = Self;
75    #[inline] fn mul(self, rhs: Self) -> Self {
76        Self {
77            re: self.re * rhs.re - self.im * rhs.im,
78            im: self.re * rhs.im + self.im * rhs.re,
79        }
80    }
81}
82impl std::ops::MulAssign for Complex32 {
83    #[inline] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; }
84}
85impl std::ops::Div for Complex32 {
86    type Output = Self;
87    fn div(self, rhs: Self) -> Self {
88        let denom = rhs.norm_sq();
89        Self {
90            re: (self.re * rhs.re + self.im * rhs.im) / denom,
91            im: (self.im * rhs.re - self.re * rhs.im) / denom,
92        }
93    }
94}
95impl std::ops::Neg for Complex32 {
96    type Output = Self;
97    #[inline] fn neg(self) -> Self { Self { re: -self.re, im: -self.im } }
98}
99impl std::ops::Mul<f32> for Complex32 {
100    type Output = Self;
101    #[inline] fn mul(self, rhs: f32) -> Self { Self { re: self.re * rhs, im: self.im * rhs } }
102}
103impl std::ops::Div<f32> for Complex32 {
104    type Output = Self;
105    #[inline] fn div(self, rhs: f32) -> Self { Self { re: self.re / rhs, im: self.im / rhs } }
106}
107
108// ---------------------------------------------------------------------------
109// Fft — Cooley-Tukey in-place radix-2 DIT
110// ---------------------------------------------------------------------------
111
112/// Cooley-Tukey radix-2 in-place FFT for power-of-2 sizes.
113pub struct Fft;
114
115impl Fft {
116    /// In-place DIT FFT. `data.len()` must be a power of 2.
117    pub fn forward(data: &mut [Complex32]) {
118        let n = data.len();
119        debug_assert!(super::is_power_of_two(n), "FFT size must be a power of 2");
120        Self::bit_reverse_permute(data);
121        let mut len = 2usize;
122        while len <= n {
123            let half = len / 2;
124            let w_step = Complex32::from_polar(1.0, -PI / half as f32);
125            for chunk_start in (0..n).step_by(len) {
126                let mut w = Complex32::one();
127                for k in 0..half {
128                    let u = data[chunk_start + k];
129                    let v = data[chunk_start + k + half] * w;
130                    data[chunk_start + k]        = u + v;
131                    data[chunk_start + k + half] = u - v;
132                    w *= w_step;
133                }
134            }
135            len <<= 1;
136        }
137    }
138
139    /// In-place inverse FFT with 1/N scaling.
140    pub fn inverse(data: &mut [Complex32]) {
141        let n = data.len();
142        // Conjugate, forward FFT, conjugate, scale
143        for x in data.iter_mut() { *x = x.conj(); }
144        Self::forward(data);
145        let scale = 1.0 / n as f32;
146        for x in data.iter_mut() { *x = x.conj() * scale; }
147    }
148
149    /// Bit-reversal permutation.
150    fn bit_reverse_permute(data: &mut [Complex32]) {
151        let n = data.len();
152        let bits = n.trailing_zeros() as usize;
153        for i in 0..n {
154            let j = Self::reverse_bits(i, bits);
155            if i < j {
156                data.swap(i, j);
157            }
158        }
159    }
160
161    fn reverse_bits(mut x: usize, bits: usize) -> usize {
162        let mut result = 0usize;
163        for _ in 0..bits {
164            result = (result << 1) | (x & 1);
165            x >>= 1;
166        }
167        result
168    }
169}
170
171// ---------------------------------------------------------------------------
172// RealFft
173// ---------------------------------------------------------------------------
174
175/// Optimized FFT for real-valued input using the half-spectrum trick.
176pub struct RealFft;
177
178impl RealFft {
179    /// Forward FFT of a real signal. Returns the N/2+1 complex bins.
180    pub fn forward_real(data: &[f32]) -> Vec<Complex32> {
181        let n = next_power_of_two(data.len());
182        let mut buf: Vec<Complex32> = data.iter()
183            .map(|&x| Complex32::new(x, 0.0))
184            .collect();
185        buf.resize(n, Complex32::zero());
186        Fft::forward(&mut buf);
187        // Return only the first half + DC + Nyquist
188        buf.truncate(n / 2 + 1);
189        buf
190    }
191
192    /// Inverse FFT from half-spectrum back to real signal of length `n`.
193    pub fn inverse_real(spectrum: &[Complex32], n: usize) -> Vec<f32> {
194        let n_fft = next_power_of_two(n);
195        let mut buf = Vec::with_capacity(n_fft);
196        buf.extend_from_slice(spectrum);
197        // Reconstruct conjugate-symmetric upper half
198        let half = n_fft / 2;
199        for k in 1..half {
200            buf.push(buf[half - k].conj());
201        }
202        buf.resize(n_fft, Complex32::zero());
203        Fft::inverse(&mut buf);
204        buf[..n].iter().map(|c| c.re).collect()
205    }
206}
207
208// ---------------------------------------------------------------------------
209// FftPlanner
210// ---------------------------------------------------------------------------
211
212/// A plan produced by `FftPlanner` for a specific transform size.
213pub struct FftPlan {
214    /// The FFT size.
215    pub size: usize,
216    /// Pre-computed twiddle factors w_n^k for k = 0..size/2.
217    pub twiddles: Vec<Complex32>,
218}
219
220impl FftPlan {
221    /// Execute the forward FFT using pre-computed twiddles.
222    pub fn forward(&self, data: &mut [Complex32]) {
223        debug_assert_eq!(data.len(), self.size);
224        Fft::forward(data); // Uses inline twiddle computation for correctness; planner twiddles can be used for optimization
225    }
226
227    /// Execute the inverse FFT.
228    pub fn inverse(&self, data: &mut [Complex32]) {
229        debug_assert_eq!(data.len(), self.size);
230        Fft::inverse(data);
231    }
232}
233
234/// Pre-computes twiddle factors and caches plans by size.
235pub struct FftPlanner {
236    cache: HashMap<usize, Vec<Complex32>>,
237}
238
239impl FftPlanner {
240    pub fn new() -> Self {
241        Self { cache: HashMap::new() }
242    }
243
244    /// Return a plan for a given (power-of-2) size, building twiddles if needed.
245    pub fn plan(&mut self, size: usize) -> FftPlan {
246        let n = next_power_of_two(size);
247        let twiddles = self.cache.entry(n).or_insert_with(|| {
248            (0..n / 2)
249                .map(|k| Complex32::from_polar(1.0, -2.0 * PI * k as f32 / n as f32))
250                .collect()
251        });
252        FftPlan { size: n, twiddles: twiddles.clone() }
253    }
254
255    /// Clear the twiddle-factor cache.
256    pub fn clear_cache(&mut self) {
257        self.cache.clear();
258    }
259}
260
261impl Default for FftPlanner {
262    fn default() -> Self { Self::new() }
263}
264
265// ---------------------------------------------------------------------------
266// Spectrum
267// ---------------------------------------------------------------------------
268
269/// Post-FFT frequency-domain representation.
270#[derive(Debug, Clone)]
271pub struct Spectrum {
272    /// Complex FFT bins (full N/2+1 half-spectrum or full N).
273    pub bins: Vec<Complex32>,
274    /// The FFT size used.
275    pub fft_size: usize,
276}
277
278impl Spectrum {
279    pub fn new(bins: Vec<Complex32>, fft_size: usize) -> Self {
280        Self { bins, fft_size }
281    }
282
283    /// Compute from real signal. Uses RealFft internally.
284    pub fn from_real(signal: &[f32]) -> Self {
285        let fft_size = next_power_of_two(signal.len());
286        let bins = RealFft::forward_real(signal);
287        Self { bins, fft_size }
288    }
289
290    /// Number of bins (N/2+1 for real input).
291    pub fn num_bins(&self) -> usize { self.bins.len() }
292
293    /// Magnitude of bin `k`.
294    pub fn magnitude(&self, k: usize) -> f32 { self.bins[k].norm() }
295
296    /// Phase of bin `k` in radians.
297    pub fn phase(&self, k: usize) -> f32 { self.bins[k].arg() }
298
299    /// Power of bin `k` (magnitude²).
300    pub fn power(&self, k: usize) -> f32 { self.bins[k].norm_sq() }
301
302    /// Frequency of bin `k` given sample rate.
303    pub fn frequency_of_bin(&self, k: usize, sample_rate: f32) -> f32 {
304        k as f32 * sample_rate / self.fft_size as f32
305    }
306
307    /// All magnitudes.
308    pub fn magnitude_spectrum(&self) -> Vec<f32> {
309        self.bins.iter().map(|c| c.norm()).collect()
310    }
311
312    /// Power spectrum (magnitude²).
313    pub fn power_spectrum(&self) -> Vec<f32> {
314        self.bins.iter().map(|c| c.norm_sq()).collect()
315    }
316
317    /// dBFS spectrum with noise floor.
318    pub fn to_db(&self, floor_db: f32) -> Vec<f32> {
319        self.bins.iter().map(|c| {
320            let mag = c.norm();
321            if mag <= 0.0 { return floor_db; }
322            (20.0 * mag.log10()).max(floor_db)
323        }).collect()
324    }
325
326    /// Frequency of the dominant (peak) bin.
327    pub fn dominant_frequency(&self, sample_rate: f32) -> f32 {
328        let (peak_k, _) = self.bins.iter().enumerate()
329            .map(|(k, c)| (k, c.norm_sq()))
330            .fold((0, 0.0f32), |(ak, av), (k, v)| if v > av { (k, v) } else { (ak, av) });
331        self.frequency_of_bin(peak_k, sample_rate)
332    }
333
334    /// Spectral centroid — weighted mean frequency.
335    pub fn spectral_centroid(&self, sample_rate: f32) -> f32 {
336        let mut num = 0.0f32;
337        let mut denom = 0.0f32;
338        for (k, c) in self.bins.iter().enumerate() {
339            let mag = c.norm();
340            let freq = self.frequency_of_bin(k, sample_rate);
341            num += freq * mag;
342            denom += mag;
343        }
344        if denom < 1e-10 { 0.0 } else { num / denom }
345    }
346
347    /// Spectral spread — weighted standard deviation of frequency.
348    pub fn spectral_spread(&self, sample_rate: f32) -> f32 {
349        let centroid = self.spectral_centroid(sample_rate);
350        let mut num = 0.0f32;
351        let mut denom = 0.0f32;
352        for (k, c) in self.bins.iter().enumerate() {
353            let mag = c.norm();
354            let freq = self.frequency_of_bin(k, sample_rate);
355            num += (freq - centroid).powi(2) * mag;
356            denom += mag;
357        }
358        if denom < 1e-10 { 0.0 } else { (num / denom).sqrt() }
359    }
360
361    /// Spectral flux: sum of positive differences between consecutive frames.
362    pub fn spectral_flux(&self, other: &Spectrum) -> f32 {
363        let len = self.bins.len().min(other.bins.len());
364        let mut flux = 0.0f32;
365        for i in 0..len {
366            let diff = other.bins[i].norm() - self.bins[i].norm();
367            if diff > 0.0 { flux += diff; }
368        }
369        flux
370    }
371
372    /// Spectral flatness (Wiener entropy): geometric mean / arithmetic mean.
373    pub fn spectral_flatness(&self) -> f32 {
374        let mags: Vec<f32> = self.bins.iter().map(|c| c.norm()).collect();
375        let n = mags.len();
376        if n == 0 { return 0.0; }
377        let arithmetic_mean: f32 = mags.iter().sum::<f32>() / n as f32;
378        if arithmetic_mean < 1e-10 { return 1.0; }
379        let log_sum: f32 = mags.iter()
380            .map(|&m| if m > 1e-10 { m.ln() } else { -23.0 }) // ln(1e-10)
381            .sum();
382        let geometric_mean = (log_sum / n as f32).exp();
383        (geometric_mean / arithmetic_mean).min(1.0)
384    }
385
386    /// Spectral rolloff: frequency below which `threshold` fraction of energy is contained.
387    pub fn spectral_rolloff(&self, threshold: f32) -> f32 {
388        // Computed over bins 0..num_bins
389        let total: f32 = self.bins.iter().map(|c| c.norm_sq()).sum();
390        if total < 1e-10 { return 0.0; }
391        let target = threshold * total;
392        let mut accum = 0.0f32;
393        for (k, c) in self.bins.iter().enumerate() {
394            accum += c.norm_sq();
395            if accum >= target {
396                // Linearly interpolate between bins k-1 and k
397                if k == 0 { return 0.0; }
398                let prev = accum - c.norm_sq();
399                let frac = (target - prev) / c.norm_sq().max(1e-30);
400                return (k as f32 - 1.0 + frac.min(1.0)) / self.fft_size as f32;
401            }
402        }
403        (self.bins.len() - 1) as f32 / self.fft_size as f32
404    }
405
406    /// Spectral rolloff frequency in Hz (not normalized).
407    pub fn spectral_rolloff_hz(&self, threshold: f32, sample_rate: f32) -> f32 {
408        self.spectral_rolloff(threshold) * sample_rate
409    }
410
411    /// Reconstruct the time-domain signal via IFFT.
412    pub fn to_signal(&self) -> Vec<f32> {
413        RealFft::inverse_real(&self.bins, self.fft_size)
414    }
415}
416
417// ---------------------------------------------------------------------------
418// Stft — Short-Time Fourier Transform
419// ---------------------------------------------------------------------------
420
421/// Configuration for the STFT.
422#[derive(Debug, Clone)]
423pub struct StftConfig {
424    pub fft_size: usize,
425    pub hop_size: usize,
426    pub window: WindowFunction,
427}
428
429impl Default for StftConfig {
430    fn default() -> Self {
431        Self {
432            fft_size: 2048,
433            hop_size: 512,
434            window: WindowFunction::Hann,
435        }
436    }
437}
438
439/// Short-Time Fourier Transform: sliding-window analysis and synthesis.
440pub struct Stft {
441    pub config: StftConfig,
442}
443
444impl Stft {
445    pub fn new(config: StftConfig) -> Self {
446        assert!(super::is_power_of_two(config.fft_size), "fft_size must be power of 2");
447        Self { config }
448    }
449
450    /// Analyze a signal into a sequence of Spectrum frames.
451    pub fn analyze(&self, signal: &[f32]) -> Vec<Spectrum> {
452        let cfg = &self.config;
453        let win = cfg.window.generate(cfg.fft_size);
454        let mut frames = Vec::new();
455        let mut pos = 0usize;
456        while pos + cfg.fft_size <= signal.len() {
457            let mut frame: Vec<f32> = signal[pos..pos + cfg.fft_size].to_vec();
458            for (s, &w) in frame.iter_mut().zip(win.iter()) {
459                *s *= w;
460            }
461            let bins = RealFft::forward_real(&frame);
462            frames.push(Spectrum::new(bins, cfg.fft_size));
463            pos += cfg.hop_size;
464        }
465        // Handle the last partial frame with zero-padding
466        if pos < signal.len() {
467            let mut frame = vec![0.0f32; cfg.fft_size];
468            let rem = signal.len() - pos;
469            frame[..rem].copy_from_slice(&signal[pos..]);
470            for (s, &w) in frame.iter_mut().zip(win.iter()) {
471                *s *= w;
472            }
473            let bins = RealFft::forward_real(&frame);
474            frames.push(Spectrum::new(bins, cfg.fft_size));
475        }
476        frames
477    }
478
479    /// Synthesize a signal from STFT frames using overlap-add (OLA).
480    pub fn synthesize(&self, frames: &[Spectrum]) -> Vec<f32> {
481        let cfg = &self.config;
482        if frames.is_empty() { return Vec::new(); }
483        let total_len = (frames.len() - 1) * cfg.hop_size + cfg.fft_size;
484        let mut output = vec![0.0f32; total_len];
485        let mut norm = vec![0.0f32; total_len];
486        let win = cfg.window.generate(cfg.fft_size);
487        // Compute normalization factor for OLA
488        for (frame_idx, spectrum) in frames.iter().enumerate() {
489            let pos = frame_idx * cfg.hop_size;
490            let time_signal = spectrum.to_signal();
491            for (i, (&s, &w)) in time_signal.iter().zip(win.iter()).enumerate() {
492                if pos + i < output.len() {
493                    output[pos + i] += s * w;
494                    norm[pos + i] += w * w;
495                }
496            }
497        }
498        // Normalize
499        for (s, &n) in output.iter_mut().zip(norm.iter()) {
500            if n > 1e-10 { *s /= n; }
501        }
502        output
503    }
504
505    /// Number of frames for a signal of given length.
506    pub fn num_frames(&self, signal_len: usize) -> usize {
507        if signal_len < self.config.fft_size { return 0; }
508        1 + (signal_len - self.config.fft_size) / self.config.hop_size
509    }
510
511    /// Time in seconds of the center of frame `idx`.
512    pub fn frame_time(&self, idx: usize, sample_rate: f32) -> f32 {
513        (idx * self.config.hop_size + self.config.fft_size / 2) as f32 / sample_rate
514    }
515}
516
517// ---------------------------------------------------------------------------
518// Cqt — Constant-Q Transform
519// ---------------------------------------------------------------------------
520
521/// Constant-Q Transform — logarithmically-spaced frequency bins.
522pub struct Cqt {
523    pub sample_rate: f32,
524    pub min_freq: f32,
525    pub bins_per_octave: u32,
526    pub n_bins: u32,
527}
528
529impl Cqt {
530    pub fn new(sample_rate: f32, min_freq: f32, bins_per_octave: u32, n_octaves: u32) -> Self {
531        Self {
532            sample_rate,
533            min_freq,
534            bins_per_octave,
535            n_bins: bins_per_octave * n_octaves,
536        }
537    }
538
539    /// Q factor for this transform.
540    pub fn q_factor(&self) -> f32 {
541        1.0 / (2.0f32.powf(1.0 / self.bins_per_octave as f32) - 1.0)
542    }
543
544    /// Analyze signal using the direct CQT kernel approach.
545    pub fn analyze(&self, signal: &[f32], min_freq: f32, bins_per_octave: u32) -> Vec<f32> {
546        let n_bins = self.n_bins as usize;
547        let q = 1.0 / (2.0f32.powf(1.0 / bins_per_octave as f32) - 1.0);
548        let mut result = vec![0.0f32; n_bins];
549
550        for bin in 0..n_bins {
551            let freq = min_freq * 2.0f32.powf(bin as f32 / bins_per_octave as f32);
552            let window_len = (q * self.sample_rate / freq).round() as usize;
553            if window_len == 0 || window_len > signal.len() {
554                continue;
555            }
556            // Direct DFT at this frequency
557            let mut re = 0.0f32;
558            let mut im = 0.0f32;
559            for n in 0..window_len {
560                // Hann window
561                let win = 0.5 * (1.0 - (2.0 * PI * n as f32 / (window_len - 1) as f32).cos());
562                let sample = signal.get(n).copied().unwrap_or(0.0);
563                let phase = 2.0 * PI * q * n as f32 / window_len as f32;
564                re += win * sample * phase.cos();
565                im -= win * sample * phase.sin();
566            }
567            result[bin] = (re * re + im * im).sqrt() / window_len as f32;
568        }
569        result
570    }
571
572    /// Frequency of CQT bin `k`.
573    pub fn bin_frequency(&self, k: usize) -> f32 {
574        self.min_freq * 2.0f32.powf(k as f32 / self.bins_per_octave as f32)
575    }
576
577    /// Convert CQT output to dB.
578    pub fn to_db(cqt: &[f32], floor_db: f32) -> Vec<f32> {
579        cqt.iter().map(|&x| {
580            if x <= 0.0 { floor_db } else { (20.0 * x.log10()).max(floor_db) }
581        }).collect()
582    }
583}
584
585// ---------------------------------------------------------------------------
586// Autocorrelation
587// ---------------------------------------------------------------------------
588
589/// Autocorrelation functions and YIN pitch detection.
590pub struct Autocorrelation;
591
592impl Autocorrelation {
593    /// Compute the normalized autocorrelation of a signal.
594    pub fn compute(signal: &[f32]) -> Vec<f32> {
595        let n = signal.len();
596        let mut out = vec![0.0f32; n];
597        let energy: f32 = signal.iter().map(|&x| x * x).sum();
598        if energy < 1e-10 { return out; }
599        for lag in 0..n {
600            let mut sum = 0.0f32;
601            for i in 0..n - lag {
602                sum += signal[i] * signal[i + lag];
603            }
604            out[lag] = sum / energy;
605        }
606        out
607    }
608
609    /// Autocorrelation via FFT (O(N log N)).
610    pub fn compute_fft(signal: &[f32]) -> Vec<f32> {
611        let n = next_power_of_two(signal.len() * 2);
612        let mut buf: Vec<Complex32> = signal.iter()
613            .map(|&x| Complex32::new(x, 0.0))
614            .collect();
615        buf.resize(n, Complex32::zero());
616        Fft::forward(&mut buf);
617        // Multiply by conjugate = power spectrum
618        for c in buf.iter_mut() {
619            *c = Complex32::new(c.norm_sq(), 0.0);
620        }
621        Fft::inverse(&mut buf);
622        let scale = 1.0 / buf[0].re.max(1e-30);
623        buf[..signal.len()].iter().map(|c| c.re * scale).collect()
624    }
625
626    /// YIN pitch detection algorithm.
627    /// Returns the fundamental frequency in Hz, or None if no pitch found.
628    pub fn pitch_yin(signal: &[f32], sample_rate: f32) -> Option<f32> {
629        let n = signal.len();
630        if n < 2 { return None; }
631        let half = n / 2;
632
633        // Step 1: difference function
634        let mut d = vec![0.0f32; half];
635        for tau in 0..half {
636            for j in 0..half {
637                let diff = signal[j] - signal[j + tau];
638                d[tau] += diff * diff;
639            }
640        }
641        d[0] = 0.0;
642
643        // Step 2: cumulative mean normalized difference
644        let mut cmnd = vec![0.0f32; half];
645        cmnd[0] = 1.0;
646        let mut running_sum = 0.0f32;
647        for tau in 1..half {
648            running_sum += d[tau];
649            cmnd[tau] = if running_sum > 1e-10 {
650                d[tau] * tau as f32 / running_sum
651            } else {
652                1.0
653            };
654        }
655
656        // Step 3: absolute threshold — find first local minimum below threshold
657        let threshold = 0.1f32;
658        let mut tau_min = None;
659        for tau in 2..half - 1 {
660            if cmnd[tau] < threshold && cmnd[tau] < cmnd[tau - 1] && cmnd[tau] < cmnd[tau + 1] {
661                tau_min = Some(tau);
662                break;
663            }
664        }
665
666        // If no strong minimum found, use global minimum
667        let tau = tau_min.unwrap_or_else(|| {
668            cmnd[1..].iter().enumerate()
669                .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
670                .map(|(i, _)| i + 1)
671                .unwrap_or(0)
672        });
673
674        if tau == 0 { return None; }
675
676        // Step 4: parabolic interpolation around the minimum
677        let tau_f = if tau > 0 && tau < half - 1 {
678            let alpha = cmnd[tau - 1];
679            let beta  = cmnd[tau];
680            let gamma = cmnd[tau + 1];
681            let denom = alpha - 2.0 * beta + gamma;
682            if denom.abs() > 1e-10 {
683                tau as f32 - 0.5 * (gamma - alpha) / denom
684            } else {
685                tau as f32
686            }
687        } else {
688            tau as f32
689        };
690
691        if tau_f < 1.0 { return None; }
692        let freq = sample_rate / tau_f;
693        if freq < 20.0 || freq > 20000.0 { return None; }
694        Some(freq)
695    }
696}
697
698// ---------------------------------------------------------------------------
699// MelFilterbank
700// ---------------------------------------------------------------------------
701
702/// Mel-scale filterbank for perceptual frequency analysis.
703pub struct MelFilterbank {
704    /// Number of Mel filters.
705    pub n_filters: usize,
706    /// FFT size.
707    pub fft_size: usize,
708    /// Sample rate.
709    pub sample_rate: f32,
710    /// Filter weights: n_filters × (fft_size/2+1).
711    filterbank: Vec<Vec<f32>>,
712}
713
714impl MelFilterbank {
715    /// Create a Mel filterbank.
716    pub fn new(n_filters: usize, fft_size: usize, sample_rate: f32) -> Self {
717        let n_bins = fft_size / 2 + 1;
718        let min_mel = super::hz_to_mel(0.0);
719        let max_mel = super::hz_to_mel(sample_rate / 2.0);
720
721        // Linearly-spaced Mel points (n_filters + 2 for lower and upper edges)
722        let mel_points: Vec<f32> = (0..n_filters + 2)
723            .map(|i| min_mel + i as f32 * (max_mel - min_mel) / (n_filters + 1) as f32)
724            .collect();
725        let hz_points: Vec<f32> = mel_points.iter().map(|&m| super::mel_to_hz(m)).collect();
726        // Convert Hz to FFT bin indices
727        let bin_points: Vec<f32> = hz_points.iter()
728            .map(|&hz| hz * fft_size as f32 / sample_rate)
729            .collect();
730
731        let mut filterbank = vec![vec![0.0f32; n_bins]; n_filters];
732        for m in 0..n_filters {
733            let left   = bin_points[m];
734            let center = bin_points[m + 1];
735            let right  = bin_points[m + 2];
736            for k in 0..n_bins {
737                let k_f = k as f32;
738                if k_f >= left && k_f <= center {
739                    filterbank[m][k] = (k_f - left) / (center - left).max(1e-10);
740                } else if k_f > center && k_f <= right {
741                    filterbank[m][k] = (right - k_f) / (right - center).max(1e-10);
742                }
743            }
744        }
745
746        Self { n_filters, fft_size, sample_rate, filterbank }
747    }
748
749    /// Apply the filterbank to a power spectrum. Returns n_filters values.
750    pub fn apply(&self, spectrum: &[f32]) -> Vec<f32> {
751        let n_bins = self.fft_size / 2 + 1;
752        self.filterbank.iter().map(|filter| {
753            filter.iter().zip(spectrum.iter().take(n_bins))
754                .map(|(&w, &s)| w * s)
755                .sum()
756        }).collect()
757    }
758
759    /// Apply to a Spectrum struct.
760    pub fn apply_spectrum(&self, spectrum: &Spectrum) -> Vec<f32> {
761        let power: Vec<f32> = spectrum.power_spectrum();
762        self.apply(&power)
763    }
764}
765
766// ---------------------------------------------------------------------------
767// Mfcc — Mel-Frequency Cepstral Coefficients
768// ---------------------------------------------------------------------------
769
770/// Mel-Frequency Cepstral Coefficient computation.
771pub struct Mfcc {
772    pub n_coeffs: usize,
773    pub sample_rate: f32,
774    pub fft_size: usize,
775    pub n_mel: usize,
776    filterbank: MelFilterbank,
777}
778
779impl Mfcc {
780    pub fn new(n_coeffs: usize, fft_size: usize, sample_rate: f32) -> Self {
781        let n_mel = 40;
782        let filterbank = MelFilterbank::new(n_mel, fft_size, sample_rate);
783        Self { n_coeffs, sample_rate, fft_size, n_mel, filterbank }
784    }
785
786    /// Compute MFCCs for a signal block.
787    pub fn compute(&self, signal: &[f32], _sample_rate: f32, n_coeffs: usize) -> Vec<f32> {
788        // 1. Windowed FFT
789        let n = next_power_of_two(signal.len().max(self.fft_size));
790        let mut frame: Vec<f32> = signal.to_vec();
791        frame.resize(n, 0.0);
792        WindowFunction::Hann.apply(&mut frame);
793        let spectrum = Spectrum::from_real(&frame);
794
795        // 2. Mel filterbank
796        let mel_energies = self.filterbank.apply_spectrum(&spectrum);
797
798        // 3. Log
799        let log_mel: Vec<f32> = mel_energies.iter()
800            .map(|&e| if e > 1e-10 { e.ln() } else { -23.0 })
801            .collect();
802
803        // 4. DCT type-II
804        Self::dct_ii(&log_mel, n_coeffs)
805    }
806
807    /// DCT type-II: N input → n_coeffs output.
808    pub fn dct_ii(input: &[f32], n_coeffs: usize) -> Vec<f32> {
809        let n = input.len();
810        let mut out = Vec::with_capacity(n_coeffs);
811        let scale0 = (1.0 / n as f32).sqrt();
812        let scale_k = (2.0 / n as f32).sqrt();
813        for k in 0..n_coeffs {
814            let scale = if k == 0 { scale0 } else { scale_k };
815            let sum: f32 = input.iter().enumerate().map(|(n_idx, &x)| {
816                x * (PI * k as f32 * (2 * n_idx + 1) as f32 / (2 * n) as f32).cos()
817            }).sum();
818            out.push(scale * sum);
819        }
820        out
821    }
822
823    /// Inverse DCT type-II.
824    pub fn idct_ii(coeffs: &[f32], n_out: usize) -> Vec<f32> {
825        let n_coeffs = coeffs.len();
826        let mut out = vec![0.0f32; n_out];
827        let scale0 = (1.0 / n_out as f32).sqrt();
828        let scale_k = (2.0 / n_out as f32).sqrt();
829        for (n_idx, s) in out.iter_mut().enumerate() {
830            let mut sum = 0.0f32;
831            for k in 0..n_coeffs {
832                let scale = if k == 0 { scale0 } else { scale_k };
833                sum += scale * coeffs[k]
834                    * (PI * k as f32 * (2 * n_idx + 1) as f32 / (2 * n_out) as f32).cos();
835            }
836            *s = sum;
837        }
838        out
839    }
840
841    /// Compute delta (first-order difference) of a feature matrix (rows = frames).
842    pub fn delta(features: &[Vec<f32>]) -> Vec<Vec<f32>> {
843        let n_frames = features.len();
844        if n_frames < 3 { return features.to_vec(); }
845        let n_feats = features[0].len();
846        let width = 2usize; // ±2 frames
847        let denom: f32 = (1..=width as i32).map(|m| (m * m) as f32).sum::<f32>() * 2.0;
848        features.iter().enumerate().map(|(t, _)| {
849            (0..n_feats).map(|d| {
850                let mut num = 0.0f32;
851                for m in 1..=width {
852                    let fwd = features.get(t + m).map(|f| f[d]).unwrap_or_else(|| features[n_frames - 1][d]);
853                    let bwd = if t >= m { features[t - m][d] } else { features[0][d] };
854                    num += m as f32 * (fwd - bwd);
855                }
856                num / denom
857            }).collect()
858        }).collect()
859    }
860}
861
862// ---------------------------------------------------------------------------
863// Chroma
864// ---------------------------------------------------------------------------
865
866/// Chroma feature — pitch class profile (12 semitone bins).
867pub struct Chroma;
868
869impl Chroma {
870    /// Compute the 12-bin chroma vector from a Spectrum.
871    pub fn compute(spectrum: &Spectrum, sample_rate: f32) -> [f32; 12] {
872        let mut chroma = [0.0f32; 12];
873        let n_bins = spectrum.num_bins();
874        // Ignore DC (bin 0)
875        for k in 1..n_bins {
876            let freq = spectrum.frequency_of_bin(k, sample_rate);
877            if freq <= 0.0 { continue; }
878            // Convert to pitch class: midi note mod 12
879            let midi = super::freq_to_midi(freq);
880            if midi < 0.0 { continue; }
881            let pitch_class = (midi.round() as i32).rem_euclid(12) as usize;
882            chroma[pitch_class] += spectrum.magnitude(k);
883        }
884        // Normalize to sum to 1
885        let sum: f32 = chroma.iter().sum();
886        if sum > 1e-10 {
887            for c in chroma.iter_mut() { *c /= sum; }
888        }
889        chroma
890    }
891
892    /// Distance between two chroma vectors (cosine distance, 0=identical, 1=orthogonal).
893    pub fn cosine_distance(a: &[f32; 12], b: &[f32; 12]) -> f32 {
894        let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
895        let norm_a: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
896        let norm_b: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
897        let denom = norm_a * norm_b;
898        if denom < 1e-10 { return 1.0; }
899        1.0 - (dot / denom).min(1.0)
900    }
901
902    /// Rotate chroma vector by `semitones` (for transposition-invariant comparison).
903    pub fn rotate(chroma: &[f32; 12], semitones: i32) -> [f32; 12] {
904        let mut out = [0.0f32; 12];
905        for i in 0..12 {
906            out[i] = chroma[(i as i32 - semitones).rem_euclid(12) as usize];
907        }
908        out
909    }
910
911    /// Find the key (0=C, 1=C#, ..., 11=B) that best matches the chroma.
912    /// Uses the Krumhansl-Schmuckler key profiles.
913    pub fn estimate_key(chroma: &[f32; 12]) -> (usize, bool) {
914        // Major and minor key profiles (Krumhansl-Schmuckler)
915        let major = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88];
916        let minor = [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17];
917
918        let mut best_key = 0usize;
919        let mut best_is_minor = false;
920        let mut best_corr = -f32::INFINITY;
921
922        for root in 0..12 {
923            // Major
924            let maj_corr = Self::profile_correlation(chroma, &major, root);
925            if maj_corr > best_corr {
926                best_corr = maj_corr;
927                best_key = root;
928                best_is_minor = false;
929            }
930            // Minor
931            let min_corr = Self::profile_correlation(chroma, &minor, root);
932            if min_corr > best_corr {
933                best_corr = min_corr;
934                best_key = root;
935                best_is_minor = true;
936            }
937        }
938        (best_key, best_is_minor)
939    }
940
941    fn profile_correlation(chroma: &[f32; 12], profile: &[f64; 12], root: usize) -> f32 {
942        // Pearson correlation
943        let c_mean: f32 = chroma.iter().sum::<f32>() / 12.0;
944        let p_mean: f64 = profile.iter().sum::<f64>() / 12.0;
945        let mut num = 0.0f64;
946        let mut var_c = 0.0f64;
947        let mut var_p = 0.0f64;
948        for i in 0..12 {
949            let ci = (chroma[(i + root) % 12] - c_mean) as f64;
950            let pi = profile[i] - p_mean;
951            num += ci * pi;
952            var_c += ci * ci;
953            var_p += pi * pi;
954        }
955        let denom = (var_c * var_p).sqrt();
956        if denom < 1e-10 { 0.0 } else { (num / denom) as f32 }
957    }
958}
959
960// ---------------------------------------------------------------------------
961// Tests
962// ---------------------------------------------------------------------------
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967    use crate::dsp::SignalGenerator;
968
969    fn nearly_eq(a: f32, b: f32, tol: f32) -> bool {
970        (a - b).abs() < tol
971    }
972
973    // --- Complex32 arithmetic ---
974
975    #[test]
976    fn test_complex_add_sub() {
977        let a = Complex32::new(1.0, 2.0);
978        let b = Complex32::new(3.0, -1.0);
979        let s = a + b;
980        assert!(nearly_eq(s.re, 4.0, 1e-6));
981        assert!(nearly_eq(s.im, 1.0, 1e-6));
982        let d = a - b;
983        assert!(nearly_eq(d.re, -2.0, 1e-6));
984        assert!(nearly_eq(d.im, 3.0, 1e-6));
985    }
986
987    #[test]
988    fn test_complex_mul() {
989        let a = Complex32::new(1.0, 2.0);
990        let b = Complex32::new(3.0, 4.0);
991        let m = a * b;
992        assert!(nearly_eq(m.re, -5.0, 1e-5));
993        assert!(nearly_eq(m.im, 10.0, 1e-5));
994    }
995
996    #[test]
997    fn test_complex_norm_and_conj() {
998        let c = Complex32::new(3.0, 4.0);
999        assert!(nearly_eq(c.norm(), 5.0, 1e-5));
1000        let conj = c.conj();
1001        assert!(nearly_eq(conj.im, -4.0, 1e-6));
1002    }
1003
1004    #[test]
1005    fn test_fft_roundtrip() {
1006        let n = 64usize;
1007        let signal: Vec<f32> = (0..n).map(|i| (2.0 * PI * 3.0 * i as f32 / n as f32).sin()).collect();
1008        let mut buf: Vec<Complex32> = signal.iter().map(|&x| Complex32::new(x, 0.0)).collect();
1009        Fft::forward(&mut buf);
1010        Fft::inverse(&mut buf);
1011        for (orig, recovered) in signal.iter().zip(buf.iter()) {
1012            assert!(nearly_eq(*orig, recovered.re, 1e-4), "orig={}, rec={}", orig, recovered.re);
1013        }
1014    }
1015
1016    #[test]
1017    fn test_fft_impulse() {
1018        // FFT of an impulse should be all-ones in magnitude
1019        let n = 16usize;
1020        let mut buf = vec![Complex32::zero(); n];
1021        buf[0] = Complex32::one();
1022        Fft::forward(&mut buf);
1023        for c in &buf {
1024            assert!(nearly_eq(c.norm(), 1.0, 1e-5));
1025        }
1026    }
1027
1028    #[test]
1029    fn test_real_fft_roundtrip() {
1030        let signal: Vec<f32> = (0..256).map(|i| (2.0 * PI * 10.0 * i as f32 / 256.0).sin()).collect();
1031        let spectrum = RealFft::forward_real(&signal);
1032        let recovered = RealFft::inverse_real(&spectrum, 256);
1033        for (orig, rec) in signal.iter().zip(recovered.iter()) {
1034            assert!(nearly_eq(*orig, *rec, 1e-3), "orig={}, rec={}", orig, rec);
1035        }
1036    }
1037
1038    #[test]
1039    fn test_spectrum_dominant_frequency() {
1040        let sr = 44100.0f32;
1041        let freq = 440.0f32;
1042        let sig = SignalGenerator::sine(freq, 1.0, 1.0, sr);
1043        let win_size = 4096;
1044        let frame: Vec<f32> = sig.samples[..win_size].to_vec();
1045        let spectrum = Spectrum::from_real(&frame);
1046        let dom = spectrum.dominant_frequency(sr);
1047        // Should be within a couple bins of 440 Hz
1048        assert!((dom - freq).abs() < sr / win_size as f32 * 2.0, "dom={}", dom);
1049    }
1050
1051    #[test]
1052    fn test_spectrum_centroid_noise() {
1053        // Flat spectrum -> centroid near Nyquist/2
1054        let mut buf: Vec<Complex32> = (0..513).map(|_| Complex32::new(1.0, 0.0)).collect();
1055        let spectrum = Spectrum::new(buf.clone(), 1024);
1056        let centroid = spectrum.spectral_centroid(44100.0);
1057        assert!(centroid > 10000.0 && centroid < 15000.0, "centroid={}", centroid);
1058    }
1059
1060    #[test]
1061    fn test_spectrum_flatness_flat() {
1062        let buf: Vec<Complex32> = (0..513).map(|_| Complex32::new(1.0, 0.0)).collect();
1063        let spectrum = Spectrum::new(buf, 1024);
1064        let flatness = spectrum.spectral_flatness();
1065        // All-ones magnitude → flatness ≈ 1.0
1066        assert!(flatness > 0.9, "flatness={}", flatness);
1067    }
1068
1069    #[test]
1070    fn test_spectrum_flux() {
1071        let buf1: Vec<Complex32> = (0..65).map(|_| Complex32::new(1.0, 0.0)).collect();
1072        let buf2: Vec<Complex32> = (0..65).map(|_| Complex32::new(2.0, 0.0)).collect();
1073        let s1 = Spectrum::new(buf1, 128);
1074        let s2 = Spectrum::new(buf2, 128);
1075        let flux = s1.spectral_flux(&s2);
1076        assert!(flux > 0.0);
1077    }
1078
1079    #[test]
1080    fn test_stft_analyze_synthesize() {
1081        let sr = 44100.0;
1082        let sig = SignalGenerator::sine(440.0, 1.0, 0.1, sr);
1083        let stft = Stft::new(StftConfig { fft_size: 1024, hop_size: 256, window: WindowFunction::Hann });
1084        let frames = stft.analyze(&sig.samples);
1085        assert!(!frames.is_empty());
1086        let reconstructed = stft.synthesize(&frames);
1087        // Reconstruction should have similar RMS
1088        let orig_rms: f32 = {
1089            let s: f32 = sig.samples.iter().map(|&x| x * x).sum();
1090            (s / sig.samples.len() as f32).sqrt()
1091        };
1092        let rec_len = reconstructed.len().min(sig.samples.len());
1093        let rec_rms: f32 = {
1094            let s: f32 = reconstructed[..rec_len].iter().map(|&x| x * x).sum();
1095            (s / rec_len as f32).sqrt()
1096        };
1097        // Allow significant tolerance for OLA boundary effects
1098        assert!((orig_rms - rec_rms).abs() < 0.3, "orig_rms={}, rec_rms={}", orig_rms, rec_rms);
1099    }
1100
1101    #[test]
1102    fn test_yin_pitch_detection() {
1103        let sr = 44100.0;
1104        let freq = 220.0f32;
1105        let sig = SignalGenerator::sine(freq, 1.0, 0.05, sr);
1106        let detected = Autocorrelation::pitch_yin(&sig.samples, sr);
1107        assert!(detected.is_some(), "YIN should detect pitch");
1108        let detected_freq = detected.unwrap();
1109        assert!((detected_freq - freq).abs() < 10.0, "detected={}, expected={}", detected_freq, freq);
1110    }
1111
1112    #[test]
1113    fn test_mel_filterbank_shape() {
1114        let fb = MelFilterbank::new(40, 2048, 44100.0);
1115        assert_eq!(fb.filterbank.len(), 40);
1116        assert_eq!(fb.filterbank[0].len(), 1025);
1117    }
1118
1119    #[test]
1120    fn test_mel_filterbank_apply() {
1121        let fb = MelFilterbank::new(40, 2048, 44100.0);
1122        let flat_spectrum = vec![1.0f32; 1025];
1123        let mel = fb.apply(&flat_spectrum);
1124        assert_eq!(mel.len(), 40);
1125        for &v in &mel {
1126            assert!(v >= 0.0, "All mel values must be non-negative");
1127        }
1128    }
1129
1130    #[test]
1131    fn test_mfcc_shape() {
1132        let sr = 44100.0;
1133        let sig = SignalGenerator::sine(440.0, 1.0, 0.05, sr);
1134        let mfcc = Mfcc::new(13, 2048, sr);
1135        let coeffs = mfcc.compute(&sig.samples, sr, 13);
1136        assert_eq!(coeffs.len(), 13);
1137    }
1138
1139    #[test]
1140    fn test_chroma_unit_sum() {
1141        let sr = 44100.0;
1142        let sig = SignalGenerator::sine(440.0, 1.0, 0.1, sr);
1143        let spectrum = Spectrum::from_real(&sig.samples[..2048]);
1144        let chroma = Chroma::compute(&spectrum, sr);
1145        let sum: f32 = chroma.iter().sum();
1146        // Should be approximately 1 (normalized)
1147        if sum > 1e-10 {
1148            assert!((sum - 1.0).abs() < 1e-4, "sum={}", sum);
1149        }
1150    }
1151
1152    #[test]
1153    fn test_fft_planner_caching() {
1154        let mut planner = FftPlanner::new();
1155        let plan1 = planner.plan(512);
1156        let plan2 = planner.plan(512);
1157        assert_eq!(plan1.size, plan2.size);
1158    }
1159
1160    #[test]
1161    fn test_autocorr_unity_at_zero() {
1162        let sig: Vec<f32> = (0..256).map(|i| (0.1 * i as f32).sin()).collect();
1163        let ac = Autocorrelation::compute(&sig);
1164        assert!(nearly_eq(ac[0], 1.0, 1e-4));
1165    }
1166
1167    #[test]
1168    fn test_cqt_output_shape() {
1169        let cqt = Cqt::new(44100.0, 55.0, 12, 7);
1170        let sig: Vec<f32> = (0..44100).map(|i| (2.0 * PI * 440.0 * i as f32 / 44100.0).sin()).collect();
1171        let result = cqt.analyze(&sig, 55.0, 12);
1172        assert_eq!(result.len(), (12 * 7) as usize);
1173    }
1174
1175    #[test]
1176    fn test_spectrum_rolloff() {
1177        let n_bins = 513;
1178        // Bins 0..256 have energy 1, rest have 0
1179        let bins: Vec<Complex32> = (0..n_bins).map(|k| {
1180            if k < 256 { Complex32::new(1.0, 0.0) } else { Complex32::zero() }
1181        }).collect();
1182        let spec = Spectrum::new(bins, 1024);
1183        let rolloff = spec.spectral_rolloff(0.85);
1184        // Should be around bin 256 / 1024 ≈ 0.25
1185        assert!(rolloff <= 0.5, "rolloff={}", rolloff);
1186    }
1187
1188    #[test]
1189    fn test_chroma_cosine_distance_identical() {
1190        let a = [1.0f32; 12].map(|x| x / 12.0);
1191        let dist = Chroma::cosine_distance(&a, &a);
1192        assert!(dist.abs() < 1e-5, "dist={}", dist);
1193    }
1194
1195    #[test]
1196    fn test_dct_round_trip() {
1197        let input: Vec<f32> = (0..20).map(|i| i as f32 * 0.1).collect();
1198        let coeffs = Mfcc::dct_ii(&input, 20);
1199        let recovered = Mfcc::idct_ii(&coeffs, 20);
1200        for (a, b) in input.iter().zip(recovered.iter()) {
1201            assert!(nearly_eq(*a, *b, 1e-3), "a={}, b={}", a, b);
1202        }
1203    }
1204
1205    #[test]
1206    fn test_complex_from_polar_roundtrip() {
1207        let r = 3.5f32;
1208        let theta = 1.2f32;
1209        let c = Complex32::from_polar(r, theta);
1210        assert!(nearly_eq(c.norm(), r, 1e-5));
1211        assert!(nearly_eq(c.arg(), theta, 1e-5));
1212    }
1213
1214    #[test]
1215    fn test_fft_linearity() {
1216        let n = 32;
1217        let a: Vec<Complex32> = (0..n).map(|i| Complex32::new((i as f32 * 0.1).sin(), 0.0)).collect();
1218        let b: Vec<Complex32> = (0..n).map(|i| Complex32::new((i as f32 * 0.2).cos(), 0.0)).collect();
1219        let mut fa = a.clone();
1220        let mut fb = b.clone();
1221        Fft::forward(&mut fa);
1222        Fft::forward(&mut fb);
1223        // F(a + b) == F(a) + F(b)
1224        let mut ab: Vec<Complex32> = a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
1225        Fft::forward(&mut ab);
1226        for ((&fa_k, &fb_k), &fab_k) in fa.iter().zip(fb.iter()).zip(ab.iter()) {
1227            let diff = (fa_k + fb_k) - fab_k;
1228            assert!(diff.norm() < 1e-3, "linearity violation at bin: {}", diff.norm());
1229        }
1230    }
1231}