ki/
lib.rs

1//! A library for multidimensional interpolation.
2
3use nalgebra::{DMatrix, DVector, SVD};
4
5pub enum Basis {
6    PolyHarmonic(i32),
7    Gaussian(f64),
8    MultiQuadric(f64),
9    InverseMultiQuadric(f64),
10}
11
12pub struct Scatter {
13    // Note: could make basis a type-level parameter
14    basis: Basis,
15    // TODO(explore): use matrix & slicing instead (fewer allocs).
16    // An array of n vectors each of size m.
17    centers: Vec<DVector<f64>>,
18    // An m x n' matrix, where n' is the number of basis functions (including polynomial),
19    // and m is the number of coords.
20    deltas: DMatrix<f64>,
21}
22
23impl Basis {
24    fn eval(&self, r: f64) -> f64 {
25        match self {
26            Basis::PolyHarmonic(n) if n % 2 == 0 => {
27                // Somewhat arbitrary but don't expect tiny nonzero values.
28                if r < 1e-12 {
29                    0.0
30                } else {
31                    r.powi(*n) * r.ln()
32                }
33            }
34            Basis::PolyHarmonic(n) => r.powi(*n),
35            // Note: it might be slightly more efficient to pre-recip c, but
36            // let's keep code clean for now.
37            Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
38            Basis::MultiQuadric(c) => r.hypot(*c),
39            Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
40        }
41    }
42}
43
44impl Scatter {
45    pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
46        let n = self.centers.len();
47        let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
48            if row < n {
49                // component from basis functions
50                self.basis.eval((&coords - &self.centers[row]).norm())
51            } else if row == n {
52                // constant component
53                1.0
54            } else {
55                // linear component
56                coords[row - n - 1]
57            }
58        });
59        &self.deltas * basis
60    }
61
62    // The order for the polynomial part, meaning terms up to (order - 1) are included.
63    // This usage is consistent with Wilna du Toit's masters thesis "Radial Basis
64    // Function Interpolation"
65    pub fn create(
66        centers: Vec<DVector<f64>>,
67        vals: Vec<DVector<f64>>,
68        basis: Basis,
69        order: usize,
70    ) -> Scatter {
71        let n = centers.len();
72        // n x m matrix. There's probably a better way to do this, ah well.
73        let mut vals = DMatrix::from_columns(&vals).transpose();
74        let n_aug = match order {
75            // Pure radial basis functions
76            0 => n,
77            // Constant term
78            1 => n + 1,
79            // Affine terms
80            2 => n + 1 + centers[0].len(),
81            _ => unimplemented!("don't yet support higher order polynomials"),
82        };
83        // Augment to n' x m matrix, where n' is the total number of basis functions.
84        if n_aug > n {
85            vals = vals.resize_vertically(n_aug, 0.0);
86        }
87        // We translate the system to center the mean at the origin so that when
88        // the system is degenerate, the pseudoinverse below minimizes the linear
89        // coefficients.
90        let means: Vec<_> = if order == 2 {
91            let n = centers.len();
92            let n_recip = (n as f64).recip();
93            (0..centers[0].len())
94                .map(|i| centers.iter().map(|c| c[i]).sum::<f64>() * n_recip)
95                .collect()
96        } else {
97            Vec::new()
98        };
99        let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
100            if r < n && c < n {
101                basis.eval((&centers[r] - &centers[c]).norm())
102            } else if r < n {
103                if c == n {
104                    1.0
105                } else {
106                    centers[r][c - n - 1] - means[c - n - 1]
107                }
108            } else if c < n {
109                if r == n {
110                    1.0
111                } else {
112                    centers[c][r - n - 1] - means[r - n - 1]
113                }
114            } else {
115                0.0
116            }
117        });
118        // inv is an n' x n' matrix.
119        let svd = SVD::new(mat, true, true);
120        // Use pseudo-inverse here to get "least squares fit" when there's
121        // no unique result (for example, when dimensionality is too small).
122        let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
123        // Again, this transpose feels like I don't know what I'm doing.
124        let mut deltas = (inv * vals).transpose();
125        if order == 2 {
126            let m = centers[0].len();
127            for i in 0..deltas.nrows() {
128                let offset: f64 = (0..m).map(|j| means[j] * deltas[(i, n + 1 + j)]).sum();
129                deltas[(i, n)] -= offset;
130            }
131        }
132        Scatter {
133            basis,
134            centers,
135            deltas,
136        }
137    }
138}