1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use nalgebra::{DMatrix, DVector, SVD};
pub enum Basis {
PolyHarmonic(i32),
Gaussian(f64),
MultiQuadric(f64),
InverseMultiQuadric(f64),
}
pub struct Scatter {
basis: Basis,
centers: Vec<DVector<f64>>,
deltas: DMatrix<f64>,
}
impl Basis {
fn eval(&self, r: f64) -> f64 {
match self {
Basis::PolyHarmonic(n) if n % 2 == 0 => {
if r < 1e-12 {
0.0
} else {
r.powi(*n) * r.ln()
}
}
Basis::PolyHarmonic(n) => r.powi(*n),
Basis::Gaussian(c) => (-(r / c).powi(2)).exp(),
Basis::MultiQuadric(c) => r.hypot(*c),
Basis::InverseMultiQuadric(c) => (r * r + c * c).powf(-0.5),
}
}
}
impl Scatter {
pub fn eval(&self, coords: DVector<f64>) -> DVector<f64> {
let n = self.centers.len();
let basis = DVector::from_fn(self.deltas.ncols(), |row, _c| {
if row < n {
self.basis.eval((&coords - &self.centers[row]).norm())
} else if row == n {
1.0
} else {
coords[row - n - 1]
}
});
&self.deltas * basis
}
pub fn create(
centers: Vec<DVector<f64>>,
vals: Vec<DVector<f64>>,
basis: Basis,
order: usize,
) -> Scatter {
let n = centers.len();
let mut vals = DMatrix::from_columns(&vals).transpose();
let n_aug = match order {
0 => n,
1 => n + 1,
2 => n + 1 + centers[0].len(),
_ => unimplemented!("don't yet support higher order polynomials"),
};
if n_aug > n {
vals = vals.resize_vertically(n_aug, 0.0);
}
let mat = DMatrix::from_fn(n_aug, n_aug, |r, c| {
if r < n && c < n {
basis.eval((¢ers[r] - ¢ers[c]).norm())
} else if r < n {
if c == n {
1.0
} else {
centers[r][c - n - 1]
}
} else if c < n {
if r == n {
1.0
} else {
centers[c][r - n - 1]
}
} else {
0.0
}
});
let svd = SVD::new(mat, true, true);
let inv = svd.pseudo_inverse(1e-6).expect("error inverting matrix");
let deltas = (inv * vals).transpose();
Scatter {
basis,
centers,
deltas,
}
}
}