use crate::{compute_rms, zero_crossing_rate};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AudioScene {
Indoor,
Outdoor,
Quiet,
Noisy,
Speech,
Music,
Mixed,
}
impl AudioScene {
#[must_use]
pub fn label(self) -> &'static str {
match self {
Self::Indoor => "Indoor",
Self::Outdoor => "Outdoor",
Self::Quiet => "Quiet",
Self::Noisy => "Noisy",
Self::Speech => "Speech",
Self::Music => "Music",
Self::Mixed => "Mixed",
}
}
}
#[derive(Debug, Clone)]
pub struct SceneFeatures {
pub spectral_centroid_norm: f32,
pub zcr: f32,
pub energy: f32,
pub flatness: f32,
pub rolloff_norm: f32,
}
#[derive(Debug, Clone)]
pub struct SceneClassifierConfig {
pub fft_size: usize,
pub hop_size: usize,
pub silence_threshold: f32,
pub quiet_frame_fraction: f32,
}
impl Default for SceneClassifierConfig {
fn default() -> Self {
Self {
fft_size: 2048,
hop_size: 512,
silence_threshold: 0.005,
quiet_frame_fraction: 0.85,
}
}
}
#[derive(Debug, Clone)]
pub struct ClassificationResult {
pub scene: AudioScene,
pub confidence: f32,
pub features: SceneFeatures,
}
pub struct AudioSceneClassifier {
config: SceneClassifierConfig,
}
impl AudioSceneClassifier {
#[must_use]
pub fn new(config: SceneClassifierConfig) -> Self {
Self { config }
}
pub fn classify(
&self,
samples: &[f32],
sample_rate: u32,
) -> crate::Result<ClassificationResult> {
if samples.len() < self.config.fft_size {
return Err(crate::AnalysisError::InsufficientSamples {
needed: self.config.fft_size,
got: samples.len(),
});
}
let features = self.extract_features(samples, sample_rate);
let (scene, confidence) = classify_features(&features, &self.config, samples);
Ok(ClassificationResult {
scene,
confidence,
features,
})
}
fn extract_features(&self, samples: &[f32], sample_rate: u32) -> SceneFeatures {
let n_fft = self.config.fft_size;
let hop = self.config.hop_size;
let n_bins = n_fft / 2 + 1;
let nyquist = sample_rate as f32 / 2.0;
let window: Vec<f32> = (0..n_fft)
.map(|i| {
let x = std::f64::consts::PI * i as f64 / (n_fft - 1) as f64;
(0.5 * (1.0 - x.cos())) as f32
})
.collect();
let num_frames = (samples.len().saturating_sub(n_fft)) / hop + 1;
let mut sum_centroid = 0.0_f64;
let mut sum_zcr = 0.0_f64;
let mut sum_energy = 0.0_f64;
let mut sum_flatness = 0.0_f64;
let mut sum_rolloff = 0.0_f64;
let mut counted = 0_usize;
for fi in 0..num_frames {
let start = fi * hop;
let end = start + n_fft;
if end > samples.len() {
break;
}
let frame = &samples[start..end];
let energy = compute_rms(frame);
sum_energy += f64::from(energy);
let zcr = zero_crossing_rate(frame);
sum_zcr += f64::from(zcr);
let windowed: Vec<oxifft::Complex<f64>> = frame
.iter()
.zip(window.iter())
.map(|(&s, &w)| oxifft::Complex::new(f64::from(s * w), 0.0))
.collect();
let spectrum = oxifft::fft(&windowed);
let magnitude: Vec<f32> = spectrum[..n_bins]
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt() as f32)
.collect();
sum_centroid += f64::from(compute_spectral_centroid_norm(
&magnitude,
n_fft,
sample_rate,
));
sum_flatness += f64::from(compute_spectral_flatness(&magnitude));
sum_rolloff += f64::from(compute_spectral_rolloff_norm(
&magnitude,
n_fft,
sample_rate,
));
counted += 1;
}
let n = counted.max(1) as f64;
SceneFeatures {
spectral_centroid_norm: (sum_centroid / n / f64::from(nyquist)) as f32,
zcr: (sum_zcr / n) as f32,
energy: (sum_energy / n) as f32,
flatness: (sum_flatness / n) as f32,
rolloff_norm: (sum_rolloff / n) as f32,
}
}
}
fn compute_spectral_centroid_norm(magnitude: &[f32], n_fft: usize, sample_rate: u32) -> f32 {
let mut weighted = 0.0_f32;
let mut total = 0.0_f32;
for (k, &m) in magnitude.iter().enumerate() {
let freq = k as f32 * sample_rate as f32 / n_fft as f32;
weighted += freq * m;
total += m;
}
if total > 0.0 {
weighted / total
} else {
0.0
}
}
fn compute_spectral_flatness(magnitude: &[f32]) -> f32 {
let n = magnitude.len();
if n == 0 {
return 0.0;
}
let eps = 1e-10_f64;
let log_sum: f64 = magnitude.iter().map(|&m| (f64::from(m) + eps).ln()).sum();
let geo_mean = (log_sum / n as f64).exp();
let arith_mean: f64 = magnitude.iter().map(|&m| f64::from(m) + eps).sum::<f64>() / n as f64;
if arith_mean > 0.0 {
(geo_mean / arith_mean).clamp(0.0, 1.0) as f32
} else {
0.0
}
}
fn compute_spectral_rolloff_norm(magnitude: &[f32], n_fft: usize, sample_rate: u32) -> f32 {
let total_energy: f32 = magnitude.iter().map(|&m| m * m).sum();
if total_energy <= 0.0 {
return 0.0;
}
let threshold = 0.85 * total_energy;
let mut cum = 0.0_f32;
let nyquist = sample_rate as f32 / 2.0;
for (k, &m) in magnitude.iter().enumerate() {
cum += m * m;
if cum >= threshold {
let freq = k as f32 * sample_rate as f32 / n_fft as f32;
return (freq / nyquist).clamp(0.0, 1.0);
}
}
1.0
}
fn classify_features(
f: &SceneFeatures,
cfg: &SceneClassifierConfig,
samples: &[f32],
) -> (AudioScene, f32) {
let silent_fraction = {
let hop = cfg.hop_size;
let n_fft = cfg.fft_size;
let num_frames = (samples.len().saturating_sub(n_fft)) / hop + 1;
let silent: usize = (0..num_frames)
.filter(|&fi| {
let start = fi * hop;
let end = (start + n_fft).min(samples.len());
compute_rms(&samples[start..end]) < cfg.silence_threshold
})
.count();
if num_frames > 0 {
silent as f32 / num_frames as f32
} else {
1.0
}
};
if silent_fraction >= cfg.quiet_frame_fraction {
return (AudioScene::Quiet, 0.85 + 0.15 * silent_fraction);
}
if f.flatness > 0.55 && f.energy > 0.06 {
let conf = (f.flatness * 0.7 + (f.energy / 0.5).min(1.0) * 0.3).min(1.0);
return (AudioScene::Noisy, conf);
}
let is_speech_centroid = f.spectral_centroid_norm > 0.02 && f.spectral_centroid_norm < 0.30;
let is_speech_zcr = f.zcr > 0.02 && f.zcr < 0.30;
let is_speech_flatness = f.flatness < 0.40;
if is_speech_centroid && is_speech_zcr && is_speech_flatness && f.energy > cfg.silence_threshold
{
let centroid_score = 1.0 - (f.spectral_centroid_norm - 0.10).abs() / 0.20;
let zcr_score = 1.0 - (f.zcr - 0.10).abs() / 0.20;
let conf = ((centroid_score + zcr_score) / 2.0).clamp(0.5, 0.95);
return (AudioScene::Speech, conf);
}
let is_music_flatness = f.flatness < 0.35;
let is_music_rolloff = f.rolloff_norm > 0.20;
let is_music_energy = f.energy > cfg.silence_threshold * 2.0;
if is_music_flatness && is_music_rolloff && is_music_energy {
let roll_score = (f.rolloff_norm * 2.0).min(1.0);
let conf = (0.5 + roll_score * 0.4).min(0.95);
return (AudioScene::Music, conf);
}
if f.spectral_centroid_norm > 0.25 && f.flatness > 0.30 {
return (AudioScene::Outdoor, 0.65);
}
if f.spectral_centroid_norm < 0.20 && f.flatness < 0.40 {
return (AudioScene::Indoor, 0.60);
}
(AudioScene::Mixed, 0.50)
}
#[cfg(test)]
mod tests {
use super::*;
fn sine_wave(freq: f64, sample_rate: u32, duration_secs: f64) -> Vec<f32> {
let num = (f64::from(sample_rate) * duration_secs) as usize;
(0..num)
.map(|i| {
(2.0 * std::f64::consts::PI * freq * i as f64 / f64::from(sample_rate)).sin() as f32
})
.collect()
}
fn white_noise(n: usize, amplitude: f32) -> Vec<f32> {
let mut state: u64 = 0xDEAD_BEEF_1234_5678;
(0..n)
.map(|_| {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let v = (state as i64 as f64) / (i64::MAX as f64);
v as f32 * amplitude
})
.collect()
}
#[test]
fn test_classify_quiet_silence() {
let silence = vec![0.0_f32; 44100];
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let result = classifier
.classify(&silence, 44100)
.expect("should succeed");
assert_eq!(result.scene, AudioScene::Quiet, "silence should be Quiet");
}
#[test]
fn test_classify_noisy_white_noise() {
let noise = white_noise(44100, 0.5);
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let result = classifier.classify(&noise, 44100).expect("should succeed");
assert!(
matches!(result.scene, AudioScene::Noisy | AudioScene::Mixed),
"white noise should be Noisy/Mixed, got {:?}",
result.scene
);
}
#[test]
fn test_classify_music_tonal_sine() {
let samples = sine_wave(440.0, 44100, 1.0);
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let result = classifier
.classify(&samples, 44100)
.expect("should succeed");
assert!(
matches!(
result.scene,
AudioScene::Music | AudioScene::Speech | AudioScene::Indoor
),
"pure tone should be Music/Speech/Indoor, got {:?}",
result.scene
);
}
#[test]
fn test_classify_insufficient_samples() {
let short = vec![0.0_f32; 100];
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
assert!(classifier.classify(&short, 44100).is_err());
}
#[test]
fn test_features_energy_for_sine() {
let samples = sine_wave(440.0, 44100, 0.5);
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let features = classifier.extract_features(&samples, 44100);
assert!(
features.energy > 0.5 && features.energy < 0.8,
"RMS energy of unit sine should be ~0.707, got {}",
features.energy
);
}
#[test]
fn test_features_flatness_range() {
let noise = white_noise(22050, 0.3);
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let f = classifier.extract_features(&noise, 22050);
assert!(
f.flatness >= 0.0 && f.flatness <= 1.0,
"flatness out of range: {}",
f.flatness
);
}
#[test]
fn test_scene_label_strings() {
assert_eq!(AudioScene::Indoor.label(), "Indoor");
assert_eq!(AudioScene::Outdoor.label(), "Outdoor");
assert_eq!(AudioScene::Quiet.label(), "Quiet");
assert_eq!(AudioScene::Noisy.label(), "Noisy");
assert_eq!(AudioScene::Speech.label(), "Speech");
assert_eq!(AudioScene::Music.label(), "Music");
assert_eq!(AudioScene::Mixed.label(), "Mixed");
}
#[test]
fn test_confidence_in_range() {
let samples = sine_wave(1000.0, 22050, 0.5);
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let result = classifier.classify(&samples, 22050).expect("ok");
assert!(
result.confidence >= 0.0 && result.confidence <= 1.0,
"confidence {} out of [0,1]",
result.confidence
);
}
#[test]
fn test_classify_low_energy_is_quiet() {
let samples: Vec<f32> = sine_wave(440.0, 44100, 0.5)
.into_iter()
.map(|s| s * 0.001) .collect();
let classifier = AudioSceneClassifier::new(SceneClassifierConfig::default());
let result = classifier.classify(&samples, 44100).expect("ok");
assert_eq!(
result.scene,
AudioScene::Quiet,
"very low amplitude should be Quiet, got {:?}",
result.scene
);
}
#[test]
fn test_spectral_flatness_pure_tone() {
let mut mag = vec![0.0_f32; 513];
mag[100] = 1.0;
let flatness = compute_spectral_flatness(&mag);
assert!(
flatness < 0.05,
"pure tone flatness should be near 0, got {}",
flatness
);
}
#[test]
fn test_spectral_flatness_uniform() {
let mag = vec![1.0_f32; 512];
let flatness = compute_spectral_flatness(&mag);
assert!(
(flatness - 1.0).abs() < 0.01,
"uniform flatness should be ~1, got {}",
flatness
);
}
}