Skip to main content

gmac/morph/
transforms.rs

1use crate::error::Result;
2
3#[cfg(feature = "rayon")]
4use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
5
6/// Apply an affine transformation to a set of 3D points.
7///
8/// # Arguments
9/// * `points`: A vector of points in 3D space, represented as arrays `[x, y, z]`.
10/// * `affine_weights`: A 4x4 array representing the affine transformation weights.
11///
12/// # Returns
13/// * `Result<Vec<[f64; 3]>>`: A Result containing either:
14/// A vector of transformed points (`Ok`)
15/// An error if something goes wrong (`Err`)
16pub fn apply_affine_transform(
17    points: &[[f64; 3]],
18    affine_weights: &[[f64; 4]; 4],
19) -> Result<Vec<[f64; 3]>> {
20    let transformed_points = {
21        #[cfg(feature = "rayon")]
22        {
23            points.par_iter()
24        }
25        #[cfg(not(feature = "rayon"))]
26        {
27            points.iter()
28        }
29    }
30    .map(|point| {
31        let mut transformed = [0.0; 3];
32        for i in 0..3 {
33            transformed[i] = point[0] * affine_weights[0][i]
34                + point[1] * affine_weights[1][i]
35                + point[2] * affine_weights[2][i]
36                + 1.0 * affine_weights[3][i];
37        }
38        transformed
39    })
40    .collect();
41
42    Ok(transformed_points)
43}
44
45/// Apply a Bernstein transform to a set of 3D points.
46///
47/// # Arguments
48/// * `points`: A vector of points in 3D space, represented as arrays of f64 numbers `[x, y, z]`.
49/// * `deltas`: A vector of delta shifts for each point, also in 3D `[dx, dy, dz]`.
50/// * `resolution`: An array specifying the resolution in each dimension `[res_x, res_y, res_z]`.
51///
52/// # Returns
53///
54/// * `Result<Vec<[f64; 3]>, String>`: A Result containing either:
55/// A vector of transformed points in the same format as the input (`Ok`)
56/// An error message if something goes wrong (`Err`)
57pub fn apply_bernstein_transform(
58    points: &[[f64; 3]],
59    deltas: &[[f64; 3]],
60    resolution: &[usize; 3],
61) -> Result<Vec<[f64; 3]>> {
62    let dimension = [resolution[0] + 1, resolution[1] + 1, resolution[2] + 1];
63
64    // Pre-compute all binomial coefficients once (common logic)
65    let coeffs_x: Vec<f64> = (0..dimension[0])
66        .map(|i| binomial_coefficient(dimension[0] - 1, i))
67        .collect();
68    let coeffs_y: Vec<f64> = (0..dimension[1])
69        .map(|j| binomial_coefficient(dimension[1] - 1, j))
70        .collect();
71    let coeffs_z: Vec<f64> = (0..dimension[2])
72        .map(|k| binomial_coefficient(dimension[2] - 1, k))
73        .collect();
74
75    // Process all points using the appropriate iterator
76    #[cfg(feature = "rayon")]
77    let transformed_points: Vec<[f64; 3]> = points
78        .par_iter()
79        .map(|point| {
80            transform_single_point(
81                point, deltas, &dimension, &coeffs_x, &coeffs_y, &coeffs_z,
82            )
83        })
84        .collect();
85
86    #[cfg(not(feature = "rayon"))]
87    let transformed_points: Vec<[f64; 3]> = points
88        .iter()
89        .map(|point| {
90            transform_single_point(
91                point, deltas, &dimension, &coeffs_x, &coeffs_y, &coeffs_z,
92            )
93        })
94        .collect();
95
96    Ok(transformed_points)
97}
98
99/// Helper function containing the core logic to transform a single point.
100fn transform_single_point(
101    point: &[f64; 3],
102    deltas: &[[f64; 3]],
103    dimension: &[usize; 3],
104    coeffs_x: &[f64],
105    coeffs_y: &[f64],
106    coeffs_z: &[f64],
107) -> [f64; 3] {
108    // Pre-compute 1D Bernstein basis values for this point
109    let bernstein_x: Vec<f64> = (0..dimension[0])
110        .map(|i| {
111            let p = point[0];
112            coeffs_x[i] * (1.0 - p).powi((dimension[0] - 1 - i) as i32) * p.powi(i as i32)
113        })
114        .collect();
115
116    let bernstein_y: Vec<f64> = (0..dimension[1])
117        .map(|j| {
118            let p = point[1];
119            coeffs_y[j] * (1.0 - p).powi((dimension[1] - 1 - j) as i32) * p.powi(j as i32)
120        })
121        .collect();
122
123    let bernstein_z: Vec<f64> = (0..dimension[2])
124        .map(|k| {
125            let p = point[2];
126            coeffs_z[k] * (1.0 - p).powi((dimension[2] - 1 - k) as i32) * p.powi(k as i32)
127        })
128        .collect();
129
130    // Perform the summation using the pre-computed values
131    let mut aux_shift = [0.0; 3];
132    for i in 0..dimension[0] {
133        for j in 0..dimension[1] {
134            for k in 0..dimension[2] {
135                let bernstein_prod = bernstein_x[i] * bernstein_y[j] * bernstein_z[k];
136                let delta_id = i * dimension[1] * dimension[2] + j * dimension[2] + k;
137                let delta = deltas[delta_id];
138
139                aux_shift[0] += bernstein_prod * delta[0];
140                aux_shift[1] += bernstein_prod * delta[1];
141                aux_shift[2] += bernstein_prod * delta[2];
142            }
143        }
144    }
145
146    // Add the final shift to the original point
147    [
148        point[0] + aux_shift[0],
149        point[1] + aux_shift[1],
150        point[2] + aux_shift[2],
151    ]
152}
153
154/// Compute the binomial coefficient "n choose k".
155///
156/// # Arguments
157/// * `n`: The total number of items.
158/// * `k`: The number of items to choose.
159///
160/// # Returns
161/// * `f64`: The computed binomial coefficient.
162fn binomial_coefficient(n: usize, k: usize) -> f64 {
163    let mut coeff = 1.0;
164    for i in 0..k {
165        coeff *= (n - i) as f64 / (k - i) as f64;
166    }
167    coeff
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    const EPSILON: f64 = 1e-9;
174
175    #[test]
176    fn test_affine_identity() {
177        let points = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
178        let identity_matrix = [
179            [1.0, 0.0, 0.0, 0.0],
180            [0.0, 1.0, 0.0, 0.0],
181            [0.0, 0.0, 1.0, 0.0],
182            [0.0, 0.0, 0.0, 1.0],
183        ];
184        let transformed = apply_affine_transform(&points, &identity_matrix).unwrap();
185        assert_eq!(
186            points, transformed,
187            "Identity matrix should not change points"
188        );
189    }
190
191    #[test]
192    fn test_affine_translation() {
193        let points = vec![[1.0, 2.0, 3.0]];
194        let translation_matrix = [
195            [1.0, 0.0, 0.0, 0.0],
196            [0.0, 1.0, 0.0, 0.0],
197            [0.0, 0.0, 1.0, 0.0],
198            [10.0, -5.0, 2.0, 1.0], // Translate by (10, -5, 2)
199        ];
200        let transformed = apply_affine_transform(&points, &translation_matrix).unwrap();
201        let expected = vec![[11.0, -3.0, 5.0]];
202
203        for i in 0..3 {
204            assert!(
205                (transformed[0][i] - expected[0][i]).abs() < EPSILON,
206                "Translation failed at index {}",
207                i
208            );
209        }
210    }
211
212    #[test]
213    fn test_bernstein_zero_deltas() {
214        let points = vec![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]];
215        let resolution = [2, 2, 2];
216        let num_deltas = (resolution[0] + 1) * (resolution[1] + 1) * (resolution[2] + 1);
217        let deltas = vec![[0.0, 0.0, 0.0]; num_deltas];
218
219        let transformed =
220            apply_bernstein_transform(&points, &deltas, &resolution).unwrap();
221        assert_eq!(
222            points, transformed,
223            "Zero deltas should result in no change"
224        );
225    }
226
227    #[test]
228    fn test_bernstein_simple_linear() {
229        let resolution = [1, 1, 1];
230        let point = vec![[0.5, 0.5, 0.5]];
231        let num_deltas = 2 * 2 * 2;
232        let deltas = vec![[1.0, 2.0, 3.0]; num_deltas];
233
234        let transformed =
235            apply_bernstein_transform(&point, &deltas, &resolution).unwrap();
236
237        let expected = vec![[0.5 + 1.0, 0.5 + 2.0, 0.5 + 3.0]];
238
239        for i in 0..3 {
240            assert!(
241                (transformed[0][i] - expected[0][i]).abs() < EPSILON,
242                "Linear Bernstein failed at index {}",
243                i
244            );
245        }
246    }
247}