1use crate::error::{AnalyticsError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum KrigingType {
12 Ordinary,
14 Universal,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum VariogramModel {
21 Spherical,
23 Exponential,
25 Gaussian,
27 Linear,
29}
30
31#[derive(Debug, Clone, Copy)]
33pub struct Variogram {
34 pub nugget: f64,
36 pub sill: f64,
38 pub range: f64,
40 pub model: VariogramModel,
42}
43
44impl Variogram {
45 pub fn new(model: VariogramModel, nugget: f64, sill: f64, range: f64) -> Self {
47 Self {
48 nugget,
49 sill,
50 range,
51 model,
52 }
53 }
54
55 pub fn evaluate(&self, h: f64) -> f64 {
57 if h < f64::EPSILON {
58 return 0.0;
59 }
60
61 let partial_sill = self.sill - self.nugget;
62
63 match self.model {
64 VariogramModel::Spherical => {
65 if h >= self.range {
66 self.sill
67 } else {
68 let h_r = h / self.range;
69 self.nugget + partial_sill * (1.5 * h_r - 0.5 * h_r.powi(3))
70 }
71 }
72 VariogramModel::Exponential => {
73 self.nugget + partial_sill * (1.0 - (-h / self.range).exp())
74 }
75 VariogramModel::Gaussian => {
76 self.nugget + partial_sill * (1.0 - (-(h * h) / (self.range * self.range)).exp())
77 }
78 VariogramModel::Linear => {
79 let slope = self.sill / self.range;
80 self.nugget + slope * h.min(self.range)
81 }
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct KrigingResult {
89 pub values: Array1<f64>,
91 pub variances: Array1<f64>,
93 pub coordinates: Array2<f64>,
95}
96
97pub struct KrigingInterpolator {
99 kriging_type: KrigingType,
100 variogram: Variogram,
101}
102
103impl KrigingInterpolator {
104 pub fn new(kriging_type: KrigingType, variogram: Variogram) -> Self {
110 Self {
111 kriging_type,
112 variogram,
113 }
114 }
115
116 pub fn interpolate(
126 &self,
127 points: &Array2<f64>,
128 values: &ArrayView1<f64>,
129 targets: &Array2<f64>,
130 ) -> Result<KrigingResult> {
131 let n_points = points.nrows();
132 let n_targets = targets.nrows();
133
134 if values.len() != n_points {
135 return Err(AnalyticsError::dimension_mismatch(
136 format!("{}", n_points),
137 format!("{}", values.len()),
138 ));
139 }
140
141 if targets.ncols() != points.ncols() {
142 return Err(AnalyticsError::dimension_mismatch(
143 format!("{}", points.ncols()),
144 format!("{}", targets.ncols()),
145 ));
146 }
147
148 let cov_matrix = self.build_covariance_matrix(points)?;
150
151 let weights_matrix = self.solve_kriging_system(&cov_matrix)?;
153
154 let mut interpolated = Array1::zeros(n_targets);
155 let mut variances = Array1::zeros(n_targets);
156
157 for i in 0..n_targets {
158 let target = targets.row(i);
159 let (value, variance) =
160 self.interpolate_point(&target, points, values, &weights_matrix)?;
161 interpolated[i] = value;
162 variances[i] = variance;
163 }
164
165 Ok(KrigingResult {
166 values: interpolated,
167 variances,
168 coordinates: targets.clone(),
169 })
170 }
171
172 fn build_covariance_matrix(&self, points: &Array2<f64>) -> Result<Array2<f64>> {
174 let n = points.nrows();
175 let size = match self.kriging_type {
176 KrigingType::Ordinary => n + 1, KrigingType::Universal => n + 4, };
179
180 let mut matrix = Array2::zeros((size, size));
181
182 for i in 0..n {
184 for j in 0..n {
185 let dist = self.calculate_distance(&points.row(i), &points.row(j))?;
186 let covariance = self.variogram.sill - self.variogram.evaluate(dist);
187 matrix[[i, j]] = covariance;
188 }
189 }
190
191 match self.kriging_type {
193 KrigingType::Ordinary => {
194 for i in 0..n {
196 matrix[[i, n]] = 1.0;
197 matrix[[n, i]] = 1.0;
198 }
199 }
200 KrigingType::Universal => {
201 for i in 0..n {
203 let x = points[[i, 0]];
204 let y = points[[i, 1]];
205 matrix[[i, n]] = 1.0;
206 matrix[[n, i]] = 1.0;
207 matrix[[i, n + 1]] = x;
208 matrix[[n + 1, i]] = x;
209 matrix[[i, n + 2]] = y;
210 matrix[[n + 2, i]] = y;
211 matrix[[i, n + 3]] = x * y;
212 matrix[[n + 3, i]] = x * y;
213 }
214 }
215 }
216
217 Ok(matrix)
218 }
219
220 fn solve_kriging_system(&self, cov_matrix: &Array2<f64>) -> Result<Array2<f64>> {
222 self.matrix_inverse(cov_matrix)
225 }
226
227 fn matrix_inverse(&self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
229 let n = matrix.nrows();
230 if n != matrix.ncols() {
231 return Err(AnalyticsError::matrix_error("Matrix must be square"));
232 }
233
234 let mut aug = Array2::zeros((n, 2 * n));
236 for i in 0..n {
237 for j in 0..n {
238 aug[[i, j]] = matrix[[i, j]];
239 }
240 aug[[i, n + i]] = 1.0;
241 }
242
243 for i in 0..n {
245 let mut max_row = i;
247 let mut max_val = aug[[i, i]].abs();
248 for k in (i + 1)..n {
249 if aug[[k, i]].abs() > max_val {
250 max_val = aug[[k, i]].abs();
251 max_row = k;
252 }
253 }
254
255 if max_val < f64::EPSILON {
256 return Err(AnalyticsError::matrix_error("Matrix is singular"));
257 }
258
259 if max_row != i {
261 for j in 0..(2 * n) {
262 let tmp = aug[[i, j]];
263 aug[[i, j]] = aug[[max_row, j]];
264 aug[[max_row, j]] = tmp;
265 }
266 }
267
268 let pivot = aug[[i, i]];
270 for j in 0..(2 * n) {
271 aug[[i, j]] /= pivot;
272 }
273
274 for k in 0..n {
275 if k != i {
276 let factor = aug[[k, i]];
277 for j in 0..(2 * n) {
278 aug[[k, j]] -= factor * aug[[i, j]];
279 }
280 }
281 }
282 }
283
284 let mut inverse = Array2::zeros((n, n));
286 for i in 0..n {
287 for j in 0..n {
288 inverse[[i, j]] = aug[[i, n + j]];
289 }
290 }
291
292 Ok(inverse)
293 }
294
295 fn interpolate_point(
297 &self,
298 target: &scirs2_core::ndarray::ArrayView1<f64>,
299 points: &Array2<f64>,
300 values: &ArrayView1<f64>,
301 weights_matrix: &Array2<f64>,
302 ) -> Result<(f64, f64)> {
303 let n = points.nrows();
304
305 let rhs_size = match self.kriging_type {
307 KrigingType::Ordinary => n + 1,
308 KrigingType::Universal => n + 4,
309 };
310
311 let mut rhs = Array1::zeros(rhs_size);
312
313 for i in 0..n {
315 let dist = self.calculate_distance(&points.row(i), target)?;
316 rhs[i] = self.variogram.sill - self.variogram.evaluate(dist);
317 }
318
319 match self.kriging_type {
321 KrigingType::Ordinary => {
322 rhs[n] = 1.0;
323 }
324 KrigingType::Universal => {
325 rhs[n] = 1.0;
326 rhs[n + 1] = target[0];
327 rhs[n + 2] = target[1];
328 rhs[n + 3] = target[0] * target[1];
329 }
330 }
331
332 let mut weights: Array1<f64> = Array1::zeros(rhs_size);
334 for i in 0..rhs_size {
335 for j in 0..rhs_size {
336 weights[i] += weights_matrix[[i, j]] * rhs[j];
337 }
338 }
339
340 let mut value: f64 = 0.0;
342 for i in 0..n {
343 value += weights[i] * values[i];
344 }
345
346 let mut variance = self.variogram.sill;
348 for i in 0..rhs_size {
349 variance -= weights[i] * rhs[i];
350 }
351
352 Ok((value, variance.max(0.0)))
353 }
354
355 fn calculate_distance(
357 &self,
358 p1: &scirs2_core::ndarray::ArrayView1<f64>,
359 p2: &scirs2_core::ndarray::ArrayView1<f64>,
360 ) -> Result<f64> {
361 if p1.len() != p2.len() {
362 return Err(AnalyticsError::dimension_mismatch(
363 format!("{}", p1.len()),
364 format!("{}", p2.len()),
365 ));
366 }
367
368 let dist_sq: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
369 Ok(dist_sq.sqrt())
370 }
371}
372
373pub struct SemivariogramCalculator;
375
376impl SemivariogramCalculator {
377 pub fn calculate(
387 points: &Array2<f64>,
388 values: &ArrayView1<f64>,
389 n_bins: usize,
390 ) -> Result<(Array1<f64>, Array1<f64>)> {
391 let n = points.nrows();
392 if values.len() != n {
393 return Err(AnalyticsError::dimension_mismatch(
394 format!("{}", n),
395 format!("{}", values.len()),
396 ));
397 }
398
399 let mut pairs = Vec::new();
401 for i in 0..n {
402 for j in (i + 1)..n {
403 let mut dist_sq = 0.0;
404 for k in 0..points.ncols() {
405 let diff = points[[i, k]] - points[[j, k]];
406 dist_sq += diff * diff;
407 }
408 let dist = dist_sq.sqrt();
409 let semivar = 0.5 * (values[i] - values[j]).powi(2);
410 pairs.push((dist, semivar));
411 }
412 }
413
414 if pairs.is_empty() {
415 return Err(AnalyticsError::insufficient_data("Need at least 2 points"));
416 }
417
418 let max_dist = pairs
420 .iter()
421 .map(|(d, _)| *d)
422 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
423 .ok_or_else(|| AnalyticsError::insufficient_data("No valid distances"))?;
424
425 let bin_width = max_dist / (n_bins as f64);
426
427 let mut bin_sums = vec![0.0; n_bins];
429 let mut bin_counts = vec![0usize; n_bins];
430
431 for (dist, semivar) in pairs {
432 let bin = ((dist / bin_width).floor() as usize).min(n_bins - 1);
433 bin_sums[bin] += semivar;
434 bin_counts[bin] += 1;
435 }
436
437 let mut distances = Vec::new();
439 let mut semivariances = Vec::new();
440
441 for i in 0..n_bins {
442 if bin_counts[i] > 0 {
443 distances.push((i as f64 + 0.5) * bin_width);
444 semivariances.push(bin_sums[i] / (bin_counts[i] as f64));
445 }
446 }
447
448 Ok((Array1::from_vec(distances), Array1::from_vec(semivariances)))
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use approx::assert_abs_diff_eq;
456 use scirs2_core::ndarray::array;
457
458 #[test]
459 fn test_variogram_spherical() {
460 let var = Variogram::new(VariogramModel::Spherical, 0.1, 1.0, 10.0);
461
462 assert_abs_diff_eq!(var.evaluate(0.0), 0.0, epsilon = 1e-10);
463 assert_abs_diff_eq!(var.evaluate(10.0), 1.0, epsilon = 1e-10);
464 assert_abs_diff_eq!(var.evaluate(20.0), 1.0, epsilon = 1e-10);
465 }
466
467 #[test]
468 fn test_kriging_simple() {
469 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
470 let values = array![1.0, 2.0, 3.0, 4.0];
471 let targets = array![[0.5, 0.5]];
472
473 let var = Variogram::new(VariogramModel::Spherical, 0.0, 1.0, 2.0);
474 let interpolator = KrigingInterpolator::new(KrigingType::Ordinary, var);
475
476 let result = interpolator
477 .interpolate(&points, &values.view(), &targets)
478 .expect("Kriging interpolation should succeed for valid data");
479
480 assert_eq!(result.values.len(), 1);
481 assert_eq!(result.variances.len(), 1);
482 assert!(result.values[0] > 2.0 && result.values[0] < 3.0);
483 }
484
485 #[test]
486 fn test_semivariogram_calculation() {
487 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
488 let values = array![1.0, 2.0, 3.0];
489
490 let (distances, semivariances) =
491 SemivariogramCalculator::calculate(&points, &values.view(), 3)
492 .expect("Semivariogram calculation should succeed");
493
494 assert!(!distances.is_empty());
495 assert_eq!(distances.len(), semivariances.len());
496 }
497}