use std::sync::Arc;
use nalgebra::{DMatrix, DVector, Dyn, linalg::LU};
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use crate::Real;
use crate::distance::{GeoCoord, PreparedGeoCoord, haversine_distance_prepared, prepare_geo_coord};
use crate::error::KrigingError;
use crate::geo_dataset::GeoDataset;
use crate::kriging::ordinary::{Prediction, kriging_diagonal_jitter};
use crate::variogram::models::VariogramModel;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UniversalTrend {
Constant,
Linear,
Quadratic,
}
impl UniversalTrend {
pub fn n_basis(self) -> usize {
match self {
UniversalTrend::Constant => 1,
UniversalTrend::Linear => 3,
UniversalTrend::Quadratic => 6,
}
}
fn eval(self, coord: GeoCoord, out: &mut [Real]) {
let lat = coord.lat();
let lon = coord.lon();
match self {
UniversalTrend::Constant => {
out[0] = 1.0;
}
UniversalTrend::Linear => {
out[0] = 1.0;
out[1] = lat;
out[2] = lon;
}
UniversalTrend::Quadratic => {
out[0] = 1.0;
out[1] = lat;
out[2] = lon;
out[3] = lat * lat;
out[4] = lat * lon;
out[5] = lon * lon;
}
}
}
}
#[derive(Debug)]
pub struct UniversalKrigingModel {
coords: Vec<GeoCoord>,
prepared_coords: Vec<PreparedGeoCoord>,
values: Vec<Real>,
variogram: VariogramModel,
trend: UniversalTrend,
cov_at_zero: Real,
system: DMatrix<Real>,
system_lu: Arc<LU<Real, Dyn, Dyn>>,
}
impl Clone for UniversalKrigingModel {
fn clone(&self) -> Self {
Self {
coords: self.coords.clone(),
prepared_coords: self.prepared_coords.clone(),
values: self.values.clone(),
variogram: self.variogram,
trend: self.trend,
cov_at_zero: self.cov_at_zero,
system: self.system.clone(),
system_lu: Arc::clone(&self.system_lu),
}
}
}
impl UniversalKrigingModel {
pub fn new(
dataset: GeoDataset,
variogram: VariogramModel,
trend: UniversalTrend,
) -> Result<Self, KrigingError> {
let (coords, values) = dataset.into_parts();
let n = coords.len();
let p = trend.n_basis();
if n < p + 1 {
return Err(KrigingError::InsufficientData(p + 1));
}
let prepared_coords = coords
.iter()
.copied()
.map(prepare_geo_coord)
.collect::<Vec<_>>();
let system = build_universal_system(&coords, &prepared_coords, variogram, trend);
let system_lu = Arc::new(system.clone().lu());
let probe = DVector::from_element(n + p, 0.0);
if system_lu.solve(&probe).is_none() {
return Err(KrigingError::MatrixError(
"could not factorize universal kriging system".to_string(),
));
}
Ok(Self {
coords,
prepared_coords,
values,
variogram,
trend,
cov_at_zero: variogram.covariance(0.0),
system,
system_lu,
})
}
pub fn trend(&self) -> UniversalTrend {
self.trend
}
pub fn predict(&self, coord: GeoCoord) -> Result<Prediction, KrigingError> {
let n = self.coords.len();
let p = self.trend.n_basis();
let mut rhs = DVector::from_element(n + p, 0.0);
self.predict_with_rhs(coord, &mut rhs)
}
pub fn predict_batch(&self, coords: &[GeoCoord]) -> Result<Vec<Prediction>, KrigingError> {
#[cfg(not(target_arch = "wasm32"))]
{
let n = self.coords.len();
let p = self.trend.n_basis();
coords
.par_iter()
.map_init(
|| DVector::<Real>::from_element(n + p, 0.0),
|rhs, c| self.predict_with_rhs(*c, rhs),
)
.collect()
}
#[cfg(target_arch = "wasm32")]
{
let n = self.coords.len();
let p = self.trend.n_basis();
let mut rhs = DVector::from_element(n + p, 0.0);
let mut out = Vec::with_capacity(coords.len());
for &c in coords {
out.push(self.predict_with_rhs(c, &mut rhs)?);
}
Ok(out)
}
}
fn predict_with_rhs(
&self,
coord: GeoCoord,
rhs: &mut DVector<Real>,
) -> Result<Prediction, KrigingError> {
let n = self.coords.len();
let p = self.trend.n_basis();
let prepared = prepare_geo_coord(coord);
for i in 0..n {
rhs[i] = self.variogram.covariance(haversine_distance_prepared(
self.prepared_coords[i],
prepared,
));
}
let mut f0 = vec![0.0 as Real; p];
self.trend.eval(coord, &mut f0);
for l in 0..p {
rhs[n + l] = f0[l];
}
let sol = self.system_lu.solve(rhs).ok_or_else(|| {
KrigingError::MatrixError("could not solve universal kriging system".to_string())
})?;
let mut value: Real = 0.0;
let mut cov_dot: Real = 0.0;
for i in 0..n {
value += sol[i] * self.values[i];
cov_dot += sol[i] * rhs[i];
}
let mut mu_dot: Real = 0.0;
for l in 0..p {
mu_dot += sol[n + l] * f0[l];
}
let variance = (self.cov_at_zero - cov_dot - mu_dot).max(0.0);
Ok(Prediction { value, variance })
}
}
fn build_universal_system(
coords: &[GeoCoord],
prepared: &[PreparedGeoCoord],
variogram: VariogramModel,
trend: UniversalTrend,
) -> DMatrix<Real> {
let n = coords.len();
let p = trend.n_basis();
let diag_eps = kriging_diagonal_jitter(n, variogram);
let mut m = DMatrix::from_element(n + p, n + p, 0.0);
for i in 0..n {
for j in i..n {
let mut cov =
variogram.covariance(haversine_distance_prepared(prepared[i], prepared[j]));
if i == j {
cov += diag_eps;
}
m[(i, j)] = cov;
m[(j, i)] = cov;
}
}
let mut fi = vec![0.0 as Real; p];
for i in 0..n {
trend.eval(coords[i], &mut fi);
for l in 0..p {
m[(i, n + l)] = fi[l];
m[(n + l, i)] = fi[l];
}
}
m
}
#[cfg(test)]
mod tests {
use super::*;
use crate::variogram::models::VariogramType;
#[test]
fn constant_trend_matches_ordinary_kriging_closely() {
let coords = vec![
GeoCoord::try_new(0.0, 0.0).unwrap(),
GeoCoord::try_new(0.0, 1.0).unwrap(),
GeoCoord::try_new(1.0, 0.0).unwrap(),
GeoCoord::try_new(1.0, 1.0).unwrap(),
];
let values = vec![10.0, 12.0, 14.0, 16.0];
let variogram = VariogramModel::new(0.01, 5.0, 300.0, VariogramType::Exponential).unwrap();
let uk = UniversalKrigingModel::new(
GeoDataset::new(coords.clone(), values.clone()).unwrap(),
variogram,
UniversalTrend::Constant,
)
.expect("uk");
let ok =
crate::OrdinaryKrigingModel::new(GeoDataset::new(coords, values).unwrap(), variogram)
.expect("ok");
let target = GeoCoord::try_new(0.5, 0.5).unwrap();
let uk_pred = uk.predict(target).expect("uk predict");
let ok_pred = ok.predict(target).expect("ok predict");
assert!((uk_pred.value - ok_pred.value).abs() < 1e-3);
assert!((uk_pred.variance - ok_pred.variance).abs() < 1e-3);
}
#[test]
fn linear_trend_fits_planar_surface_exactly() {
let coords = vec![
GeoCoord::try_new(0.0, 0.0).unwrap(),
GeoCoord::try_new(0.0, 1.0).unwrap(),
GeoCoord::try_new(1.0, 0.0).unwrap(),
GeoCoord::try_new(1.0, 1.0).unwrap(),
GeoCoord::try_new(2.0, 0.5).unwrap(),
];
let values: Vec<Real> = coords
.iter()
.map(|c| 1.0 + 2.0 * c.lat() + 3.0 * c.lon())
.collect();
let variogram = VariogramModel::new(0.01, 1.0, 500.0, VariogramType::Exponential).unwrap();
let model = UniversalKrigingModel::new(
GeoDataset::new(coords, values).unwrap(),
variogram,
UniversalTrend::Linear,
)
.expect("uk");
let target = GeoCoord::try_new(0.7, 0.3).unwrap();
let expected = 1.0 + 2.0 * 0.7 + 3.0 * 0.3;
let pred = model.predict(target).expect("prediction");
assert!(
(pred.value - expected).abs() < 0.1,
"got {}, expected {}",
pred.value,
expected
);
}
#[test]
fn rejects_insufficient_data_for_quadratic_trend() {
let coords = vec![
GeoCoord::try_new(0.0, 0.0).unwrap(),
GeoCoord::try_new(0.0, 1.0).unwrap(),
GeoCoord::try_new(1.0, 0.0).unwrap(),
];
let values = vec![1.0, 2.0, 3.0];
let variogram = VariogramModel::new(0.01, 1.0, 100.0, VariogramType::Exponential).unwrap();
let err = UniversalKrigingModel::new(
GeoDataset::new(coords, values).unwrap(),
variogram,
UniversalTrend::Quadratic,
)
.expect_err("should fail");
match err {
KrigingError::InsufficientData(_) => {}
other => panic!("expected InsufficientData, got {other:?}"),
}
}
}