Skip to main content

kriging_rs/kriging/
simple.rs

1//! Simple kriging: interpolation with a known, constant mean.
2//!
3//! Unlike ordinary kriging (which treats the mean as unknown and adds a Lagrangian
4//! constraint that weights sum to one), simple kriging assumes the global mean `m` is known.
5//! The predictor is
6//!
7//! ```text
8//!   Z*(x0) = m + Σ_i w_i [Z(x_i) − m]
9//! ```
10//!
11//! where the weights solve the plain covariance system `C · w = c0` (no border row/col).
12//! The kriging variance is `σ²_K(x0) = C(0) − wᵀ c0`.
13//!
14//! Use simple kriging when you have an independently estimated mean (e.g. from a calibration
15//! dataset) and want slightly lower variance than ordinary kriging buys you.
16
17use std::sync::Arc;
18
19use nalgebra::{DMatrix, DVector, Dyn, linalg::LU};
20#[cfg(not(target_arch = "wasm32"))]
21use rayon::prelude::*;
22
23use crate::Real;
24use crate::distance::{GeoCoord, PreparedGeoCoord, haversine_distance_prepared, prepare_geo_coord};
25use crate::error::KrigingError;
26use crate::geo_dataset::GeoDataset;
27use crate::kriging::ordinary::{Prediction, kriging_diagonal_jitter};
28use crate::variogram::models::VariogramModel;
29
30/// Fitted simple kriging model.
31#[derive(Debug)]
32pub struct SimpleKrigingModel {
33    coords: Vec<GeoCoord>,
34    prepared_coords: Vec<PreparedGeoCoord>,
35    residuals: Vec<Real>,
36    mean: Real,
37    variogram: VariogramModel,
38    cov_at_zero: Real,
39    system: DMatrix<Real>,
40    /// Shared LU factorization; `Clone` just bumps the `Arc`.
41    system_lu: Arc<LU<Real, Dyn, Dyn>>,
42}
43
44impl Clone for SimpleKrigingModel {
45    fn clone(&self) -> Self {
46        Self {
47            coords: self.coords.clone(),
48            prepared_coords: self.prepared_coords.clone(),
49            residuals: self.residuals.clone(),
50            mean: self.mean,
51            variogram: self.variogram,
52            cov_at_zero: self.cov_at_zero,
53            system: self.system.clone(),
54            system_lu: Arc::clone(&self.system_lu),
55        }
56    }
57}
58
59impl SimpleKrigingModel {
60    /// Build a simple kriging model using a known `mean`.
61    pub fn new(
62        dataset: GeoDataset,
63        variogram: VariogramModel,
64        mean: Real,
65    ) -> Result<Self, KrigingError> {
66        let (coords, values) = dataset.into_parts();
67        let prepared_coords = coords
68            .iter()
69            .copied()
70            .map(prepare_geo_coord)
71            .collect::<Vec<_>>();
72        let residuals: Vec<Real> = values.iter().map(|v| *v - mean).collect();
73
74        let system = build_simple_system(&prepared_coords, variogram);
75        let system_lu = Arc::new(system.clone().lu());
76        // Probe solvability up front.
77        let probe = DVector::from_element(coords.len(), 1.0);
78        if system_lu.solve(&probe).is_none() {
79            return Err(KrigingError::MatrixError(
80                "could not factorize simple kriging system".to_string(),
81            ));
82        }
83        Ok(Self {
84            coords,
85            prepared_coords,
86            residuals,
87            mean,
88            variogram,
89            cov_at_zero: variogram.covariance(0.0),
90            system,
91            system_lu,
92        })
93    }
94
95    /// The known mean used by the model.
96    pub fn mean(&self) -> Real {
97        self.mean
98    }
99
100    /// Predict at a single target.
101    pub fn predict(&self, coord: GeoCoord) -> Result<Prediction, KrigingError> {
102        let mut rhs = DVector::from_element(self.coords.len(), 0.0);
103        self.predict_with_rhs(coord, &mut rhs)
104    }
105
106    /// Batch predictions; parallel on native builds.
107    pub fn predict_batch(&self, coords: &[GeoCoord]) -> Result<Vec<Prediction>, KrigingError> {
108        #[cfg(not(target_arch = "wasm32"))]
109        {
110            let n = self.coords.len();
111            coords
112                .par_iter()
113                .map_init(
114                    || DVector::<Real>::from_element(n, 0.0),
115                    |rhs, c| self.predict_with_rhs(*c, rhs),
116                )
117                .collect()
118        }
119        #[cfg(target_arch = "wasm32")]
120        {
121            let mut rhs = DVector::from_element(self.coords.len(), 0.0);
122            let mut out = Vec::with_capacity(coords.len());
123            for &c in coords {
124                out.push(self.predict_with_rhs(c, &mut rhs)?);
125            }
126            Ok(out)
127        }
128    }
129
130    fn predict_with_rhs(
131        &self,
132        coord: GeoCoord,
133        rhs: &mut DVector<Real>,
134    ) -> Result<Prediction, KrigingError> {
135        let n = self.coords.len();
136        let prepared = prepare_geo_coord(coord);
137        for i in 0..n {
138            rhs[i] = self.variogram.covariance(haversine_distance_prepared(
139                self.prepared_coords[i],
140                prepared,
141            ));
142        }
143        let w = self.system_lu.solve(rhs).ok_or_else(|| {
144            KrigingError::MatrixError("could not solve simple kriging system".to_string())
145        })?;
146        let mut residual_pred: Real = 0.0;
147        let mut cov_dot: Real = 0.0;
148        for i in 0..n {
149            residual_pred += w[i] * self.residuals[i];
150            cov_dot += w[i] * rhs[i];
151        }
152        let variance = (self.cov_at_zero - cov_dot).max(0.0);
153        Ok(Prediction {
154            value: self.mean + residual_pred,
155            variance,
156        })
157    }
158}
159
160fn build_simple_system(coords: &[PreparedGeoCoord], variogram: VariogramModel) -> DMatrix<Real> {
161    let n = coords.len();
162    let diag_eps = kriging_diagonal_jitter(n, variogram);
163    let mut m = DMatrix::from_element(n, n, 0.0);
164    for i in 0..n {
165        for j in i..n {
166            let mut cov = variogram.covariance(haversine_distance_prepared(coords[i], coords[j]));
167            if i == j {
168                cov += diag_eps;
169            }
170            m[(i, j)] = cov;
171            m[(j, i)] = cov;
172        }
173    }
174    m
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::variogram::models::VariogramType;
181
182    #[test]
183    fn recovers_training_value_at_collocated_point() {
184        let coords = vec![
185            GeoCoord::try_new(0.0, 0.0).unwrap(),
186            GeoCoord::try_new(0.0, 1.0).unwrap(),
187            GeoCoord::try_new(1.0, 0.0).unwrap(),
188        ];
189        let values = vec![10.0, 20.0, 15.0];
190        let variogram = VariogramModel::new(0.01, 5.0, 300.0, VariogramType::Exponential).unwrap();
191        let dataset = GeoDataset::new(coords.clone(), values).unwrap();
192        let model = SimpleKrigingModel::new(dataset, variogram, 15.0).expect("model");
193        let pred = model.predict(coords[0]).expect("prediction");
194        assert!((pred.value - 10.0).abs() < 1e-3);
195        assert!(pred.variance >= 0.0);
196    }
197
198    #[test]
199    fn reverts_to_mean_far_from_any_station() {
200        let coords = vec![
201            GeoCoord::try_new(0.0, 0.0).unwrap(),
202            GeoCoord::try_new(0.0, 0.1).unwrap(),
203            GeoCoord::try_new(0.1, 0.0).unwrap(),
204        ];
205        let values = vec![10.0, 12.0, 14.0];
206        let mean = 20.0;
207        let variogram = VariogramModel::new(0.01, 1.0, 5.0, VariogramType::Exponential).unwrap();
208        let dataset = GeoDataset::new(coords, values).unwrap();
209        let model = SimpleKrigingModel::new(dataset, variogram, mean).expect("model");
210        // Target far from all stations (many range units away) has near-zero
211        // covariance with the data, so weights ~ 0 and the prediction reverts to the mean.
212        let pred = model
213            .predict(GeoCoord::try_new(50.0, 50.0).unwrap())
214            .expect("prediction");
215        assert!((pred.value - mean).abs() < 1e-3, "got {}", pred.value);
216    }
217
218    #[test]
219    fn batch_matches_single_predictions() {
220        let coords = vec![
221            GeoCoord::try_new(0.0, 0.0).unwrap(),
222            GeoCoord::try_new(0.0, 1.0).unwrap(),
223            GeoCoord::try_new(1.0, 0.0).unwrap(),
224            GeoCoord::try_new(1.0, 1.0).unwrap(),
225        ];
226        let values = vec![10.0, 12.0, 14.0, 16.0];
227        let variogram = VariogramModel::new(0.01, 5.0, 300.0, VariogramType::Exponential).unwrap();
228        let dataset = GeoDataset::new(coords, values.clone()).unwrap();
229        let model = SimpleKrigingModel::new(dataset, variogram, 13.0).expect("model");
230        let queries = vec![
231            GeoCoord::try_new(0.2, 0.3).unwrap(),
232            GeoCoord::try_new(0.7, 0.4).unwrap(),
233        ];
234        let batch = model.predict_batch(&queries).expect("batch");
235        for (i, q) in queries.iter().enumerate() {
236            let single = model.predict(*q).expect("single");
237            assert!((batch[i].value - single.value).abs() < 1e-5);
238            assert!((batch[i].variance - single.variance).abs() < 1e-5);
239        }
240    }
241}