scirs2-interpolate 0.4.0

Interpolation module for SciRS2 (scirs2-interpolate)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
//! Fast Universal Kriging Implementation
//!
//! This module contains the implementation of universal kriging for large datasets
//! using various approximation methods to improve computational efficiency.
//! Universal kriging allows for trend modeling using basis functions.

use crate::advanced::enhanced_kriging::{AnisotropicCovariance, TrendFunction};
use crate::advanced::fast_kriging::{
    FastKriging, FastKrigingBuilder, FastKrigingMethod, FastPredictionResult,
};
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::ndarray::ArrayStatCompat;
use std::fmt::{Debug, Display};
use std::ops::{Add, Div, Mul, Sub};

// Import shared utility functions from parent module
use super::covariance::{
use statrs::statistics::Statistics;
    compute_anisotropic_distance, compute_covariance, compute_low_rank_approximation,
    compute_tapered_covariance, find_nearest_neighbors, project_to_feature,
};

/// Create basis functions for the trend model
#[allow(dead_code)]
pub fn create_basis_functions<F: Float + FromPrimitive>(
    points: &ArrayView2<F>,
    trend_fn: TrendFunction,
) -> InterpolateResult<Array2<F>> {
    let n_points = points.shape()[0];
    let n_dims = points.shape()[1];

    match trend_fn {
        TrendFunction::Constant => {
            // Just a constant term: X = [1, 1, ..., 1]
            let mut basis = Array2::zeros((n_points, 1));
            for i in 0..n_points {
                basis[[i, 0]] = F::one();
            }
            Ok(basis)
        }
        TrendFunction::Linear => {
            // Constant term + linear terms: X = [1, x1, x2, ..., xn]
            let n_basis = n_dims + 1;
            let mut basis = Array2::zeros((n_points, n_basis));

            // Constant term
            for i in 0..n_points {
                basis[[i, 0]] = F::one();
            }

            // Linear terms
            for i in 0..n_points {
                for j in 0..n_dims {
                    basis[[i, j + 1]] = points[[i, j]];
                }
            }

            Ok(basis)
        }
        TrendFunction::Quadratic => {
            // Constant + linear + quadratic terms
            // X = [1, x1, x2, ..., xn, x1^2, x1*x2, ..., xn^2]

            // Number of basis functions:
            // 1 (constant) + n_dims (linear) + n_dims*(n_dims+1)/2 (quadratic)
            let n_quad_terms = n_dims * (n_dims + 1) / 2;
            let n_basis = 1 + n_dims + n_quad_terms;

            let mut basis = Array2::zeros((n_points, n_basis));

            // Constant term
            for i in 0..n_points {
                basis[[i, 0]] = F::one();
            }

            // Linear terms
            for i in 0..n_points {
                for j in 0..n_dims {
                    basis[[i, j + 1]] = points[[i, j]];
                }
            }

            // Quadratic terms
            let mut term_idx = 1 + n_dims;
            for i in 0..n_points {
                for j in 0..n_dims {
                    for k in j..n_dims {
                        if j == k {
                            // x_j^2
                            basis[[i, term_idx]] = points[[i, j]] * points[[i, j]];
                        } else {
                            // x_j * x_k
                            basis[[i, term_idx]] = points[[i, j]] * points[[i, k]];
                        }
                        term_idx += 1;
                    }
                }
            }

            Ok(basis)
        }
        TrendFunction::Custom(_) => {
            // For custom basis functions, default to constant
            let mut basis = Array2::zeros((n_points, 1));
            for i in 0..n_points {
                basis[[i, 0]] = F::one();
            }
            Ok(basis)
        }
    }
}

/// Compute trend coefficients using least squares
#[allow(dead_code)]
pub fn compute_trend_coefficients<F: Float + FromPrimitive + 'static>(
    _points: &Array2<F>,
    values: &Array1<F>,
    basis_functions: &Array2<F>, _trend_fn: TrendFunction,
) -> InterpolateResult<Array1<F>> {
    // Basic least squares: β = (X'X)^(-1) X'y
    // Compute matrix products for least squares
    #[allow(unused_variables)]
    let xtx = basis_functions.t().dot(basis_functions);
    #[allow(unused_variables)]
    let xty = basis_functions.t().dot(values);

    #[cfg(feature = "linalg")]
    {
        use ndarray_linalg::Solve;
        // Convert to f64 for linear algebra
        let xtx_f64 = xtx.mapv(|x| x.to_f64().expect("Operation failed"));
        let xty_f64 = xty.mapv(|x| x.to_f64().expect("Operation failed"));
        match xtx_f64.solve(&xty_f64) {
            Ok(coeffs) => Ok(coeffs.mapv(|x| F::from_f64(x).expect("Operation failed"))),
            Err(_) => {
                // Fallback to simple mean for constant trend
                let mut coeffs = Array1::zeros(basis_functions.shape()[1]);
                coeffs[0] = values.mean_or(F::zero());
                Ok(coeffs)
            }
        }
    }

    #[cfg(not(feature = "linalg"))]
    {
        // Without linalg, just use mean as constant
        let mut coeffs = Array1::zeros(basis_functions.shape()[1]);
        coeffs[0] = values.mean_or(F::zero());
        Ok(coeffs)
    }
}

impl<F> FastKriging<F>
where
    F: Float
        + FromPrimitive
        + Debug
        + Display
        + Add<Output = F>
        + Sub<Output = F>
        + Mul<Output = F>
        + Div<Output = F>
        + std::ops::AddAssign
        + std::ops::SubAssign
        + std::ops::MulAssign
        + std::ops::DivAssign
        + std::ops::RemAssign
        + 'static,
{
    /// Local kriging prediction with trend model
    pub fn predict_local_with_trend(
        &self,
        query_points: &ArrayView2<F>,
    ) -> InterpolateResult<FastPredictionResult<F>> {
        let n_query = query_points.shape()[0];
        let mut values = Array1::zeros(n_query);
        let mut variances = Array1::zeros(n_query);

        // Compute basis functions for query _points
        let query_basis = create_basis_functions(query_points, self.trend_fn)?;

        // For each query point
        for i in 0..n_query {
            let query_point = query_points.slice(scirs2_core::ndarray::s![i, ..]);

            // Find nearest neighbors
            let (indices_distances) = find_nearest_neighbors(
                &query_point,
                &self._points,
                self.max_neighbors,
                self.radius_multiplier,
            )?;

            // Skip if no neighbors found
            if indices.is_empty() {
                // Use global mean as fallback
                values[i] = self.values.mean_or(F::zero());
                variances[i] = self.anisotropic_cov.sigma_sq;
                continue;
            }

            // Extract local neighborhood
            let n_neighbors = indices.len();
            let mut local_points = Array2::zeros((n_neighbors, query_point.len()));
            let mut local_values = Array1::zeros(n_neighbors);

            for (j, &idx) in indices.iter().enumerate() {
                local_points
                    .slice_mut(scirs2_core::ndarray::s![j, ..])
                    .assign(&self._points.slice(scirs2_core::ndarray::s![idx, ..]));
                local_values[j] = self.values[idx];
            }

            // Compute local covariance matrix
            let mut cov_matrix = Array2::zeros((n_neighbors, n_neighbors));
            for j in 0..n_neighbors {
                for k in 0..n_neighbors {
                    if j == k {
                        cov_matrix[[j, k]] =
                            self.anisotropic_cov.sigma_sq + self.anisotropic_cov.nugget;
                    } else {
                        let dist = compute_anisotropic_distance(
                            &local_points.slice(scirs2_core::ndarray::s![j, ..]),
                            &local_points.slice(scirs2_core::ndarray::s![k, ..]),
                            &self.anisotropic_cov,
                        )?;
                        cov_matrix[[j, k]] = compute_covariance(dist, &self.anisotropic_cov);
                    }
                }
            }

            // Compute local trend basis
            let local_basis = create_basis_functions(&local_points.view(), self.trend_fn)?;
            let n_basis = local_basis.shape()[1];

            // Augmented system for Universal Kriging
            let mut aug_matrix = Array2::zeros((n_neighbors + n_basis, n_neighbors + n_basis));

            // Fill covariance block
            for j in 0..n_neighbors {
                for k in 0..n_neighbors {
                    aug_matrix[[j, k]] = cov_matrix[[j, k]];
                }
            }

            // Fill basis function blocks
            for j in 0..n_neighbors {
                for k in 0..n_basis {
                    let idx = n_neighbors + k;
                    aug_matrix[[j, idx]] = local_basis[[j, k]];
                    aug_matrix[[idx, j]] = local_basis[[j, k]];
                }
            }

            // Zero block in lower right
            for j in 0..n_basis {
                for k in 0..n_basis {
                    let idx1 = n_neighbors + j;
                    let idx2 = n_neighbors + k;
                    aug_matrix[[idx1, idx2]] = F::zero();
                }
            }

            // Create right-hand side
            let mut rhs = Array1::zeros(n_neighbors + n_basis);
            for j in 0..n_neighbors {
                rhs[j] = local_values[j];
            }

            // Solve the system
            #[cfg(not(feature = "linalg"))]
            {
                // Fallback without linalg
                values[i] = local_values.mean_or(F::zero());
                variances[i] = self.anisotropic_cov.sigma_sq;
                continue;
            }

            // Only gets here if linalg is enabled
            #[cfg(feature = "linalg")]
            {
                use ndarray_linalg::Solve;
                // Convert to f64 for linear algebra
                let aug_matrix_f64 = aug_matrix.mapv(|x| x.to_f64().expect("Operation failed"));
                let rhs_f64 = rhs.mapv(|x| x.to_f64().expect("Operation failed"));
                let solution = match aug_matrix_f64.solve(&rhs_f64) {
                    Ok(sol) => sol.mapv(|x| F::from_f64(x).expect("Operation failed")),
                    Err(_) => {
                        // Fallback to standard kriging if system can't be solved
                        let cov_matrix_f64 = cov_matrix.mapv(|x| x.to_f64().expect("Operation failed"));
                        let local_values_f64 = local_values.mapv(|x| x.to_f64().expect("Operation failed"));
                        let weights = match cov_matrix_f64.solve(&local_values_f64) {
                            Ok(w) => w.mapv(|x| F::from_f64(x).expect("Operation failed")),
                            Err(_) => {
                                // Return mean as last resort
                                values[i] = local_values.mean_or(F::zero());
                                variances[i] = self.anisotropic_cov.sigma_sq;
                                continue;
                            }
                        };

                        // Use weights for prediction
                        let mut prediction = F::zero();
                        for j in 0..n_neighbors {
                            prediction += weights[j] * local_values[j];
                        }

                        // Return basic prediction
                        values[i] = prediction;
                        variances[i] = self.anisotropic_cov.sigma_sq; // Simplified variance
                        continue;
                    }
                };

                // Extract weights and trend coefficients
                let weights = solution.slice(scirs2_core::ndarray::s![0..n_neighbors]).to_owned();
                let trend_coeffs = solution.slice(scirs2_core::ndarray::s![n_neighbors..]).to_owned();

                // Compute prediction
                let mut trend = F::zero();
                for j in 0..n_basis {
                    trend += trend_coeffs[j] * query_basis[[i, j]];
                }

                let mut residual = F::zero();
                for j in 0..n_neighbors {
                    residual += weights[j] * local_values[j];
                }

                // Store prediction
                values[i] = trend + residual;

                // Compute approximate variance (simplified)
                let mut k_star = Array1::zeros(n_neighbors);
                for j in 0..n_neighbors {
                    let dist = compute_anisotropic_distance(
                        &query_point,
                        &local_points.slice(scirs2_core::ndarray::s![j, ..]),
                        &self.anisotropic_cov,
                    )?;
                    k_star[j] = compute_covariance(dist, &self.anisotropic_cov);
                }

                let variance = self.anisotropic_cov.sigma_sq - k_star.dot(&weights);
                variances[i] = if variance < F::zero() {
                    F::zero()
                } else {
                    variance
                };
            }
        }

        Ok(FastPredictionResult {
            value: values,
            variance: variances,
            method: self.approx_method,
            computation_time_ms: None,
        })
    }

    /// HODLR kriging prediction with hierarchical matrices (universal version)
    pub fn predict_hodlr(
        &self,
        query_points: &ArrayView2<F>,
    ) -> InterpolateResult<FastPredictionResult<F>> {
        // Extract leaf size from method
        let leaf_size = match self.approx_method {
            FastKrigingMethod::HODLR(size) => size_ => {
                return Err(InterpolateError::InvalidOperation(
                    "Invalid method type for HODLR prediction".to_string(),
                ));
            }
        };

        // For HODLR, we divide the dataset into a hierarchical tree of blocks
        // For this implementation, we'll use a simplified approach with recursion

        let n_query = query_points.shape()[0];
        let mut values = Array1::zeros(n_query);
        let mut variances = Array1::zeros(n_query);

        // Compute basis functions for query _points if needed for universal kriging
        let query_basis = create_basis_functions(query_points, self.trend_fn)?;

        // Create temporary trend coefficients if not pre-computed
        let trend_coeffs = match &self.trend_coeffs {
            Some(coeffs) => coeffs.clone(),
            None => {
                // We need to compute trend coefficients
                let basis_functions = match &self.basis_functions {
                    Some(basis) => basis,
                    None => {
                        // Create basis functions first
                        &create_basis_functions(&self._points.view(), self.trend_fn)?
                    }
                };

                compute_trend_coefficients(
                    &self._points,
                    &self.values,
                    basis_functions,
                    self.trend_fn,
                )?
            }
        };

        // Recursion helper: Start with the full dataset
        let n_train = self._points.shape()[0];
        let train_indices: Vec<usize> = (0..n_train).collect();

        // For each query point
        for i in 0..n_query {
            let query_point = query_points.slice(scirs2_core::ndarray::s![i, ..]);

            // Compute trend component
            let mut trend = F::zero();
            for j in 0..query_basis.shape()[1] {
                trend += trend_coeffs[j] * query_basis[[i, j]];
            }

            // Compute prediction using HODLR approximation
            let residual =
                self.hodlr_predict_point(&query_point, &train_indices, leaf_size, trend)?;

            // Store results
            values[i] = trend + residual;

            // For variance, we use a simplified approximation
            // In a full implementation, this would involve traversing the hierarchy
            variances[i] = self.anisotropic_cov.sigma_sq * F::from_f64(0.1).expect("Operation failed");
        }

        Ok(FastPredictionResult {
            value: values,
            variance: variances,
            method: self.approx_method,
            computation_time_ms: None,
        })
    }

    /// Helper function for HODLR prediction of a single point
    pub fn hodlr_predict_point(
        &self,
        query_point: &ArrayView1<F>,
        indices: &[usize],
        leaf_size: usize,
        trend: F,
    ) -> InterpolateResult<F> {
        // If we're at a leaf node or only have a few points, solve directly
        if indices.len() <= leaf_size {
            // Use direct solution for small blocks
            let n_points = indices.len();

            // Extract subset of points
            let mut block_points = Array2::zeros((n_points, query_point.len()));
            let mut block_values = Array1::zeros(n_points);

            for (i, &idx) in indices.iter().enumerate() {
                block_points
                    .slice_mut(scirs2_core::ndarray::s![i, ..])
                    .assign(&self.points.slice(scirs2_core::ndarray::s![idx, ..]));
                block_values[i] = self.values[idx] - trend; // Use residuals
            }

            // Compute covariance matrix for this block
            let mut cov_matrix = Array2::zeros((n_points, n_points));
            for i in 0..n_points {
                for j in 0..n_points {
                    if i == j {
                        cov_matrix[[i, j]] =
                            self.anisotropic_cov.sigma_sq + self.anisotropic_cov.nugget;
                    } else {
                        let dist = compute_anisotropic_distance(
                            &block_points.slice(scirs2_core::ndarray::s![i, ..]),
                            &block_points.slice(scirs2_core::ndarray::s![j, ..]),
                            &self.anisotropic_cov,
                        )?;
                        cov_matrix[[i, j]] = compute_covariance(dist, &self.anisotropic_cov);
                    }
                }
            }

            // Compute cross-covariance vector
            let mut k_star = Array1::zeros(n_points);
            for i in 0..n_points {
                let dist = compute_anisotropic_distance(
                    query_point,
                    &block_points.slice(scirs2_core::ndarray::s![i, ..]),
                    &self.anisotropic_cov,
                )?;
                k_star[i] = compute_covariance(dist, &self.anisotropic_cov);
            }

            // Solve the system for weights
            #[cfg(feature = "linalg")]
            let weights = {
                use ndarray_linalg::Solve;
                // Convert to f64 for linear algebra
                let cov_matrix_f64 = cov_matrix.mapv(|x| x.to_f64().expect("Operation failed"));
                let block_values_f64 = block_values.mapv(|x| x.to_f64().expect("Operation failed"));
                match cov_matrix_f64.solve(&block_values_f64) {
                    Ok(w) => w.mapv(|x| F::from_f64(x).expect("Operation failed")),
                    Err(_) => {
                        // Fallback to diagonal approximation
                        let mut w = Array1::zeros(n_points);
                        for i in 0..n_points {
                            w[i] = block_values[i]
                                / (self.anisotropic_cov.sigma_sq + self.anisotropic_cov.nugget);
                        }
                        w
                    }
                }
            };

            #[cfg(not(feature = "linalg"))]
            let weights = {
                // Fallback without linalg
                let mut w = Array1::zeros(n_points);
                for i in 0..n_points {
                    w[i] = block_values[i]
                        / (self.anisotropic_cov.sigma_sq + self.anisotropic_cov.nugget);
                }
                w
            };

            // Compute prediction
            let prediction = k_star.dot(&weights);

            return Ok(prediction);
        }

        // Otherwise, partition the points into near and far sets
        // For simplicity, we split on the dimension with largest extent
        let n_dims = self.points.shape()[1];
        let mut max_extent = F::zero();
        let mut split_dim = 0;

        for d in 0..n_dims {
            let mut min_val = F::infinity();
            let mut max_val = F::neg_infinity();

            for &idx in indices {
                let val = self.points[[idx, d]];
                if val < min_val {
                    min_val = val;
                }
                if val > max_val {
                    max_val = val;
                }
            }

            let extent = max_val - min_val;
            if extent > max_extent {
                max_extent = extent;
                split_dim = d;
            }
        }

        // Find median value for splitting
        let mut values_at_dim: Vec<F> = indices
            .iter()
            .map(|&idx| self.points[[idx, split_dim]])
            .collect();
        values_at_dim.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));

        // Partition into near and far sets
        let query_val = query_point[split_dim];
        let (near_indices, far_indices): (Vec<usize>, Vec<usize>) =
            indices.iter().copied().partition(|&idx| {
                let dist = (self.points[[idx, split_dim]] - query_val).abs();
                dist <= max_extent * F::from_f64(0.5).expect("Operation failed")
            });

        // Recursively compute prediction for near points
        let near_prediction = if !near_indices.is_empty() {
            self.hodlr_predict_point(query_point, &near_indices, leaf_size, trend)?
        } else {
            F::zero()
        };

        // For far points, use low-rank approximation based on a few sample points
        let far_prediction = if !far_indices.is_empty() {
            // Select a subsample of far points for low-rank approximation
            let n_samples =
                std::cmp::min(far_indices.len(), std::cmp::max(5, far_indices.len() / 10));

            let step = if n_samples >= far_indices.len() {
                1
            } else {
                far_indices.len() / n_samples
            };

            let mut sample_indices = Vec::with_capacity(n_samples);
            for i in (0..far_indices.len()).step_by(step) {
                if sample_indices.len() < n_samples {
                    sample_indices.push(far_indices[i]);
                } else {
                    break;
                }
            }

            // Compute a simplified low-rank approximation for the far points
            self.hodlr_predict_point(query_point, &sample_indices, leaf_size, trend)?
                * F::from_f64(far_indices.len() as f64 / sample_indices.len() as f64).expect("Operation failed")
        } else {
            F::zero()
        };

        // Combine results with appropriate weighting
        let total_points = near_indices.len() + far_indices.len();
        let near_weight = F::from_f64(near_indices.len() as f64 / total_points as f64).expect("Operation failed");
        let far_weight = F::from_f64(far_indices.len() as f64 / total_points as f64).expect("Operation failed");

        Ok(near_weight * near_prediction + far_weight * far_prediction)
    }
}

/// Builder extension methods for universal kriging
impl<F> FastKrigingBuilder<F>
where
    F: Float
        + FromPrimitive
        + Debug
        + Display
        + Add<Output = F>
        + Sub<Output = F>
        + Mul<Output = F>
        + Div<Output = F>
        + std::ops::AddAssign
        + std::ops::SubAssign
        + std::ops::MulAssign
        + std::ops::DivAssign
        + std::ops::RemAssign
        + 'static,
{
    /// Build a universal fast kriging model
    pub fn build_universal(self) -> InterpolateResult<FastKriging<F>> {
        FastKriging::from_builder(self)
    }
}