ki/lib.rs
1//! A library for multidimensional interpolation.
2
3use nalgebra::{DMatrix, DVector, SVD};
4
5pub enum Basis {
6 PolyHarmonic(i32),
7 Gaussian(f64),
8 MultiQuadric(f64),
9 InverseMultiQuadric(f64),
10}
11
12pub struct Scatter {
13 // Note: could make basis a type-level parameter
14 basis: Basis,
15 // TODO(explore): use matrix & slicing instead (fewer allocs).
16 // An array of n vectors each of size m.
17 centers: Vec<DVector<f64>>,
18 // An m x n' matrix, where n' is the number of basis functions (including polynomial),
19 // and m is the number of coords.
20 deltas: DMatrix<f64>,
21}
22
23impl Basis {
24 fn eval(&self, r: f64) -> f64 {
25 match self {
26 Basis::PolyHarmonic(n) if n % 2 == 0 => {
27 // Somewhat arbitrary but don't expect tiny nonzero values.
28 if r < 1e-12 {
29 0.0
30 } else {
31 r.powi(*n) * r.ln()
32 }
33 }
34 Basis::PolyHarmonic(n) => r.powi(*n),
35 // Note: it might be slightly more efficient to pre-recip c, but
36 // let's keep code clean for now.
37 Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
38 Basis::MultiQuadric(c) => r.hypot(*c),
39 Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
40 }
41 }
42}
43
44impl Scatter {
45 pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
46 let n = self.centers.len();
47 let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
48 if row < n {
49 // component from basis functions
50 self.basis.eval((&coords - &self.centers[row]).norm())
51 } else if row == n {
52 // constant component
53 1.0
54 } else {
55 // linear component
56 coords[row - n - 1]
57 }
58 });
59 &self.deltas * basis
60 }
61
62 // The order for the polynomial part, meaning terms up to (order - 1) are included.
63 // This usage is consistent with Wilna du Toit's masters thesis "Radial Basis
64 // Function Interpolation"
65 pub fn create(
66 centers: Vec<DVector<f64>>,
67 vals: Vec<DVector<f64>>,
68 basis: Basis,
69 order: usize,
70 ) -> Scatter {
71 let n = centers.len();
72 // n x m matrix. There's probably a better way to do this, ah well.
73 let mut vals = DMatrix::from_columns(&vals).transpose();
74 let n_aug = match order {
75 // Pure radial basis functions
76 0 => n,
77 // Constant term
78 1 => n + 1,
79 // Affine terms
80 2 => n + 1 + centers[0].len(),
81 _ => unimplemented!("don't yet support higher order polynomials"),
82 };
83 // Augment to n' x m matrix, where n' is the total number of basis functions.
84 if n_aug > n {
85 vals = vals.resize_vertically(n_aug, 0.0);
86 }
87 // We translate the system to center the mean at the origin so that when
88 // the system is degenerate, the pseudoinverse below minimizes the linear
89 // coefficients.
90 let means: Vec<_> = if order == 2 {
91 let n = centers.len();
92 let n_recip = (n as f64).recip();
93 (0..centers[0].len())
94 .map(|i| centers.iter().map(|c| c[i]).sum::<f64>() * n_recip)
95 .collect()
96 } else {
97 Vec::new()
98 };
99 let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
100 if r < n && c < n {
101 basis.eval((¢ers[r] - ¢ers[c]).norm())
102 } else if r < n {
103 if c == n {
104 1.0
105 } else {
106 centers[r][c - n - 1] - means[c - n - 1]
107 }
108 } else if c < n {
109 if r == n {
110 1.0
111 } else {
112 centers[c][r - n - 1] - means[r - n - 1]
113 }
114 } else {
115 0.0
116 }
117 });
118 // inv is an n' x n' matrix.
119 let svd = SVD::new(mat, true, true);
120 // Use pseudo-inverse here to get "least squares fit" when there's
121 // no unique result (for example, when dimensionality is too small).
122 let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
123 // Again, this transpose feels like I don't know what I'm doing.
124 let mut deltas = (inv * vals).transpose();
125 if order == 2 {
126 let m = centers[0].len();
127 for i in 0..deltas.nrows() {
128 let offset: f64 = (0..m).map(|j| means[j] * deltas[(i, n + 1 + j)]).sum();
129 deltas[(i, n)] -= offset;
130 }
131 }
132 Scatter {
133 basis,
134 centers,
135 deltas,
136 }
137 }
138}