1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//! A library for multidimensional interpolation.

use nalgebra::{DMatrix, DVector, SVD};

pub enum Basis {
    PolyHarmonic(i32),
    Gaussian(f64),
    MultiQuadric(f64),
    InverseMultiQuadric(f64),
}

pub struct Scatter {
    // Note: could make basis a type-level parameter
    basis: Basis,
    // TODO(explore): use matrix & slicing instead (fewer allocs).
    // An array of n vectors each of size m.
    centers: Vec<DVector<f64>>,
    // An m x n' matrix, where n' is the number of basis functions (including polynomial),
    // and m is the number of coords.
    deltas: DMatrix<f64>,
}

impl Basis {
    fn eval(&self, r: f64) -> f64 {
        match self {
            Basis::PolyHarmonic(n) if n % 2 == 0 => {
                // Somewhat arbitrary but don't expect tiny nonzero values.
                if r < 1e-12 {
                    0.0
                } else {
                    r.powi(*n) * r.ln()
                }
            }
            Basis::PolyHarmonic(n) => r.powi(*n),
            // Note: it might be slightly more efficient to pre-recip c, but
            // let's keep code clean for now.
            Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
            Basis::MultiQuadric(c) => r.hypot(*c),
            Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
        }
    }
}

impl Scatter {
    pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
        let n = self.centers.len();
        let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
            if row < n {
                // component from basis functions
                self.basis.eval((&coords - &self.centers[row]).norm())
            } else if row == n {
                // constant component
                1.0
            } else {
                // linear component
                coords[row - n - 1]
            }
        });
        &self.deltas * basis
    }

    // The order for the polynomial part, meaning terms up to (order - 1) are included.
    // This usage is consistent with Wilna du Toit's masters thesis "Radial Basis
    // Function Interpolation"
    pub fn create(
        centers: Vec<DVector<f64>>,
        vals: Vec<DVector<f64>>,
        basis: Basis,
        order: usize,
    ) -> Scatter {
        let n = centers.len();
        // n x m matrix. There's probably a better way to do this, ah well.
        let mut vals = DMatrix::from_columns(&vals).transpose();
        let n_aug = match order {
            // Pure radial basis functions
            0 => n,
            // Constant term
            1 => n + 1,
            // Affine terms
            2 => n + 1 + centers[0].len(),
            _ => unimplemented!("don't yet support higher order polynomials"),
        };
        // Augment to n' x m matrix, where n' is the total number of basis functions.
        if n_aug > n {
            vals = vals.resize_vertically(n_aug, 0.0);
        }
        // We translate the system to center the mean at the origin so that when
        // the system is degenerate, the pseudoinverse below minimizes the linear
        // coefficients.
        let means: Vec<_> = if order == 2 {
            let n = centers.len();
            let n_recip = (n as f64).recip();
            (0..centers[0].len())
                .map(|i| centers.iter().map(|c| c[i]).sum::<f64>() * n_recip)
                .collect()
        } else {
            Vec::new()
        };
        let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
            if r < n && c < n {
                basis.eval((&centers[r] - &centers[c]).norm())
            } else if r < n {
                if c == n {
                    1.0
                } else {
                    centers[r][c - n - 1] - means[c - n - 1]
                }
            } else if c < n {
                if r == n {
                    1.0
                } else {
                    centers[c][r - n - 1] - means[r - n - 1]
                }
            } else {
                0.0
            }
        });
        // inv is an n' x n' matrix.
        let svd = SVD::new(mat, true, true);
        // Use pseudo-inverse here to get "least squares fit" when there's
        // no unique result (for example, when dimensionality is too small).
        let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
        // Again, this transpose feels like I don't know what I'm doing.
        let mut deltas = (inv * vals).transpose();
        if order == 2 {
            let m = centers[0].len();
            for i in 0..deltas.nrows() {
                let offset: f64 = (0..m).map(|j| means[j] * deltas[(i, n + 1 + j)]).sum();
                deltas[(i, n)] -= offset;
            }
        }
        Scatter {
            basis,
            centers,
            deltas,
        }
    }
}