Skip to main content

oxihuman_morph/
pose_driver.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! RBF-based pose-driven corrective shapes.
5
6#[allow(dead_code)]
7#[derive(Debug, Clone)]
8pub struct PoseDriverSample {
9    pub pose_values: Vec<f32>,
10    pub deltas: Vec<[f32; 3]>,
11    pub weight: f32,
12}
13
14#[allow(dead_code)]
15#[derive(Debug, Clone)]
16pub struct PoseDriverConfig {
17    pub rbf_radius: f32,
18    pub normalize: bool,
19    pub falloff: RbfFalloff,
20}
21
22impl Default for PoseDriverConfig {
23    fn default() -> Self {
24        Self {
25            rbf_radius: 1.0,
26            normalize: true,
27            falloff: RbfFalloff::Gaussian,
28        }
29    }
30}
31
32#[allow(dead_code)]
33#[derive(Debug, Clone, PartialEq)]
34pub enum RbfFalloff {
35    Gaussian,
36    InverseDistance,
37    ThinPlate,
38}
39
40#[allow(dead_code)]
41#[derive(Debug, Clone)]
42pub struct PoseDriver {
43    pub samples: Vec<PoseDriverSample>,
44    pub config: PoseDriverConfig,
45    pub vertex_count: usize,
46}
47
48impl PoseDriver {
49    #[allow(dead_code)]
50    pub fn new(vertex_count: usize, config: PoseDriverConfig) -> Self {
51        Self {
52            samples: Vec::new(),
53            config,
54            vertex_count,
55        }
56    }
57
58    #[allow(dead_code)]
59    pub fn add_sample(&mut self, sample: PoseDriverSample) {
60        self.samples.push(sample);
61    }
62
63    #[allow(dead_code)]
64    pub fn evaluate(&self, pose: &[f32]) -> Vec<[f32; 3]> {
65        let mut result = vec![[0.0f32; 3]; self.vertex_count];
66
67        if self.samples.is_empty() {
68            return result;
69        }
70
71        let radius = self.config.rbf_radius;
72        let mut rbf_weights: Vec<f32> = self
73            .samples
74            .iter()
75            .map(|s| {
76                let dist = pose_distance(pose, &s.pose_values);
77                let rbf = match self.config.falloff {
78                    RbfFalloff::Gaussian => rbf_gaussian(dist, radius),
79                    RbfFalloff::InverseDistance => rbf_inverse_distance(dist, 1e-6),
80                    RbfFalloff::ThinPlate => rbf_thin_plate(dist),
81                };
82                rbf * s.weight
83            })
84            .collect();
85
86        if self.config.normalize {
87            normalize_weights(&mut rbf_weights);
88        }
89
90        for (sample, &w) in self.samples.iter().zip(rbf_weights.iter()) {
91            let deltas = &sample.deltas;
92            let vcount = deltas.len().min(self.vertex_count);
93            for i in 0..vcount {
94                result[i][0] += w * deltas[i][0];
95                result[i][1] += w * deltas[i][1];
96                result[i][2] += w * deltas[i][2];
97            }
98        }
99
100        result
101    }
102}
103
104#[allow(dead_code)]
105pub fn rbf_gaussian(dist: f32, radius: f32) -> f32 {
106    let r = if radius.abs() < 1e-10 { 1e-10 } else { radius };
107    let x = dist / r;
108    (-x * x).exp()
109}
110
111#[allow(dead_code)]
112pub fn rbf_inverse_distance(dist: f32, eps: f32) -> f32 {
113    1.0 / (dist + eps)
114}
115
116#[allow(dead_code)]
117pub fn rbf_thin_plate(dist: f32) -> f32 {
118    dist * dist * (dist + 1e-10_f32).ln()
119}
120
121#[allow(dead_code)]
122pub fn pose_distance(a: &[f32], b: &[f32]) -> f32 {
123    a.iter()
124        .zip(b.iter())
125        .map(|(x, y)| {
126            let d = x - y;
127            d * d
128        })
129        .sum::<f32>()
130        .sqrt()
131}
132
133#[allow(dead_code)]
134pub fn normalize_weights(weights: &mut [f32]) {
135    let sum: f32 = weights.iter().sum();
136    if sum.abs() < 1e-10 {
137        if !weights.is_empty() {
138            let v = 1.0 / weights.len() as f32;
139            for w in weights.iter_mut() {
140                *w = v;
141            }
142        }
143        return;
144    }
145    for w in weights.iter_mut() {
146        *w /= sum;
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_rbf_gaussian_zero_distance() {
156        let v = rbf_gaussian(0.0, 1.0);
157        assert!(
158            (v - 1.0).abs() < 1e-6,
159            "zero dist should return 1.0, got {v}"
160        );
161    }
162
163    #[test]
164    fn test_rbf_gaussian_decay() {
165        let v0 = rbf_gaussian(0.0, 1.0);
166        let v1 = rbf_gaussian(1.0, 1.0);
167        let v2 = rbf_gaussian(2.0, 1.0);
168        assert!(v0 > v1, "should decay with distance");
169        assert!(v1 > v2, "should decay with distance");
170        // at dist=radius, value = exp(-1)
171        let expected = (-1.0_f32).exp();
172        assert!((v1 - expected).abs() < 1e-6);
173    }
174
175    #[test]
176    fn test_rbf_gaussian_large_radius() {
177        // with large radius, decay is slow
178        let v = rbf_gaussian(1.0, 100.0);
179        assert!(v > 0.999, "large radius should give near-1: {v}");
180    }
181
182    #[test]
183    fn test_rbf_inverse_distance() {
184        let v = rbf_inverse_distance(0.0, 1e-6);
185        assert!(v > 1e5, "near zero dist should give large value");
186        let v1 = rbf_inverse_distance(1.0, 1e-6);
187        let v2 = rbf_inverse_distance(2.0, 1e-6);
188        assert!(v1 > v2);
189    }
190
191    #[test]
192    fn test_rbf_thin_plate_zero() {
193        let v = rbf_thin_plate(0.0);
194        // 0 * 0 * ln(1e-10) — should be near 0
195        assert!(v.abs() < 1e-6, "thin plate at zero: {v}");
196    }
197
198    #[test]
199    fn test_rbf_thin_plate_positive() {
200        let v = rbf_thin_plate(2.0);
201        let expected = 4.0_f32 * (2.0_f32 + 1e-10_f32).ln();
202        assert!((v - expected).abs() < 1e-5);
203    }
204
205    #[test]
206    fn test_pose_distance_identical() {
207        let d = pose_distance(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]);
208        assert!(d.abs() < 1e-6);
209    }
210
211    #[test]
212    fn test_pose_distance_basic() {
213        let d = pose_distance(&[0.0, 0.0], &[3.0, 4.0]);
214        assert!((d - 5.0).abs() < 1e-5);
215    }
216
217    #[test]
218    fn test_normalize_weights_sum_one() {
219        let mut w = vec![1.0, 2.0, 3.0, 4.0];
220        normalize_weights(&mut w);
221        let sum: f32 = w.iter().sum();
222        assert!((sum - 1.0).abs() < 1e-6);
223    }
224
225    #[test]
226    fn test_normalize_weights_zero_sum() {
227        let mut w = vec![0.0, 0.0, 0.0];
228        normalize_weights(&mut w);
229        let sum: f32 = w.iter().sum();
230        assert!((sum - 1.0).abs() < 1e-6, "zero-sum should uniform: {sum}");
231    }
232
233    #[test]
234    fn test_single_sample_evaluate_returns_deltas() {
235        let cfg = PoseDriverConfig::default();
236        let mut driver = PoseDriver::new(3, cfg);
237        let deltas = vec![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
238        driver.add_sample(PoseDriverSample {
239            pose_values: vec![0.0],
240            deltas: deltas.clone(),
241            weight: 1.0,
242        });
243        let result = driver.evaluate(&[0.0]);
244        // single sample normalized → weight = 1.0, so result == deltas
245        for i in 0..3 {
246            for j in 0..3 {
247                assert!((result[i][j] - deltas[i][j]).abs() < 1e-5);
248            }
249        }
250    }
251
252    #[test]
253    fn test_multiple_samples_interpolation() {
254        let cfg = PoseDriverConfig {
255            rbf_radius: 1.0,
256            normalize: true,
257            falloff: RbfFalloff::Gaussian,
258        };
259        let mut driver = PoseDriver::new(1, cfg);
260        driver.add_sample(PoseDriverSample {
261            pose_values: vec![0.0],
262            deltas: vec![[1.0, 0.0, 0.0]],
263            weight: 1.0,
264        });
265        driver.add_sample(PoseDriverSample {
266            pose_values: vec![2.0],
267            deltas: vec![[0.0, 1.0, 0.0]],
268            weight: 1.0,
269        });
270        // Query at midpoint
271        let result = driver.evaluate(&[1.0]);
272        // Both have equal distance → equal weights after normalize
273        assert!(
274            (result[0][0] - result[0][1]).abs() < 1e-4,
275            "midpoint interpolation: x={} y={}",
276            result[0][0],
277            result[0][1]
278        );
279    }
280
281    #[test]
282    fn test_evaluate_empty_samples() {
283        let cfg = PoseDriverConfig::default();
284        let driver = PoseDriver::new(2, cfg);
285        let result = driver.evaluate(&[0.0, 1.0]);
286        assert_eq!(result.len(), 2);
287        for v in &result {
288            assert_eq!(*v, [0.0, 0.0, 0.0]);
289        }
290    }
291
292    #[test]
293    fn test_inverse_distance_falloff() {
294        let cfg = PoseDriverConfig {
295            rbf_radius: 1.0,
296            normalize: true,
297            falloff: RbfFalloff::InverseDistance,
298        };
299        let mut driver = PoseDriver::new(1, cfg);
300        driver.add_sample(PoseDriverSample {
301            pose_values: vec![0.0],
302            deltas: vec![[1.0, 0.0, 0.0]],
303            weight: 1.0,
304        });
305        driver.add_sample(PoseDriverSample {
306            pose_values: vec![10.0],
307            deltas: vec![[0.0, 1.0, 0.0]],
308            weight: 1.0,
309        });
310        let result = driver.evaluate(&[0.0]);
311        // Near sample 0 → result[0][0] >> result[0][1]
312        assert!(
313            result[0][0] > result[0][1],
314            "inverse: close sample should dominate"
315        );
316    }
317
318    #[test]
319    fn test_thin_plate_falloff_evaluate() {
320        let cfg = PoseDriverConfig {
321            rbf_radius: 1.0,
322            normalize: true,
323            falloff: RbfFalloff::ThinPlate,
324        };
325        let mut driver = PoseDriver::new(1, cfg);
326        driver.add_sample(PoseDriverSample {
327            pose_values: vec![0.0],
328            deltas: vec![[2.0, 0.0, 0.0]],
329            weight: 1.0,
330        });
331        let result = driver.evaluate(&[0.0]);
332        // single sample → gets all weight
333        assert!(
334            (result[0][0] - 2.0).abs() < 1e-4,
335            "thin plate single: {}",
336            result[0][0]
337        );
338    }
339}