1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! A library for multidimensional interpolation.

use nalgebra::{DMatrix, DVector, SVD};

pub enum Basis {
    PolyHarmonic(i32),
    Gaussian(f64),
    MultiQuadric(f64),
    InverseMultiQuadric(f64),
}

pub struct Scatter {
    // Note: could make basis a type-level parameter
    basis: Basis,
    // TODO(explore): use matrix & slicing instead (fewer allocs).
    // An array of n vectors each of size m.
    centers: Vec<DVector<f64>>,
    // An m x n' matrix, where n' is the number of basis functions (including polynomial),
    // and m is the number of coords.
    deltas: DMatrix<f64>,
}

impl Basis {
    fn eval(&self, r: f64) -> f64 {
        match self {
            Basis::PolyHarmonic(n) if n % 2 == 0 => {
                // Somewhat arbitrary but don't expect tiny nonzero values.
                if r < 1e-12 {
                    0.0
                } else {
                    r.powi(*n) * r.ln()
                }
            }
            Basis::PolyHarmonic(n) => r.powi(*n),
            // Note: it might be slightly more efficient to pre-recip c, but
            // let's keep code clean for now.
            Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
            Basis::MultiQuadric(c) => r.hypot(*c),
            Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
        }
    }
}

impl Scatter {
    pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
        let n = self.centers.len();
        let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
            if row < n {
                // component from basis functions
                self.basis.eval((&coords - &self.centers[row]).norm())
            } else if row == n {
                // constant component
                1.0
            } else {
                // linear component
                coords[row - n - 1]
            }
        });
        &self.deltas * basis
    }

    // The order for the polynomial part, meaning terms up to (order - 1) are included.
    // This usage is consistent with Wilna du Toit's masters thesis "Radial Basis
    // Function Interpolation"
    pub fn create(
        centers: Vec<DVector<f64>>,
        vals: Vec<DVector<f64>>,
        basis: Basis,
        order: usize,
    ) -> Scatter {
        let n = centers.len();
        // n x m matrix. There's probably a better way to do this, ah well.
        let mut vals = DMatrix::from_columns(&vals).transpose();
        let n_aug = match order {
            // Pure radial basis functions
            0 => n,
            // Constant term
            1 => n + 1,
            // Affine terms
            2 => n + 1 + centers[0].len(),
            _ => unimplemented!("don't yet support higher order polynomials"),
        };
        // Augment to n' x m matrix, where n' is the total number of basis functions.
        if n_aug > n {
            vals = vals.resize_vertically(n_aug, 0.0);
        }
        let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
            if r < n && c < n {
                basis.eval((&centers[r] - &centers[c]).norm())
            } else if r < n {
                if c == n {
                    1.0
                } else {
                    centers[r][c - n - 1]
                }
            } else if c < n {
                if r == n {
                    1.0
                } else {
                    centers[c][r - n - 1]
                }
            } else {
                0.0
            }
        });
        // inv is an n' x n' matrix.
        let svd = SVD::new(mat, true, true);
        // Use pseudo-inverse here to get "least squares fit" when there's
        // no unique result (for example, when dimensionality is too small).
        let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
        // Again, this transpose feels like I don't know what I'm doing.
        let deltas = (inv * vals).transpose();
        Scatter {
            basis,
            centers,
            deltas,
        }
    }
}