use ndarray::Array1;
#[derive(Debug, Clone)]
pub struct DecomposedCorrectionConfig {
pub schroeder_freq: f64,
pub min_mode_q: f64,
pub min_mode_prominence_db: f64,
pub mode_correction_weight: f64,
pub early_reflection_weight: f64,
pub steady_state_weight: f64,
pub transition_width_oct: f64,
}
impl Default for DecomposedCorrectionConfig {
fn default() -> Self {
Self {
schroeder_freq: 250.0,
min_mode_q: 3.0,
min_mode_prominence_db: 3.0,
mode_correction_weight: 1.0,
early_reflection_weight: 0.3,
steady_state_weight: 0.4,
transition_width_oct: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct RoomMode {
pub frequency: f64,
pub q: f64,
pub prominence_db: f64,
pub index: usize,
}
#[derive(Debug, Clone)]
pub struct DecomposedCorrectionResult {
pub room_modes: Vec<RoomMode>,
pub correction_weights: Array1<f64>,
pub schroeder_freq: f64,
}
pub fn analyze_decomposed_correction(
freq: &Array1<f64>,
spl: &Array1<f64>,
config: &DecomposedCorrectionConfig,
) -> DecomposedCorrectionResult {
let room_modes = detect_room_modes(freq, spl, config);
let correction_weights = build_correction_weights(freq, &room_modes, config);
DecomposedCorrectionResult {
room_modes,
correction_weights,
schroeder_freq: config.schroeder_freq,
}
}
pub fn detect_room_modes(
freq: &Array1<f64>,
spl: &Array1<f64>,
config: &DecomposedCorrectionConfig,
) -> Vec<RoomMode> {
let mut modes = Vec::new();
let n = freq.len();
if n < 5 {
return modes;
}
for i in 2..n - 2 {
if freq[i] > config.schroeder_freq {
break;
}
let is_peak = spl[i] > spl[i - 1]
&& spl[i] > spl[i + 1]
&& spl[i] > spl[i - 2]
&& spl[i] > spl[i + 2];
if !is_peak {
continue;
}
let f_low = freq[i] / 2.0; let f_high = freq[i] * 2.0; let baseline = compute_local_baseline(freq, spl, i, f_low, f_high);
let prominence = spl[i] - baseline;
if prominence < config.min_mode_prominence_db {
continue;
}
let q = estimate_peak_q(freq, spl, i);
if q >= config.min_mode_q {
modes.push(RoomMode {
frequency: freq[i],
q,
prominence_db: prominence,
index: i,
});
}
}
modes
}
fn compute_local_baseline(
freq: &Array1<f64>,
spl: &Array1<f64>,
center_idx: usize,
f_low: f64,
f_high: f64,
) -> f64 {
let mut sum = 0.0;
let mut count = 0;
for j in 0..freq.len() {
if j == center_idx {
continue;
}
if freq[j] >= f_low && freq[j] <= f_high {
sum += spl[j];
count += 1;
}
}
if count > 0 {
sum / count as f64
} else {
spl[center_idx]
}
}
fn estimate_peak_q(freq: &Array1<f64>, spl: &Array1<f64>, peak_idx: usize) -> f64 {
let peak_spl = spl[peak_idx];
let threshold = peak_spl - 3.0; let f_center = freq[peak_idx];
let mut f_low: Option<f64> = None;
for i in (0..peak_idx).rev() {
if spl[i] <= threshold {
let denom = spl[i + 1] - spl[i];
if denom.abs() > 1e-12 {
let t = ((threshold - spl[i]) / denom).clamp(0.0, 1.0);
f_low = Some(freq[i] + t * (freq[i + 1] - freq[i]));
} else {
f_low = Some(freq[i]);
}
break;
}
}
let mut f_high: Option<f64> = None;
for i in (peak_idx + 1)..freq.len() {
if spl[i] <= threshold {
let denom = spl[i] - spl[i - 1];
if denom.abs() > 1e-12 {
let t = ((threshold - spl[i - 1]) / denom).clamp(0.0, 1.0);
f_high = Some(freq[i - 1] + t * (freq[i] - freq[i - 1]));
} else {
f_high = Some(freq[i]);
}
break;
}
}
let bandwidth = match (f_low, f_high) {
(Some(lo), Some(hi)) => hi - lo,
(Some(lo), None) => 2.0 * (f_center - lo),
(None, Some(hi)) => 2.0 * (hi - f_center),
(None, None) => 0.0, };
if bandwidth > 0.0 {
f_center / bandwidth
} else {
20.0
}
}
fn build_correction_weights(
freq: &Array1<f64>,
modes: &[RoomMode],
config: &DecomposedCorrectionConfig,
) -> Array1<f64> {
let n = freq.len();
let mut weights = Array1::zeros(n);
let schroeder_log = config.schroeder_freq.log2();
let half_transition = config.transition_width_oct / 2.0;
for i in 0..n {
let f = freq[i];
let f_log = f.log2();
let schroeder_blend = if config.transition_width_oct <= 0.0 {
if f <= config.schroeder_freq { 0.0 } else { 1.0 }
} else {
let x = (f_log - schroeder_log) / half_transition;
1.0 / (1.0 + (-x).exp())
};
let base_weight = config.early_reflection_weight
+ (config.steady_state_weight - config.early_reflection_weight) * schroeder_blend;
weights[i] = base_weight;
}
for mode in modes {
let bandwidth = mode.frequency / mode.q;
let f_low = mode.frequency - bandwidth / 2.0;
let f_high = mode.frequency + bandwidth / 2.0;
for i in 0..n {
if freq[i] >= f_low && freq[i] <= f_high {
weights[i] = weights[i].max(config.mode_correction_weight);
}
}
}
weights
}
#[derive(Debug, Clone)]
pub struct NullDetectionConfig {
pub min_null_q: f64,
pub min_null_depth_db: f64,
}
impl Default for NullDetectionConfig {
fn default() -> Self {
Self {
min_null_q: 3.0,
min_null_depth_db: 4.0,
}
}
}
#[derive(Debug, Clone)]
pub struct NarrowNull {
pub frequency: f64,
pub q: f64,
pub depth_db: f64,
pub index: usize,
}
pub fn detect_narrow_nulls(
freq: &Array1<f64>,
spl: &Array1<f64>,
config: &NullDetectionConfig,
) -> Vec<NarrowNull> {
let mut nulls = Vec::new();
let n = freq.len();
if n < 5 {
return nulls;
}
for i in 2..n - 2 {
let is_min = spl[i] < spl[i - 1]
&& spl[i] < spl[i + 1]
&& spl[i] < spl[i - 2]
&& spl[i] < spl[i + 2];
if !is_min {
continue;
}
let f_low_window = freq[i] / 2.0;
let f_high_window = freq[i] * 2.0;
let baseline = compute_local_baseline(freq, spl, i, f_low_window, f_high_window);
let depth = baseline - spl[i];
if depth < config.min_null_depth_db {
continue;
}
let q = estimate_dip_q(freq, spl, i);
if q >= config.min_null_q {
nulls.push(NarrowNull {
frequency: freq[i],
q,
depth_db: depth,
index: i,
});
}
}
nulls
}
fn estimate_dip_q(freq: &Array1<f64>, spl: &Array1<f64>, dip_idx: usize) -> f64 {
let dip_spl = spl[dip_idx];
let threshold = dip_spl + 3.0; let f_center = freq[dip_idx];
let mut f_low: Option<f64> = None;
for i in (0..dip_idx).rev() {
if spl[i] >= threshold {
let denom = spl[i + 1] - spl[i];
if denom.abs() > 1e-12 {
let t = ((threshold - spl[i]) / denom).clamp(0.0, 1.0);
f_low = Some(freq[i] + t * (freq[i + 1] - freq[i]));
} else {
f_low = Some(freq[i]);
}
break;
}
}
let mut f_high: Option<f64> = None;
for i in (dip_idx + 1)..freq.len() {
if spl[i] >= threshold {
let denom = spl[i] - spl[i - 1];
if denom.abs() > 1e-12 {
let t = ((threshold - spl[i - 1]) / denom).clamp(0.0, 1.0);
f_high = Some(freq[i - 1] + t * (freq[i] - freq[i - 1]));
} else {
f_high = Some(freq[i]);
}
break;
}
}
let bandwidth = match (f_low, f_high) {
(Some(lo), Some(hi)) => hi - lo,
(Some(lo), None) => 2.0 * (f_center - lo),
(None, Some(hi)) => 2.0 * (hi - f_center),
(None, None) => 0.0,
};
if bandwidth > 0.0 {
f_center / bandwidth
} else {
20.0
}
}
pub fn build_null_suppression_mask(freq: &Array1<f64>, nulls: &[NarrowNull]) -> Array1<f64> {
let n = freq.len();
let mut mask = Array1::ones(n);
if nulls.is_empty() {
return mask;
}
const TAPER_FRAC: f64 = 0.5;
for null in nulls {
let bw = null.frequency / null.q.max(1e-6);
let f_inner = bw * 0.5 * (1.0 - TAPER_FRAC);
let f_outer = bw * 0.5;
for i in 0..n {
let df = (freq[i] - null.frequency).abs();
if df > f_outer {
continue;
}
let w = if df <= f_inner {
0.0
} else {
let t = (df - f_inner) / (f_outer - f_inner);
0.5 * (1.0 - (std::f64::consts::PI * t).cos())
};
if w < mask[i] {
mask[i] = w;
}
}
}
mask
}
pub fn build_ssir_correction_weights(
freq: &Array1<f64>,
spl: &Array1<f64>,
ssir_result: &math_rir::SsirResult,
config: &DecomposedCorrectionConfig,
) -> DecomposedCorrectionResult {
let n = freq.len();
let room_modes = detect_room_modes(freq, spl, config);
let ssir_boundary_freq = config.schroeder_freq;
let num_reflections = ssir_result.num_reflections();
let reflection_weight = if num_reflections > 8 {
config.early_reflection_weight * 0.7
} else if num_reflections > 4 {
config.early_reflection_weight
} else {
config.early_reflection_weight * 1.5_f64.min(1.0)
};
let mut weights = Array1::zeros(n);
let boundary_log = ssir_boundary_freq.log2();
let half_transition = config.transition_width_oct / 2.0;
for i in 0..n {
let f = freq[i];
let f_log = f.log2();
let blend = if config.transition_width_oct <= 0.0 {
if f <= ssir_boundary_freq { 0.0 } else { 1.0 }
} else {
let x = (f_log - boundary_log) / half_transition;
1.0 / (1.0 + (-x).exp())
};
let base_weight =
reflection_weight + (config.steady_state_weight - reflection_weight) * blend;
weights[i] = base_weight;
}
for mode in &room_modes {
let bandwidth = mode.frequency / mode.q;
let f_low = mode.frequency - bandwidth / 2.0;
let f_high = mode.frequency + bandwidth / 2.0;
for i in 0..n {
if freq[i] >= f_low && freq[i] <= f_high {
weights[i] = weights[i].max(config.mode_correction_weight);
}
}
}
DecomposedCorrectionResult {
room_modes,
correction_weights: weights,
schroeder_freq: ssir_boundary_freq,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_room_modes_flat_response() {
let n = 100;
let freq = Array1::linspace(20.0, 500.0, n);
let spl = Array1::from_elem(n, 80.0);
let config = DecomposedCorrectionConfig::default();
let modes = detect_room_modes(&freq, &spl, &config);
assert!(modes.is_empty(), "flat response should have no modes");
}
#[test]
fn test_detect_room_modes_with_peak() {
let n = 200;
let freq = Array1::linspace(20.0, 300.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let q = 10.0;
let bw = 60.0 / q;
let response = 10.0 / (1.0 + ((f - 60.0) / (bw / 2.0_f64)).powi(2));
spl[i] += response;
}
let config = DecomposedCorrectionConfig::default();
let modes = detect_room_modes(&freq, &spl, &config);
assert!(
!modes.is_empty(),
"should detect the 60 Hz peak as a room mode"
);
let nearest = modes
.iter()
.min_by(|a, b| {
(a.frequency - 60.0)
.abs()
.partial_cmp(&(b.frequency - 60.0).abs())
.unwrap()
})
.unwrap();
assert!(
(nearest.frequency - 60.0).abs() < 10.0,
"detected mode at {:.1} Hz should be near 60 Hz",
nearest.frequency
);
assert!(
nearest.q >= 3.0,
"detected Q={:.1} should be >= 3.0",
nearest.q
);
}
#[test]
fn test_correction_weights_below_schroeder() {
let freq = Array1::from_vec(vec![50.0]);
let modes = vec![];
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.3,
steady_state_weight: 0.5,
transition_width_oct: 0.0, ..Default::default()
};
let weights = build_correction_weights(&freq, &modes, &config);
assert!(
(weights[0] - 0.3).abs() < 0.01,
"below Schroeder should use early_reflection_weight, got {}",
weights[0]
);
}
#[test]
fn test_correction_weights_above_schroeder() {
let freq = Array1::from_vec(vec![500.0]);
let modes = vec![];
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.3,
steady_state_weight: 0.5,
transition_width_oct: 0.0,
..Default::default()
};
let weights = build_correction_weights(&freq, &modes, &config);
assert!(
(weights[0] - 0.5).abs() < 0.01,
"above Schroeder should use steady_state_weight, got {}",
weights[0]
);
}
#[test]
fn test_correction_weights_mode_boost() {
let freq = Array1::from_vec(vec![50.0, 60.0, 70.0]);
let modes = vec![RoomMode {
frequency: 60.0,
q: 5.0,
prominence_db: 8.0,
index: 1,
}];
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.3,
mode_correction_weight: 1.0,
transition_width_oct: 0.0,
..Default::default()
};
let weights = build_correction_weights(&freq, &modes, &config);
assert!(
weights[1] > 0.9,
"mode frequency should have boosted weight, got {}",
weights[1]
);
}
#[test]
fn test_full_decomposed_analysis() {
let n = 200;
let freq = Array1::linspace(20.0, 500.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let q = 8.0;
let bw = 80.0 / q;
spl[i] += 8.0 / (1.0 + ((f - 80.0) / (bw / 2.0_f64)).powi(2));
}
let config = DecomposedCorrectionConfig::default();
let result = analyze_decomposed_correction(&freq, &spl, &config);
assert_eq!(result.schroeder_freq, 250.0);
assert!(!result.correction_weights.iter().any(|w| w.is_nan()));
assert!(
result
.correction_weights
.iter()
.all(|&w| (0.0..=1.0).contains(&w))
);
}
#[test]
fn test_estimate_peak_q_narrow() {
let n = 100;
let freq = Array1::linspace(40.0, 120.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let bw = 80.0 / 15.0; spl[i] += 10.0 / (1.0 + ((f - 80.0) / (bw / 2.0_f64)).powi(2));
}
let peak_idx = spl
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
let q = estimate_peak_q(&freq, &spl, peak_idx);
assert!(q > 5.0, "narrow peak should have high Q, got {:.1}", q);
}
#[test]
fn test_estimate_peak_q_at_array_edge() {
let freq = Array1::from_vec(vec![20.0, 30.0, 40.0, 50.0, 60.0]);
let spl = Array1::from_vec(vec![90.0, 88.0, 85.0, 83.0, 80.0]); let q = estimate_peak_q(&freq, &spl, 0);
assert!(
q.is_finite() && q > 0.0,
"Q at edge should be positive finite, got {}",
q
);
}
#[test]
fn test_compute_local_baseline_excludes_center() {
let freq = Array1::from_vec(vec![50.0, 60.0, 70.0, 80.0, 90.0]);
let spl = Array1::from_vec(vec![80.0, 80.0, 95.0, 80.0, 80.0]); let baseline = compute_local_baseline(&freq, &spl, 2, 40.0, 100.0);
assert!(
(baseline - 80.0).abs() < 0.5,
"baseline should be ~80, got {:.1}",
baseline
);
}
#[test]
fn test_correction_weights_smooth_transition() {
let freq = Array1::from_vec(vec![50.0, 200.0, 800.0]);
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.2,
steady_state_weight: 0.6,
transition_width_oct: 1.0, ..Default::default()
};
let weights = build_correction_weights(&freq, &[], &config);
assert!(
weights[0] < 0.4,
"50 Hz weight should be near 0.2, got {}",
weights[0]
);
let midpoint = (0.2 + 0.6) / 2.0;
assert!(
(weights[1] - midpoint).abs() < 0.15,
"200 Hz weight should be near {:.1}, got {:.2}",
midpoint,
weights[1]
);
assert!(
weights[2] > 0.4,
"800 Hz weight should be near 0.6, got {}",
weights[2]
);
}
#[test]
fn test_detect_room_modes_short_array() {
let freq = Array1::from_vec(vec![50.0, 60.0, 70.0]);
let spl = Array1::from_vec(vec![85.0, 90.0, 85.0]);
let config = DecomposedCorrectionConfig::default();
let modes = detect_room_modes(&freq, &spl, &config);
assert!(modes.is_empty());
}
#[test]
fn test_detect_room_modes_ignores_above_schroeder() {
let n = 200;
let freq = Array1::linspace(20.0, 1000.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
spl[i] += 10.0 / (1.0 + ((f - 400.0) / 5.0_f64).powi(2));
}
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
..Default::default()
};
let modes = detect_room_modes(&freq, &spl, &config);
assert!(
modes.iter().all(|m| m.frequency <= 200.0),
"modes above Schroeder should not be detected"
);
}
#[test]
fn test_estimate_peak_q_one_sided_low() {
let n = 100;
let freq = Array1::linspace(20.0, 200.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let bw = 25.0 / 5.0; spl[i] += 10.0 / (1.0 + ((f - 25.0) / (bw / 2.0_f64)).powi(2));
}
let peak_idx = spl
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
let q = estimate_peak_q(&freq, &spl, peak_idx);
assert!(
q < 15.0,
"one-sided Q estimate should not double: got {:.1}, expected ~5",
q
);
assert!(q > 2.0, "Q should still be reasonable: got {:.1}", q);
}
#[test]
fn test_estimate_peak_q_interpolation_denom_zero() {
let freq = Array1::from_vec(vec![50.0, 55.0, 60.0, 65.0, 70.0, 75.0, 80.0]);
let spl = Array1::from_vec(vec![80.0, 83.0, 83.0, 90.0, 83.0, 83.0, 80.0]);
let q = estimate_peak_q(&freq, &spl, 3);
assert!(
q.is_finite() && q > 0.0,
"Q should be finite positive, got {:.1}",
q
);
}
#[test]
fn test_decomposed_correction_config_defaults() {
let config = DecomposedCorrectionConfig::default();
assert_eq!(config.schroeder_freq, 250.0);
assert_eq!(config.steady_state_weight, 0.4);
assert_eq!(config.min_mode_q, 3.0);
assert_eq!(config.min_mode_prominence_db, 3.0);
assert_eq!(config.mode_correction_weight, 1.0);
assert_eq!(config.early_reflection_weight, 0.3);
assert_eq!(config.transition_width_oct, 0.5);
}
fn log_linspace(f_min: f64, f_max: f64, n: usize) -> Array1<f64> {
let lo = f_min.ln();
let hi = f_max.ln();
Array1::from_iter((0..n).map(|i| (lo + (hi - lo) * i as f64 / (n - 1) as f64).exp()))
}
#[test]
fn test_detect_narrow_nulls_flat_response_is_empty() {
let freq = log_linspace(20.0, 20000.0, 512);
let spl = Array1::from_elem(freq.len(), 80.0);
let nulls = detect_narrow_nulls(&freq, &spl, &NullDetectionConfig::default());
assert!(
nulls.is_empty(),
"flat response must not produce narrow nulls, got {nulls:?}"
);
}
#[test]
fn test_detect_narrow_nulls_finds_high_q_notch() {
let freq = log_linspace(20.0, 20000.0, 512);
let f0 = 80.0;
let q = 10.0;
let bw = f0 / q;
let spl: Array1<f64> = freq.mapv(|f| {
let x = (f - f0) / (bw / 2.0);
80.0 - 15.0 / (1.0 + x * x)
});
let nulls = detect_narrow_nulls(&freq, &spl, &NullDetectionConfig::default());
assert!(
!nulls.is_empty(),
"should detect the 80 Hz Q=10 notch as a narrow null"
);
let nearest = nulls
.iter()
.min_by(|a, b| {
(a.frequency - f0)
.abs()
.partial_cmp(&(b.frequency - f0).abs())
.unwrap()
})
.unwrap();
assert!(
(nearest.frequency - f0).abs() < 5.0,
"detected null at {:.1} Hz should be near {f0} Hz",
nearest.frequency
);
assert!(
nearest.q >= 3.0,
"detected Q={:.1} should exceed min_null_q=3",
nearest.q
);
assert!(
nearest.depth_db >= 4.0,
"detected depth={:.1} should exceed min_null_depth_db=4",
nearest.depth_db
);
}
#[test]
fn test_detect_narrow_nulls_ignores_broad_dip() {
let freq = log_linspace(20.0, 20000.0, 512);
let f0 = 400.0;
let q = 0.8;
let bw = f0 / q;
let spl: Array1<f64> = freq.mapv(|f| {
let x = (f - f0) / (bw / 2.0);
80.0 - 8.0 / (1.0 + x * x)
});
let nulls = detect_narrow_nulls(&freq, &spl, &NullDetectionConfig::default());
assert!(
nulls.is_empty(),
"a broad Q=0.8 dip must not be flagged as a narrow null, got {nulls:?}"
);
}
#[test]
fn test_build_null_suppression_mask_is_zero_at_null() {
let freq = log_linspace(20.0, 20000.0, 512);
let nulls = vec![NarrowNull {
frequency: 80.0,
q: 10.0,
depth_db: 15.0,
index: 0,
}];
let mask = build_null_suppression_mask(&freq, &nulls);
assert_eq!(mask.len(), freq.len());
let center_idx = freq
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| (*a - 80.0).abs().partial_cmp(&(*b - 80.0).abs()).unwrap())
.unwrap()
.0;
assert!(
mask[center_idx] < 1e-6,
"mask at null centre must be ~0, got {}",
mask[center_idx]
);
let far_idx = freq
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
(*a - 5000.0)
.abs()
.partial_cmp(&(*b - 5000.0).abs())
.unwrap()
})
.unwrap()
.0;
assert!(
(mask[far_idx] - 1.0).abs() < 1e-12,
"mask far from any null must be 1.0, got {}",
mask[far_idx]
);
for (i, &m) in mask.iter().enumerate() {
assert!(
(0.0..=1.0).contains(&m),
"mask[{i}] = {m} must be in [0, 1]"
);
}
}
#[test]
fn test_build_null_suppression_mask_empty_input_is_all_ones() {
let freq = log_linspace(20.0, 20000.0, 256);
let mask = build_null_suppression_mask(&freq, &[]);
assert!(
mask.iter().all(|&m| (m - 1.0).abs() < 1e-12),
"empty null list must yield an all-ones mask"
);
}
#[test]
fn test_null_detection_config_defaults() {
let config = NullDetectionConfig::default();
assert_eq!(config.min_null_q, 3.0);
assert_eq!(config.min_null_depth_db, 4.0);
}
}