use nalgebra::{DMatrix, DVector};
use crate::Real;
use crate::error::KrigingError;
use crate::spacetime::empirical::{
EmpiricalSpaceTimeVariogram, spatial_marginal, temporal_marginal,
};
use crate::spacetime::variogram::SpaceTimeVariogram;
use crate::variogram::fitting::fit_variogram;
use crate::variogram::models::{VariogramModel, VariogramType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SpaceTimeVariogramType {
Separable,
ProductSum,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SpaceTimeFitConfig {
pub family: SpaceTimeVariogramType,
pub spatial_model: VariogramType,
pub temporal_model: VariogramType,
}
#[derive(Debug, Clone, Copy)]
pub struct SpaceTimeFitResult {
pub model: SpaceTimeVariogram,
pub residuals: Real,
}
pub fn fit_spacetime_variogram(
empirical: &EmpiricalSpaceTimeVariogram,
config: SpaceTimeFitConfig,
) -> Result<SpaceTimeFitResult, KrigingError> {
let spatial_marg = spatial_marginal(empirical);
let temporal_marg = temporal_marginal(empirical);
if spatial_marg.distances.is_empty() || temporal_marg.distances.is_empty() {
return Err(KrigingError::FittingError(
"empirical ST variogram lacks a populated spatial or temporal marginal".to_string(),
));
}
let spatial_fit = fit_variogram(&spatial_marg, config.spatial_model)?;
let temporal_fit = fit_variogram(&temporal_marg, config.temporal_model)?;
match config.family {
SpaceTimeVariogramType::Separable => {
let model = SpaceTimeVariogram::new_separable(spatial_fit.model, temporal_fit.model)?;
let residuals = weighted_residuals(empirical, model);
Ok(SpaceTimeFitResult { model, residuals })
}
SpaceTimeVariogramType::ProductSum => {
let (k1, k2, k3) =
fit_product_sum_coefficients(empirical, spatial_fit.model, temporal_fit.model)?;
let model = SpaceTimeVariogram::new_product_sum(
spatial_fit.model,
temporal_fit.model,
k1,
k2,
k3,
)?;
let residuals = weighted_residuals(empirical, model);
Ok(SpaceTimeFitResult { model, residuals })
}
}
}
fn weighted_residuals(emp: &EmpiricalSpaceTimeVariogram, model: SpaceTimeVariogram) -> Real {
let mut sum: Real = 0.0;
for i in 0..emp.n_spatial_bins {
for j in 0..emp.n_temporal_bins {
let idx = emp.index(i, j);
let n = emp.n_pairs[idx];
if n == 0 {
continue;
}
let diff = emp.semivariances[idx]
- model.semivariance(emp.spatial_lags[idx], emp.temporal_lags[idx]);
sum += (n as Real) * diff * diff;
}
}
sum
}
fn fit_product_sum_coefficients(
emp: &EmpiricalSpaceTimeVariogram,
spatial: VariogramModel,
temporal: VariogramModel,
) -> Result<(Real, Real, Real), KrigingError> {
let cs0 = spatial.covariance(0.0);
let ct0 = temporal.covariance(0.0);
let mut rows = Vec::new();
let mut targets = Vec::new();
let mut weights = Vec::new();
for i in 0..emp.n_spatial_bins {
for j in 0..emp.n_temporal_bins {
let idx = emp.index(i, j);
let n = emp.n_pairs[idx];
if n == 0 {
continue;
}
let cs = spatial.covariance(emp.spatial_lags[idx]);
let ct = temporal.covariance(emp.temporal_lags[idx]);
let f1 = cs0 * ct0 - cs * ct;
let f2 = cs0 - cs;
let f3 = ct0 - ct;
rows.push([f1, f2, f3]);
targets.push(emp.semivariances[idx]);
weights.push((n as Real).sqrt());
}
}
if rows.len() < 3 {
return Err(KrigingError::FittingError(
"need at least 3 populated bins to fit product-sum coefficients".to_string(),
));
}
let k_all = solve_nnls_3(&rows, &targets, &weights);
Ok((k_all[0], k_all[1], k_all[2]))
}
fn solve_nnls_3(rows: &[[Real; 3]], targets: &[Real], weights: &[Real]) -> [Real; 3] {
let m = rows.len();
let mut best: Option<([Real; 3], Real)> = None;
for mask in 0u8..8 {
let free: Vec<usize> = (0..3).filter(|k| (mask >> k) & 1 == 1).collect();
if free.is_empty() {
let residuals = weighted_sse(rows, targets, weights, &[0.0, 0.0, 0.0]);
match best {
None => best = Some(([0.0, 0.0, 0.0], residuals)),
Some((_, r)) if residuals < r => best = Some(([0.0, 0.0, 0.0], residuals)),
_ => {}
}
continue;
}
let p = free.len();
let mut a = DMatrix::<Real>::zeros(m, p);
let mut b = DVector::<Real>::zeros(m);
for (r, row) in rows.iter().enumerate() {
let w = weights[r];
for (c, &k) in free.iter().enumerate() {
a[(r, c)] = w * row[k];
}
b[r] = w * targets[r];
}
let ata = a.transpose() * &a;
let atb = a.transpose() * &b;
let Some(sol) = ata.lu().solve(&atb) else {
continue;
};
if sol.iter().any(|v| !v.is_finite() || *v < 0.0) {
continue;
}
let mut full = [0.0 as Real; 3];
for (c, &k) in free.iter().enumerate() {
full[k] = sol[c];
}
let residuals = weighted_sse(rows, targets, weights, &full);
match best {
None => best = Some((full, residuals)),
Some((_, r)) if residuals < r => best = Some((full, residuals)),
_ => {}
}
}
best.map(|(k, _)| k).unwrap_or([0.0, 0.0, 0.0])
}
fn weighted_sse(rows: &[[Real; 3]], targets: &[Real], weights: &[Real], k: &[Real; 3]) -> Real {
rows.iter()
.zip(targets.iter())
.zip(weights.iter())
.map(|((row, y), w)| {
let fit = k[0] * row[0] + k[1] * row[1] + k[2] * row[2];
let d = (*y - fit) * w;
d * d
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distance::GeoCoord;
use crate::spacetime::SpaceTimeCoord;
use crate::spacetime::dataset::SpaceTimeDataset;
use crate::spacetime::empirical::{
SpaceTimeVariogramConfig, compute_empirical_spacetime_variogram,
};
use crate::spacetime::metric::GeoMetric;
use std::num::NonZeroUsize;
fn synthetic_dataset(
true_model: SpaceTimeVariogram,
n_spatial: usize,
n_time: usize,
) -> SpaceTimeDataset<GeoCoord> {
let _ = true_model;
let mut coords = Vec::new();
let mut values = Vec::new();
for i in 0..n_spatial {
for t in 0..n_time {
coords.push(SpaceTimeCoord::new(
GeoCoord::try_new(i as Real * 0.1, 0.0).unwrap(),
t as Real,
));
values.push(((i as Real) * 0.7).sin() + 0.5 * ((t as Real) * 0.5).cos());
}
}
SpaceTimeDataset::new(coords, values).unwrap()
}
#[test]
fn separable_fit_returns_positive_marginals() {
let spatial = VariogramModel::new(0.01, 1.0, 30.0, VariogramType::Exponential).unwrap();
let temporal = VariogramModel::new(0.01, 0.5, 3.0, VariogramType::Exponential).unwrap();
let truth = SpaceTimeVariogram::new_separable(spatial, temporal).unwrap();
let ds = synthetic_dataset(truth, 8, 6);
let emp = compute_empirical_spacetime_variogram(
&GeoMetric,
&ds,
&SpaceTimeVariogramConfig {
n_spatial_bins: NonZeroUsize::new(5).unwrap(),
n_temporal_bins: NonZeroUsize::new(5).unwrap(),
..Default::default()
},
)
.unwrap();
let fit = fit_spacetime_variogram(
&emp,
SpaceTimeFitConfig {
family: SpaceTimeVariogramType::Separable,
spatial_model: VariogramType::Exponential,
temporal_model: VariogramType::Exponential,
},
)
.unwrap();
let (n_s, s_s, r_s) = fit.model.spatial().params();
let (n_t, s_t, r_t) = fit.model.temporal().params();
assert!(n_s >= 0.0 && s_s > n_s && r_s > 0.0);
assert!(n_t >= 0.0 && s_t > n_t && r_t > 0.0);
assert!(fit.residuals.is_finite());
}
#[test]
fn product_sum_coefficients_are_non_negative() {
let ds = synthetic_dataset(
SpaceTimeVariogram::new_separable(
VariogramModel::new(0.01, 1.0, 20.0, VariogramType::Exponential).unwrap(),
VariogramModel::new(0.01, 1.0, 3.0, VariogramType::Exponential).unwrap(),
)
.unwrap(),
10,
8,
);
let emp = compute_empirical_spacetime_variogram(
&GeoMetric,
&ds,
&SpaceTimeVariogramConfig {
n_spatial_bins: NonZeroUsize::new(6).unwrap(),
n_temporal_bins: NonZeroUsize::new(6).unwrap(),
..Default::default()
},
)
.unwrap();
let fit = fit_spacetime_variogram(
&emp,
SpaceTimeFitConfig {
family: SpaceTimeVariogramType::ProductSum,
spatial_model: VariogramType::Exponential,
temporal_model: VariogramType::Exponential,
},
)
.unwrap();
match fit.model {
SpaceTimeVariogram::ProductSum { k1, k2, k3, .. } => {
assert!(k1 >= 0.0 && k2 >= 0.0 && k3 >= 0.0);
assert!(k1 + k2 + k3 > 0.0);
}
_ => panic!("expected ProductSum"),
}
assert!(fit.residuals.is_finite());
}
#[test]
fn product_sum_fit_is_not_worse_than_separable_fit() {
let ds = synthetic_dataset(
SpaceTimeVariogram::new_separable(
VariogramModel::new(0.05, 1.0, 15.0, VariogramType::Exponential).unwrap(),
VariogramModel::new(0.05, 1.0, 2.5, VariogramType::Exponential).unwrap(),
)
.unwrap(),
10,
10,
);
let emp = compute_empirical_spacetime_variogram(
&GeoMetric,
&ds,
&SpaceTimeVariogramConfig {
n_spatial_bins: NonZeroUsize::new(6).unwrap(),
n_temporal_bins: NonZeroUsize::new(6).unwrap(),
..Default::default()
},
)
.unwrap();
let sep = fit_spacetime_variogram(
&emp,
SpaceTimeFitConfig {
family: SpaceTimeVariogramType::Separable,
spatial_model: VariogramType::Exponential,
temporal_model: VariogramType::Exponential,
},
)
.unwrap();
let ps = fit_spacetime_variogram(
&emp,
SpaceTimeFitConfig {
family: SpaceTimeVariogramType::ProductSum,
spatial_model: VariogramType::Exponential,
temporal_model: VariogramType::Exponential,
},
)
.unwrap();
assert!(
ps.residuals <= sep.residuals * 1.0001,
"product-sum residuals {} should not exceed separable residuals {}",
ps.residuals,
sep.residuals
);
}
#[test]
fn fitting_rejects_empty_marginal() {
let emp = EmpiricalSpaceTimeVariogram {
n_spatial_bins: 3,
n_temporal_bins: 3,
spatial_lags: vec![0.0; 9],
temporal_lags: vec![0.0; 9],
semivariances: vec![0.0; 9],
n_pairs: vec![0; 9],
};
assert!(
fit_spacetime_variogram(
&emp,
SpaceTimeFitConfig {
family: SpaceTimeVariogramType::Separable,
spatial_model: VariogramType::Exponential,
temporal_model: VariogramType::Exponential,
},
)
.is_err()
);
}
}