use super::{BandMetric, BandMetrics, BandLayout, PsychoacousticConfig};
use non_empty_slice::NonEmptySlice;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MaskerType {
Tonal,
Noise,
}
#[must_use]
pub fn classify_masker_types(band_energy_db: &[f32], tonal_threshold_db: f32) -> Vec<MaskerType> {
let n = band_energy_db.len();
(0..n)
.map(|i| {
let neighbours: f32 = match (i.checked_sub(1), (i + 1 < n).then_some(i + 1)) {
(Some(l), Some(r)) => (band_energy_db[l] + band_energy_db[r]) / 2.0,
(Some(l), None) => band_energy_db[l],
(None, Some(r)) => band_energy_db[r],
(None, None) => band_energy_db[i],
};
if band_energy_db[i] - neighbours >= tonal_threshold_db {
MaskerType::Tonal
} else {
MaskerType::Noise
}
})
.collect()
}
#[inline]
#[must_use]
pub fn absolute_threshold_of_hearing(freq_hz: f32) -> f32 {
if freq_hz <= 0.0 {
return 60.0;
}
let f = freq_hz / 1000.0;
let term1 = 3.64_f32 * f.powf(-0.8_f32);
let term2 = -6.5_f32 * (-0.6_f32 * (f - 3.3_f32).powi(2)).exp();
let term3 = 1e-3_f32 * f.powi(4);
(term1 + term2 + term3).clamp(-20.0, 60.0)
}
#[inline]
#[must_use]
pub fn spreading_attenuation(
masker_bark: f32,
maskee_bark: f32,
masker_type: MaskerType,
config: &PsychoacousticConfig,
) -> f32 {
let base_gain = match masker_type {
MaskerType::Tonal => config.masking_gain,
MaskerType::Noise => config.noise_masking_gain,
};
let dz = maskee_bark - masker_bark;
let distance_decay = if dz >= 0.0 {
config.upward_spread * dz
} else {
config.downward_spread * (-dz)
};
base_gain + distance_decay
}
pub fn compute_band_metrics(
bin_energies: &[f32],
band_layout: &BandLayout,
config: &PsychoacousticConfig,
n_bins: usize,
) -> BandMetrics {
use super::bands::hz_to_bark;
let bands = band_layout.as_slice();
let n_bands = bands.len().get();
let weights = config.perceptual_weights.as_non_empty_slice();
let band_energy_db: Vec<f32> = bands
.iter()
.map(|band| {
let start = band.start_bin.min(n_bins);
let end = band.end_bin.min(n_bins);
let width = (end - start).max(1);
let sum: f32 = bin_energies[start..end].iter().sum();
10.0_f32 * (sum / width as f32 + config.epsilon).log10()
})
.collect();
let ath_db: Vec<f32> = bands
.iter()
.map(|band| absolute_threshold_of_hearing(band.centre_frequency))
.collect();
let bark_positions: Vec<f32> = bands
.iter()
.map(|band| hz_to_bark(band.centre_frequency))
.collect();
let masker_types = classify_masker_types(&band_energy_db, 7.0);
let masking_thresholds: Vec<f32> = (0..n_bands)
.map(|j| {
let dominated_by: f32 = (0..n_bands)
.map(|i| {
let attenuation = spreading_attenuation(
bark_positions[i],
bark_positions[j],
masker_types[i],
config,
);
band_energy_db[i] - attenuation
})
.fold(f32::NEG_INFINITY, f32::max);
dominated_by.max(ath_db[j])
})
.collect();
let metrics: Vec<BandMetric> = (0..n_bands)
.map(|i| {
let energy = band_energy_db[i];
let masking_threshold = masking_thresholds[i];
let smr = compute_smr(energy, masking_threshold);
let weight = weights[i];
let importance = weight * smr.max(0.0);
let allowed_noise = masking_threshold - smr.max(0.0);
BandMetric::new(
bands[i].clone(),
energy,
masking_threshold,
smr,
importance,
allowed_noise,
)
})
.collect();
let metrics_slice = unsafe { NonEmptySlice::new_unchecked(&metrics) };
BandMetrics::new(metrics_slice)
}
pub fn apply_temporal_masking(
frame_metrics: &[BandMetrics],
config: &PsychoacousticConfig,
hop_duration_ms: f32,
post_masking_decay_db_per_ms: f32,
pre_masking_decay_db_per_ms: f32,
) -> Vec<BandMetrics> {
if frame_metrics.is_empty() {
return Vec::new();
}
let n_frames = frame_metrics.len();
let n_bands = frame_metrics[0].metrics.len().get();
let mut thresholds: Vec<Vec<f32>> = frame_metrics
.iter()
.map(|fm| fm.metrics.iter().map(|m| m.masking_threshold).collect())
.collect();
for t in 0..n_frames {
for b in 0..n_bands {
let masker = frame_metrics[t].metrics[b].masking_threshold;
let mut k = 1usize;
loop {
let future = t + k;
if future >= n_frames {
break;
}
let decay = post_masking_decay_db_per_ms * (k as f32 * hop_duration_ms);
let contribution = masker - decay;
if contribution <= thresholds[future][b] {
break;
}
thresholds[future][b] = contribution;
k += 1;
}
}
}
for t in 1..n_frames {
for b in 0..n_bands {
let masker = frame_metrics[t].metrics[b].masking_threshold;
let mut k = 1usize;
loop {
if k > t {
break;
}
let past = t - k;
let decay = pre_masking_decay_db_per_ms * (k as f32 * hop_duration_ms);
let contribution = masker - decay;
if contribution <= thresholds[past][b] {
break;
}
thresholds[past][b] = contribution;
k += 1;
}
}
}
frame_metrics
.iter()
.zip(thresholds.iter())
.map(|(fm, thresh_row)| {
let metrics: Vec<BandMetric> = fm
.metrics
.iter()
.zip(thresh_row.iter())
.enumerate()
.map(|(b, (m, &new_threshold))| {
let energy = m.energy;
let smr = compute_smr(energy, new_threshold);
let weight = config.perceptual_weights[b];
let importance = weight * smr.max(0.0);
let allowed_noise = new_threshold - smr.max(0.0);
BandMetric::new(m.band.clone(), energy, new_threshold, smr, importance, allowed_noise)
})
.collect();
let metrics_ne = unsafe { non_empty_slice::NonEmptySlice::new_unchecked(&metrics) };
BandMetrics::new(metrics_ne)
})
.collect()
}
#[must_use]
pub fn detect_transient_windows(
samples: &non_empty_slice::NonEmptySlice<f32>,
window_size: std::num::NonZeroUsize,
hop_size: std::num::NonZeroUsize,
threshold_ratio: f32,
) -> non_empty_slice::NonEmptyVec<bool> {
let ws = window_size.get();
let hs = hop_size.get();
let n = samples.len().get();
let energies: Vec<f32> = (0..)
.map(|i: usize| i * hs)
.take_while(|&pos| pos < n)
.map(|pos| {
let end = (pos + ws).min(n);
let w = &samples[pos..end];
w.iter().map(|s| s * s).sum::<f32>() / w.len() as f32
})
.collect();
let result: Vec<bool> = energies
.iter()
.enumerate()
.map(|(i, &e)| {
i > 0 && energies[i - 1] > 1e-10 && e > energies[i - 1] * threshold_ratio
})
.collect();
unsafe { non_empty_slice::NonEmptyVec::new_unchecked(result) }
}
#[inline]
#[must_use]
pub fn compute_smr(band_energy_db: f32, masking_threshold_db: f32) -> f32 {
band_energy_db - masking_threshold_db
}