use crate::estimate::EstimationError;
fn validate(deriv: &[f64], deriv_var: &[f64], epsilon: f64) -> Result<(), EstimationError> {
if deriv.len() != deriv_var.len() {
return Err(EstimationError::InvalidInput(format!(
"posterior-SNR weights: {} derivatives but {} variances",
deriv.len(),
deriv_var.len()
)));
}
if !(epsilon.is_finite() && epsilon > 0.0) {
return Err(EstimationError::InvalidInput(format!(
"posterior-SNR weights: epsilon must be finite and positive; got {epsilon}"
)));
}
for (k, (&d, &v)) in deriv.iter().zip(deriv_var.iter()).enumerate() {
if !d.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"posterior-SNR weights: derivative[{k}] is not finite ({d})"
)));
}
if !(v.is_finite() && v >= 0.0) {
return Err(EstimationError::InvalidInput(format!(
"posterior-SNR weights: variance[{k}] must be finite and non-negative; got {v}"
)));
}
}
Ok(())
}
pub fn magnitude_adaptive_weights(
deriv: &[f64],
epsilon: f64,
) -> Result<Vec<f64>, EstimationError> {
if !(epsilon.is_finite() && epsilon > 0.0) {
return Err(EstimationError::InvalidInput(format!(
"magnitude weights: epsilon must be finite and positive; got {epsilon}"
)));
}
let eps2 = epsilon * epsilon;
deriv
.iter()
.enumerate()
.map(|(k, &d)| {
if d.is_finite() {
Ok(1.0 / (d * d + eps2).sqrt())
} else {
Err(EstimationError::InvalidInput(format!(
"magnitude weights: derivative[{k}] is not finite ({d})"
)))
}
})
.collect()
}
pub fn posterior_snr_weights(
deriv: &[f64],
deriv_var: &[f64],
epsilon: f64,
) -> Result<Vec<f64>, EstimationError> {
validate(deriv, deriv_var, epsilon)?;
let eps2 = epsilon * epsilon;
Ok(deriv
.iter()
.zip(deriv_var.iter())
.map(|(&d, &v)| 1.0 / (d * d + v + eps2).sqrt())
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_variance_recovers_magnitude_only_weights() {
let deriv = [0.0, 0.5, -2.0, 3.1];
let zero_var = [0.0; 4];
let snr = posterior_snr_weights(&deriv, &zero_var, 1e-3).unwrap();
let mag = magnitude_adaptive_weights(&deriv, 1e-3).unwrap();
for (a, b) in snr.iter().zip(mag.iter()) {
assert!((a - b).abs() < 1e-15, "{a} vs {b}");
}
}
#[test]
fn variance_shrinks_weights_below_magnitude_only() {
let deriv = [0.0, 0.5, -2.0, 3.1];
let var = [1.0, 0.2, 4.0, 0.01];
let snr = posterior_snr_weights(&deriv, &var, 1e-3).unwrap();
let mag = magnitude_adaptive_weights(&deriv, 1e-3).unwrap();
for k in 0..deriv.len() {
assert!(snr[k] < mag[k], "k={k}: snr {} !< mag {}", snr[k], mag[k]);
}
}
#[test]
fn weights_are_monotone_decreasing_in_variance() {
let deriv = [0.4_f64];
let low = posterior_snr_weights(&deriv, &[0.1], 1e-2).unwrap()[0];
let mid = posterior_snr_weights(&deriv, &[1.0], 1e-2).unwrap()[0];
let high = posterior_snr_weights(&deriv, &[10.0], 1e-2).unwrap()[0];
assert!(low > mid && mid > high, "{low} {mid} {high}");
}
#[test]
fn exact_value_matches_formula() {
let w = posterior_snr_weights(&[3.0], &[4.0], 0.0_f64.max(f64::EPSILON)).unwrap()[0];
assert!((w - 1.0 / 13.0_f64.sqrt()).abs() < 1e-6, "{w}");
}
#[test]
fn invalid_inputs_are_rejected() {
assert!(posterior_snr_weights(&[1.0, 2.0], &[1.0], 1e-3).is_err());
assert!(posterior_snr_weights(&[1.0], &[-1.0], 1e-3).is_err());
assert!(posterior_snr_weights(&[1.0], &[1.0], 0.0).is_err());
assert!(posterior_snr_weights(&[f64::NAN], &[1.0], 1e-3).is_err());
}
}