ki/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
//! A library for multidimensional interpolation.
use nalgebra::{DMatrix, DVector, SVD};
pub enum Basis {
PolyHarmonic(i32),
Gaussian(f64),
MultiQuadric(f64),
InverseMultiQuadric(f64),
}
pub struct Scatter {
// Note: could make basis a type-level parameter
basis: Basis,
// TODO(explore): use matrix & slicing instead (fewer allocs).
// An array of n vectors each of size m.
centers: Vec<DVector<f64>>,
// An m x n' matrix, where n' is the number of basis functions (including polynomial),
// and m is the number of coords.
deltas: DMatrix<f64>,
}
impl Basis {
fn eval(&self, r: f64) -> f64 {
match self {
Basis::PolyHarmonic(n) if n % 2 == 0 => {
// Somewhat arbitrary but don't expect tiny nonzero values.
if r < 1e-12 {
0.0
} else {
r.powi(*n) * r.ln()
}
}
Basis::PolyHarmonic(n) => r.powi(*n),
// Note: it might be slightly more efficient to pre-recip c, but
// let's keep code clean for now.
Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
Basis::MultiQuadric(c) => r.hypot(*c),
Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
}
}
}
impl Scatter {
pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
let n = self.centers.len();
let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
if row < n {
// component from basis functions
self.basis.eval((&coords - &self.centers[row]).norm())
} else if row == n {
// constant component
1.0
} else {
// linear component
coords[row - n - 1]
}
});
&self.deltas * basis
}
// The order for the polynomial part, meaning terms up to (order - 1) are included.
// This usage is consistent with Wilna du Toit's masters thesis "Radial Basis
// Function Interpolation"
pub fn create(
centers: Vec<DVector<f64>>,
vals: Vec<DVector<f64>>,
basis: Basis,
order: usize,
) -> Scatter {
let n = centers.len();
// n x m matrix. There's probably a better way to do this, ah well.
let mut vals = DMatrix::from_columns(&vals).transpose();
let n_aug = match order {
// Pure radial basis functions
0 => n,
// Constant term
1 => n + 1,
// Affine terms
2 => n + 1 + centers[0].len(),
_ => unimplemented!("don't yet support higher order polynomials"),
};
// Augment to n' x m matrix, where n' is the total number of basis functions.
if n_aug > n {
vals = vals.resize_vertically(n_aug, 0.0);
}
// We translate the system to center the mean at the origin so that when
// the system is degenerate, the pseudoinverse below minimizes the linear
// coefficients.
let means: Vec<_> = if order == 2 {
let n = centers.len();
let n_recip = (n as f64).recip();
(0..centers[0].len())
.map(|i| centers.iter().map(|c| c[i]).sum::<f64>() * n_recip)
.collect()
} else {
Vec::new()
};
let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
if r < n && c < n {
basis.eval((¢ers[r] - ¢ers[c]).norm())
} else if r < n {
if c == n {
1.0
} else {
centers[r][c - n - 1] - means[c - n - 1]
}
} else if c < n {
if r == n {
1.0
} else {
centers[c][r - n - 1] - means[r - n - 1]
}
} else {
0.0
}
});
// inv is an n' x n' matrix.
let svd = SVD::new(mat, true, true);
// Use pseudo-inverse here to get "least squares fit" when there's
// no unique result (for example, when dimensionality is too small).
let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
// Again, this transpose feels like I don't know what I'm doing.
let mut deltas = (inv * vals).transpose();
if order == 2 {
let m = centers[0].len();
for i in 0..deltas.nrows() {
let offset: f64 = (0..m).map(|j| means[j] * deltas[(i, n + 1 + j)]).sum();
deltas[(i, n)] -= offset;
}
}
Scatter {
basis,
centers,
deltas,
}
}
}