use crate::error::{InterpolateError, InterpolateResult};
pub trait Variogram: Send + Sync {
fn gamma(&self, h: f64) -> f64;
fn is_bounded(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Variogram>;
}
#[derive(Debug, Clone, Copy)]
pub struct SphericalVariogram {
pub nugget: f64,
pub sill: f64,
pub range: f64,
}
impl Variogram for SphericalVariogram {
fn gamma(&self, h: f64) -> f64 {
if h <= 0.0 {
return 0.0;
}
if h >= self.range {
return self.nugget + self.sill;
}
let u = h / self.range;
self.nugget + self.sill * (1.5 * u - 0.5 * u * u * u)
}
fn clone_box(&self) -> Box<dyn Variogram> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ExponentialVariogram {
pub nugget: f64,
pub sill: f64,
pub range: f64,
}
impl Variogram for ExponentialVariogram {
fn gamma(&self, h: f64) -> f64 {
if h <= 0.0 {
return 0.0;
}
self.nugget + self.sill * (1.0 - (-3.0 * h / self.range).exp())
}
fn clone_box(&self) -> Box<dyn Variogram> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct GaussianVariogram {
pub nugget: f64,
pub sill: f64,
pub range: f64,
}
impl Variogram for GaussianVariogram {
fn gamma(&self, h: f64) -> f64 {
if h <= 0.0 {
return 0.0;
}
let u = h / self.range;
self.nugget + self.sill * (1.0 - (-3.0 * u * u).exp())
}
fn clone_box(&self) -> Box<dyn Variogram> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct PowerVariogram {
pub nugget: f64,
pub slope: f64,
pub power: f64,
}
impl Variogram for PowerVariogram {
fn gamma(&self, h: f64) -> f64 {
if h <= 0.0 {
return 0.0;
}
self.nugget + self.slope * h.powf(self.power)
}
fn is_bounded(&self) -> bool {
false
}
fn clone_box(&self) -> Box<dyn Variogram> {
Box::new(*self)
}
}
fn lu_factor(mut a: Vec<f64>, n: usize) -> InterpolateResult<(Vec<f64>, Vec<usize>)> {
let mut piv: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = a[k * n + k].abs();
let mut max_row = k;
for i in (k + 1)..n {
let v = a[i * n + k].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val < 1e-15 {
return Err(InterpolateError::ComputationError(
"Singular kriging matrix; add nugget > 0 or check data".into(),
));
}
if max_row != k {
piv.swap(k, max_row);
for j in 0..n {
let tmp = a[k * n + j];
a[k * n + j] = a[max_row * n + j];
a[max_row * n + j] = tmp;
}
}
for i in (k + 1)..n {
a[i * n + k] /= a[k * n + k];
for j in (k + 1)..n {
let tmp = a[i * n + k] * a[k * n + j];
a[i * n + j] -= tmp;
}
}
}
Ok((a, piv))
}
fn lu_solve(lu: &[f64], piv: &[usize], b: &[f64], n: usize) -> Vec<f64> {
let mut x: Vec<f64> = (0..n).map(|i| b[piv[i]]).collect();
for i in 0..n {
for j in 0..i {
x[i] -= lu[i * n + j] * x[j];
}
}
for i in (0..n).rev() {
for j in (i + 1)..n {
x[i] -= lu[i * n + j] * x[j];
}
x[i] /= lu[i * n + i];
}
x
}
pub struct OrdinaryKriging {
pub points: Vec<Vec<f64>>,
pub values: Vec<f64>,
variogram: Box<dyn Variogram>,
lu_mat: Vec<f64>,
lu_piv: Vec<usize>,
c0: f64,
n: usize,
bounded: bool,
}
impl Clone for OrdinaryKriging {
fn clone(&self) -> Self {
Self {
points: self.points.clone(),
values: self.values.clone(),
variogram: self.variogram.clone_box(),
lu_mat: self.lu_mat.clone(),
lu_piv: self.lu_piv.clone(),
c0: self.c0,
n: self.n,
bounded: self.bounded,
}
}
}
impl std::fmt::Debug for OrdinaryKriging {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OrdinaryKriging")
.field("n", &self.n)
.finish()
}
}
impl OrdinaryKriging {
pub fn fit(
points: Vec<Vec<f64>>,
values: Vec<f64>,
variogram: Box<dyn Variogram>,
) -> InterpolateResult<OrdinaryKriging> {
let n = points.len();
if n == 0 {
return Err(InterpolateError::InvalidInput {
message: "no data points".into(),
});
}
if values.len() != n {
return Err(InterpolateError::ShapeMismatch {
expected: format!("{}", n),
actual: format!("{}", values.len()),
object: "values".into(),
});
}
let bounded = variogram.is_bounded();
let c0 = if bounded {
variogram.gamma(1e12)
} else {
0.0 };
let m = n + 1;
let mut mat = vec![0.0_f64; m * m];
for i in 0..n {
for j in 0..n {
let h = euclidean_dist(&points[i], &points[j]);
let gamma = variogram.gamma(h);
mat[i * m + j] = if bounded { c0 - gamma } else { gamma };
}
mat[i * m + n] = 1.0;
mat[n * m + i] = 1.0;
}
mat[n * m + n] = 0.0;
let (lu_mat, lu_piv) = lu_factor(mat, m)?;
Ok(OrdinaryKriging {
points,
values,
variogram,
lu_mat,
lu_piv,
c0,
n,
bounded,
})
}
pub fn predict(&self, x: &[f64]) -> InterpolateResult<(f64, f64)> {
if !self.points.is_empty() && x.len() != self.points[0].len() {
return Err(InterpolateError::DimensionMismatch(format!(
"expected dim {}, got {}",
self.points[0].len(),
x.len()
)));
}
let m = self.n + 1;
let mut rhs = vec![0.0_f64; m];
for i in 0..self.n {
let h = euclidean_dist(x, &self.points[i]);
let gamma = self.variogram.gamma(h);
rhs[i] = if self.bounded { self.c0 - gamma } else { gamma };
}
rhs[self.n] = 1.0;
let sol = lu_solve(&self.lu_mat, &self.lu_piv, &rhs, m);
let estimate: f64 = (0..self.n).map(|i| sol[i] * self.values[i]).sum();
let rhs_dot_w: f64 = (0..self.n).map(|i| rhs[i] * sol[i]).sum();
let variance = if self.bounded {
(self.c0 - rhs_dot_w - sol[self.n]).max(0.0)
} else {
(rhs_dot_w + sol[self.n]).max(0.0)
};
Ok((estimate, variance))
}
}
fn euclidean_dist(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f64>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_1d_data() -> (Vec<Vec<f64>>, Vec<f64>) {
let xs = vec![0.0_f64, 1.0, 2.0, 3.0, 4.0];
let pts: Vec<Vec<f64>> = xs.iter().map(|&x| vec![x]).collect();
let vals: Vec<f64> = xs.iter().map(|&x| x * x).collect(); (pts, vals)
}
#[test]
fn test_spherical_kriging_interpolates_data() {
let (pts, vals) = make_1d_data();
let vgm = SphericalVariogram {
nugget: 0.0,
sill: 20.0,
range: 10.0,
};
let ok =
OrdinaryKriging::fit(pts.clone(), vals.clone(), Box::new(vgm)).expect("fit failed");
for (p, &v) in pts.iter().zip(vals.iter()) {
let (est, _var) = ok.predict(p).expect("predict failed");
assert!(
(est - v).abs() < 1e-6,
"spherical: at {:?} expected {} got {}",
p,
v,
est
);
}
}
#[test]
fn test_exponential_kriging_interpolates_data() {
let (pts, vals) = make_1d_data();
let vgm = ExponentialVariogram {
nugget: 0.0,
sill: 20.0,
range: 10.0,
};
let ok =
OrdinaryKriging::fit(pts.clone(), vals.clone(), Box::new(vgm)).expect("fit failed");
for (p, &v) in pts.iter().zip(vals.iter()) {
let (est, _) = ok.predict(p).expect("predict");
assert!((est - v).abs() < 1e-6, "exp: {:?} {} {}", p, v, est);
}
}
#[test]
fn test_gaussian_kriging_interpolates_data() {
let (pts, vals) = make_1d_data();
let vgm = GaussianVariogram {
nugget: 0.0,
sill: 20.0,
range: 10.0,
};
let ok =
OrdinaryKriging::fit(pts.clone(), vals.clone(), Box::new(vgm)).expect("fit failed");
for (p, &v) in pts.iter().zip(vals.iter()) {
let (est, _) = ok.predict(p).expect("predict");
assert!((est - v).abs() < 1e-6, "gauss: {:?} {} {}", p, v, est);
}
}
#[test]
fn test_power_variogram() {
let (pts, vals) = make_1d_data();
let vgm = PowerVariogram {
nugget: 0.0,
slope: 1.0,
power: 1.5,
};
let ok =
OrdinaryKriging::fit(pts.clone(), vals.clone(), Box::new(vgm)).expect("fit failed");
for (p, &v) in pts.iter().zip(vals.iter()) {
let (est, _) = ok.predict(p).expect("predict");
assert!((est - v).abs() < 1e-4, "power: {:?} {} {}", p, v, est);
}
}
#[test]
fn test_variance_is_nonnegative() {
let (pts, vals) = make_1d_data();
let vgm = SphericalVariogram {
nugget: 0.01,
sill: 20.0,
range: 10.0,
};
let ok = OrdinaryKriging::fit(pts, vals, Box::new(vgm)).expect("fit failed");
let test_pts = vec![vec![0.5_f64], vec![1.5], vec![2.5]];
for p in &test_pts {
let (_est, var) = ok.predict(p).expect("predict");
assert!(var >= 0.0, "variance negative at {:?}: {}", p, var);
}
}
#[test]
fn test_variogram_gamma_at_zero() {
let svgm = SphericalVariogram {
nugget: 0.1,
sill: 1.0,
range: 2.0,
};
assert_eq!(svgm.gamma(0.0), 0.0);
let evgm = ExponentialVariogram {
nugget: 0.1,
sill: 1.0,
range: 2.0,
};
assert_eq!(evgm.gamma(0.0), 0.0);
let gvgm = GaussianVariogram {
nugget: 0.1,
sill: 1.0,
range: 2.0,
};
assert_eq!(gvgm.gamma(0.0), 0.0);
let pvgm = PowerVariogram {
nugget: 0.1,
slope: 1.0,
power: 1.5,
};
assert_eq!(pvgm.gamma(0.0), 0.0);
}
#[test]
fn test_spherical_reaches_sill() {
let vgm = SphericalVariogram {
nugget: 0.0,
sill: 5.0,
range: 2.0,
};
let v = vgm.gamma(100.0);
assert!((v - 5.0).abs() < 1e-10, "should reach sill: {}", v);
}
#[test]
fn test_error_on_empty() {
let vgm = SphericalVariogram {
nugget: 0.0,
sill: 1.0,
range: 1.0,
};
let r = OrdinaryKriging::fit(vec![], vec![], Box::new(vgm));
assert!(r.is_err());
}
#[test]
fn test_error_on_dim_mismatch_predict() {
let pts = vec![vec![0.0_f64, 0.0], vec![1.0, 1.0]];
let vals = vec![0.0_f64, 1.0];
let vgm = GaussianVariogram {
nugget: 0.0,
sill: 1.0,
range: 5.0,
};
let ok = OrdinaryKriging::fit(pts, vals, Box::new(vgm)).expect("fit");
let r = ok.predict(&[0.5]); assert!(r.is_err());
}
}