use crate::analyse::rms_db;
use rustfft::{FftPlanner, num_complex::Complex};
use time::Duration;
#[derive(Debug, Clone)]
pub struct SpectralViolation {
pub time: Duration,
pub band_ratio: f32,
pub kind: SpectralViolationKind,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpectralViolationKind {
Sibilance,
Plosive,
}
pub const SIBILANCE_RATIO_THRESHOLD: f32 = 0.55;
pub const PLOSIVE_RATIO_THRESHOLD: f32 = 0.35;
const MIN_SPEECH_RMS_DB: f32 = -50.0;
const WINDOW_MS: usize = 50;
pub fn scan(samples: &[i16], sample_rate: u32) -> Vec<SpectralViolation> {
let window_size = (sample_rate as usize * WINDOW_MS) / 1000;
if window_size == 0 || samples.is_empty() {
return Vec::new();
}
let hann: Vec<f32> = hann_window(window_size);
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(window_size);
let mut violations = Vec::new();
for (win_idx, chunk) in samples.chunks(window_size).enumerate() {
if chunk.len() < window_size {
break; }
if rms_db(chunk) < MIN_SPEECH_RMS_DB {
continue;
}
let mut buffer: Vec<Complex<f32>> = chunk
.iter()
.zip(hann.iter())
.map(|(&s, &w)| Complex {
re: s as f32 * w / i16::MAX as f32,
im: 0.0,
})
.collect();
fft.process(&mut buffer);
let half = window_size / 2;
let power: Vec<f32> = buffer[..half]
.iter()
.map(|c| c.re * c.re + c.im * c.im)
.collect();
let total: f32 = power.iter().sum();
if total < f32::EPSILON {
continue;
}
let freq_resolution = sample_rate as f32 / window_size as f32;
let time = Duration::milliseconds((win_idx * WINDOW_MS) as i64);
let sib_lo = (4000.0 / freq_resolution) as usize;
let sib_hi = (10000.0 / freq_resolution).min(half as f32) as usize;
if sib_lo < sib_hi {
let sib_energy: f32 = power[sib_lo..sib_hi].iter().sum();
let ratio = sib_energy / total;
if ratio > SIBILANCE_RATIO_THRESHOLD {
violations.push(SpectralViolation {
time,
band_ratio: ratio,
kind: SpectralViolationKind::Sibilance,
});
}
}
let plo_lo = (20.0 / freq_resolution).max(1.0) as usize;
let plo_hi = (150.0 / freq_resolution) as usize;
if plo_lo < plo_hi && plo_hi <= half {
let plo_energy: f32 = power[plo_lo..plo_hi].iter().sum();
let ratio = plo_energy / total;
if ratio > PLOSIVE_RATIO_THRESHOLD {
violations.push(SpectralViolation {
time,
band_ratio: ratio,
kind: SpectralViolationKind::Plosive,
});
}
}
}
violations
}
fn hann_window(n: usize) -> Vec<f32> {
(0..n)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (n as f32 - 1.0)).cos()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
const SR: u32 = 24_000;
fn pure_tone(freq_hz: f32, amplitude: f32, secs: f32, sr: u32) -> Vec<i16> {
let n = (sr as f32 * secs) as usize;
(0..n)
.map(|i| {
let v =
amplitude * (2.0 * std::f32::consts::PI * freq_hz * i as f32 / sr as f32).sin();
v.clamp(i16::MIN as f32, i16::MAX as f32) as i16
})
.collect()
}
#[test]
fn pure_sibilance_tone_is_flagged() {
let samples = pure_tone(7000.0, 8000.0, 0.5, SR);
let violations = scan(&samples, SR);
assert!(
violations
.iter()
.any(|v| v.kind == SpectralViolationKind::Sibilance),
"Expected sibilance violation for 7 kHz tone"
);
}
#[test]
fn pure_plosive_tone_is_flagged() {
let samples = pure_tone(80.0, 8000.0, 0.5, SR);
let violations = scan(&samples, SR);
assert!(
violations
.iter()
.any(|v| v.kind == SpectralViolationKind::Plosive),
"Expected plosive violation for 80 Hz tone"
);
}
#[test]
fn speech_frequency_tone_not_flagged() {
let samples = pure_tone(1000.0, 8000.0, 0.5, SR);
let violations = scan(&samples, SR);
assert!(
violations.is_empty(),
"Unexpected violations for 1 kHz speech-range tone: {:?}",
violations.iter().map(|v| &v.kind).collect::<Vec<_>>()
);
}
#[test]
fn empty_input_returns_no_violations() {
assert!(scan(&[], SR).is_empty());
}
#[test]
fn silence_has_no_violations() {
let samples = vec![0i16; SR as usize];
assert!(scan(&samples, SR).is_empty());
}
#[test]
fn violation_timestamps_are_reasonable() {
let window = SR as usize * WINDOW_MS / 1000;
let mut samples = pure_tone(1000.0, 8000.0, WINDOW_MS as f32 / 1000.0, SR);
samples.extend(pure_tone(7000.0, 8000.0, WINDOW_MS as f32 / 1000.0, SR));
samples.extend(vec![0i16; window]);
let violations = scan(&samples, SR);
let sib: Vec<_> = violations
.iter()
.filter(|v| v.kind == SpectralViolationKind::Sibilance)
.collect();
assert!(!sib.is_empty());
for v in &sib {
assert!(
v.time >= Duration::milliseconds(40),
"Violation at {:?} unexpected",
v.time
);
}
}
}