freq_det/
lib.rs

1//! Frequency detection made easy
2//!
3//! ```
4//! use freq_det::FreqDetector;
5//!
6//! let sample_count = 4096;
7//!
8//! let sinusoid_440hz = (0..sample_count)
9//!     .map(|i| {
10//!         use std::f32::consts::TAU;
11//!         (i as f32 / 44100.0 * 440.0 * TAU).sin()
12//!         // noise
13//!         + 0.9 * (i as f32 / 44100.0 * 100.0 * TAU).sin()
14//!         + 0.9 * (i as f32 / 44100.0 * 120.0 * TAU).sin()
15//!     })
16//!     .collect::<Vec<_>>();
17//!
18//! let freq_detector = FreqDetector::new(44100, sample_count).unwrap();
19//! assert_eq!(freq_detector.detect(&sinusoid_440hz).unwrap().round(), 440.0);
20//! ```
21
22use std::sync::Arc;
23
24use rustfft::{
25    num_complex::{Complex, ComplexFloat},
26    Fft, FftPlanner,
27};
28use thiserror::Error;
29
30/// Frequency detector
31pub struct FreqDetector {
32    fft: Arc<dyn Fft<f32>>,
33    sample_count: usize,
34    sample_rate: usize,
35}
36
37impl FreqDetector {
38    /// `sample_rate` is `44100` for most modern applications
39    ///
40    /// `sample_count` numbers between `2048` and `8192` work well.
41    /// More samples usually means more accuracy, but requires more audio,
42    /// which also means more latency for realtime application.
43    ///
44    /// # Errors
45    /// - if sample rate is 0
46    /// - if fewer than 4 samples are passed
47    pub fn new(sample_rate: usize, sample_count: usize) -> Result<Self, DetectorCreateError> {
48        let mut planner = FftPlanner::new();
49        if sample_rate < 1 {
50            return Err(DetectorCreateError::SampleRateTooLow);
51        }
52        if sample_count < 4 {
53            return Err(DetectorCreateError::TooFewSamples);
54        }
55        Ok(Self {
56            fft: planner.plan_fft_forward(sample_count),
57
58            sample_count,
59            sample_rate,
60        })
61    }
62
63    /// # Errors
64    ///
65    /// - if `samples.len()` does not match the `sample_count` passed to [Self::new]
66    /// - if there are `NaN`s in the sample slice
67    pub fn detect(&self, samples: &[f32]) -> Result<f32, DetectError> {
68        if samples.len() != self.sample_count {
69            return Err(DetectError::SampleCountMismatch {
70                expected: self.sample_count,
71                passed: samples.len(),
72            });
73        }
74        let mut fft_buf = samples
75            .iter()
76            .copied()
77            .map(|s| Complex { re: s, im: 0.0 })
78            .collect::<Vec<_>>();
79
80        self.fft.process(&mut fft_buf);
81
82        let antialised_power_values = fft_buf
83            .windows(2)
84            // only interested in positive frequencies
85            .take(self.sample_count / 2)
86            .map(|w| [w[0].abs(), w[1].abs()])
87            .enumerate()
88            .collect::<Vec<_>>();
89
90        let peak_window = antialised_power_values
91            .iter()
92            .copied()
93            .max_by(|c1, c2| {
94                c1.1.iter()
95                    .sum::<f32>()
96                    .total_cmp(&c2.1.iter().sum::<f32>())
97            })
98            .expect("to have at least 1 positive frequency");
99
100        if peak_window.1.iter().sum::<f32>() < 0.0001 {
101            return Ok(0.0);
102        }
103
104        // take weighted average of the two biggest
105        // not sure what is the math behind why this works, but it does
106        let res = (self.fft_bucket_to_freq(peak_window.0) * peak_window.1[0]
107            + self.fft_bucket_to_freq(peak_window.0 + 1) * peak_window.1[1])
108            / peak_window.1.iter().sum::<f32>();
109        if res.is_nan() {
110            Err(DetectError::NansFound)
111        } else {
112            Ok(res)
113        }
114    }
115
116    fn fft_bucket_to_freq(&self, bucket: usize) -> f32 {
117        bucket as f32 * self.sample_rate as f32 / self.sample_count as f32
118    }
119}
120
121#[derive(Error, Debug)]
122pub enum DetectError {
123    #[error("Invalid sample count passed (expected {expected}, passed {passed})")]
124    SampleCountMismatch { expected: usize, passed: usize },
125    #[error("NaNs in the samples")]
126    NansFound,
127}
128
129#[derive(Error, Debug)]
130pub enum DetectorCreateError {
131    #[error("Detector does not support sample rate < 1 sample per second")]
132    SampleRateTooLow,
133    #[error("Needs at least 4 samples for detection")]
134    TooFewSamples,
135}
136
137#[cfg(test)]
138mod tests {
139    use super::FreqDetector;
140
141    #[test]
142    fn freq_detector_smoke_test() {
143        use std::f32::consts::TAU;
144        let sample_count = 4096 * 2;
145        let freq_detector = FreqDetector::new(44100, sample_count).unwrap();
146
147        for freq in [10, 20, 30, 100, 1000, 2000] {
148            let sin_samples = (0..sample_count)
149                .map(|i| {
150                    (i as f32 / 44100.0 * freq as f32 * TAU).sin()
151                        // noise
152                        + 0.9 * (i as f32 / 44100.0 * 101.0 * TAU).sin()
153                        + 0.9 * (i as f32 / 44100.0 * 120.0 * TAU).sin()
154                })
155                .collect::<Vec<_>>();
156
157            let detected_freq = freq_detector.detect(&sin_samples).unwrap();
158            dbg!(detected_freq, freq);
159            assert!(
160                (detected_freq - freq as f32).abs() < 0.5,
161                "detected {detected_freq} expected {freq}"
162            );
163        }
164    }
165}