Skip to main content

math_audio_dsp/audio_features/
mod.rs

1//! Audio feature extraction for music similarity analysis.
2//!
3//! Pure Rust replacement for bliss-audio, producing a compatible 23-element
4//! feature vector:
5//!
6//! | Index | Feature |
7//! |-------|---------|
8//! | 0 | Tempo (BPM) |
9//! | 1 | Zero-crossing rate |
10//! | 2-3 | Spectral centroid (mean, std) |
11//! | 4-5 | Spectral rolloff (mean, std) |
12//! | 6-7 | Spectral flatness (mean, std) |
13//! | 8-9 | Loudness (mean, std) |
14//! | 10-22 | Chroma interval features (13) |
15
16pub mod chroma;
17pub mod loudness;
18pub mod spectral;
19pub mod tempo;
20pub mod utils;
21pub mod zcr;
22
23/// Number of features in the analysis vector (bliss v2 compatible).
24pub const FEATURES_COUNT: usize = 23;
25
26/// Minimum number of samples required for analysis (largest FFT window).
27pub const MIN_SAMPLES: usize = 8192;
28
29/// Error type for audio feature analysis.
30#[derive(Debug, Clone)]
31pub enum AnalysisError {
32    TooShort,
33    ChromaError(String),
34}
35
36impl std::fmt::Display for AnalysisError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            AnalysisError::TooShort => write!(
40                f,
41                "audio too short for analysis (need >= {MIN_SAMPLES} samples)"
42            ),
43            AnalysisError::ChromaError(s) => write!(f, "chroma analysis error: {s}"),
44        }
45    }
46}
47
48impl std::error::Error for AnalysisError {}
49
50/// Analyze audio samples and return a 23-element feature vector.
51///
52/// Input should be mono, 22050 Hz samples (same as bliss-audio requirements).
53/// The returned vector is ordered identically to bliss v2 Analysis for database compatibility.
54pub fn analyze_audio_features(
55    samples: &[f32],
56    sample_rate: u32,
57) -> Result<Vec<f32>, AnalysisError> {
58    if samples.len() < MIN_SAMPLES {
59        return Err(AnalysisError::TooShort);
60    }
61
62    // Run descriptors in parallel using scoped threads
63    std::thread::scope(|s| {
64        let child_tempo = s.spawn(|| tempo::compute_tempo(samples, sample_rate));
65
66        let child_zcr = s.spawn(|| zcr::compute_zcr(samples));
67
68        let child_spectral = s.spawn(|| spectral::compute_spectral_features(samples, sample_rate));
69
70        let child_loudness = s.spawn(|| loudness::compute_loudness(samples));
71
72        let child_chroma = s.spawn(|| chroma::compute_chroma_features(samples, sample_rate));
73
74        let tempo_val = child_tempo.join().unwrap();
75        let zcr_val = child_zcr.join().unwrap();
76        let spectral_vals = child_spectral.join().unwrap();
77        let loudness_vals = child_loudness.join().unwrap();
78        let chroma_vals = child_chroma
79            .join()
80            .unwrap()
81            .map_err(|e| AnalysisError::ChromaError(e.0))?;
82
83        // Assemble in bliss order:
84        // [tempo, zcr, centroid(2), rolloff(2), flatness(2), loudness(2), chroma(13)]
85        let mut result = Vec::with_capacity(FEATURES_COUNT);
86        result.push(tempo_val);
87        result.push(zcr_val);
88        result.extend_from_slice(&spectral_vals); // 6 values: centroid(2), rolloff(2), flatness(2)
89        result.extend_from_slice(&loudness_vals); // 2 values
90        result.extend_from_slice(&chroma_vals); // 13 values
91
92        assert_eq!(result.len(), FEATURES_COUNT);
93        Ok(result)
94    })
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_analyze_too_short() {
103        let samples = vec![0.0; 100];
104        assert!(analyze_audio_features(&samples, 22050).is_err());
105    }
106
107    #[test]
108    fn test_analyze_features_count() {
109        // Generate a simple tone long enough
110        let sr = 22050u32;
111        let signal: Vec<f32> = (0..sr as usize * 5)
112            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin())
113            .collect();
114
115        let features = analyze_audio_features(&signal, sr).unwrap();
116        assert_eq!(features.len(), FEATURES_COUNT);
117
118        // All values should be in reasonable range [-1, 1]
119        for (i, &f) in features.iter().enumerate() {
120            assert!(
121                (-1.5..=1.5).contains(&f),
122                "feature[{i}] = {f} out of expected range"
123            );
124        }
125    }
126
127    #[test]
128    fn test_analyze_silence() {
129        // Silence should not crash
130        let samples = vec![0.0; 22050 * 3];
131        let features = analyze_audio_features(&samples, 22050).unwrap();
132        assert_eq!(features.len(), FEATURES_COUNT);
133    }
134}