#![allow(dead_code)]
use crate::{MirError, MirResult};
use oxifft::Complex;
use std::f32::consts::PI;
const MIDI_MIN: u8 = 36;
const MIDI_MAX: u8 = 96;
const N_HARMONICS: usize = 8;
const VOICED_THRESHOLD: f32 = 0.8;
const VIBRATO_MIN_RATE_HZ: f32 = 4.0;
const VIBRATO_MAX_RATE_HZ: f32 = 8.0;
const VIBRATO_MIN_EXTENT_ST: f32 = 0.25;
const VIBRATO_MIN_FRAMES: usize = 16;
#[derive(Debug, Clone)]
pub struct MelodyExtractorConfig {
pub window_size: usize,
pub hop_size: usize,
pub harmonic_decay_exp: f32,
pub smooth_radius: usize,
}
impl Default for MelodyExtractorConfig {
fn default() -> Self {
Self {
window_size: 2048,
hop_size: 512,
harmonic_decay_exp: 1.0,
smooth_radius: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct PitchFrame {
pub pitch_hz: Option<f32>,
pub midi_pitch: Option<u8>,
pub salience: f32,
pub time_s: f32,
}
impl PitchFrame {
#[must_use]
pub fn is_voiced(&self) -> bool {
self.pitch_hz.is_some()
}
}
#[derive(Debug, Clone)]
pub struct VibratoSegment {
pub start_s: f32,
pub end_s: f32,
pub rate_hz: f32,
pub extent_semitones: f32,
}
#[derive(Debug, Clone)]
pub struct ExtractionResult {
pub pitch_hz: Vec<Option<f32>>,
pub midi_pitch: Vec<Option<u8>>,
pub salience: Vec<f32>,
pub timestamps_s: Vec<f32>,
pub vibrato: Vec<VibratoSegment>,
pub voiced_fraction: f32,
pub mean_pitch_hz: f32,
}
impl ExtractionResult {
#[must_use]
pub fn n_frames(&self) -> usize {
self.pitch_hz.len()
}
#[must_use]
pub fn n_voiced(&self) -> usize {
self.pitch_hz.iter().filter(|p| p.is_some()).count()
}
#[must_use]
pub fn voiced_pitches(&self) -> Vec<f32> {
self.pitch_hz.iter().filter_map(|p| *p).collect()
}
}
pub struct MelodyExtractor {
sample_rate: u32,
config: MelodyExtractorConfig,
}
impl MelodyExtractor {
#[must_use]
pub fn new(sample_rate: u32, config: MelodyExtractorConfig) -> Self {
Self {
sample_rate,
config,
}
}
pub fn extract(&self, samples: &[f32]) -> MirResult<ExtractionResult> {
if samples.len() < self.config.window_size {
return Err(MirError::InsufficientData(format!(
"need ≥{} samples, got {}",
self.config.window_size,
samples.len()
)));
}
let window = hann_window(self.config.window_size);
let hop = self.config.hop_size.max(1);
let sr = self.sample_rate as f32;
let n_frames = (samples.len().saturating_sub(self.config.window_size)) / hop + 1;
let mut raw_midi: Vec<Option<u8>> = Vec::with_capacity(n_frames);
let mut saliences: Vec<f32> = Vec::with_capacity(n_frames);
let mut timestamps: Vec<f32> = Vec::with_capacity(n_frames);
for frame_idx in 0..n_frames {
let start = frame_idx * hop;
let end = start + self.config.window_size;
if end > samples.len() {
break;
}
let fft_in: Vec<Complex<f32>> = samples[start..end]
.iter()
.zip(window.iter())
.map(|(&s, &w)| Complex::new(s * w, 0.0))
.collect();
let spectrum = oxifft::fft(&fft_in);
let n_bins = spectrum.len() / 2;
let mags: Vec<f32> = spectrum[..n_bins].iter().map(|c| c.norm()).collect();
let (best_midi, best_sal) = self.compute_salience(&mags, n_bins, sr);
raw_midi.push(best_midi);
saliences.push(best_sal);
timestamps.push(frame_idx as f32 * hop as f32 / sr);
}
let mean_sal: f32 = if saliences.is_empty() {
0.0
} else {
saliences.iter().sum::<f32>() / saliences.len() as f32
};
let threshold = mean_sal * VOICED_THRESHOLD;
let voiced_midi: Vec<Option<u8>> = raw_midi
.iter()
.zip(saliences.iter())
.map(|(&m, &s)| if s >= threshold { m } else { None })
.collect();
let smoothed_midi = self.smooth_pitch(&voiced_midi, self.config.smooth_radius);
let pitch_hz: Vec<Option<f32>> = smoothed_midi.iter().map(|m| m.map(midi_to_hz)).collect();
let midi_out: Vec<Option<u8>> = smoothed_midi;
let voiced_pitches: Vec<f32> = pitch_hz.iter().filter_map(|p| *p).collect();
let voiced_fraction = voiced_pitches.len() as f32 / pitch_hz.len().max(1) as f32;
let mean_pitch_hz = if voiced_pitches.is_empty() {
0.0
} else {
voiced_pitches.iter().sum::<f32>() / voiced_pitches.len() as f32
};
let vibrato = self.detect_vibrato(&pitch_hz, ×tamps, sr / hop as f32);
Ok(ExtractionResult {
pitch_hz,
midi_pitch: midi_out,
salience: saliences,
timestamps_s: timestamps,
vibrato,
voiced_fraction,
mean_pitch_hz,
})
}
fn compute_salience(&self, mags: &[f32], n_bins: usize, sr: f32) -> (Option<u8>, f32) {
let hz_per_bin = sr / (2.0 * n_bins as f32);
if hz_per_bin < f32::EPSILON {
return (None, 0.0);
}
let mut best_midi: Option<u8> = None;
let mut best_sal = 0.0_f32;
for midi in MIDI_MIN..=MIDI_MAX {
let f0 = midi_to_hz(midi);
let mut sal = 0.0_f32;
for k in 1..=(N_HARMONICS as u32) {
let freq = f0 * k as f32;
let bin = (freq / hz_per_bin).round() as usize;
if bin >= n_bins {
break;
}
let weight = 1.0 / (k as f32).powf(self.config.harmonic_decay_exp);
sal += mags[bin] * weight;
}
if sal > best_sal {
best_sal = sal;
best_midi = Some(midi);
}
}
(best_midi, best_sal)
}
fn smooth_pitch(&self, midi: &[Option<u8>], radius: usize) -> Vec<Option<u8>> {
let n = midi.len();
let mut out = vec![None; n];
for i in 0..n {
let lo = i.saturating_sub(radius);
let hi = (i + radius + 1).min(n);
let window: Vec<u8> = midi[lo..hi].iter().filter_map(|m| *m).collect();
if window.is_empty() {
out[i] = None;
} else {
let mut sorted = window.clone();
sorted.sort_unstable();
out[i] = Some(sorted[sorted.len() / 2]);
}
}
out
}
fn detect_vibrato(
&self,
pitch_hz: &[Option<f32>],
timestamps: &[f32],
frame_rate: f32,
) -> Vec<VibratoSegment> {
if frame_rate < f32::EPSILON || pitch_hz.len() < VIBRATO_MIN_FRAMES {
return Vec::new();
}
let mut vibrato_segments = Vec::new();
let mut voiced_start: Option<usize> = None;
let n = pitch_hz.len();
let flush_segment = |start: usize, end: usize| -> Option<VibratoSegment> {
let voiced: Vec<f32> = pitch_hz[start..end].iter().filter_map(|p| *p).collect();
if voiced.len() < VIBRATO_MIN_FRAMES {
return None;
}
let mean_hz = voiced.iter().sum::<f32>() / voiced.len() as f32;
if mean_hz < f32::EPSILON {
return None;
}
let semitones: Vec<f32> = voiced
.iter()
.map(|&hz| 12.0 * (hz / mean_hz).log2())
.collect();
let max_st = semitones.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let min_st = semitones.iter().cloned().fold(f32::INFINITY, f32::min);
let extent = max_st - min_st;
if extent < VIBRATO_MIN_EXTENT_ST {
return None;
}
let crossings = semitones.windows(2).filter(|w| w[0] * w[1] < 0.0).count();
let duration_s = voiced.len() as f32 / frame_rate;
let rate_hz = crossings as f32 / (2.0 * duration_s);
if rate_hz < VIBRATO_MIN_RATE_HZ || rate_hz > VIBRATO_MAX_RATE_HZ {
return None;
}
let start_s = timestamps.get(start).copied().unwrap_or(0.0);
let end_s = timestamps
.get(end.saturating_sub(1))
.copied()
.unwrap_or(duration_s);
Some(VibratoSegment {
start_s,
end_s,
rate_hz,
extent_semitones: extent,
})
};
for i in 0..=n {
let voiced = i < n && pitch_hz[i].is_some();
match (voiced_start, voiced) {
(None, true) => voiced_start = Some(i),
(Some(start), false) => {
if let Some(seg) = flush_segment(start, i) {
vibrato_segments.push(seg);
}
voiced_start = None;
}
_ => {}
}
}
vibrato_segments
}
}
#[must_use]
pub fn midi_to_hz(midi: u8) -> f32 {
440.0 * 2.0_f32.powf((midi as f32 - 69.0) / 12.0)
}
#[must_use]
pub fn hz_to_midi(hz: f32) -> Option<u8> {
if hz <= 0.0 {
return None;
}
let midi = 69.0 + 12.0 * (hz / 440.0).log2();
if midi < 0.0 || midi > 127.0 {
None
} else {
Some(midi.round() as u8)
}
}
fn hann_window(size: usize) -> Vec<f32> {
if size == 0 {
return Vec::new();
}
let denom = (size.saturating_sub(1)) as f32;
(0..size)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / denom).cos()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::f32::consts::TAU;
fn sine(freq_hz: f32, sr: u32, secs: f32) -> Vec<f32> {
let n = (sr as f32 * secs) as usize;
(0..n)
.map(|i| (TAU * freq_hz * i as f32 / sr as f32).sin())
.collect()
}
#[test]
fn test_midi_69_is_a440() {
let hz = midi_to_hz(69);
assert!((hz - 440.0).abs() < 1e-3, "A4 = {hz}");
}
#[test]
fn test_midi_60_is_middle_c() {
let hz = midi_to_hz(60);
assert!((hz - 261.63).abs() < 1.0, "C4 = {hz}");
}
#[test]
fn test_hz_to_midi_a440() {
let midi = hz_to_midi(440.0);
assert_eq!(midi, Some(69));
}
#[test]
fn test_hz_to_midi_zero_is_none() {
assert!(hz_to_midi(0.0).is_none());
}
#[test]
fn test_midi_roundtrip() {
for midi in 36_u8..=96 {
let hz = midi_to_hz(midi);
let back = hz_to_midi(hz).expect("should round-trip");
assert_eq!(back, midi, "round-trip failed for MIDI {midi}");
}
}
#[test]
fn test_n_voiced_counts_some() {
let result = ExtractionResult {
pitch_hz: vec![Some(440.0), None, Some(220.0)],
midi_pitch: vec![Some(69), None, Some(57)],
salience: vec![1.0, 0.0, 0.8],
timestamps_s: vec![0.0, 0.1, 0.2],
vibrato: vec![],
voiced_fraction: 2.0 / 3.0,
mean_pitch_hz: 330.0,
};
assert_eq!(result.n_voiced(), 2);
assert_eq!(result.n_frames(), 3);
}
#[test]
fn test_voiced_pitches_filters_nones() {
let result = ExtractionResult {
pitch_hz: vec![None, Some(440.0), None, Some(220.0)],
midi_pitch: vec![None, Some(69), None, Some(57)],
salience: vec![0.0, 1.0, 0.0, 0.9],
timestamps_s: vec![0.0, 0.1, 0.2, 0.3],
vibrato: vec![],
voiced_fraction: 0.5,
mean_pitch_hz: 330.0,
};
let vp = result.voiced_pitches();
assert_eq!(vp, vec![440.0, 220.0]);
}
#[test]
fn test_extract_sine_returns_voiced_frames() {
let sr = 22050_u32;
let config = MelodyExtractorConfig {
window_size: 1024,
hop_size: 256,
..Default::default()
};
let extractor = MelodyExtractor::new(sr, config);
let samples = sine(440.0, sr, 1.0);
let result = extractor.extract(&samples).expect("extract failed");
assert!(
result.n_voiced() > 0,
"expected voiced frames for 440 Hz sine"
);
}
#[test]
fn test_extract_too_short_returns_error() {
let sr = 22050_u32;
let config = MelodyExtractorConfig::default();
let extractor = MelodyExtractor::new(sr, config);
let samples = vec![0.0_f32; 100];
let err = extractor.extract(&samples);
assert!(err.is_err());
}
#[test]
fn test_pitch_frames_have_timestamps() {
let sr = 22050_u32;
let config = MelodyExtractorConfig {
window_size: 512,
hop_size: 128,
..Default::default()
};
let extractor = MelodyExtractor::new(sr, config);
let samples = sine(440.0, sr, 0.5);
let result = extractor.extract(&samples).expect("extract failed");
assert_eq!(result.timestamps_s.len(), result.pitch_hz.len());
for w in result.timestamps_s.windows(2) {
assert!(
w[1] >= w[0],
"timestamps not monotone: {} >= {}",
w[0],
w[1]
);
}
}
#[test]
fn test_voiced_fraction_in_unit_range() {
let sr = 22050_u32;
let config = MelodyExtractorConfig {
window_size: 1024,
hop_size: 512,
..Default::default()
};
let extractor = MelodyExtractor::new(sr, config);
let samples = sine(440.0, sr, 1.0);
let result = extractor.extract(&samples).expect("extract failed");
assert!(
result.voiced_fraction >= 0.0 && result.voiced_fraction <= 1.0,
"voiced_fraction = {}",
result.voiced_fraction
);
}
#[test]
fn test_a440_pitch_detected_near_440hz() {
let sr = 44100_u32;
let config = MelodyExtractorConfig {
window_size: 4096,
hop_size: 1024,
..Default::default()
};
let extractor = MelodyExtractor::new(sr, config);
let samples = sine(440.0, sr, 2.0);
let result = extractor.extract(&samples).expect("extract failed");
let vp = result.voiced_pitches();
assert!(!vp.is_empty(), "no voiced pitches detected");
let mean: f32 = vp.iter().sum::<f32>() / vp.len() as f32;
assert!(
mean > 380.0 && mean < 520.0,
"mean pitch far from 440 Hz: {mean}"
);
}
#[test]
fn test_hann_window_length() {
let w = hann_window(1024);
assert_eq!(w.len(), 1024);
}
#[test]
fn test_hann_window_edges_near_zero() {
let w = hann_window(512);
assert!(w[0].abs() < 1e-4);
assert!(w[511].abs() < 1e-2);
}
}