use ndarray::Array1;
#[derive(Debug, Clone)]
pub struct DecomposedCorrectionConfig {
pub schroeder_freq: f64,
pub min_mode_q: f64,
pub min_mode_prominence_db: f64,
pub mode_correction_weight: f64,
pub early_reflection_weight: f64,
pub steady_state_weight: f64,
pub transition_width_oct: f64,
}
impl Default for DecomposedCorrectionConfig {
fn default() -> Self {
Self {
schroeder_freq: 250.0,
min_mode_q: 3.0,
min_mode_prominence_db: 3.0,
mode_correction_weight: 1.0,
early_reflection_weight: 0.3,
steady_state_weight: 0.4,
transition_width_oct: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct RoomMode {
pub frequency: f64,
pub q: f64,
pub prominence_db: f64,
pub index: usize,
}
#[derive(Debug, Clone)]
pub struct DecomposedCorrectionResult {
pub room_modes: Vec<RoomMode>,
pub correction_weights: Array1<f64>,
pub schroeder_freq: f64,
}
pub fn analyze_decomposed_correction(
freq: &Array1<f64>,
spl: &Array1<f64>,
config: &DecomposedCorrectionConfig,
) -> DecomposedCorrectionResult {
let room_modes = detect_room_modes(freq, spl, config);
let correction_weights = build_correction_weights(freq, &room_modes, config);
DecomposedCorrectionResult {
room_modes,
correction_weights,
schroeder_freq: config.schroeder_freq,
}
}
pub fn detect_room_modes(
freq: &Array1<f64>,
spl: &Array1<f64>,
config: &DecomposedCorrectionConfig,
) -> Vec<RoomMode> {
let mut modes = Vec::new();
let n = freq.len();
if n < 5 {
return modes;
}
for i in 2..n - 2 {
if freq[i] > config.schroeder_freq {
break;
}
let is_peak = spl[i] > spl[i - 1]
&& spl[i] > spl[i + 1]
&& spl[i] > spl[i - 2]
&& spl[i] > spl[i + 2];
if !is_peak {
continue;
}
let f_low = freq[i] / 2.0; let f_high = freq[i] * 2.0; let baseline = compute_local_baseline(freq, spl, i, f_low, f_high);
let prominence = spl[i] - baseline;
if prominence < config.min_mode_prominence_db {
continue;
}
let q = estimate_peak_q(freq, spl, i);
if q >= config.min_mode_q {
modes.push(RoomMode {
frequency: freq[i],
q,
prominence_db: prominence,
index: i,
});
}
}
modes
}
fn compute_local_baseline(
freq: &Array1<f64>,
spl: &Array1<f64>,
center_idx: usize,
f_low: f64,
f_high: f64,
) -> f64 {
let mut sum = 0.0;
let mut count = 0;
for j in 0..freq.len() {
if j == center_idx {
continue;
}
if freq[j] >= f_low && freq[j] <= f_high {
sum += spl[j];
count += 1;
}
}
if count > 0 {
sum / count as f64
} else {
spl[center_idx]
}
}
fn estimate_peak_q(freq: &Array1<f64>, spl: &Array1<f64>, peak_idx: usize) -> f64 {
let peak_spl = spl[peak_idx];
let threshold = peak_spl - 3.0; let f_center = freq[peak_idx];
let mut f_low: Option<f64> = None;
for i in (0..peak_idx).rev() {
if spl[i] <= threshold {
let denom = spl[i + 1] - spl[i];
if denom.abs() > 1e-12 {
let t = ((threshold - spl[i]) / denom).clamp(0.0, 1.0);
f_low = Some(freq[i] + t * (freq[i + 1] - freq[i]));
} else {
f_low = Some(freq[i]);
}
break;
}
}
let mut f_high: Option<f64> = None;
for i in (peak_idx + 1)..freq.len() {
if spl[i] <= threshold {
let denom = spl[i] - spl[i - 1];
if denom.abs() > 1e-12 {
let t = ((threshold - spl[i - 1]) / denom).clamp(0.0, 1.0);
f_high = Some(freq[i - 1] + t * (freq[i] - freq[i - 1]));
} else {
f_high = Some(freq[i]);
}
break;
}
}
let bandwidth = match (f_low, f_high) {
(Some(lo), Some(hi)) => hi - lo,
(Some(lo), None) => 2.0 * (f_center - lo),
(None, Some(hi)) => 2.0 * (hi - f_center),
(None, None) => 0.0, };
if bandwidth > 0.0 {
f_center / bandwidth
} else {
20.0
}
}
fn build_correction_weights(
freq: &Array1<f64>,
modes: &[RoomMode],
config: &DecomposedCorrectionConfig,
) -> Array1<f64> {
let n = freq.len();
let mut weights = Array1::zeros(n);
let schroeder_log = config.schroeder_freq.log2();
let half_transition = config.transition_width_oct / 2.0;
for i in 0..n {
let f = freq[i];
let f_log = f.log2();
let schroeder_blend = if config.transition_width_oct <= 0.0 {
if f <= config.schroeder_freq { 0.0 } else { 1.0 }
} else {
let x = (f_log - schroeder_log) / half_transition;
1.0 / (1.0 + (-x).exp())
};
let base_weight = config.early_reflection_weight
+ (config.steady_state_weight - config.early_reflection_weight) * schroeder_blend;
weights[i] = base_weight;
}
for mode in modes {
let bandwidth = mode.frequency / mode.q;
let f_low = mode.frequency - bandwidth / 2.0;
let f_high = mode.frequency + bandwidth / 2.0;
for i in 0..n {
if freq[i] >= f_low && freq[i] <= f_high {
weights[i] = weights[i].max(config.mode_correction_weight);
}
}
}
weights
}
pub fn build_ssir_correction_weights(
freq: &Array1<f64>,
spl: &Array1<f64>,
ssir_result: &math_rir::SsirResult,
config: &DecomposedCorrectionConfig,
) -> DecomposedCorrectionResult {
let n = freq.len();
let room_modes = detect_room_modes(freq, spl, config);
let mixing_time_s = ssir_result.mixing_time_samples as f64 / ssir_result.sample_rate;
let ssir_boundary_freq = if mixing_time_s > 0.001 {
(1.0 / mixing_time_s).clamp(50.0, 500.0)
} else {
config.schroeder_freq
};
let num_reflections = ssir_result.num_reflections();
let reflection_weight = if num_reflections > 8 {
config.early_reflection_weight * 0.7
} else if num_reflections > 4 {
config.early_reflection_weight
} else {
config.early_reflection_weight * 1.5_f64.min(1.0)
};
let mut weights = Array1::zeros(n);
let boundary_log = ssir_boundary_freq.log2();
let half_transition = config.transition_width_oct / 2.0;
for i in 0..n {
let f = freq[i];
let f_log = f.log2();
let blend = if config.transition_width_oct <= 0.0 {
if f <= ssir_boundary_freq { 0.0 } else { 1.0 }
} else {
let x = (f_log - boundary_log) / half_transition;
1.0 / (1.0 + (-x).exp())
};
let base_weight =
reflection_weight + (config.steady_state_weight - reflection_weight) * blend;
weights[i] = base_weight;
}
for mode in &room_modes {
let bandwidth = mode.frequency / mode.q;
let f_low = mode.frequency - bandwidth / 2.0;
let f_high = mode.frequency + bandwidth / 2.0;
for i in 0..n {
if freq[i] >= f_low && freq[i] <= f_high {
weights[i] = weights[i].max(config.mode_correction_weight);
}
}
}
DecomposedCorrectionResult {
room_modes,
correction_weights: weights,
schroeder_freq: ssir_boundary_freq,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_room_modes_flat_response() {
let n = 100;
let freq = Array1::linspace(20.0, 500.0, n);
let spl = Array1::from_elem(n, 80.0);
let config = DecomposedCorrectionConfig::default();
let modes = detect_room_modes(&freq, &spl, &config);
assert!(modes.is_empty(), "flat response should have no modes");
}
#[test]
fn test_detect_room_modes_with_peak() {
let n = 200;
let freq = Array1::linspace(20.0, 300.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let q = 10.0;
let bw = 60.0 / q;
let response = 10.0 / (1.0 + ((f - 60.0) / (bw / 2.0_f64)).powi(2));
spl[i] += response;
}
let config = DecomposedCorrectionConfig::default();
let modes = detect_room_modes(&freq, &spl, &config);
assert!(
!modes.is_empty(),
"should detect the 60 Hz peak as a room mode"
);
let nearest = modes
.iter()
.min_by(|a, b| {
(a.frequency - 60.0)
.abs()
.partial_cmp(&(b.frequency - 60.0).abs())
.unwrap()
})
.unwrap();
assert!(
(nearest.frequency - 60.0).abs() < 10.0,
"detected mode at {:.1} Hz should be near 60 Hz",
nearest.frequency
);
assert!(
nearest.q >= 3.0,
"detected Q={:.1} should be >= 3.0",
nearest.q
);
}
#[test]
fn test_correction_weights_below_schroeder() {
let freq = Array1::from_vec(vec![50.0]);
let modes = vec![];
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.3,
steady_state_weight: 0.5,
transition_width_oct: 0.0, ..Default::default()
};
let weights = build_correction_weights(&freq, &modes, &config);
assert!(
(weights[0] - 0.3).abs() < 0.01,
"below Schroeder should use early_reflection_weight, got {}",
weights[0]
);
}
#[test]
fn test_correction_weights_above_schroeder() {
let freq = Array1::from_vec(vec![500.0]);
let modes = vec![];
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.3,
steady_state_weight: 0.5,
transition_width_oct: 0.0,
..Default::default()
};
let weights = build_correction_weights(&freq, &modes, &config);
assert!(
(weights[0] - 0.5).abs() < 0.01,
"above Schroeder should use steady_state_weight, got {}",
weights[0]
);
}
#[test]
fn test_correction_weights_mode_boost() {
let freq = Array1::from_vec(vec![50.0, 60.0, 70.0]);
let modes = vec![RoomMode {
frequency: 60.0,
q: 5.0,
prominence_db: 8.0,
index: 1,
}];
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.3,
mode_correction_weight: 1.0,
transition_width_oct: 0.0,
..Default::default()
};
let weights = build_correction_weights(&freq, &modes, &config);
assert!(
weights[1] > 0.9,
"mode frequency should have boosted weight, got {}",
weights[1]
);
}
#[test]
fn test_full_decomposed_analysis() {
let n = 200;
let freq = Array1::linspace(20.0, 500.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let q = 8.0;
let bw = 80.0 / q;
spl[i] += 8.0 / (1.0 + ((f - 80.0) / (bw / 2.0_f64)).powi(2));
}
let config = DecomposedCorrectionConfig::default();
let result = analyze_decomposed_correction(&freq, &spl, &config);
assert_eq!(result.schroeder_freq, 250.0);
assert!(!result.correction_weights.iter().any(|w| w.is_nan()));
assert!(
result
.correction_weights
.iter()
.all(|&w| (0.0..=1.0).contains(&w))
);
}
#[test]
fn test_estimate_peak_q_narrow() {
let n = 100;
let freq = Array1::linspace(40.0, 120.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let bw = 80.0 / 15.0; spl[i] += 10.0 / (1.0 + ((f - 80.0) / (bw / 2.0_f64)).powi(2));
}
let peak_idx = spl
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
let q = estimate_peak_q(&freq, &spl, peak_idx);
assert!(q > 5.0, "narrow peak should have high Q, got {:.1}", q);
}
#[test]
fn test_estimate_peak_q_at_array_edge() {
let freq = Array1::from_vec(vec![20.0, 30.0, 40.0, 50.0, 60.0]);
let spl = Array1::from_vec(vec![90.0, 88.0, 85.0, 83.0, 80.0]); let q = estimate_peak_q(&freq, &spl, 0);
assert!(
q.is_finite() && q > 0.0,
"Q at edge should be positive finite, got {}",
q
);
}
#[test]
fn test_compute_local_baseline_excludes_center() {
let freq = Array1::from_vec(vec![50.0, 60.0, 70.0, 80.0, 90.0]);
let spl = Array1::from_vec(vec![80.0, 80.0, 95.0, 80.0, 80.0]); let baseline = compute_local_baseline(&freq, &spl, 2, 40.0, 100.0);
assert!(
(baseline - 80.0).abs() < 0.5,
"baseline should be ~80, got {:.1}",
baseline
);
}
#[test]
fn test_correction_weights_smooth_transition() {
let freq = Array1::from_vec(vec![50.0, 200.0, 800.0]);
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
early_reflection_weight: 0.2,
steady_state_weight: 0.6,
transition_width_oct: 1.0, ..Default::default()
};
let weights = build_correction_weights(&freq, &[], &config);
assert!(
weights[0] < 0.4,
"50 Hz weight should be near 0.2, got {}",
weights[0]
);
let midpoint = (0.2 + 0.6) / 2.0;
assert!(
(weights[1] - midpoint).abs() < 0.15,
"200 Hz weight should be near {:.1}, got {:.2}",
midpoint,
weights[1]
);
assert!(
weights[2] > 0.4,
"800 Hz weight should be near 0.6, got {}",
weights[2]
);
}
#[test]
fn test_detect_room_modes_short_array() {
let freq = Array1::from_vec(vec![50.0, 60.0, 70.0]);
let spl = Array1::from_vec(vec![85.0, 90.0, 85.0]);
let config = DecomposedCorrectionConfig::default();
let modes = detect_room_modes(&freq, &spl, &config);
assert!(modes.is_empty());
}
#[test]
fn test_detect_room_modes_ignores_above_schroeder() {
let n = 200;
let freq = Array1::linspace(20.0, 1000.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
spl[i] += 10.0 / (1.0 + ((f - 400.0) / 5.0_f64).powi(2));
}
let config = DecomposedCorrectionConfig {
schroeder_freq: 200.0,
..Default::default()
};
let modes = detect_room_modes(&freq, &spl, &config);
assert!(
modes.iter().all(|m| m.frequency <= 200.0),
"modes above Schroeder should not be detected"
);
}
#[test]
fn test_estimate_peak_q_one_sided_low() {
let n = 100;
let freq = Array1::linspace(20.0, 200.0, n);
let mut spl = Array1::from_elem(n, 80.0);
for i in 0..n {
let f = freq[i];
let bw = 25.0 / 5.0; spl[i] += 10.0 / (1.0 + ((f - 25.0) / (bw / 2.0_f64)).powi(2));
}
let peak_idx = spl
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
let q = estimate_peak_q(&freq, &spl, peak_idx);
assert!(
q < 15.0,
"one-sided Q estimate should not double: got {:.1}, expected ~5",
q
);
assert!(q > 2.0, "Q should still be reasonable: got {:.1}", q);
}
#[test]
fn test_estimate_peak_q_interpolation_denom_zero() {
let freq = Array1::from_vec(vec![50.0, 55.0, 60.0, 65.0, 70.0, 75.0, 80.0]);
let spl = Array1::from_vec(vec![80.0, 83.0, 83.0, 90.0, 83.0, 83.0, 80.0]);
let q = estimate_peak_q(&freq, &spl, 3);
assert!(
q.is_finite() && q > 0.0,
"Q should be finite positive, got {:.1}",
q
);
}
#[test]
fn test_decomposed_correction_config_defaults() {
let config = DecomposedCorrectionConfig::default();
assert_eq!(config.schroeder_freq, 250.0);
assert_eq!(config.steady_state_weight, 0.4);
assert_eq!(config.min_mode_q, 3.0);
assert_eq!(config.min_mode_prominence_db, 3.0);
assert_eq!(config.mode_correction_weight, 1.0);
assert_eq!(config.early_reflection_weight, 0.3);
assert_eq!(config.transition_width_oct, 0.5);
}
}