1use std::sync::Arc;
23
24use rustfft::{
25 num_complex::{Complex, ComplexFloat},
26 Fft, FftPlanner,
27};
28use thiserror::Error;
29
30pub struct FreqDetector {
32 fft: Arc<dyn Fft<f32>>,
33 sample_count: usize,
34 sample_rate: usize,
35}
36
37impl FreqDetector {
38 pub fn new(sample_rate: usize, sample_count: usize) -> Result<Self, DetectorCreateError> {
48 let mut planner = FftPlanner::new();
49 if sample_rate < 1 {
50 return Err(DetectorCreateError::SampleRateTooLow);
51 }
52 if sample_count < 4 {
53 return Err(DetectorCreateError::TooFewSamples);
54 }
55 Ok(Self {
56 fft: planner.plan_fft_forward(sample_count),
57
58 sample_count,
59 sample_rate,
60 })
61 }
62
63 pub fn detect(&self, samples: &[f32]) -> Result<f32, DetectError> {
68 if samples.len() != self.sample_count {
69 return Err(DetectError::SampleCountMismatch {
70 expected: self.sample_count,
71 passed: samples.len(),
72 });
73 }
74 let mut fft_buf = samples
75 .iter()
76 .copied()
77 .map(|s| Complex { re: s, im: 0.0 })
78 .collect::<Vec<_>>();
79
80 self.fft.process(&mut fft_buf);
81
82 let antialised_power_values = fft_buf
83 .windows(2)
84 .take(self.sample_count / 2)
86 .map(|w| [w[0].abs(), w[1].abs()])
87 .enumerate()
88 .collect::<Vec<_>>();
89
90 let peak_window = antialised_power_values
91 .iter()
92 .copied()
93 .max_by(|c1, c2| {
94 c1.1.iter()
95 .sum::<f32>()
96 .total_cmp(&c2.1.iter().sum::<f32>())
97 })
98 .expect("to have at least 1 positive frequency");
99
100 if peak_window.1.iter().sum::<f32>() < 0.0001 {
101 return Ok(0.0);
102 }
103
104 let res = (self.fft_bucket_to_freq(peak_window.0) * peak_window.1[0]
107 + self.fft_bucket_to_freq(peak_window.0 + 1) * peak_window.1[1])
108 / peak_window.1.iter().sum::<f32>();
109 if res.is_nan() {
110 Err(DetectError::NansFound)
111 } else {
112 Ok(res)
113 }
114 }
115
116 fn fft_bucket_to_freq(&self, bucket: usize) -> f32 {
117 bucket as f32 * self.sample_rate as f32 / self.sample_count as f32
118 }
119}
120
121#[derive(Error, Debug)]
122pub enum DetectError {
123 #[error("Invalid sample count passed (expected {expected}, passed {passed})")]
124 SampleCountMismatch { expected: usize, passed: usize },
125 #[error("NaNs in the samples")]
126 NansFound,
127}
128
129#[derive(Error, Debug)]
130pub enum DetectorCreateError {
131 #[error("Detector does not support sample rate < 1 sample per second")]
132 SampleRateTooLow,
133 #[error("Needs at least 4 samples for detection")]
134 TooFewSamples,
135}
136
137#[cfg(test)]
138mod tests {
139 use super::FreqDetector;
140
141 #[test]
142 fn freq_detector_smoke_test() {
143 use std::f32::consts::TAU;
144 let sample_count = 4096 * 2;
145 let freq_detector = FreqDetector::new(44100, sample_count).unwrap();
146
147 for freq in [10, 20, 30, 100, 1000, 2000] {
148 let sin_samples = (0..sample_count)
149 .map(|i| {
150 (i as f32 / 44100.0 * freq as f32 * TAU).sin()
151 + 0.9 * (i as f32 / 44100.0 * 101.0 * TAU).sin()
153 + 0.9 * (i as f32 / 44100.0 * 120.0 * TAU).sin()
154 })
155 .collect::<Vec<_>>();
156
157 let detected_freq = freq_detector.detect(&sin_samples).unwrap();
158 dbg!(detected_freq, freq);
159 assert!(
160 (detected_freq - freq as f32).abs() < 0.5,
161 "detected {detected_freq} expected {freq}"
162 );
163 }
164 }
165}