#![allow(dead_code)]
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MusicGenre {
Rock,
Pop,
Electronic,
HipHop,
Jazz,
Classical,
Country,
RnB,
Metal,
Blues,
Reggae,
Ambient,
}
impl fmt::Display for MusicGenre {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Rock => write!(f, "Rock"),
Self::Pop => write!(f, "Pop"),
Self::Electronic => write!(f, "Electronic"),
Self::HipHop => write!(f, "Hip-Hop"),
Self::Jazz => write!(f, "Jazz"),
Self::Classical => write!(f, "Classical"),
Self::Country => write!(f, "Country"),
Self::RnB => write!(f, "R&B"),
Self::Metal => write!(f, "Metal"),
Self::Blues => write!(f, "Blues"),
Self::Reggae => write!(f, "Reggae"),
Self::Ambient => write!(f, "Ambient"),
}
}
}
impl MusicGenre {
#[must_use]
pub fn all() -> &'static [MusicGenre] {
&[
Self::Rock,
Self::Pop,
Self::Electronic,
Self::HipHop,
Self::Jazz,
Self::Classical,
Self::Country,
Self::RnB,
Self::Metal,
Self::Blues,
Self::Reggae,
Self::Ambient,
]
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GenreFeatures {
pub spectral_centroid_hz: f64,
pub spectral_rolloff_hz: f64,
pub zero_crossing_rate: f64,
pub rms_energy: f64,
pub tempo_bpm: f64,
pub spectral_flatness: f64,
pub mfcc1: f64,
}
impl GenreFeatures {
#[must_use]
pub fn new(
spectral_centroid_hz: f64,
spectral_rolloff_hz: f64,
zero_crossing_rate: f64,
rms_energy: f64,
tempo_bpm: f64,
spectral_flatness: f64,
mfcc1: f64,
) -> Self {
Self {
spectral_centroid_hz,
spectral_rolloff_hz,
zero_crossing_rate,
rms_energy,
tempo_bpm,
spectral_flatness,
mfcc1,
}
}
}
#[derive(Debug, Clone)]
pub struct GenreClassification {
pub scores: HashMap<MusicGenre, f64>,
}
impl GenreClassification {
#[must_use]
pub fn top(&self) -> (MusicGenre, f64) {
self.scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or((MusicGenre::Pop, 0.0), |(&g, &s)| (g, s))
}
#[must_use]
pub fn top_n(&self, n: usize) -> Vec<(MusicGenre, f64)> {
let mut sorted: Vec<(MusicGenre, f64)> =
self.scores.iter().map(|(&g, &s)| (g, s)).collect();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
sorted.truncate(n);
sorted
}
#[must_use]
pub fn is_confident(&self, threshold: f64) -> bool {
self.top().1 >= threshold
}
}
#[derive(Debug)]
pub struct GenreClassifier {
pub min_confidence: f64,
}
impl Default for GenreClassifier {
fn default() -> Self {
Self {
min_confidence: 0.1,
}
}
}
impl GenreClassifier {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn classify(&self, features: &GenreFeatures) -> GenreClassification {
let prototypes = Self::genre_prototypes();
let mut raw_scores: HashMap<MusicGenre, f64> = HashMap::new();
for (genre, proto) in &prototypes {
let dist = Self::feature_distance(features, proto);
let sim = (-dist * 0.5).exp();
raw_scores.insert(*genre, sim);
}
let total: f64 = raw_scores.values().sum();
let scores: HashMap<MusicGenre, f64> = if total > 0.0 {
raw_scores
.into_iter()
.map(|(g, s)| (g, s / total))
.collect()
} else {
raw_scores
};
GenreClassification { scores }
}
fn feature_distance(a: &GenreFeatures, b: &GenreFeatures) -> f64 {
let dc = (a.spectral_centroid_hz - b.spectral_centroid_hz) / 2000.0;
let dr = (a.spectral_rolloff_hz - b.spectral_rolloff_hz) / 4000.0;
let dz = (a.zero_crossing_rate - b.zero_crossing_rate) / 100.0;
let de = (a.rms_energy - b.rms_energy) / 0.3;
let dt = (a.tempo_bpm - b.tempo_bpm) / 40.0;
let df = (a.spectral_flatness - b.spectral_flatness) / 0.3;
let dm = (a.mfcc1 - b.mfcc1) / 50.0;
(dc * dc + dr * dr + dz * dz + de * de + dt * dt + df * df + dm * dm).sqrt()
}
fn genre_prototypes() -> Vec<(MusicGenre, GenreFeatures)> {
vec![
(
MusicGenre::Rock,
GenreFeatures::new(3000.0, 6000.0, 80.0, 0.25, 120.0, 0.15, 20.0),
),
(
MusicGenre::Pop,
GenreFeatures::new(2500.0, 5000.0, 60.0, 0.20, 115.0, 0.12, 15.0),
),
(
MusicGenre::Electronic,
GenreFeatures::new(3500.0, 7000.0, 40.0, 0.30, 128.0, 0.25, 10.0),
),
(
MusicGenre::HipHop,
GenreFeatures::new(2000.0, 4500.0, 50.0, 0.22, 90.0, 0.18, 12.0),
),
(
MusicGenre::Jazz,
GenreFeatures::new(2200.0, 5500.0, 45.0, 0.12, 100.0, 0.10, 25.0),
),
(
MusicGenre::Classical,
GenreFeatures::new(1800.0, 4000.0, 30.0, 0.10, 80.0, 0.05, 30.0),
),
(
MusicGenre::Country,
GenreFeatures::new(2800.0, 5500.0, 55.0, 0.18, 110.0, 0.10, 18.0),
),
(
MusicGenre::RnB,
GenreFeatures::new(2300.0, 5000.0, 50.0, 0.18, 95.0, 0.14, 14.0),
),
(
MusicGenre::Metal,
GenreFeatures::new(4000.0, 8000.0, 120.0, 0.35, 140.0, 0.20, 5.0),
),
(
MusicGenre::Blues,
GenreFeatures::new(2100.0, 4500.0, 40.0, 0.14, 85.0, 0.08, 22.0),
),
(
MusicGenre::Reggae,
GenreFeatures::new(2000.0, 4200.0, 35.0, 0.15, 75.0, 0.12, 16.0),
),
(
MusicGenre::Ambient,
GenreFeatures::new(1500.0, 3500.0, 15.0, 0.05, 60.0, 0.30, 35.0),
),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_genre_display() {
assert_eq!(format!("{}", MusicGenre::Rock), "Rock");
assert_eq!(format!("{}", MusicGenre::HipHop), "Hip-Hop");
assert_eq!(format!("{}", MusicGenre::RnB), "R&B");
}
#[test]
fn test_genre_all() {
let all = MusicGenre::all();
assert_eq!(all.len(), 12);
}
#[test]
fn test_features_creation() {
let f = GenreFeatures::new(3000.0, 6000.0, 80.0, 0.25, 120.0, 0.15, 20.0);
assert!((f.spectral_centroid_hz - 3000.0).abs() < f64::EPSILON);
assert!((f.tempo_bpm - 120.0).abs() < f64::EPSILON);
}
#[test]
fn test_classifier_creation() {
let c = GenreClassifier::new();
assert!(c.min_confidence > 0.0);
}
#[test]
fn test_classify_rock_like() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(3000.0, 6000.0, 80.0, 0.25, 120.0, 0.15, 20.0);
let result = c.classify(&features);
let (top_genre, _) = result.top();
assert_eq!(top_genre, MusicGenre::Rock);
}
#[test]
fn test_classify_classical_like() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(1800.0, 4000.0, 30.0, 0.10, 80.0, 0.05, 30.0);
let result = c.classify(&features);
let (top_genre, _) = result.top();
assert_eq!(top_genre, MusicGenre::Classical);
}
#[test]
fn test_classify_metal_like() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(4000.0, 8000.0, 120.0, 0.35, 140.0, 0.20, 5.0);
let result = c.classify(&features);
let (top_genre, _) = result.top();
assert_eq!(top_genre, MusicGenre::Metal);
}
#[test]
fn test_classification_scores_sum_to_one() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(2500.0, 5000.0, 60.0, 0.20, 115.0, 0.12, 15.0);
let result = c.classify(&features);
let total: f64 = result.scores.values().sum();
assert!((total - 1.0).abs() < 0.01);
}
#[test]
fn test_top_n() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(3000.0, 6000.0, 80.0, 0.25, 120.0, 0.15, 20.0);
let result = c.classify(&features);
let top3 = result.top_n(3);
assert_eq!(top3.len(), 3);
assert!(top3[0].1 >= top3[1].1);
assert!(top3[1].1 >= top3[2].1);
}
#[test]
fn test_is_confident() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(4000.0, 8000.0, 120.0, 0.35, 140.0, 0.20, 5.0);
let result = c.classify(&features);
assert!(result.is_confident(0.05));
}
#[test]
fn test_ambient_classification() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(1500.0, 3500.0, 15.0, 0.05, 60.0, 0.30, 35.0);
let result = c.classify(&features);
let (top_genre, _) = result.top();
assert_eq!(top_genre, MusicGenre::Ambient);
}
#[test]
fn test_electronic_classification() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(3500.0, 7000.0, 40.0, 0.30, 128.0, 0.25, 10.0);
let result = c.classify(&features);
let (top_genre, _) = result.top();
assert_eq!(top_genre, MusicGenre::Electronic);
}
#[test]
fn test_hiphop_classification() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(2000.0, 4500.0, 50.0, 0.22, 90.0, 0.18, 12.0);
let result = c.classify(&features);
let (top_genre, _) = result.top();
assert_eq!(top_genre, MusicGenre::HipHop);
}
#[test]
fn test_classification_has_all_genres() {
let c = GenreClassifier::new();
let features = GenreFeatures::new(2500.0, 5000.0, 60.0, 0.20, 115.0, 0.12, 15.0);
let result = c.classify(&features);
assert_eq!(result.scores.len(), 12);
}
}