Skip to main content

oxigdal_analytics/interpolation/
kriging.rs

1//! Kriging Interpolation
2//!
3//! Kriging is a geostatistical interpolation method that uses variogram models
4//! to provide Best Linear Unbiased Predictions (BLUP).
5
6use crate::error::{AnalyticsError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8
9/// Kriging types
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum KrigingType {
12    /// Ordinary Kriging (constant mean)
13    Ordinary,
14    /// Universal Kriging (trend surface)
15    Universal,
16}
17
18/// Variogram models
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum VariogramModel {
21    /// Spherical variogram
22    Spherical,
23    /// Exponential variogram
24    Exponential,
25    /// Gaussian variogram
26    Gaussian,
27    /// Linear variogram
28    Linear,
29}
30
31/// Variogram parameters
32#[derive(Debug, Clone, Copy)]
33pub struct Variogram {
34    /// Nugget effect
35    pub nugget: f64,
36    /// Sill (total variance)
37    pub sill: f64,
38    /// Range parameter
39    pub range: f64,
40    /// Model type
41    pub model: VariogramModel,
42}
43
44impl Variogram {
45    /// Create a new variogram
46    pub fn new(model: VariogramModel, nugget: f64, sill: f64, range: f64) -> Self {
47        Self {
48            nugget,
49            sill,
50            range,
51            model,
52        }
53    }
54
55    /// Evaluate variogram at distance h
56    pub fn evaluate(&self, h: f64) -> f64 {
57        if h < f64::EPSILON {
58            return 0.0;
59        }
60
61        let partial_sill = self.sill - self.nugget;
62
63        match self.model {
64            VariogramModel::Spherical => {
65                if h >= self.range {
66                    self.sill
67                } else {
68                    let h_r = h / self.range;
69                    self.nugget + partial_sill * (1.5 * h_r - 0.5 * h_r.powi(3))
70                }
71            }
72            VariogramModel::Exponential => {
73                self.nugget + partial_sill * (1.0 - (-h / self.range).exp())
74            }
75            VariogramModel::Gaussian => {
76                self.nugget + partial_sill * (1.0 - (-(h * h) / (self.range * self.range)).exp())
77            }
78            VariogramModel::Linear => {
79                let slope = self.sill / self.range;
80                self.nugget + slope * h.min(self.range)
81            }
82        }
83    }
84}
85
86/// Kriging result
87#[derive(Debug, Clone)]
88pub struct KrigingResult {
89    /// Interpolated values
90    pub values: Array1<f64>,
91    /// Prediction variances
92    pub variances: Array1<f64>,
93    /// Target coordinates
94    pub coordinates: Array2<f64>,
95}
96
97/// Kriging interpolator
98pub struct KrigingInterpolator {
99    kriging_type: KrigingType,
100    variogram: Variogram,
101}
102
103impl KrigingInterpolator {
104    /// Create a new Kriging interpolator
105    ///
106    /// # Arguments
107    /// * `kriging_type` - Type of kriging
108    /// * `variogram` - Variogram model
109    pub fn new(kriging_type: KrigingType, variogram: Variogram) -> Self {
110        Self {
111            kriging_type,
112            variogram,
113        }
114    }
115
116    /// Interpolate values at target locations
117    ///
118    /// # Arguments
119    /// * `points` - Known point coordinates (n_points × n_dim)
120    /// * `values` - Known values (n_points)
121    /// * `targets` - Target coordinates (n_targets × n_dim)
122    ///
123    /// # Errors
124    /// Returns error if interpolation fails
125    pub fn interpolate(
126        &self,
127        points: &Array2<f64>,
128        values: &ArrayView1<f64>,
129        targets: &Array2<f64>,
130    ) -> Result<KrigingResult> {
131        let n_points = points.nrows();
132        let n_targets = targets.nrows();
133
134        if values.len() != n_points {
135            return Err(AnalyticsError::dimension_mismatch(
136                format!("{}", n_points),
137                format!("{}", values.len()),
138            ));
139        }
140
141        if targets.ncols() != points.ncols() {
142            return Err(AnalyticsError::dimension_mismatch(
143                format!("{}", points.ncols()),
144                format!("{}", targets.ncols()),
145            ));
146        }
147
148        // Build covariance matrix
149        let cov_matrix = self.build_covariance_matrix(points)?;
150
151        // Solve kriging system once for efficiency
152        let weights_matrix = self.solve_kriging_system(&cov_matrix)?;
153
154        let mut interpolated = Array1::zeros(n_targets);
155        let mut variances = Array1::zeros(n_targets);
156
157        for i in 0..n_targets {
158            let target = targets.row(i);
159            let (value, variance) =
160                self.interpolate_point(&target, points, values, &weights_matrix)?;
161            interpolated[i] = value;
162            variances[i] = variance;
163        }
164
165        Ok(KrigingResult {
166            values: interpolated,
167            variances,
168            coordinates: targets.clone(),
169        })
170    }
171
172    /// Build covariance matrix from variogram
173    fn build_covariance_matrix(&self, points: &Array2<f64>) -> Result<Array2<f64>> {
174        let n = points.nrows();
175        let size = match self.kriging_type {
176            KrigingType::Ordinary => n + 1,  // Add Lagrange multiplier
177            KrigingType::Universal => n + 4, // Add trend terms (constant + x + y + xy)
178        };
179
180        let mut matrix = Array2::zeros((size, size));
181
182        // Fill in covariances
183        for i in 0..n {
184            for j in 0..n {
185                let dist = self.calculate_distance(&points.row(i), &points.row(j))?;
186                let covariance = self.variogram.sill - self.variogram.evaluate(dist);
187                matrix[[i, j]] = covariance;
188            }
189        }
190
191        // Add constraint equations
192        match self.kriging_type {
193            KrigingType::Ordinary => {
194                // Sum of weights = 1
195                for i in 0..n {
196                    matrix[[i, n]] = 1.0;
197                    matrix[[n, i]] = 1.0;
198                }
199            }
200            KrigingType::Universal => {
201                // Trend surface constraints
202                for i in 0..n {
203                    let x = points[[i, 0]];
204                    let y = points[[i, 1]];
205                    matrix[[i, n]] = 1.0;
206                    matrix[[n, i]] = 1.0;
207                    matrix[[i, n + 1]] = x;
208                    matrix[[n + 1, i]] = x;
209                    matrix[[i, n + 2]] = y;
210                    matrix[[n + 2, i]] = y;
211                    matrix[[i, n + 3]] = x * y;
212                    matrix[[n + 3, i]] = x * y;
213                }
214            }
215        }
216
217        Ok(matrix)
218    }
219
220    /// Solve kriging system using matrix inversion
221    fn solve_kriging_system(&self, cov_matrix: &Array2<f64>) -> Result<Array2<f64>> {
222        // For simplicity, use Gaussian elimination
223        // In production, would use proper linear algebra library
224        self.matrix_inverse(cov_matrix)
225    }
226
227    /// Simple matrix inversion using Gauss-Jordan elimination
228    fn matrix_inverse(&self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
229        let n = matrix.nrows();
230        if n != matrix.ncols() {
231            return Err(AnalyticsError::matrix_error("Matrix must be square"));
232        }
233
234        // Create augmented matrix [A | I]
235        let mut aug = Array2::zeros((n, 2 * n));
236        for i in 0..n {
237            for j in 0..n {
238                aug[[i, j]] = matrix[[i, j]];
239            }
240            aug[[i, n + i]] = 1.0;
241        }
242
243        // Gauss-Jordan elimination
244        for i in 0..n {
245            // Find pivot
246            let mut max_row = i;
247            let mut max_val = aug[[i, i]].abs();
248            for k in (i + 1)..n {
249                if aug[[k, i]].abs() > max_val {
250                    max_val = aug[[k, i]].abs();
251                    max_row = k;
252                }
253            }
254
255            if max_val < f64::EPSILON {
256                return Err(AnalyticsError::matrix_error("Matrix is singular"));
257            }
258
259            // Swap rows
260            if max_row != i {
261                for j in 0..(2 * n) {
262                    let tmp = aug[[i, j]];
263                    aug[[i, j]] = aug[[max_row, j]];
264                    aug[[max_row, j]] = tmp;
265                }
266            }
267
268            // Eliminate column
269            let pivot = aug[[i, i]];
270            for j in 0..(2 * n) {
271                aug[[i, j]] /= pivot;
272            }
273
274            for k in 0..n {
275                if k != i {
276                    let factor = aug[[k, i]];
277                    for j in 0..(2 * n) {
278                        aug[[k, j]] -= factor * aug[[i, j]];
279                    }
280                }
281            }
282        }
283
284        // Extract inverse matrix
285        let mut inverse = Array2::zeros((n, n));
286        for i in 0..n {
287            for j in 0..n {
288                inverse[[i, j]] = aug[[i, n + j]];
289            }
290        }
291
292        Ok(inverse)
293    }
294
295    /// Interpolate at a single point
296    fn interpolate_point(
297        &self,
298        target: &scirs2_core::ndarray::ArrayView1<f64>,
299        points: &Array2<f64>,
300        values: &ArrayView1<f64>,
301        weights_matrix: &Array2<f64>,
302    ) -> Result<(f64, f64)> {
303        let n = points.nrows();
304
305        // Build right-hand side vector
306        let rhs_size = match self.kriging_type {
307            KrigingType::Ordinary => n + 1,
308            KrigingType::Universal => n + 4,
309        };
310
311        let mut rhs = Array1::zeros(rhs_size);
312
313        // Fill in covariances to target
314        for i in 0..n {
315            let dist = self.calculate_distance(&points.row(i), target)?;
316            rhs[i] = self.variogram.sill - self.variogram.evaluate(dist);
317        }
318
319        // Add constraints
320        match self.kriging_type {
321            KrigingType::Ordinary => {
322                rhs[n] = 1.0;
323            }
324            KrigingType::Universal => {
325                rhs[n] = 1.0;
326                rhs[n + 1] = target[0];
327                rhs[n + 2] = target[1];
328                rhs[n + 3] = target[0] * target[1];
329            }
330        }
331
332        // Solve for weights
333        let mut weights: Array1<f64> = Array1::zeros(rhs_size);
334        for i in 0..rhs_size {
335            for j in 0..rhs_size {
336                weights[i] += weights_matrix[[i, j]] * rhs[j];
337            }
338        }
339
340        // Calculate interpolated value
341        let mut value: f64 = 0.0;
342        for i in 0..n {
343            value += weights[i] * values[i];
344        }
345
346        // Calculate kriging variance
347        let mut variance = self.variogram.sill;
348        for i in 0..rhs_size {
349            variance -= weights[i] * rhs[i];
350        }
351
352        Ok((value, variance.max(0.0)))
353    }
354
355    /// Calculate distance between two points
356    fn calculate_distance(
357        &self,
358        p1: &scirs2_core::ndarray::ArrayView1<f64>,
359        p2: &scirs2_core::ndarray::ArrayView1<f64>,
360    ) -> Result<f64> {
361        if p1.len() != p2.len() {
362            return Err(AnalyticsError::dimension_mismatch(
363                format!("{}", p1.len()),
364                format!("{}", p2.len()),
365            ));
366        }
367
368        let dist_sq: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
369        Ok(dist_sq.sqrt())
370    }
371}
372
373/// Semivariogram calculator
374pub struct SemivariogramCalculator;
375
376impl SemivariogramCalculator {
377    /// Calculate experimental semivariogram
378    ///
379    /// # Arguments
380    /// * `points` - Point coordinates
381    /// * `values` - Values at points
382    /// * `n_bins` - Number of distance bins
383    ///
384    /// # Errors
385    /// Returns error if calculation fails
386    pub fn calculate(
387        points: &Array2<f64>,
388        values: &ArrayView1<f64>,
389        n_bins: usize,
390    ) -> Result<(Array1<f64>, Array1<f64>)> {
391        let n = points.nrows();
392        if values.len() != n {
393            return Err(AnalyticsError::dimension_mismatch(
394                format!("{}", n),
395                format!("{}", values.len()),
396            ));
397        }
398
399        // Calculate all pairwise distances and semivariances
400        let mut pairs = Vec::new();
401        for i in 0..n {
402            for j in (i + 1)..n {
403                let mut dist_sq = 0.0;
404                for k in 0..points.ncols() {
405                    let diff = points[[i, k]] - points[[j, k]];
406                    dist_sq += diff * diff;
407                }
408                let dist = dist_sq.sqrt();
409                let semivar = 0.5 * (values[i] - values[j]).powi(2);
410                pairs.push((dist, semivar));
411            }
412        }
413
414        if pairs.is_empty() {
415            return Err(AnalyticsError::insufficient_data("Need at least 2 points"));
416        }
417
418        // Find max distance for binning
419        let max_dist = pairs
420            .iter()
421            .map(|(d, _)| *d)
422            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
423            .ok_or_else(|| AnalyticsError::insufficient_data("No valid distances"))?;
424
425        let bin_width = max_dist / (n_bins as f64);
426
427        // Bin semivariances
428        let mut bin_sums = vec![0.0; n_bins];
429        let mut bin_counts = vec![0usize; n_bins];
430
431        for (dist, semivar) in pairs {
432            let bin = ((dist / bin_width).floor() as usize).min(n_bins - 1);
433            bin_sums[bin] += semivar;
434            bin_counts[bin] += 1;
435        }
436
437        // Calculate average semivariance for each bin
438        let mut distances = Vec::new();
439        let mut semivariances = Vec::new();
440
441        for i in 0..n_bins {
442            if bin_counts[i] > 0 {
443                distances.push((i as f64 + 0.5) * bin_width);
444                semivariances.push(bin_sums[i] / (bin_counts[i] as f64));
445            }
446        }
447
448        Ok((Array1::from_vec(distances), Array1::from_vec(semivariances)))
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use approx::assert_abs_diff_eq;
456    use scirs2_core::ndarray::array;
457
458    #[test]
459    fn test_variogram_spherical() {
460        let var = Variogram::new(VariogramModel::Spherical, 0.1, 1.0, 10.0);
461
462        assert_abs_diff_eq!(var.evaluate(0.0), 0.0, epsilon = 1e-10);
463        assert_abs_diff_eq!(var.evaluate(10.0), 1.0, epsilon = 1e-10);
464        assert_abs_diff_eq!(var.evaluate(20.0), 1.0, epsilon = 1e-10);
465    }
466
467    #[test]
468    fn test_kriging_simple() {
469        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
470        let values = array![1.0, 2.0, 3.0, 4.0];
471        let targets = array![[0.5, 0.5]];
472
473        let var = Variogram::new(VariogramModel::Spherical, 0.0, 1.0, 2.0);
474        let interpolator = KrigingInterpolator::new(KrigingType::Ordinary, var);
475
476        let result = interpolator
477            .interpolate(&points, &values.view(), &targets)
478            .expect("Kriging interpolation should succeed for valid data");
479
480        assert_eq!(result.values.len(), 1);
481        assert_eq!(result.variances.len(), 1);
482        assert!(result.values[0] > 2.0 && result.values[0] < 3.0);
483    }
484
485    #[test]
486    fn test_semivariogram_calculation() {
487        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
488        let values = array![1.0, 2.0, 3.0];
489
490        let (distances, semivariances) =
491            SemivariogramCalculator::calculate(&points, &values.view(), 3)
492                .expect("Semivariogram calculation should succeed");
493
494        assert!(!distances.is_empty());
495        assert_eq!(distances.len(), semivariances.len());
496    }
497}