use crate::spectral::SpectralFeatures;
use crate::transient::TransientResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Instrument {
Piano,
Guitar,
Violin,
Flute,
Trumpet,
Drums,
Bass,
Synthesizer,
Vocals,
Unknown,
}
#[allow(clippy::too_many_lines)]
#[must_use]
pub fn detect_instrument(
spectral: &SpectralFeatures,
transients: &TransientResult,
f0: Option<f32>,
) -> Instrument {
let is_harmonic = spectral.flatness < 0.3;
let is_noisy = spectral.flatness > 0.7;
let has_strong_transients = transients.avg_strength > 0.5;
let low_centroid = spectral.centroid < 500.0;
let high_centroid = spectral.centroid > 2000.0;
if is_noisy && has_strong_transients {
return Instrument::Drums;
}
if low_centroid && is_harmonic {
return Instrument::Bass;
}
if let Some(fundamental) = f0 {
if (80.0..=1000.0).contains(&fundamental) {
let has_formants = check_formant_structure(&spectral.magnitude_spectrum);
if has_formants {
return Instrument::Vocals;
}
}
if fundamental >= 250.0 && is_harmonic && spectral.flatness < 0.15 {
return Instrument::Flute;
}
if has_strong_transients && is_harmonic && spectral.bandwidth > 1000.0 {
return Instrument::Piano;
}
if is_harmonic && !has_strong_transients && fundamental >= 80.0 {
return Instrument::Guitar;
}
if is_harmonic && high_centroid && fundamental >= 200.0 {
return Instrument::Violin;
}
if is_harmonic && spectral.centroid > 800.0 && spectral.centroid < 2000.0 {
return Instrument::Trumpet;
}
}
if !is_noisy && !check_formant_structure(&spectral.magnitude_spectrum) {
return Instrument::Synthesizer;
}
Instrument::Unknown
}
fn check_formant_structure(spectrum: &[f32]) -> bool {
if spectrum.len() < 20 {
return false;
}
let mut peaks = 0;
for i in 2..(spectrum.len() - 2) {
if spectrum[i] > spectrum[i - 1] && spectrum[i] > spectrum[i + 1] && spectrum[i] > 0.1 {
peaks += 1;
}
}
(2..=4).contains(&peaks)
}
#[must_use]
pub fn detect_instrument_scores(
spectral: &SpectralFeatures,
transients: &TransientResult,
f0: Option<f32>,
) -> Vec<(Instrument, f32)> {
let mut scores = vec![
(Instrument::Piano, 0.0),
(Instrument::Guitar, 0.0),
(Instrument::Violin, 0.0),
(Instrument::Flute, 0.0),
(Instrument::Trumpet, 0.0),
(Instrument::Drums, 0.0),
(Instrument::Bass, 0.0),
(Instrument::Vocals, 0.0),
(Instrument::Synthesizer, 0.0),
];
if spectral.flatness > 0.5 && transients.avg_strength > 0.4 {
scores[5].1 = 0.8;
}
if spectral.centroid < 500.0 {
scores[6].1 = 0.7;
}
if let Some(fundamental) = f0 {
if (80.0..=1000.0).contains(&fundamental)
&& check_formant_structure(&spectral.magnitude_spectrum)
{
scores[7].1 = 0.8;
}
if fundamental >= 250.0 && spectral.flatness < 0.15 {
scores[3].1 = 0.7;
}
if transients.avg_strength > 0.5 && spectral.bandwidth > 1000.0 {
scores[0].1 = 0.7;
}
if spectral.flatness < 0.3 && fundamental >= 80.0 {
scores[1].1 = 0.6;
}
if spectral.centroid > 2000.0 && spectral.flatness < 0.3 {
scores[2].1 = 0.6;
}
if spectral.centroid > 800.0 && spectral.centroid < 2000.0 {
scores[4].1 = 0.6;
}
}
scores.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
scores
}
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InstrumentBand {
Kick,
Bass,
MidRange,
Treble,
HiHat,
}
impl InstrumentBand {
#[must_use]
pub fn range_hz(self) -> (f64, f64) {
match self {
Self::Kick => (0.0, 200.0),
Self::Bass => (20.0, 300.0),
Self::MidRange => (300.0, 3_000.0),
Self::Treble => (3_000.0, 20_000.0),
Self::HiHat => (8_000.0, 20_000.0),
}
}
#[must_use]
pub fn all() -> [Self; 5] {
[
Self::Kick,
Self::Bass,
Self::MidRange,
Self::Treble,
Self::HiHat,
]
}
fn threshold(self) -> f64 {
match self {
Self::Kick => 0.02,
Self::Bass => 0.015,
Self::MidRange => 0.025,
Self::Treble => 0.02,
Self::HiHat => 0.018,
}
}
}
pub struct InstrumentOnsetDetector {
pub window_size: usize,
pub hop_size: usize,
}
impl Default for InstrumentOnsetDetector {
fn default() -> Self {
Self {
window_size: 1024,
hop_size: 256,
}
}
}
impl InstrumentOnsetDetector {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn detect_onsets_per_instrument(
&self,
samples: &[f32],
sample_rate: u32,
) -> HashMap<InstrumentBand, Vec<f64>> {
let mut result: HashMap<InstrumentBand, Vec<f64>> = HashMap::new();
for band in InstrumentBand::all() {
let onsets = self.detect_band_onsets(samples, sample_rate, band);
result.insert(band, onsets);
}
result
}
fn detect_band_onsets(
&self,
samples: &[f32],
sample_rate: u32,
band: InstrumentBand,
) -> Vec<f64> {
let (low_hz, high_hz) = band.range_hz();
let sr = sample_rate as f64;
let n = self.window_size;
let low_bin = ((low_hz * n as f64 / sr).round() as usize).min(n / 2);
let high_bin = ((high_hz * n as f64 / sr).round() as usize).min(n / 2);
if high_bin <= low_bin {
return Vec::new();
}
let frames = compute_magnitude_frames(samples, n, self.hop_size);
if frames.len() < 2 {
return Vec::new();
}
let mut flux_values: Vec<f64> = Vec::with_capacity(frames.len() - 1);
for i in 1..frames.len() {
let prev = &frames[i - 1];
let curr = &frames[i];
let mut flux = 0.0_f64;
let end = high_bin.min(curr.len()).min(prev.len());
let start = low_bin.min(end);
for k in start..end {
let diff = f64::from(curr[k]) - f64::from(prev[k]);
if diff > 0.0 {
flux += diff;
}
}
flux_values.push(flux);
}
let window_frames = 8_usize;
let base_threshold = band.threshold();
let mut onsets = Vec::new();
for i in 0..flux_values.len() {
let start = i.saturating_sub(window_frames);
let slice = &flux_values[start..i + 1];
let mean = slice.iter().sum::<f64>() / slice.len() as f64;
let adaptive_threshold = (mean + base_threshold).max(base_threshold);
if flux_values[i] > adaptive_threshold {
let prev_flux = if i > 0 { flux_values[i - 1] } else { 0.0 };
let next_flux = flux_values.get(i + 1).copied().unwrap_or(0.0);
if flux_values[i] >= prev_flux && flux_values[i] >= next_flux {
let time_s = (i + 1) as f64 * self.hop_size as f64 / sr;
onsets.push(time_s);
}
}
}
onsets
}
}
fn compute_magnitude_frames(samples: &[f32], window_size: usize, hop_size: usize) -> Vec<Vec<f32>> {
use std::f64::consts::PI;
let num_bins = window_size / 2 + 1;
let mut frames = Vec::new();
let mut pos = 0_usize;
while pos + window_size <= samples.len() {
let frame = &samples[pos..pos + window_size];
let mut magnitudes = vec![0.0_f32; num_bins];
for k in 0..num_bins {
let mut re = 0.0_f64;
let mut im = 0.0_f64;
for (j, &s) in frame.iter().enumerate() {
let w = 0.5 * (1.0 - (2.0 * PI * j as f64 / (window_size - 1) as f64).cos());
let angle = -2.0 * PI * k as f64 * j as f64 / window_size as f64;
let sv = f64::from(s) * w;
re += sv * angle.cos();
im += sv * angle.sin();
}
magnitudes[k] = (re * re + im * im).sqrt() as f32;
}
frames.push(magnitudes);
pos += hop_size;
}
frames
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_instrument_detection() {
let spectral = SpectralFeatures {
centroid: 1000.0,
flatness: 0.8,
crest: 5.0,
bandwidth: 3000.0,
rolloff: 5000.0,
flux: 0.0,
magnitude_spectrum: vec![0.5; 100],
};
let transients = TransientResult {
transient_times: vec![0.1, 0.2, 0.3],
onset_strength: vec![0.8, 0.7, 0.9],
num_transients: 3,
avg_strength: 0.8,
};
let instrument = detect_instrument(&spectral, &transients, None);
assert_eq!(instrument, Instrument::Drums);
}
#[test]
fn test_instrument_scores() {
let spectral = SpectralFeatures {
centroid: 300.0,
flatness: 0.2,
crest: 3.0,
bandwidth: 500.0,
rolloff: 800.0,
flux: 0.0,
magnitude_spectrum: vec![0.5; 100],
};
let transients = TransientResult::default();
let scores = detect_instrument_scores(&spectral, &transients, Some(100.0));
let bass_score = scores
.iter()
.find(|(i, _)| *i == Instrument::Bass)
.expect("unexpected None/Err")
.1;
assert!(bass_score > 0.5);
}
#[test]
fn test_kick_onset_in_low_band() {
let sample_rate: u32 = 8000;
let sr_f = sample_rate as f64;
let burst_len = (0.05 * sr_f) as usize;
let silence_len = (0.1 * sr_f) as usize;
let freq = 80.0_f64;
let mut samples: Vec<f32> = Vec::new();
for _ in 0..3 {
for i in 0..burst_len {
let t = i as f64 / sr_f;
samples.push((2.0 * std::f64::consts::PI * freq * t).sin() as f32 * 0.8);
}
for _ in 0..silence_len {
samples.push(0.0);
}
}
let detector = InstrumentOnsetDetector {
window_size: 256,
hop_size: 64,
};
let onsets = detector.detect_onsets_per_instrument(&samples, sample_rate);
let kick_onsets = onsets
.get(&InstrumentBand::Kick)
.expect("Kick band missing");
assert!(
!kick_onsets.is_empty(),
"Expected at least one kick onset, got none"
);
let duration = samples.len() as f64 / sr_f;
for &t in kick_onsets {
assert!(
t <= duration + 0.1,
"Onset time {t} s is out of signal range"
);
}
}
#[test]
fn test_instrument_onset_bands_independent() {
let sample_rate: u32 = 22050;
let sr_f = sample_rate as f64;
let burst_len = (0.02 * sr_f) as usize;
let silence_len = (0.08 * sr_f) as usize;
let freq = 10_000.0_f64;
let mut samples: Vec<f32> = Vec::new();
for _ in 0..3 {
for i in 0..burst_len {
let t = i as f64 / sr_f;
samples.push((2.0 * std::f64::consts::PI * freq * t).sin() as f32 * 0.8);
}
for _ in 0..silence_len {
samples.push(0.0);
}
}
let detector = InstrumentOnsetDetector {
window_size: 256,
hop_size: 64,
};
let onsets = detector.detect_onsets_per_instrument(&samples, sample_rate);
let hihat_onsets = onsets.get(&InstrumentBand::HiHat).expect("HiHat missing");
assert!(
!hihat_onsets.is_empty(),
"Expected hi-hat onsets in HiHat band, got none"
);
let kick_onsets = onsets.get(&InstrumentBand::Kick).expect("Kick missing");
assert!(
kick_onsets.len() <= hihat_onsets.len(),
"Kick band ({}) should have ≤ onsets than HiHat band ({})",
kick_onsets.len(),
hihat_onsets.len()
);
}
}