#[derive(Debug, Clone, PartialEq)]
pub enum IntegrationError {
InvalidStencil {
r: usize,
},
NonIncreasingBounds {
a: f64,
b: f64,
},
NonPositiveInformation {
info: f64,
},
NonPositivePreviousInformation {
info_prev: f64,
},
NonIncreasingInformation {
info_prev: f64,
info: f64,
},
ShapeMismatch {
points: usize,
weights: usize,
values: Option<usize>,
},
}
impl std::fmt::Display for IntegrationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidStencil { r } => {
write!(f, "r must be at least 2 for Simpson integration (got {r})")
}
Self::NonIncreasingBounds { a, b } => {
write!(f, "Lower bound a must satisfy a < b (got a={a}, b={b})")
}
Self::NonPositiveInformation { info } => {
write!(f, "Information must be positive (got {info})")
}
Self::NonPositivePreviousInformation { info_prev } => {
write!(f, "Previous information must be positive (got {info_prev})")
}
Self::NonIncreasingInformation { info_prev, info } => {
write!(
f,
"Current information ({info}) must exceed previous information ({info_prev})"
)
}
Self::ShapeMismatch {
points,
weights,
values,
} => {
if let Some(values) = values {
write!(
f,
"Grid metadata mismatch: points={points}, weights={weights}, values={values}"
)
} else {
write!(
f,
"Grid metadata mismatch: points={points}, weights={weights}"
)
}
}
}
}
}
impl std::error::Error for IntegrationError {}
#[derive(Debug, Clone, PartialEq)]
pub struct IntegrationGrid {
pub points: Vec<f64>,
pub weights: Vec<f64>,
}
impl IntegrationGrid {
fn validate(&self) -> Result<(), IntegrationError> {
if self.points.len() != self.weights.len() {
return Err(IntegrationError::ShapeMismatch {
points: self.points.len(),
weights: self.weights.len(),
values: None,
});
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DensityGrid {
pub points: Vec<f64>,
pub weights: Vec<f64>,
pub values: Vec<f64>,
}
impl DensityGrid {
fn validate(&self) -> Result<(), IntegrationError> {
if !(self.points.len() == self.weights.len() && self.points.len() == self.values.len()) {
return Err(IntegrationError::ShapeMismatch {
points: self.points.len(),
weights: self.weights.len(),
values: Some(self.values.len()),
});
}
Ok(())
}
}
const SQRT_2PI: f64 = 2.506_628_274_631_000_453_559_f64;
fn normal_pdf(x: f64) -> f64 {
(-0.5 * x * x).exp() / SQRT_2PI
}
pub fn gridpts(r: usize, mu: f64, a: f64, b: f64) -> Result<IntegrationGrid, IntegrationError> {
if r < 2 {
return Err(IntegrationError::InvalidStencil { r });
}
if !(a < b) || a.is_nan() || b.is_nan() {
return Err(IntegrationError::NonIncreasingBounds { a, b });
}
let base_count = 6 * r - 1;
let mut x = vec![0.0f64; base_count];
let r_f64 = r as f64;
for i in 0..(r - 1) {
let idx = i as f64 + 1.0;
let tmp = 3.0 + 4.0 * (r_f64 / idx).ln();
x[i] = mu - tmp;
x[6 * r - 2 - i] = mu + tmp;
}
for i in (r - 1)..(5 * r) {
let idx = i as f64 - (r as f64 - 1.0);
x[i] = mu - 3.0 + 3.0 * (idx / (2.0 * r_f64));
}
let mut trimmed = x;
let min_x = trimmed.iter().copied().fold(f64::INFINITY, f64::min);
if min_x < a {
trimmed.retain(|value| *value > a);
trimmed.insert(0, a);
}
let max_x = trimmed.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if max_x > b {
trimmed.retain(|value| *value < b);
trimmed.push(b);
}
if trimmed.is_empty() {
let fallback = if a.is_finite() {
a
} else if b.is_finite() {
b
} else {
mu
};
trimmed.push(fallback);
}
let m = trimmed.len();
if m == 1 {
return Ok(IntegrationGrid {
points: vec![trimmed[0]],
weights: vec![1.0],
});
}
let mut points = vec![0.0f64; 2 * m - 1];
let mut weights = vec![0.0f64; 2 * m - 1];
for (i, value) in trimmed.iter().enumerate() {
points[2 * i] = *value;
if i + 1 < m {
points[2 * i + 1] = 0.5 * (trimmed[i] + trimmed[i + 1]);
}
}
weights[0] = trimmed[1] - trimmed[0];
for i in 1..(m - 1) {
weights[2 * i] = trimmed[i + 1] - trimmed[i - 1];
}
weights[2 * (m - 1)] = trimmed[m - 1] - trimmed[m - 2];
for i in 0..(m - 1) {
weights[2 * i + 1] = 4.0 * (trimmed[i + 1] - trimmed[i]);
}
for weight in &mut weights {
*weight /= 6.0;
}
let grid = IntegrationGrid { points, weights };
grid.validate()?;
Ok(grid)
}
pub fn h1(
r: usize,
theta: f64,
info: f64,
a: f64,
b: f64,
) -> Result<DensityGrid, IntegrationError> {
if !(info > 0.0) || info.is_nan() {
return Err(IntegrationError::NonPositiveInformation { info });
}
let mu = theta * info.sqrt();
let IntegrationGrid { points, weights } = gridpts(r, mu, a, b)?;
let values = points
.iter()
.zip(weights.iter())
.map(|(z, w)| w * normal_pdf(z - mu))
.collect();
let grid = DensityGrid {
points,
weights,
values,
};
grid.validate()?;
Ok(grid)
}
pub fn hupdate(
r: usize,
theta: f64,
info: f64,
a: f64,
b: f64,
theta_prev: f64,
info_prev: f64,
prev: &DensityGrid,
) -> Result<DensityGrid, IntegrationError> {
if !(info_prev > 0.0) || info_prev.is_nan() {
return Err(IntegrationError::NonPositivePreviousInformation { info_prev });
}
if !(info > info_prev) || info.is_nan() {
return Err(IntegrationError::NonIncreasingInformation { info_prev, info });
}
prev.validate()?;
let rt_info = info.sqrt();
let rt_info_prev = info_prev.sqrt();
let delta = info - info_prev;
let rt_delta = delta.sqrt();
let IntegrationGrid { points, weights } = gridpts(r, theta * rt_info, a, b)?;
let mu = theta * info - theta_prev * info_prev;
let scale = rt_info / rt_delta;
let mut t = Vec::with_capacity(prev.points.len());
for &z in &prev.points {
t.push((z * rt_info_prev + mu) / rt_delta);
}
let mut values = Vec::with_capacity(points.len());
for &z in &points {
let base = z * scale;
let mut acc = 0.0;
for (idx, &prev_value) in prev.values.iter().enumerate() {
let x = base - t[idx];
acc += prev_value * normal_pdf(x);
}
values.push(acc);
}
for (value, weight) in values.iter_mut().zip(weights.iter()) {
*value *= weight * scale;
}
let updated = DensityGrid {
points,
weights,
values,
};
updated.validate()?;
Ok(updated)
}