1use hologram::{
2 kernels::{
3 cubic_kernel, gaussian_kernel, inverse_multi_kernel, linear_kernel,
4 multiquadric_kernel, thin_plate_spline_kernel,
5 },
6 rbf::Rbf,
7 Interpolator,
8};
9
10use crate::error::{Error, Result};
11
12pub struct RbfDeformer {
26 x_mean: [f64; 3],
27 x_std: [f64; 3],
28 y_mean: [f64; 3],
29 y_std: [f64; 3],
30 removed_columns: Vec<usize>,
31 rbf: Rbf<[f64; 3], [f64; 3]>,
32}
33
34impl RbfDeformer {
35 pub fn new(
52 x: Vec<[f64; 3]>,
53 y: Vec<[f64; 3]>,
54 kernel_name: Option<&str>,
55 epsilon: Option<f64>,
56 ) -> Result<Self> {
57 assert_eq!(x.len(), y.len(), "x and y must have the same length");
58
59 let epsilon = epsilon.unwrap_or(1.0);
60 let kernel: fn(f64, f64) -> f64 = match kernel_name.unwrap_or("gaussian") {
61 "linear" => linear_kernel,
62 "cubic" => cubic_kernel,
63 "gaussian" => gaussian_kernel,
64 "multiquadric" => multiquadric_kernel,
65 "inverse_multiquadratic" => inverse_multi_kernel,
66 "thin_plate_spline" => thin_plate_spline_kernel,
67 other => {
68 return Err(Error::Deformation(format!("Unsupported kernel: {other}")))
69 }
70 };
71
72 let n = x.len();
73
74 let mut x_mean = [0.0; 3];
76 let mut x_std = [1.0; 3];
77 for d in 0..3 {
78 let mean = x.iter().map(|p| p[d]).sum::<f64>() / n as f64;
79 let std =
80 (x.iter().map(|p| (p[d] - mean).powi(2)).sum::<f64>() / n as f64).sqrt();
81 x_mean[d] = mean;
82 x_std[d] = if std < 1e-8 { 1.0 } else { std };
83 }
84
85 let normalized_x: Vec<[f64; 3]> = x
86 .iter()
87 .map(|p| {
88 let mut np = [0.0; 3];
89 for d in 0..3 {
90 np[d] = (p[d] - x_mean[d]) / x_std[d];
91 }
92 np
93 })
94 .collect();
95
96 let mut y_mean = [0.0; 3];
98 let mut y_std = [1.0; 3];
99 let mut removed_columns = Vec::new();
100
101 for d in 0..3 {
102 let mean = y.iter().map(|p| p[d]).sum::<f64>() / n as f64;
103 let std =
104 (y.iter().map(|p| (p[d] - mean).powi(2)).sum::<f64>() / n as f64).sqrt();
105 y_mean[d] = mean;
106 if std < 1e-8 {
107 removed_columns.push(d);
108 } else {
109 y_std[d] = std;
110 }
111 }
112
113 let normalized_y: Vec<[f64; 3]> = y
114 .iter()
115 .map(|p| {
116 let mut np = [0.0; 3];
117 for d in 0..3 {
118 if !removed_columns.contains(&d) {
119 np[d] = (p[d] - y_mean[d]) / y_std[d];
120 }
121 }
122 np
123 })
124 .collect();
125
126 let rbf = Rbf::new(normalized_x, normalized_y, Some(kernel), Some(epsilon))
127 .map_err(|e| Error::Deformation(format!("Failed to create RBF: {e}")))?;
128
129 Ok(Self {
130 x_mean,
131 x_std,
132 y_mean,
133 y_std,
134 removed_columns,
135 rbf,
136 })
137 }
138
139 pub fn deform(&self, points: &[[f64; 3]]) -> Result<Vec<[f64; 3]>> {
159 let normalized_input: Vec<[f64; 3]> = points
160 .iter()
161 .map(|p| {
162 let mut np = [0.0; 3];
163 for d in 0..3 {
164 np[d] = (p[d] - self.x_mean[d]) / self.x_std[d];
165 }
166 np
167 })
168 .collect();
169
170 let normalized_output = self
171 .rbf
172 .predict(&normalized_input)
173 .map_err(|e| Error::Deformation(format!("Prediction failed: {e}")))?;
174
175 let mut result = vec![[0.0; 3]; points.len()];
176 for (i, p) in normalized_output.iter().enumerate() {
177 for d in 0..3 {
178 result[i][d] = if self.removed_columns.contains(&d) {
179 self.y_mean[d]
180 } else {
181 p[d] * self.y_std[d] + self.y_mean[d]
182 };
183 }
184 }
185
186 Ok(result)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use approx::assert_relative_eq;
194
195 #[test]
196 fn test_single_point() {
197 let rbf =
198 RbfDeformer::new(vec![[1.0, 2.0, 3.0]], vec![[2.0, 3.0, 4.0]], None, None)
199 .unwrap();
200
201 let result = rbf.deform(&[[1.0, 2.0, 3.0]]).unwrap();
203 assert_eq!(result[0], [2.0, 3.0, 4.0]);
204 }
205
206 #[test]
207 fn test_constant_deformation() {
208 let original = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
209 let deformed = vec![[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]];
210 let rbf = RbfDeformer::new(original, deformed, None, None).unwrap();
211
212 let result = rbf.deform(&[[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]).unwrap();
214 assert_eq!(result, vec![[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]]);
215 }
216
217 #[test]
218 fn test_identity_deformation() {
219 let points = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
220 let rbf = RbfDeformer::new(points.clone(), points.clone(), None, None).unwrap();
221
222 let result = rbf.deform(&points).unwrap();
224 for (res, pt) in result.iter().zip(points.iter()) {
225 assert_relative_eq!(res[0], pt[0], epsilon = 1e-10);
226 assert_relative_eq!(res[1], pt[1], epsilon = 1e-10);
227 assert_relative_eq!(res[2], pt[2], epsilon = 1e-10);
228 }
229 }
230
231 #[test]
232 fn test_deform_standard() {
233 let rbf = RbfDeformer::new(
234 vec![[1.0, 2.0, 1.0], [3.0, 4.0, 2.0]],
235 vec![[2.0, 3.0, 2.0], [4.0, 5.0, 3.0]],
236 None,
237 None,
238 )
239 .unwrap();
240
241 let x_new = vec![[1.5, 2.6, 1.8]];
242 let prediction = rbf.deform(&x_new).unwrap();
243
244 assert_relative_eq!(prediction[0][0], 2.9073001606088247, epsilon = 1e-10);
246 assert_relative_eq!(prediction[0][1], 3.9073001606088247, epsilon = 1e-10);
247 assert_relative_eq!(prediction[0][2], 2.4536500803044126, epsilon = 1e-10);
248 }
249
250 #[test]
251 fn test_different_kernels() {
252 let points = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
253
254 for kernel in &[
256 "gaussian",
257 "multiquadric",
258 "inverse_multiquadratic",
259 "thin_plate_spline",
260 ] {
261 let rbf =
262 RbfDeformer::new(points.clone(), points.clone(), Some(*kernel), None)
263 .unwrap();
264
265 let result = rbf.deform(&points).unwrap();
266 for (res, pt) in result.iter().zip(points.iter()) {
267 assert_relative_eq!(res[0], pt[0], epsilon = 1e-10);
268 assert_relative_eq!(res[1], pt[1], epsilon = 1e-10);
269 assert_relative_eq!(res[2], pt[2], epsilon = 1e-10);
270 }
271 }
272 }
273
274 #[test]
275 #[should_panic(expected = "x and y must have the same length")]
276 fn test_mismatched_lengths() {
277 RbfDeformer::new(
278 vec![[1.0, 2.0, 3.0]],
279 vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
280 None,
281 None,
282 )
283 .unwrap();
284 }
285}