#![allow(dead_code, clippy::cast_precision_loss)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Genre {
Electronic,
Rock,
Pop,
Classical,
Jazz,
HipHop,
Country,
RnB,
Metal,
Folk,
Latin,
World,
Ambient,
Soundtrack,
Unknown,
}
impl Genre {
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::Electronic => "Electronic",
Self::Rock => "Rock",
Self::Pop => "Pop",
Self::Classical => "Classical",
Self::Jazz => "Jazz",
Self::HipHop => "HipHop",
Self::Country => "Country",
Self::RnB => "RnB",
Self::Metal => "Metal",
Self::Folk => "Folk",
Self::Latin => "Latin",
Self::World => "World",
Self::Ambient => "Ambient",
Self::Soundtrack => "Soundtrack",
Self::Unknown => "Unknown",
}
}
#[must_use]
pub fn all_known() -> &'static [Genre] {
&[
Self::Electronic,
Self::Rock,
Self::Pop,
Self::Classical,
Self::Jazz,
Self::HipHop,
Self::Country,
Self::RnB,
Self::Metal,
Self::Folk,
Self::Latin,
Self::World,
Self::Ambient,
Self::Soundtrack,
]
}
}
impl std::fmt::Display for Genre {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name())
}
}
fn range_score(value: f32, ideal_min: f32, ideal_max: f32, tolerance: f32) -> f32 {
if value >= ideal_min && value <= ideal_max {
return 1.0;
}
let tol = tolerance.max(1e-6);
if value < ideal_min {
((value - (ideal_min - tol)) / tol).clamp(0.0, 1.0)
} else {
((ideal_max + tol - value) / tol).clamp(0.0, 1.0)
}
}
fn above_score(value: f32, threshold: f32, tolerance: f32) -> f32 {
range_score(value, threshold, f32::INFINITY, tolerance)
}
fn below_score(value: f32, threshold: f32, tolerance: f32) -> f32 {
range_score(value, f32::NEG_INFINITY, threshold, tolerance)
}
fn chroma_variance(chroma: &[f32; 12]) -> f32 {
let mean: f32 = chroma.iter().sum::<f32>() / 12.0;
chroma.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / 12.0
}
#[must_use]
pub fn classify_genre(
spectral_centroid: f32,
spectral_rolloff: f32,
zero_crossing_rate: f32,
chroma: &[f32; 12],
tempo_bpm: f32,
) -> (Genre, f32) {
let chroma_var = chroma_variance(chroma);
let scores: [(Genre, f32); 14] = [
(
Genre::Electronic,
above_score(spectral_centroid, 0.25, 0.10) * 0.30
+ above_score(spectral_rolloff, 0.30, 0.10) * 0.20
+ range_score(tempo_bpm, 118.0, 148.0, 20.0) * 0.30
+ below_score(zero_crossing_rate, 0.12, 0.05) * 0.20,
),
(
Genre::Rock,
above_score(spectral_rolloff, 0.35, 0.10) * 0.25
+ above_score(zero_crossing_rate, 0.08, 0.04) * 0.30
+ range_score(tempo_bpm, 108.0, 165.0, 20.0) * 0.30
+ above_score(spectral_centroid, 0.18, 0.08) * 0.15,
),
(
Genre::Metal,
above_score(zero_crossing_rate, 0.14, 0.04) * 0.35
+ above_score(spectral_centroid, 0.25, 0.08) * 0.25
+ above_score(spectral_rolloff, 0.40, 0.10) * 0.20
+ range_score(tempo_bpm, 130.0, 200.0, 25.0) * 0.20,
),
(
Genre::Classical,
below_score(zero_crossing_rate, 0.05, 0.03) * 0.35
+ above_score(spectral_rolloff, 0.20, 0.10) * 0.20
+ below_score(spectral_centroid, 0.25, 0.10) * 0.20
+ above_score(chroma_var, 0.01, 0.005) * 0.25,
),
(
Genre::Jazz,
range_score(spectral_centroid, 0.12, 0.30, 0.08) * 0.25
+ above_score(chroma_var, 0.015, 0.005) * 0.35
+ range_score(tempo_bpm, 80.0, 140.0, 20.0) * 0.25
+ range_score(zero_crossing_rate, 0.04, 0.10, 0.03) * 0.15,
),
(
Genre::HipHop,
below_score(spectral_centroid, 0.20, 0.08) * 0.30
+ range_score(tempo_bpm, 70.0, 110.0, 15.0) * 0.35
+ range_score(zero_crossing_rate, 0.03, 0.09, 0.03) * 0.20
+ below_score(spectral_rolloff, 0.35, 0.10) * 0.15,
),
(
Genre::Pop,
range_score(spectral_centroid, 0.15, 0.30, 0.08) * 0.20
+ range_score(tempo_bpm, 100.0, 132.0, 15.0) * 0.40
+ range_score(spectral_rolloff, 0.25, 0.45, 0.10) * 0.25
+ range_score(zero_crossing_rate, 0.05, 0.12, 0.03) * 0.15,
),
(
Genre::Country,
range_score(spectral_centroid, 0.12, 0.25, 0.06) * 0.25
+ below_score(zero_crossing_rate, 0.08, 0.04) * 0.25
+ range_score(tempo_bpm, 78.0, 132.0, 15.0) * 0.35
+ range_score(spectral_rolloff, 0.20, 0.40, 0.08) * 0.15,
),
(
Genre::RnB,
below_score(spectral_centroid, 0.22, 0.08) * 0.30
+ range_score(tempo_bpm, 60.0, 100.0, 15.0) * 0.35
+ range_score(zero_crossing_rate, 0.03, 0.08, 0.03) * 0.25
+ range_score(spectral_rolloff, 0.15, 0.35, 0.08) * 0.10,
),
(
Genre::Folk,
below_score(spectral_centroid, 0.15, 0.06) * 0.35
+ below_score(zero_crossing_rate, 0.06, 0.03) * 0.30
+ range_score(tempo_bpm, 60.0, 130.0, 20.0) * 0.20
+ below_score(spectral_rolloff, 0.30, 0.10) * 0.15,
),
(
Genre::Latin,
range_score(spectral_centroid, 0.18, 0.32, 0.08) * 0.25
+ range_score(tempo_bpm, 95.0, 145.0, 20.0) * 0.40
+ range_score(spectral_rolloff, 0.25, 0.45, 0.10) * 0.20
+ above_score(zero_crossing_rate, 0.06, 0.03) * 0.15,
),
(
Genre::Ambient,
below_score(zero_crossing_rate, 0.04, 0.02) * 0.35
+ below_score(spectral_centroid, 0.12, 0.05) * 0.30
+ range_score(tempo_bpm, 0.0, 80.0, 20.0) * 0.20
+ below_score(spectral_rolloff, 0.20, 0.08) * 0.15,
),
(
Genre::Soundtrack,
above_score(spectral_rolloff, 0.40, 0.10) * 0.30
+ above_score(spectral_centroid, 0.20, 0.08) * 0.25
+ above_score(chroma_var, 0.012, 0.005) * 0.25
+ range_score(tempo_bpm, 60.0, 160.0, 30.0) * 0.20,
),
(
Genre::World,
range_score(spectral_centroid, 0.10, 0.25, 0.08) * 0.25
+ above_score(chroma_var, 0.018, 0.005) * 0.40
+ range_score(tempo_bpm, 60.0, 160.0, 30.0) * 0.20
+ range_score(zero_crossing_rate, 0.04, 0.12, 0.04) * 0.15,
),
];
let (best_genre, best_score) =
scores
.iter()
.copied()
.fold((Genre::Unknown, 0.0_f32), |acc, (g, s)| {
if s > acc.1 {
(g, s)
} else {
acc
}
});
let total: f32 = scores.iter().map(|(_, s)| s).sum();
let confidence = if total > 0.0 {
(best_score / total * scores.len() as f32).clamp(0.0, 1.0)
} else {
0.0
};
if best_score <= 0.0 {
(Genre::Unknown, 0.0)
} else {
(best_genre, confidence)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_genre_electronic_name() {
assert_eq!(Genre::Electronic.name(), "Electronic");
}
#[test]
fn test_genre_unknown_name() {
assert_eq!(Genre::Unknown.name(), "Unknown");
}
#[test]
fn test_genre_display_trait() {
let g = Genre::Rock;
let s = format!("{g}");
assert_eq!(s, "Rock");
}
#[test]
fn test_all_known_genres_have_non_empty_names() {
for g in Genre::all_known() {
assert!(!g.name().is_empty(), "Genre {g:?} has an empty name");
}
}
#[test]
fn test_classify_genre_confidence_in_range() {
let chroma = [1.0_f32 / 12.0; 12];
let (_, conf) = classify_genre(0.40, 0.50, 0.05, &chroma, 130.0);
assert!(
(0.0..=1.0).contains(&conf),
"Confidence must be in [0, 1], got {conf}"
);
}
#[test]
fn test_classify_genre_high_centroid_high_tempo_not_ambient() {
let chroma = [1.0_f32 / 12.0; 12];
let (genre, _) = classify_genre(0.40, 0.55, 0.08, &chroma, 130.0);
assert_ne!(
genre,
Genre::Ambient,
"Bright, 130 BPM signal should not be Ambient"
);
}
#[test]
fn test_classify_genre_ambient_features() {
let chroma = [1.0_f32 / 12.0; 12];
let (genre, conf) = classify_genre(0.05, 0.08, 0.01, &chroma, 40.0);
assert!(
matches!(genre, Genre::Ambient | Genre::Classical | Genre::Folk),
"Low-energy slow signal should be Ambient/Classical/Folk, got {genre:?} ({conf})"
);
}
#[test]
fn test_classify_genre_high_zcr_high_centroid_tends_metal_or_rock() {
let chroma = [1.0_f32 / 12.0; 12];
let (genre, _) = classify_genre(0.40, 0.60, 0.20, &chroma, 160.0);
assert!(
matches!(genre, Genre::Metal | Genre::Rock | Genre::Electronic),
"High ZCR/centroid fast signal should be Metal/Rock/Electronic, got {genre:?}"
);
}
#[test]
fn test_classify_genre_not_unknown_for_valid_features() {
let chroma = [
0.1_f32, 0.05, 0.1, 0.05, 0.1, 0.05, 0.1, 0.05, 0.1, 0.05, 0.1, 0.05,
];
let (genre, _) = classify_genre(0.20, 0.30, 0.07, &chroma, 120.0);
assert_ne!(
genre,
Genre::Unknown,
"Typical features should yield a known genre"
);
}
#[test]
fn test_genre_all_known_count() {
assert_eq!(Genre::all_known().len(), 14);
}
#[test]
fn test_classify_genre_returns_genre_and_confidence() {
let chroma = [1.0_f32 / 12.0; 12];
let result = classify_genre(0.25, 0.35, 0.06, &chroma, 120.0);
let (_, conf) = result;
assert!(conf.is_finite());
}
}