Skip to main content

gmac/morph/
rbf.rs

1use hologram::{
2    kernels::{
3        cubic_kernel, gaussian_kernel, inverse_multi_kernel, linear_kernel,
4        multiquadric_kernel, thin_plate_spline_kernel,
5    },
6    rbf::Rbf,
7    Interpolator,
8};
9
10use crate::error::{Error, Result};
11
12/// A Radial Basis Function (RBF) deformer for 3D point transformations.
13///
14/// This struct implements a deformable model that can smoothly interpolate between
15/// a set of control points in 3D space. It's particularly useful for mesh deformation,
16/// shape morphing, and other spatial transformations.
17///
18/// # Fields
19/// - `x_mean`: Mean of the input control points for normalization
20/// - `x_std`: Standard deviation of the input control points for normalization
21/// - `y_mean`: Mean of the output control points for denormalization
22/// - `y_std`: Standard deviation of the output control points for denormalization
23/// - `removed_columns`: Indices of dimensions with zero variance in the output
24/// - `rbf`: The underlying RBF interpolator
25pub struct RbfDeformer {
26    x_mean: [f64; 3],
27    x_std: [f64; 3],
28    y_mean: [f64; 3],
29    y_std: [f64; 3],
30    removed_columns: Vec<usize>,
31    rbf: Rbf<[f64; 3], [f64; 3]>,
32}
33
34impl RbfDeformer {
35    /// Creates a new RbfDeformer instance.
36    ///
37    /// # Arguments
38    /// * `x` - Input control points (n×3 array)
39    /// * `y` - Corresponding output control points (n×3 array)
40    /// * `kernel_name` - Name of the kernel function to use (optional, defaults to "gaussian"):
41    ///   - "linear": Linear kernel
42    ///   - "cubic": Cubic kernel
43    ///   - "gaussian": Gaussian kernel (default)
44    ///   - "multiquadric": Multiquadric kernel
45    ///   - "inverse_multiquadratic": Inverse multiquadric kernel
46    ///   - "thin_plate_spline": Thin plate spline kernel
47    /// * `epsilon` - Bandwidth parameter for the kernel (optional, defaults to 1.0)
48    ///
49    /// # Returns
50    /// A new `RbfDeformer` instance or an error if creation fails.
51    pub fn new(
52        x: Vec<[f64; 3]>,
53        y: Vec<[f64; 3]>,
54        kernel_name: Option<&str>,
55        epsilon: Option<f64>,
56    ) -> Result<Self> {
57        assert_eq!(x.len(), y.len(), "x and y must have the same length");
58
59        let epsilon = epsilon.unwrap_or(1.0);
60        let kernel: fn(f64, f64) -> f64 = match kernel_name.unwrap_or("gaussian") {
61            "linear" => linear_kernel,
62            "cubic" => cubic_kernel,
63            "gaussian" => gaussian_kernel,
64            "multiquadric" => multiquadric_kernel,
65            "inverse_multiquadratic" => inverse_multi_kernel,
66            "thin_plate_spline" => thin_plate_spline_kernel,
67            other => {
68                return Err(Error::Deformation(format!("Unsupported kernel: {other}")))
69            }
70        };
71
72        let n = x.len();
73
74        // Compute x mean and std
75        let mut x_mean = [0.0; 3];
76        let mut x_std = [1.0; 3];
77        for d in 0..3 {
78            let mean = x.iter().map(|p| p[d]).sum::<f64>() / n as f64;
79            let std =
80                (x.iter().map(|p| (p[d] - mean).powi(2)).sum::<f64>() / n as f64).sqrt();
81            x_mean[d] = mean;
82            x_std[d] = if std < 1e-8 { 1.0 } else { std };
83        }
84
85        let normalized_x: Vec<[f64; 3]> = x
86            .iter()
87            .map(|p| {
88                let mut np = [0.0; 3];
89                for d in 0..3 {
90                    np[d] = (p[d] - x_mean[d]) / x_std[d];
91                }
92                np
93            })
94            .collect();
95
96        // Normalize y and detect constant columns
97        let mut y_mean = [0.0; 3];
98        let mut y_std = [1.0; 3];
99        let mut removed_columns = Vec::new();
100
101        for d in 0..3 {
102            let mean = y.iter().map(|p| p[d]).sum::<f64>() / n as f64;
103            let std =
104                (y.iter().map(|p| (p[d] - mean).powi(2)).sum::<f64>() / n as f64).sqrt();
105            y_mean[d] = mean;
106            if std < 1e-8 {
107                removed_columns.push(d);
108            } else {
109                y_std[d] = std;
110            }
111        }
112
113        let normalized_y: Vec<[f64; 3]> = y
114            .iter()
115            .map(|p| {
116                let mut np = [0.0; 3];
117                for d in 0..3 {
118                    if !removed_columns.contains(&d) {
119                        np[d] = (p[d] - y_mean[d]) / y_std[d];
120                    }
121                }
122                np
123            })
124            .collect();
125
126        let rbf = Rbf::new(normalized_x, normalized_y, Some(kernel), Some(epsilon))
127            .map_err(|e| Error::Deformation(format!("Failed to create RBF: {e}")))?;
128
129        Ok(Self {
130            x_mean,
131            x_std,
132            y_mean,
133            y_std,
134            removed_columns,
135            rbf,
136        })
137    }
138
139    /// Deforms input points using the learned RBF transformation.
140    ///
141    /// # Arguments
142    /// * `points` - A slice of 3D points to transform
143    ///
144    /// # Returns
145    /// A `Vec` of transformed points with the same length as the input, or an error string
146    /// if the transformation fails.
147    ///
148    /// # Example
149    /// ```
150    /// # use gmac::morph::rbf::RbfDeformer;
151    /// # let x = vec![[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]];
152    /// # let y = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
153    /// # let deformer = RbfDeformer::new(x, y, Some("gaussian"), Some(1.0)).unwrap();
154    /// let points = [[0.5, 0.5, 0.5], [0.2, 0.8, 0.4]];
155    /// let deformed = deformer.deform(&points).unwrap();
156    /// assert_eq!(deformed.len(), 2);
157    /// ```
158    pub fn deform(&self, points: &[[f64; 3]]) -> Result<Vec<[f64; 3]>> {
159        let normalized_input: Vec<[f64; 3]> = points
160            .iter()
161            .map(|p| {
162                let mut np = [0.0; 3];
163                for d in 0..3 {
164                    np[d] = (p[d] - self.x_mean[d]) / self.x_std[d];
165                }
166                np
167            })
168            .collect();
169
170        let normalized_output = self
171            .rbf
172            .predict(&normalized_input)
173            .map_err(|e| Error::Deformation(format!("Prediction failed: {e}")))?;
174
175        let mut result = vec![[0.0; 3]; points.len()];
176        for (i, p) in normalized_output.iter().enumerate() {
177            for d in 0..3 {
178                result[i][d] = if self.removed_columns.contains(&d) {
179                    self.y_mean[d]
180                } else {
181                    p[d] * self.y_std[d] + self.y_mean[d]
182                };
183            }
184        }
185
186        Ok(result)
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use approx::assert_relative_eq;
194
195    #[test]
196    fn test_single_point() {
197        let rbf =
198            RbfDeformer::new(vec![[1.0, 2.0, 3.0]], vec![[2.0, 3.0, 4.0]], None, None)
199                .unwrap();
200
201        // Should return exact deformation for training points
202        let result = rbf.deform(&[[1.0, 2.0, 3.0]]).unwrap();
203        assert_eq!(result[0], [2.0, 3.0, 4.0]);
204    }
205
206    #[test]
207    fn test_constant_deformation() {
208        let original = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
209        let deformed = vec![[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]];
210        let rbf = RbfDeformer::new(original, deformed, None, None).unwrap();
211
212        // All points should map to [10.0, 10.0, 10.0]
213        let result = rbf.deform(&[[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]).unwrap();
214        assert_eq!(result, vec![[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]]);
215    }
216
217    #[test]
218    fn test_identity_deformation() {
219        let points = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
220        let rbf = RbfDeformer::new(points.clone(), points.clone(), None, None).unwrap();
221
222        // Should return exact same points
223        let result = rbf.deform(&points).unwrap();
224        for (res, pt) in result.iter().zip(points.iter()) {
225            assert_relative_eq!(res[0], pt[0], epsilon = 1e-10);
226            assert_relative_eq!(res[1], pt[1], epsilon = 1e-10);
227            assert_relative_eq!(res[2], pt[2], epsilon = 1e-10);
228        }
229    }
230
231    #[test]
232    fn test_deform_standard() {
233        let rbf = RbfDeformer::new(
234            vec![[1.0, 2.0, 1.0], [3.0, 4.0, 2.0]],
235            vec![[2.0, 3.0, 2.0], [4.0, 5.0, 3.0]],
236            None,
237            None,
238        )
239        .unwrap();
240
241        let x_new = vec![[1.5, 2.6, 1.8]];
242        let prediction = rbf.deform(&x_new).unwrap();
243
244        // Compare the predicted result with the expected result
245        assert_relative_eq!(prediction[0][0], 2.9073001606088247, epsilon = 1e-10);
246        assert_relative_eq!(prediction[0][1], 3.9073001606088247, epsilon = 1e-10);
247        assert_relative_eq!(prediction[0][2], 2.4536500803044126, epsilon = 1e-10);
248    }
249
250    #[test]
251    fn test_different_kernels() {
252        let points = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
253
254        // Test with each kernel type
255        for kernel in &[
256            "gaussian",
257            "multiquadric",
258            "inverse_multiquadratic",
259            "thin_plate_spline",
260        ] {
261            let rbf =
262                RbfDeformer::new(points.clone(), points.clone(), Some(*kernel), None)
263                    .unwrap();
264
265            let result = rbf.deform(&points).unwrap();
266            for (res, pt) in result.iter().zip(points.iter()) {
267                assert_relative_eq!(res[0], pt[0], epsilon = 1e-10);
268                assert_relative_eq!(res[1], pt[1], epsilon = 1e-10);
269                assert_relative_eq!(res[2], pt[2], epsilon = 1e-10);
270            }
271        }
272    }
273
274    #[test]
275    #[should_panic(expected = "x and y must have the same length")]
276    fn test_mismatched_lengths() {
277        RbfDeformer::new(
278            vec![[1.0, 2.0, 3.0]],
279            vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
280            None,
281            None,
282        )
283        .unwrap();
284    }
285}