1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use nalgebra::{DMatrix, DVector, SVD};
pub enum Basis {
    PolyHarmonic(i32),
    Gaussian(f64),
    MultiQuadric(f64),
    InverseMultiQuadric(f64),
}
pub struct Scatter {
    
    basis: Basis,
    
    
    centers: Vec<DVector<f64>>,
    
    
    deltas: DMatrix<f64>,
}
impl Basis {
    fn eval(&self, r: f64) -> f64 {
        match self {
            Basis::PolyHarmonic(n) if n % 2 == 0 => {
                
                if r < 1e-12 {
                    0.0
                } else {
                    r.powi(*n) * r.ln()
                }
            }
            Basis::PolyHarmonic(n) => r.powi(*n),
            
            
            Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
            Basis::MultiQuadric(c) => r.hypot(*c),
            Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
        }
    }
}
impl Scatter {
    pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
        let n = self.centers.len();
        let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
            if row < n {
                
                self.basis.eval((&coords - &self.centers[row]).norm())
            } else if row == n {
                
                1.0
            } else {
                
                coords[row - n - 1]
            }
        });
        &self.deltas * basis
    }
    
    
    
    pub fn create(
        centers: Vec<DVector<f64>>,
        vals: Vec<DVector<f64>>,
        basis: Basis,
        order: usize,
    ) -> Scatter {
        let n = centers.len();
        
        let mut vals = DMatrix::from_columns(&vals).transpose();
        let n_aug = match order {
            
            0 => n,
            
            1 => n + 1,
            
            2 => n + 1 + centers[0].len(),
            _ => unimplemented!("don't yet support higher order polynomials"),
        };
        
        if n_aug > n {
            vals = vals.resize_vertically(n_aug, 0.0);
        }
        let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
            if r < n && c < n {
                basis.eval((¢ers[r] - ¢ers[c]).norm())
            } else if r < n {
                if c == n {
                    1.0
                } else {
                    centers[r][c - n - 1]
                }
            } else if c < n {
                if r == n {
                    1.0
                } else {
                    centers[c][r - n - 1]
                }
            } else {
                0.0
            }
        });
        
        let svd = SVD::new(mat, true, true);
        
        
        let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
        
        let deltas = (inv * vals).transpose();
        Scatter {
            basis,
            centers,
            deltas,
        }
    }
}