use crate::Curve;
use ndarray::Array1;
#[derive(Debug, Clone)]
pub struct SpatialRobustnessConfig {
pub variance_threshold_db: f64,
pub transition_width_db: f64,
pub min_correction_depth: f64,
pub mask_smoothing_octaves: f64,
}
impl Default for SpatialRobustnessConfig {
fn default() -> Self {
Self {
variance_threshold_db: 3.0,
transition_width_db: 2.0,
min_correction_depth: 0.1,
mask_smoothing_octaves: 1.0 / 6.0,
}
}
}
#[derive(Debug, Clone)]
pub struct SpatialRobustnessResult {
pub averaged_curve: Curve,
pub spatial_variance: Array1<f64>,
pub correction_depth: Array1<f64>,
}
pub fn rms_average(curves: &[Curve]) -> Curve {
assert!(!curves.is_empty(), "need at least one curve");
let n = curves.len() as f64;
let len = curves[0].freq.len();
let mut avg_spl = Array1::zeros(len);
for bin in 0..len {
let sum_power: f64 = curves
.iter()
.map(|c| 10.0_f64.powf(c.spl[bin] / 10.0))
.sum();
avg_spl[bin] = 10.0 * (sum_power / n).log10();
}
Curve {
freq: curves[0].freq.clone(),
spl: avg_spl,
phase: None,
}
}
pub fn spatial_std_dev(curves: &[Curve]) -> Array1<f64> {
assert!(!curves.is_empty(), "need at least one curve");
if curves.len() == 1 {
return Array1::zeros(curves[0].freq.len());
}
let n = curves.len() as f64;
let len = curves[0].freq.len();
let mut std_dev = Array1::zeros(len);
for bin in 0..len {
let mean: f64 = curves.iter().map(|c| c.spl[bin]).sum::<f64>() / n;
let variance: f64 = curves
.iter()
.map(|c| (c.spl[bin] - mean).powi(2))
.sum::<f64>()
/ (n - 1.0);
std_dev[bin] = variance.sqrt();
}
std_dev
}
pub fn correction_depth_mask(
freq: &Array1<f64>,
spatial_variance: &Array1<f64>,
config: &SpatialRobustnessConfig,
) -> Array1<f64> {
let len = freq.len();
let mut mask = Array1::zeros(len);
for i in 0..len {
let sigmoid = if config.transition_width_db <= 0.0 {
if spatial_variance[i] <= config.variance_threshold_db {
1.0
} else {
0.0
}
} else {
let x =
(config.variance_threshold_db - spatial_variance[i]) / config.transition_width_db;
1.0 / (1.0 + (-x).exp())
};
mask[i] = config.min_correction_depth + (1.0 - config.min_correction_depth) * sigmoid;
}
if config.mask_smoothing_octaves > 0.0 {
mask = smooth_log_frequency(&mask, freq, config.mask_smoothing_octaves);
}
mask
}
pub fn analyze_spatial_robustness(
curves: &[Curve],
config: &SpatialRobustnessConfig,
) -> SpatialRobustnessResult {
let averaged_curve = rms_average(curves);
let spatial_variance = spatial_std_dev(curves);
let correction_depth = correction_depth_mask(&averaged_curve.freq, &spatial_variance, config);
SpatialRobustnessResult {
averaged_curve,
spatial_variance,
correction_depth,
}
}
fn smooth_log_frequency(data: &Array1<f64>, freq: &Array1<f64>, width_octaves: f64) -> Array1<f64> {
let len = data.len();
let half_width = width_octaves / 2.0;
let mut smoothed = Array1::zeros(len);
for i in 0..len {
let center_log = freq[i].log2();
let low_log = center_log - half_width;
let high_log = center_log + half_width;
let mut sum = 0.0;
let mut count = 0.0;
for j in 0..len {
let f_log = freq[j].log2();
if f_log >= low_log && f_log <= high_log {
sum += data[j];
count += 1.0;
}
}
smoothed[i] = if count > 0.0 { sum / count } else { data[i] };
}
smoothed
}
#[cfg(test)]
mod tests {
use super::*;
fn make_curve(freq: Vec<f64>, spl: Vec<f64>) -> Curve {
Curve {
freq: Array1::from_vec(freq),
spl: Array1::from_vec(spl),
phase: None,
}
}
#[test]
fn test_rms_average_identical_curves() {
let curve = make_curve(vec![100.0, 1000.0, 10000.0], vec![80.0, 85.0, 75.0]);
let avg = rms_average(&[curve.clone(), curve.clone()]);
for i in 0..3 {
assert!(
(avg.spl[i] - curve.spl[i]).abs() < 0.01,
"bin {}: expected {}, got {}",
i,
curve.spl[i],
avg.spl[i]
);
}
}
#[test]
fn test_rms_average_vs_arithmetic() {
let c1 = make_curve(vec![100.0], vec![80.0]);
let c2 = make_curve(vec![100.0], vec![90.0]);
let avg = rms_average(&[c1, c2]);
let arithmetic_mean = (80.0 + 90.0) / 2.0; assert!(
avg.spl[0] > arithmetic_mean,
"RMS average ({:.2}) should be > arithmetic mean ({:.2})",
avg.spl[0],
arithmetic_mean
);
}
#[test]
fn test_spatial_std_dev_identical() {
let curve = make_curve(vec![100.0, 1000.0], vec![80.0, 85.0]);
let std = spatial_std_dev(&[curve.clone(), curve.clone()]);
assert!(std[0] < 0.01);
assert!(std[1] < 0.01);
}
#[test]
fn test_spatial_std_dev_different() {
let c1 = make_curve(vec![100.0], vec![80.0]);
let c2 = make_curve(vec![100.0], vec![86.0]);
let std = spatial_std_dev(&[c1, c2]);
assert!(
(std[0] - 4.24).abs() < 0.1,
"expected ~4.24, got {}",
std[0]
);
}
#[test]
fn test_correction_depth_low_variance() {
let freq = Array1::from_vec(vec![100.0]);
let variance = Array1::from_vec(vec![0.5]); let config = SpatialRobustnessConfig {
mask_smoothing_octaves: 0.0, ..Default::default()
};
let depth = correction_depth_mask(&freq, &variance, &config);
assert!(
depth[0] > 0.75,
"low variance should give high correction, got {}",
depth[0]
);
}
#[test]
fn test_correction_depth_high_variance() {
let freq = Array1::from_vec(vec![100.0]);
let variance = Array1::from_vec(vec![10.0]); let config = SpatialRobustnessConfig {
mask_smoothing_octaves: 0.0,
..Default::default()
};
let depth = correction_depth_mask(&freq, &variance, &config);
assert!(
depth[0] < 0.3,
"high variance should give reduced correction, got {}",
depth[0]
);
assert!(
depth[0] >= config.min_correction_depth,
"should never go below min_correction_depth"
);
}
#[test]
fn test_correction_depth_at_threshold() {
let freq = Array1::from_vec(vec![100.0]);
let variance = Array1::from_vec(vec![3.0]); let config = SpatialRobustnessConfig {
mask_smoothing_octaves: 0.0,
..Default::default()
};
let depth = correction_depth_mask(&freq, &variance, &config);
let expected = 0.1 + 0.9 * 0.5;
assert!(
(depth[0] - expected).abs() < 0.01,
"expected ~{:.2}, got {:.2}",
expected,
depth[0]
);
}
#[test]
fn test_correction_depth_zero_transition_width() {
let freq = Array1::from_vec(vec![100.0, 200.0]);
let variance = Array1::from_vec(vec![1.0, 5.0]); let config = SpatialRobustnessConfig {
variance_threshold_db: 3.0,
transition_width_db: 0.0, min_correction_depth: 0.1,
mask_smoothing_octaves: 0.0,
};
let depth = correction_depth_mask(&freq, &variance, &config);
assert!(
depth[0] > 0.9,
"below threshold should give full correction, got {}",
depth[0]
);
assert!(
(depth[1] - 0.1).abs() < 0.01,
"above threshold should give min correction, got {}",
depth[1]
);
}
#[test]
fn test_spatial_std_dev_single_curve() {
let curve = make_curve(vec![100.0, 1000.0], vec![80.0, 85.0]);
let std = spatial_std_dev(&[curve]);
assert_eq!(std[0], 0.0);
assert_eq!(std[1], 0.0);
}
#[test]
fn test_analyze_spatial_robustness_single_curve() {
let curve = make_curve(vec![100.0, 1000.0], vec![80.0, 85.0]);
let config = SpatialRobustnessConfig {
mask_smoothing_octaves: 0.0,
..Default::default()
};
let result = analyze_spatial_robustness(&[curve], &config);
assert!(result.spatial_variance.iter().all(|&v| v == 0.0));
assert!(
result.correction_depth.iter().all(|&d| d > 0.8),
"single curve should have high correction depth, got min={:.3}",
result
.correction_depth
.iter()
.cloned()
.fold(f64::INFINITY, f64::min)
);
}
#[test]
fn test_full_analysis() {
let c1 = make_curve(vec![100.0, 5000.0], vec![90.0, 80.0]);
let c2 = make_curve(vec![100.0, 5000.0], vec![91.0, 72.0]);
let c3 = make_curve(vec![100.0, 5000.0], vec![89.0, 85.0]);
let config = SpatialRobustnessConfig {
mask_smoothing_octaves: 0.0,
..Default::default()
};
let result = analyze_spatial_robustness(&[c1, c2, c3], &config);
assert!(result.spatial_variance[0] < 2.0);
assert!(result.correction_depth[0] > 0.7);
assert!(result.spatial_variance[1] > 5.0);
assert!(result.correction_depth[1] < 0.5);
}
#[test]
fn test_rms_average_negative_spl() {
let c1 = make_curve(vec![100.0], vec![-10.0]);
let c2 = make_curve(vec![100.0], vec![-20.0]);
let avg = rms_average(&[c1, c2]);
assert!(avg.spl[0] > -20.0 && avg.spl[0] < -10.0);
assert!(avg.spl[0].is_finite());
}
#[test]
fn test_smooth_log_frequency_reduces_variation() {
let freq = Array1::from_vec(vec![
50.0, 70.0, 100.0, 140.0, 200.0, 280.0, 400.0, 560.0, 800.0, 1120.0, 1600.0,
]);
let data = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
let smoothed = smooth_log_frequency(&data, &freq, 1.5);
let orig_range = 1.0;
let smooth_range = smoothed.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
- smoothed.iter().cloned().fold(f64::INFINITY, f64::min);
assert!(
smooth_range < orig_range,
"smoothing should reduce range: orig={:.2}, smoothed={:.2}",
orig_range,
smooth_range
);
}
#[test]
fn test_smooth_log_frequency_preserves_constant() {
let freq = Array1::from_vec(vec![100.0, 200.0, 400.0, 800.0]);
let data = Array1::from_vec(vec![0.5, 0.5, 0.5, 0.5]);
let smoothed = smooth_log_frequency(&data, &freq, 1.0);
for &v in smoothed.iter() {
assert!((v - 0.5).abs() < 0.001);
}
}
#[test]
fn test_correction_depth_with_smoothing_enabled() {
let freq = Array1::from_vec(vec![50.0, 100.0, 200.0, 500.0, 1000.0]);
let variance = Array1::from_vec(vec![1.0, 8.0, 1.0, 8.0, 1.0]);
let config = SpatialRobustnessConfig {
mask_smoothing_octaves: 0.5, ..Default::default()
};
let depth = correction_depth_mask(&freq, &variance, &config);
for &d in depth.iter() {
assert!(d.is_finite(), "depth should be finite");
assert!(
(0.0..=1.0).contains(&d),
"depth should be in [0, 1], got {}",
d
);
}
}
}