use std::{collections::HashMap, ops::Deref};
use rustfft::{
num_complex::{Complex, ComplexFloat},
FftPlanner,
};
use crate::core::note::{HasPrimaryHarmonicSeries, ALL_PITCH_NOTES_WITH_FREQUENCY};
use crate::core::{base::Res, note::Note, pitch::HasFrequency};
pub fn get_notes_from_audio_data(data: &[f32], length_in_seconds: u8) -> Res<Vec<Note>> {
if length_in_seconds < 1 {
return Err(anyhow::Error::msg("Listening length in seconds must be greater than 1."));
}
let num_nan = data.iter().filter(|n| n.is_nan()).count();
if num_nan > 0 {
return Err(anyhow::Error::msg(format!("{num_nan} NaNs in audio data.")));
}
let frequency_space = get_frequency_space(data, length_in_seconds);
let smoothed_frequency_space = get_smoothed_frequency_space(&frequency_space, length_in_seconds);
Ok(get_notes_from_smoothed_frequency_space(&smoothed_frequency_space))
}
pub fn get_notes_from_smoothed_frequency_space(smoothed_frequency_space: &[(f32, f32)]) -> Vec<Note> {
let peak_space = translate_frequency_space_to_peak_space(smoothed_frequency_space);
let peak_best_notes = get_likely_notes_from_peak_space(&peak_space, 0.1);
let best_notes = peak_best_notes;
reduce_notes_by_harmonic_series(&best_notes, 0.1)
}
pub fn get_frequency_space(data: &[f32], length_in_seconds: u8) -> Vec<(f32, f32)> {
let num_samples = data.len();
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(num_samples);
let mut buffer = data.iter().map(|n| Complex::new(*n, 0.0)).collect::<Vec<_>>();
fft.process(&mut buffer);
buffer.into_iter().enumerate().map(|(k, d)| (k as f32 / length_in_seconds as f32, d.abs())).collect::<Vec<_>>()
}
pub fn get_time_space(data: &[f32]) -> Vec<(f32, f32)> {
let num_samples = data.len();
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_inverse(num_samples);
let mut buffer = data.iter().map(|n| Complex::new(*n, 0.0)).collect::<Vec<_>>();
fft.process(&mut buffer);
buffer.into_iter().enumerate().map(|(k, d)| (k as f32, d.abs())).collect::<Vec<_>>()
}
pub fn compute_cqt(frequency_space: &[f32]) -> Vec<f32> {
const Q_FACTOR: f32 = 24.7; const MIN_FREQ: f32 = 65.41; const MAX_FREQ: f32 = 2093.0; const N_BINS: usize = 60;
let mut cqt_output = vec![vec![0.0; frequency_space.len()]; N_BINS];
let log_min_freq = MIN_FREQ.log2();
let log_max_freq = MAX_FREQ.log2();
let log_freq_step = (log_max_freq - log_min_freq) / (N_BINS as f32 - 1.0);
for i in 0..N_BINS {
let log_freq_center = log_min_freq + i as f32 * log_freq_step;
let freq_center = 2.0f32.powf(log_freq_center);
let freq_bw = freq_center / Q_FACTOR;
let fft_freq_step = 1.0;
let start_bin = (freq_center - freq_bw / 2.0) / fft_freq_step;
let end_bin = (freq_center + freq_bw / 2.0) / fft_freq_step;
let mut cqt_bin = vec![rustfft::num_complex::Complex::new(0.0, 0.0); frequency_space.len()];
for j in start_bin as usize..=end_bin as usize {
let weight = (j as f32 - freq_center / fft_freq_step) / freq_bw;
let weight = weight * std::f32::consts::PI * 2.0;
let fft_bin = frequency_space[j];
cqt_bin[j] = rustfft::num_complex::Complex::new(fft_bin * weight.sin(), 0.0);
}
let ifft = rustfft::FftPlanner::<f32>::new().plan_fft_inverse(cqt_bin.len());
ifft.process(&mut cqt_bin);
for j in 0..frequency_space.len() {
cqt_output[i][j] = cqt_bin[j].abs();
}
}
let mut result = vec![];
for k in 0..N_BINS {
let mut sum = 0.0;
for j in 0..frequency_space.len() {
sum += cqt_output[k][j];
}
result.push(sum);
}
result
}
pub fn get_smoothed_frequency_space(frequency_space: &[(f32, f32)], length_in_seconds: u8) -> Vec<(f32, f32)> {
let mut smoothed_frequency_space = Vec::new();
let size = length_in_seconds as usize;
for k in (0..frequency_space.len()).step_by(size) {
let average_frequency = frequency_space[k..k + size].iter().map(|(f, _)| f).sum::<f32>() / size as f32;
let average_magnitude = frequency_space[k..k + size].iter().map(|(_, m)| m).sum::<f32>() / size as f32;
smoothed_frequency_space.push((average_frequency, average_magnitude));
}
smoothed_frequency_space
}
pub fn translate_frequency_space_to_peak_space(frequency_space: &[(f32, f32)]) -> Vec<(f32, f32)> {
let magic_window_number = 50f32;
let min_index = 50;
let max_index = 8_000;
let mut peak_space = frequency_space.to_vec();
let mut last_k = min_index;
let mut k = min_index;
while k < max_index {
let window_size = (frequency_space[k].0 / magic_window_number) as usize;
let max_in_window = (k..k + window_size).map(|i| frequency_space[i].1).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap_or_default();
peak_space[k] = (peak_space[k].0, peak_space[k].1);
let mut next = 0;
for j in k..(k + window_size) {
if frequency_space[j].1 == max_in_window {
peak_space[j] = (peak_space[j].0, peak_space[j].1);
next = j;
} else {
peak_space[j] = (peak_space[j].0, 0.0);
}
}
k = next;
if last_k == k {
k += 1;
}
last_k = k;
}
let skip = min_index;
let take = max_index - min_index;
for (k, (_, magnitude)) in peak_space.iter_mut().enumerate().skip(skip).take(take) {
let window_size = 3;
let average_right_derivative = ((frequency_space[k + window_size].1 - frequency_space[k].1) / window_size as f32).abs();
let average_left_derivative = ((frequency_space[k].1 - frequency_space[k - window_size].1) / window_size as f32).abs();
let average_derivative = (average_right_derivative + average_left_derivative) / 2f32;
if average_derivative / *magnitude < 0.1 {
*magnitude = 0.0;
}
}
peak_space.into_iter().skip(min_index).take(max_index - min_index).collect()
}
fn get_likely_notes_from_peak_space(peak_space: &[(f32, f32)], cutoff: f32) -> Vec<(Note, f32)> {
let mut peak_space = peak_space.iter().filter(|(_, m)| *m > 0.1).copied().collect::<Vec<_>>();
peak_space.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let max_power = peak_space[0].1;
let peak_space = peak_space.into_iter().filter(|(_, m)| *m > max_power * cutoff).collect::<Vec<_>>();
let mut candidates = HashMap::new();
for (frequency, magnitude) in &peak_space {
if let Some(pair) = binary_search_closest(ALL_PITCH_NOTES_WITH_FREQUENCY.deref(), *frequency, |t| t.1) {
let note = pair.0;
let entry = candidates.entry(note).or_insert(*magnitude);
*entry += magnitude;
}
}
candidates.into_iter().collect::<Vec<_>>()
}
fn reduce_notes_by_harmonic_series(notes: &[(Note, f32)], cutoff: f32) -> Vec<Note> {
let mut working_set = notes.to_vec();
working_set.sort_unstable_by(|a, b| a.0.frequency().partial_cmp(&b.0.frequency()).unwrap());
let mut k = 0;
while k < working_set.len() {
let note = working_set[k].0;
let mut j = k + 1;
while j < working_set.len() {
let other_note = working_set[j].0;
for harmonic in note.primary_harmonic_series() {
if harmonic.frequency() == other_note.frequency() {
working_set[k].1 += working_set[j].1;
working_set.remove(j);
j -= 1;
}
}
j += 1;
}
k += 1;
}
working_set.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let cutoff = working_set[0].1 * cutoff;
working_set.retain(|(_, magnitude)| *magnitude > cutoff);
working_set.into_iter().map(|(note, _)| note).collect()
}
pub fn get_frequency_bins(notes: &[Note]) -> Vec<(Note, (f32, f32))> {
let mut bins = Vec::new();
for (i, note) in notes.iter().enumerate() {
let low = if i == 0 {
continue;
} else {
note.frequency() - 0.50 * (note.frequency() - notes[i - 1].frequency())
};
let high = if i == notes.len() - 1 {
continue;
} else {
note.frequency() + 0.50 * (notes[i + 1].frequency() - note.frequency())
};
bins.push((*note, (low, high)));
}
bins
}
pub fn binary_search_closest<T, F>(array: &[T], target: f32, mut get_value: F) -> Option<&T>
where
F: FnMut(&T) -> f32,
{
let mut low = 0;
let mut high = array.len();
while low < high {
let mid = (low + high) / 2;
let value = get_value(&array[mid]);
if value < target {
low = mid + 1;
} else {
high = mid;
}
}
if low == 0 || low == array.len() {
return None;
}
let low_index = low - 1;
let high_index = low;
let low_value = get_value(&array[low_index]);
let high_value = get_value(&array[high_index]);
if (high_value - target).abs() < (target - low_value).abs() {
Some(&array[high_index])
} else {
Some(&array[low_index])
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::{fs::File, io::Read};
use crate::core::note::ALL_PITCH_NOTES;
use super::*;
pub fn load_test_data() -> Vec<f32> {
let mut file = File::open("tests/vec.bin").unwrap();
let file_size = file.metadata().unwrap().len() as usize;
let float_size = std::mem::size_of::<f32>();
let element_count = file_size / float_size;
let mut buffer = vec![0u8; file_size];
file.read_exact(&mut buffer).unwrap();
let data: Vec<f32> = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, element_count).to_vec() };
data
}
#[test]
#[should_panic]
fn test_get_notes_from_audio_data_length() {
get_notes_from_audio_data(&[0.0, 0.0, 0.0], 0).unwrap();
}
#[test]
#[should_panic]
fn test_get_notes_from_audio_data_nan() {
get_notes_from_audio_data(&[0.0, 0.0, f32::NAN], 10).unwrap();
}
#[test]
fn test_get_time_space() {
let data = load_test_data();
let frequency_space = get_frequency_space(&data, 5).into_iter().map(|(_, v)| v).collect::<Vec<_>>();
let _ = get_time_space(&frequency_space);
}
#[test]
fn test_get_frequency_bins() {
let bins = get_frequency_bins(&ALL_PITCH_NOTES.iter().skip(24).take(62).cloned().collect::<Vec<_>>());
assert_eq!(bins.len(), 60);
}
#[test]
#[should_panic]
fn test_binary_search_closest_empty() {
binary_search_closest(&[], 0.0, |x| *x).unwrap();
}
}