use statrs::function::erf::erf;
use std::f64::consts::PI;
fn truncated_bandwidth(avg: f64, sigma: f64, lower: Option<f64>, upper: Option<f64>) -> f64 {
let upper_integral = if let Some(upper) = upper {
integral(avg, upper, sigma)
} else {
avg * 0.5f64
};
let lower_integral = if let Some(lower) = lower {
integral(avg, lower, sigma)
} else {
integral(avg, 0f64, sigma)
};
let upper_truncate = if let Some(upper) = upper {
upper * (1f64 - cdf(upper, avg, sigma))
} else {
0.0f64
};
let lower_truncate = if let Some(lower) = lower {
lower * cdf(lower, avg, sigma)
} else {
0.0f64
};
upper_integral - lower_integral + lower_truncate + upper_truncate
}
fn integral(avg: f64, t: f64, sigma: f64) -> f64 {
let part1 = avg * 0.5f64 * erf((t - avg) / sigma / 2.0_f64.sqrt());
let part2 =
-sigma / (2.0f64 * PI).sqrt() * (-(t - avg) * (t - avg) * 0.5f64 / sigma / sigma).exp();
part1 + part2
}
fn cdf(t: f64, avg: f64, sigma: f64) -> f64 {
0.5f64 * (1f64 + erf((t - avg) / sigma / 2f64.sqrt()))
}
fn deri_truncated_bandwidth(avg: f64, sigma: f64, lower: Option<f64>, upper: Option<f64>) -> f64 {
let upper_integral = if let Some(upper) = upper {
deri_integral(avg, upper, sigma)
} else {
0.5f64
};
let lower_integral = if let Some(lower) = lower {
deri_integral(avg, lower, sigma)
} else {
deri_integral(avg, 0f64, sigma)
};
let upper_truncate = if let Some(upper) = upper {
upper * (-deri_cdf(upper, avg, sigma))
} else {
0.0f64
};
let lower_truncate = if let Some(lower) = lower {
lower * deri_cdf(lower, avg, sigma)
} else {
0.0f64
};
upper_integral - lower_integral + lower_truncate + upper_truncate
}
fn deri_integral(avg: f64, t: f64, sigma: f64) -> f64 {
let part1 = 0.5f64 * erf((t - avg) / sigma / 2.0_f64.sqrt());
let part2 = (-(t - avg) * (t - avg) * 0.5f64 / sigma / sigma).exp() * (-t)
/ (2.0f64 * PI).sqrt()
/ sigma;
part1 + part2
}
fn deri_cdf(t: f64, avg: f64, sigma: f64) -> f64 {
-(-(t - avg) * (t - avg) / 2.0f64 / sigma / sigma).exp() / sigma / (2.0f64 * PI).sqrt()
}
pub fn solve(x: f64, sigma: f64, mut lower: Option<f64>, upper: Option<f64>) -> Option<f64> {
if sigma.abs() <= f64::EPSILON {
return Some(x);
}
if lower.is_some_and(|lower| lower >= x * (1.0 + f64::EPSILON)) {
return lower;
}
if lower.is_none() && x <= f64::EPSILON {
return 0f64.into();
}
if upper.is_some_and(|upper| upper * (1.0 + f64::EPSILON) <= x) {
return upper;
}
let mut result = x;
if lower.is_some_and(|l| l < 0.0) || lower.is_none() {
lower = Some(0.0f64);
}
let mut last_diff = f64::MAX;
let mut run_cnt = 10;
while run_cnt > 0 {
let f_x = truncated_bandwidth(result, sigma, lower, upper);
let diff = (f_x - x).abs();
if diff < last_diff {
last_diff = diff;
run_cnt = 100;
} else {
run_cnt -= 1;
}
result = result - (f_x - x) / deri_truncated_bandwidth(result, sigma, lower, upper);
}
Some(result)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
fn test_deri<F, G>(func: F, deri: G, low: f64, high: f64)
where
F: Fn(f64) -> f64,
G: Fn(f64) -> f64,
{
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..1000 {
let x = rng.random_range(low..high);
let eps = 5E-8 * (low + high);
let delta1 = func(x + eps) - func(x);
let delta2 = eps * deri(x + eps * 0.5);
dbg!(delta1, delta2);
if delta1 * delta2 > 0.0 {
assert!(delta1 / delta2 < 1.0000001);
assert!(delta2 / delta1 < 1.0000001);
} else {
assert!(delta1.abs() < f32::EPSILON.into());
assert!(delta2.abs() < f32::EPSILON.into());
}
}
}
#[test]
fn test_truncated_bandwidth() {
assert_eq!(
truncated_bandwidth(10.0, 5.0, None, None),
10.042453513094314
);
test_deri(
|x| truncated_bandwidth(x, 3.0, None, None),
|x| deri_truncated_bandwidth(x, 3.0, None, None),
0.0,
10.0,
);
test_deri(
|x| truncated_bandwidth(x, 3.0, Some(3.0), None),
|x| deri_truncated_bandwidth(x, 3.0, Some(3.0), None),
0.0,
10.0,
);
test_deri(
|x| truncated_bandwidth(x, 3.0, Some(3.0), Some(20.0)),
|x| deri_truncated_bandwidth(x, 3.0, Some(3.0), Some(20.0)),
0.0,
10.0,
);
}
#[test]
fn test_integral() {
assert_eq!(integral(10.0, 5.0, 2.0), -4.972959947732017);
assert_eq!(integral(10.0, 1E9, 2.0), 5.0);
assert_eq!(integral(10.0, -1E9, 2.0), -5.0);
test_deri(
|x| integral(x, 8.0, 5.0),
|x| deri_integral(x, 8.0, 5.0),
6.0,
10.0,
);
test_deri(
|x| integral(x, 8.0, 15.0),
|x| deri_integral(x, 8.0, 15.0),
6.0,
10.0,
);
}
#[test]
fn test_cdf() {
assert_eq!(
cdf(12.0, 12.0, 4.0) * 2.0 - 1.0, 0.0
);
assert_eq!(
cdf(16.0, 12.0, 4.0) - cdf(8.0, 12.0, 4.0), 0.6826894921098856
);
assert_eq!(
cdf(20.0, 12.0, 4.0) - cdf(4.0, 12.0, 4.0), 0.9544997361056748
);
assert_eq!(
cdf(24.0, 12.0, 4.0) - cdf(0.0, 12.0, 4.0), 0.997300203936851
);
test_deri(|x| cdf(8.0, x, 5.0), |x| deri_cdf(8.0, x, 5.0), 6.0, 10.0);
test_deri(
|x| cdf(123.0, x, 15.0),
|x| deri_cdf(123.0, x, 15.0),
110.0,
136.0,
);
}
}