1use std::sync::Arc;
18
19use nalgebra::{DMatrix, DVector, Dyn, linalg::LU};
20#[cfg(not(target_arch = "wasm32"))]
21use rayon::prelude::*;
22
23use crate::Real;
24use crate::distance::{GeoCoord, PreparedGeoCoord, haversine_distance_prepared, prepare_geo_coord};
25use crate::error::KrigingError;
26use crate::geo_dataset::GeoDataset;
27use crate::kriging::ordinary::{Prediction, kriging_diagonal_jitter};
28use crate::variogram::models::VariogramModel;
29
30#[derive(Debug)]
32pub struct SimpleKrigingModel {
33 coords: Vec<GeoCoord>,
34 prepared_coords: Vec<PreparedGeoCoord>,
35 residuals: Vec<Real>,
36 mean: Real,
37 variogram: VariogramModel,
38 cov_at_zero: Real,
39 system: DMatrix<Real>,
40 system_lu: Arc<LU<Real, Dyn, Dyn>>,
42}
43
44impl Clone for SimpleKrigingModel {
45 fn clone(&self) -> Self {
46 Self {
47 coords: self.coords.clone(),
48 prepared_coords: self.prepared_coords.clone(),
49 residuals: self.residuals.clone(),
50 mean: self.mean,
51 variogram: self.variogram,
52 cov_at_zero: self.cov_at_zero,
53 system: self.system.clone(),
54 system_lu: Arc::clone(&self.system_lu),
55 }
56 }
57}
58
59impl SimpleKrigingModel {
60 pub fn new(
62 dataset: GeoDataset,
63 variogram: VariogramModel,
64 mean: Real,
65 ) -> Result<Self, KrigingError> {
66 let (coords, values) = dataset.into_parts();
67 let prepared_coords = coords
68 .iter()
69 .copied()
70 .map(prepare_geo_coord)
71 .collect::<Vec<_>>();
72 let residuals: Vec<Real> = values.iter().map(|v| *v - mean).collect();
73
74 let system = build_simple_system(&prepared_coords, variogram);
75 let system_lu = Arc::new(system.clone().lu());
76 let probe = DVector::from_element(coords.len(), 1.0);
78 if system_lu.solve(&probe).is_none() {
79 return Err(KrigingError::MatrixError(
80 "could not factorize simple kriging system".to_string(),
81 ));
82 }
83 Ok(Self {
84 coords,
85 prepared_coords,
86 residuals,
87 mean,
88 variogram,
89 cov_at_zero: variogram.covariance(0.0),
90 system,
91 system_lu,
92 })
93 }
94
95 pub fn mean(&self) -> Real {
97 self.mean
98 }
99
100 pub fn predict(&self, coord: GeoCoord) -> Result<Prediction, KrigingError> {
102 let mut rhs = DVector::from_element(self.coords.len(), 0.0);
103 self.predict_with_rhs(coord, &mut rhs)
104 }
105
106 pub fn predict_batch(&self, coords: &[GeoCoord]) -> Result<Vec<Prediction>, KrigingError> {
108 #[cfg(not(target_arch = "wasm32"))]
109 {
110 let n = self.coords.len();
111 coords
112 .par_iter()
113 .map_init(
114 || DVector::<Real>::from_element(n, 0.0),
115 |rhs, c| self.predict_with_rhs(*c, rhs),
116 )
117 .collect()
118 }
119 #[cfg(target_arch = "wasm32")]
120 {
121 let mut rhs = DVector::from_element(self.coords.len(), 0.0);
122 let mut out = Vec::with_capacity(coords.len());
123 for &c in coords {
124 out.push(self.predict_with_rhs(c, &mut rhs)?);
125 }
126 Ok(out)
127 }
128 }
129
130 fn predict_with_rhs(
131 &self,
132 coord: GeoCoord,
133 rhs: &mut DVector<Real>,
134 ) -> Result<Prediction, KrigingError> {
135 let n = self.coords.len();
136 let prepared = prepare_geo_coord(coord);
137 for i in 0..n {
138 rhs[i] = self.variogram.covariance(haversine_distance_prepared(
139 self.prepared_coords[i],
140 prepared,
141 ));
142 }
143 let w = self.system_lu.solve(rhs).ok_or_else(|| {
144 KrigingError::MatrixError("could not solve simple kriging system".to_string())
145 })?;
146 let mut residual_pred: Real = 0.0;
147 let mut cov_dot: Real = 0.0;
148 for i in 0..n {
149 residual_pred += w[i] * self.residuals[i];
150 cov_dot += w[i] * rhs[i];
151 }
152 let variance = (self.cov_at_zero - cov_dot).max(0.0);
153 Ok(Prediction {
154 value: self.mean + residual_pred,
155 variance,
156 })
157 }
158}
159
160fn build_simple_system(coords: &[PreparedGeoCoord], variogram: VariogramModel) -> DMatrix<Real> {
161 let n = coords.len();
162 let diag_eps = kriging_diagonal_jitter(n, variogram);
163 let mut m = DMatrix::from_element(n, n, 0.0);
164 for i in 0..n {
165 for j in i..n {
166 let mut cov = variogram.covariance(haversine_distance_prepared(coords[i], coords[j]));
167 if i == j {
168 cov += diag_eps;
169 }
170 m[(i, j)] = cov;
171 m[(j, i)] = cov;
172 }
173 }
174 m
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::variogram::models::VariogramType;
181
182 #[test]
183 fn recovers_training_value_at_collocated_point() {
184 let coords = vec![
185 GeoCoord::try_new(0.0, 0.0).unwrap(),
186 GeoCoord::try_new(0.0, 1.0).unwrap(),
187 GeoCoord::try_new(1.0, 0.0).unwrap(),
188 ];
189 let values = vec![10.0, 20.0, 15.0];
190 let variogram = VariogramModel::new(0.01, 5.0, 300.0, VariogramType::Exponential).unwrap();
191 let dataset = GeoDataset::new(coords.clone(), values).unwrap();
192 let model = SimpleKrigingModel::new(dataset, variogram, 15.0).expect("model");
193 let pred = model.predict(coords[0]).expect("prediction");
194 assert!((pred.value - 10.0).abs() < 1e-3);
195 assert!(pred.variance >= 0.0);
196 }
197
198 #[test]
199 fn reverts_to_mean_far_from_any_station() {
200 let coords = vec![
201 GeoCoord::try_new(0.0, 0.0).unwrap(),
202 GeoCoord::try_new(0.0, 0.1).unwrap(),
203 GeoCoord::try_new(0.1, 0.0).unwrap(),
204 ];
205 let values = vec![10.0, 12.0, 14.0];
206 let mean = 20.0;
207 let variogram = VariogramModel::new(0.01, 1.0, 5.0, VariogramType::Exponential).unwrap();
208 let dataset = GeoDataset::new(coords, values).unwrap();
209 let model = SimpleKrigingModel::new(dataset, variogram, mean).expect("model");
210 let pred = model
213 .predict(GeoCoord::try_new(50.0, 50.0).unwrap())
214 .expect("prediction");
215 assert!((pred.value - mean).abs() < 1e-3, "got {}", pred.value);
216 }
217
218 #[test]
219 fn batch_matches_single_predictions() {
220 let coords = vec![
221 GeoCoord::try_new(0.0, 0.0).unwrap(),
222 GeoCoord::try_new(0.0, 1.0).unwrap(),
223 GeoCoord::try_new(1.0, 0.0).unwrap(),
224 GeoCoord::try_new(1.0, 1.0).unwrap(),
225 ];
226 let values = vec![10.0, 12.0, 14.0, 16.0];
227 let variogram = VariogramModel::new(0.01, 5.0, 300.0, VariogramType::Exponential).unwrap();
228 let dataset = GeoDataset::new(coords, values.clone()).unwrap();
229 let model = SimpleKrigingModel::new(dataset, variogram, 13.0).expect("model");
230 let queries = vec![
231 GeoCoord::try_new(0.2, 0.3).unwrap(),
232 GeoCoord::try_new(0.7, 0.4).unwrap(),
233 ];
234 let batch = model.predict_batch(&queries).expect("batch");
235 for (i, q) in queries.iter().enumerate() {
236 let single = model.predict(*q).expect("single");
237 assert!((batch[i].value - single.value).abs() < 1e-5);
238 assert!((batch[i].variance - single.variance).abs() < 1e-5);
239 }
240 }
241}