Skip to main content

oxihuman_morph/
pose_interpolation.rs

1//! Advanced pose interpolation: SQUAD, cubic Hermite, tension-continuity-bias (TCB).
2
3#[allow(dead_code)]
4pub struct PoseKey {
5    pub time: f32,
6    pub pose: Vec<f32>, // flat joint rotations
7    pub in_tangent: Vec<f32>,
8    pub out_tangent: Vec<f32>,
9}
10
11#[allow(dead_code)]
12pub struct PoseCurve {
13    pub keys: Vec<PoseKey>,
14    pub interpolation: InterpMode,
15}
16
17#[allow(dead_code)]
18pub enum InterpMode {
19    Linear,
20    Cubic,
21    Squad,
22    Tcb,
23}
24
25#[allow(dead_code)]
26pub struct TcbParams {
27    pub tension: f32,
28    pub continuity: f32,
29    pub bias: f32,
30}
31
32// ── Core lerp ────────────────────────────────────────────────────────────────
33
34#[allow(dead_code)]
35pub fn lerp_poses(a: &[f32], b: &[f32], t: f32) -> Vec<f32> {
36    let len = a.len().min(b.len());
37    (0..len).map(|i| a[i] + (b[i] - a[i]) * t).collect()
38}
39
40// ── Cubic Hermite ─────────────────────────────────────────────────────────────
41
42#[allow(dead_code)]
43pub fn cubic_hermite_interp(p0: f32, p1: f32, m0: f32, m1: f32, t: f32) -> f32 {
44    let t2 = t * t;
45    let t3 = t2 * t;
46    (2.0 * t3 - 3.0 * t2 + 1.0) * p0
47        + (t3 - 2.0 * t2 + t) * m0
48        + (-2.0 * t3 + 3.0 * t2) * p1
49        + (t3 - t2) * m1
50}
51
52#[allow(dead_code)]
53pub fn cubic_hermite_pose(a: &[f32], b: &[f32], ta: &[f32], tb: &[f32], t: f32) -> Vec<f32> {
54    let len = a.len().min(b.len()).min(ta.len()).min(tb.len());
55    (0..len)
56        .map(|i| cubic_hermite_interp(a[i], b[i], ta[i], tb[i], t))
57        .collect()
58}
59
60// ── Quaternion utilities ───────────────────────────────────────────────────────
61
62#[allow(dead_code)]
63pub fn quat_dot(a: [f32; 4], b: [f32; 4]) -> f32 {
64    a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3]
65}
66
67#[allow(dead_code)]
68pub fn normalize_quat(q: [f32; 4]) -> [f32; 4] {
69    let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
70    if len < 1e-9 {
71        [0.0, 0.0, 0.0, 1.0]
72    } else {
73        [q[0] / len, q[1] / len, q[2] / len, q[3] / len]
74    }
75}
76
77#[allow(dead_code)]
78pub fn quat_slerp_interp(a: [f32; 4], b: [f32; 4], t: f32) -> [f32; 4] {
79    let mut dot = quat_dot(a, b);
80    let bq = if dot < 0.0 {
81        dot = -dot;
82        [-b[0], -b[1], -b[2], -b[3]]
83    } else {
84        b
85    };
86
87    if dot > 0.9995 {
88        let r = [
89            a[0] + t * (bq[0] - a[0]),
90            a[1] + t * (bq[1] - a[1]),
91            a[2] + t * (bq[2] - a[2]),
92            a[3] + t * (bq[3] - a[3]),
93        ];
94        normalize_quat(r)
95    } else {
96        let theta_0 = dot.acos();
97        let theta = theta_0 * t;
98        let sin_theta = theta.sin();
99        let sin_theta_0 = theta_0.sin();
100        let s0 = (theta_0 * (1.0 - t)).sin() / sin_theta_0;
101        let s1 = sin_theta / sin_theta_0;
102        normalize_quat([
103            s0 * a[0] + s1 * bq[0],
104            s0 * a[1] + s1 * bq[1],
105            s0 * a[2] + s1 * bq[2],
106            s0 * a[3] + s1 * bq[3],
107        ])
108    }
109}
110
111#[allow(dead_code)]
112pub fn quat_multiply(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
113    // [x,y,z,w] convention
114    let (ax, ay, az, aw) = (a[0], a[1], a[2], a[3]);
115    let (bx, by, bz, bw) = (b[0], b[1], b[2], b[3]);
116    [
117        aw * bx + ax * bw + ay * bz - az * by,
118        aw * by - ax * bz + ay * bw + az * bx,
119        aw * bz + ax * by - ay * bx + az * bw,
120        aw * bw - ax * bx - ay * by - az * bz,
121    ]
122}
123
124fn quat_conjugate(q: [f32; 4]) -> [f32; 4] {
125    [-q[0], -q[1], -q[2], q[3]]
126}
127
128fn quat_log(q: [f32; 4]) -> [f32; 4] {
129    let nq = normalize_quat(q);
130    let w = nq[3].clamp(-1.0, 1.0);
131    let theta = w.acos();
132    let sin_theta = theta.sin();
133    if sin_theta.abs() < 1e-9 {
134        [0.0, 0.0, 0.0, 0.0]
135    } else {
136        let s = theta / sin_theta;
137        [nq[0] * s, nq[1] * s, nq[2] * s, 0.0]
138    }
139}
140
141fn quat_exp(q: [f32; 4]) -> [f32; 4] {
142    let theta = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2]).sqrt();
143    let sin_theta = theta.sin();
144    let cos_theta = theta.cos();
145    if theta < 1e-9 {
146        normalize_quat([0.0, 0.0, 0.0, cos_theta])
147    } else {
148        let s = sin_theta / theta;
149        normalize_quat([q[0] * s, q[1] * s, q[2] * s, cos_theta])
150    }
151}
152
153// ── SQUAD ─────────────────────────────────────────────────────────────────────
154
155#[allow(dead_code)]
156pub fn squad_intermediate(q_prev: [f32; 4], q_curr: [f32; 4], q_next: [f32; 4]) -> [f32; 4] {
157    // s_i = q_i * exp(-(log(q_i^-1 * q_{i+1}) + log(q_i^-1 * q_{i-1})) / 4)
158    let q_inv = quat_conjugate(q_curr);
159    let log1 = quat_log(quat_multiply(q_inv, q_next));
160    let log2 = quat_log(quat_multiply(q_inv, q_prev));
161    let sum = [
162        -(log1[0] + log2[0]) / 4.0,
163        -(log1[1] + log2[1]) / 4.0,
164        -(log1[2] + log2[2]) / 4.0,
165        -(log1[3] + log2[3]) / 4.0,
166    ];
167    normalize_quat(quat_multiply(q_curr, quat_exp(sum)))
168}
169
170#[allow(dead_code)]
171pub fn squad_quat(q0: [f32; 4], q1: [f32; 4], s0: [f32; 4], s1: [f32; 4], t: f32) -> [f32; 4] {
172    let slerp_q = quat_slerp_interp(q0, q1, t);
173    let slerp_s = quat_slerp_interp(s0, s1, t);
174    quat_slerp_interp(slerp_q, slerp_s, 2.0 * t * (1.0 - t))
175}
176
177// ── TCB tangents ──────────────────────────────────────────────────────────────
178
179#[allow(dead_code)]
180pub fn tcb_tangents(keys: &[PoseKey], idx: usize, params: &TcbParams) -> (Vec<f32>, Vec<f32>) {
181    let n = keys.len();
182    if n == 0 {
183        return (Vec::new(), Vec::new());
184    }
185    let dim = keys[idx].pose.len();
186    let (tc, c, b) = (params.tension, params.continuity, params.bias);
187
188    if n == 1 || idx == 0 {
189        return (Vec::new(), Vec::new());
190    }
191
192    let prev = if idx > 0 { idx - 1 } else { 0 };
193    let next = if idx + 1 < n { idx + 1 } else { idx };
194
195    // TCB incoming tangent: (1-t)(1+c)(1+b)/2 * (p[i]-p[i-1]) + (1-t)(1-c)(1-b)/2 * (p[i+1]-p[i])
196    // TCB outgoing tangent: (1-t)(1+c)(1-b)/2 * (p[i]-p[i-1]) + (1-t)(1-c)(1+b)/2 * (p[i+1]-p[i])
197    let a_in = (1.0 - tc) * (1.0 + c) * (1.0 + b) / 2.0;
198    let b_in = (1.0 - tc) * (1.0 - c) * (1.0 - b) / 2.0;
199    let a_out = (1.0 - tc) * (1.0 + c) * (1.0 - b) / 2.0;
200    let b_out = (1.0 - tc) * (1.0 - c) * (1.0 + b) / 2.0;
201
202    let in_t: Vec<f32> = (0..dim)
203        .map(|d| {
204            let dp = keys[idx].pose[d] - keys[prev].pose[d];
205            let dn = keys[next].pose[d] - keys[idx].pose[d];
206            a_in * dp + b_in * dn
207        })
208        .collect();
209    let out_t: Vec<f32> = (0..dim)
210        .map(|d| {
211            let dp = keys[idx].pose[d] - keys[prev].pose[d];
212            let dn = keys[next].pose[d] - keys[idx].pose[d];
213            a_out * dp + b_out * dn
214        })
215        .collect();
216
217    (in_t, out_t)
218}
219
220// ── Curve operations ──────────────────────────────────────────────────────────
221
222#[allow(dead_code)]
223pub fn sample_pose_curve(curve: &PoseCurve, time: f32) -> Vec<f32> {
224    let keys = &curve.keys;
225    if keys.is_empty() {
226        return Vec::new();
227    }
228    if keys.len() == 1 {
229        return keys[0].pose.clone();
230    }
231
232    // Clamp to range
233    if time <= keys[0].time {
234        return keys[0].pose.clone();
235    }
236    if time >= keys[keys.len() - 1].time {
237        return keys[keys.len() - 1].pose.clone();
238    }
239
240    // Find bracket
241    let idx = keys
242        .windows(2)
243        .position(|w| time >= w[0].time && time < w[1].time)
244        .unwrap_or(0);
245
246    let k0 = &keys[idx];
247    let k1 = &keys[idx + 1];
248    let dt = k1.time - k0.time;
249    let t = if dt.abs() < 1e-9 {
250        0.0
251    } else {
252        (time - k0.time) / dt
253    };
254
255    match curve.interpolation {
256        InterpMode::Linear => lerp_poses(&k0.pose, &k1.pose, t),
257        InterpMode::Cubic => {
258            cubic_hermite_pose(&k0.pose, &k1.pose, &k0.out_tangent, &k1.in_tangent, t)
259        }
260        InterpMode::Squad | InterpMode::Tcb => {
261            cubic_hermite_pose(&k0.pose, &k1.pose, &k0.out_tangent, &k1.in_tangent, t)
262        }
263    }
264}
265
266#[allow(dead_code)]
267pub fn compute_cubic_tangents(keys: &mut [PoseKey]) {
268    let n = keys.len();
269    if n < 2 {
270        return;
271    }
272
273    // Catmull-Rom: compute tangents for each key
274    let poses: Vec<Vec<f32>> = keys.iter().map(|k| k.pose.clone()).collect();
275
276    for i in 0..n {
277        let dim = poses[i].len();
278        let tangent: Vec<f32> = (0..dim)
279            .map(|d| {
280                let prev = if i > 0 { poses[i - 1][d] } else { poses[i][d] };
281                let next = if i + 1 < n {
282                    poses[i + 1][d]
283                } else {
284                    poses[i][d]
285                };
286                0.5 * (next - prev)
287            })
288            .collect();
289        keys[i].in_tangent = tangent.clone();
290        keys[i].out_tangent = tangent;
291    }
292}
293
294#[allow(dead_code)]
295pub fn curve_duration(curve: &PoseCurve) -> f32 {
296    if curve.keys.is_empty() {
297        return 0.0;
298    }
299    let first = curve.keys[0].time;
300    let last = curve.keys[curve.keys.len() - 1].time;
301    last - first
302}
303
304#[allow(dead_code)]
305pub fn add_pose_key(curve: &mut PoseCurve, key: PoseKey) {
306    // Insert in sorted order by time
307    let pos = curve.keys.partition_point(|k| k.time <= key.time);
308    curve.keys.insert(pos, key);
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    fn id_quat() -> [f32; 4] {
316        [0.0, 0.0, 0.0, 1.0]
317    }
318
319    #[test]
320    fn test_lerp_at_t0() {
321        let a = vec![0.0, 1.0, 2.0];
322        let b = vec![1.0, 2.0, 3.0];
323        let r = lerp_poses(&a, &b, 0.0);
324        assert_eq!(r, a);
325    }
326
327    #[test]
328    fn test_lerp_at_t1() {
329        let a = vec![0.0, 1.0, 2.0];
330        let b = vec![1.0, 2.0, 3.0];
331        let r = lerp_poses(&a, &b, 1.0);
332        assert_eq!(r, b);
333    }
334
335    #[test]
336    fn test_lerp_at_half() {
337        let a = vec![0.0, 0.0];
338        let b = vec![2.0, 4.0];
339        let r = lerp_poses(&a, &b, 0.5);
340        assert!((r[0] - 1.0).abs() < 1e-6);
341        assert!((r[1] - 2.0).abs() < 1e-6);
342    }
343
344    #[test]
345    fn test_cubic_hermite_at_t0() {
346        let v = cubic_hermite_interp(1.0, 2.0, 0.0, 0.0, 0.0);
347        assert!((v - 1.0).abs() < 1e-5);
348    }
349
350    #[test]
351    fn test_cubic_hermite_at_t1() {
352        let v = cubic_hermite_interp(1.0, 2.0, 0.0, 0.0, 1.0);
353        assert!((v - 2.0).abs() < 1e-5);
354    }
355
356    #[test]
357    fn test_cubic_hermite_pose_length() {
358        let a = vec![0.0, 0.0, 0.0];
359        let b = vec![1.0, 1.0, 1.0];
360        let ta = vec![0.5, 0.5, 0.5];
361        let tb = vec![0.5, 0.5, 0.5];
362        let r = cubic_hermite_pose(&a, &b, &ta, &tb, 0.5);
363        assert_eq!(r.len(), 3);
364    }
365
366    #[test]
367    fn test_squad_returns_normalized() {
368        let q0 = id_quat();
369        let q1 = [0.0, 0.0, 0.707, 0.707];
370        let s0 = squad_intermediate(q0, q0, q1);
371        let s1 = squad_intermediate(q0, q1, q0);
372        let result = squad_quat(q0, q1, s0, s1, 0.5);
373        let len = result.iter().map(|v| v * v).sum::<f32>().sqrt();
374        assert!((len - 1.0).abs() < 1e-4);
375    }
376
377    #[test]
378    fn test_quat_slerp_at_endpoints() {
379        let a = id_quat();
380        // Use a properly normalized quaternion for b
381        let b = [
382            0.0_f32,
383            0.0,
384            std::f32::consts::FRAC_1_SQRT_2,
385            std::f32::consts::FRAC_1_SQRT_2,
386        ];
387        let r0 = quat_slerp_interp(a, b, 0.0);
388        let r1 = quat_slerp_interp(a, b, 1.0);
389        for i in 0..4 {
390            assert!((r0[i] - a[i]).abs() < 1e-3);
391            assert!((r1[i] - b[i]).abs() < 1e-3);
392        }
393    }
394
395    #[test]
396    fn test_normalize_quat() {
397        let q = [1.0, 0.0, 0.0, 0.0];
398        let n = normalize_quat(q);
399        let len: f32 = n.iter().map(|v| v * v).sum::<f32>().sqrt();
400        assert!((len - 1.0).abs() < 1e-6);
401    }
402
403    #[test]
404    fn test_quat_dot() {
405        let q = id_quat();
406        assert!((quat_dot(q, q) - 1.0).abs() < 1e-6);
407    }
408
409    #[test]
410    fn test_quat_multiply_identity() {
411        let q = [0.0, 0.0, 0.5, 0.866];
412        let id = id_quat();
413        let r = quat_multiply(q, id);
414        for i in 0..4 {
415            assert!((r[i] - q[i]).abs() < 1e-5);
416        }
417    }
418
419    #[test]
420    fn test_sample_curve_before_start() {
421        let key = PoseKey {
422            time: 1.0,
423            pose: vec![1.0, 2.0],
424            in_tangent: vec![0.0, 0.0],
425            out_tangent: vec![0.0, 0.0],
426        };
427        let curve = PoseCurve {
428            keys: vec![key],
429            interpolation: InterpMode::Linear,
430        };
431        let result = sample_pose_curve(&curve, 0.0);
432        assert_eq!(result, vec![1.0, 2.0]);
433    }
434
435    #[test]
436    fn test_sample_curve_linear() {
437        let k0 = PoseKey {
438            time: 0.0,
439            pose: vec![0.0],
440            in_tangent: vec![0.0],
441            out_tangent: vec![0.0],
442        };
443        let k1 = PoseKey {
444            time: 1.0,
445            pose: vec![1.0],
446            in_tangent: vec![0.0],
447            out_tangent: vec![0.0],
448        };
449        let curve = PoseCurve {
450            keys: vec![k0, k1],
451            interpolation: InterpMode::Linear,
452        };
453        let r = sample_pose_curve(&curve, 0.5);
454        assert!((r[0] - 0.5).abs() < 1e-5);
455    }
456
457    #[test]
458    fn test_curve_duration() {
459        let k0 = PoseKey {
460            time: 0.0,
461            pose: vec![0.0],
462            in_tangent: vec![],
463            out_tangent: vec![],
464        };
465        let k1 = PoseKey {
466            time: 2.0,
467            pose: vec![1.0],
468            in_tangent: vec![],
469            out_tangent: vec![],
470        };
471        let curve = PoseCurve {
472            keys: vec![k0, k1],
473            interpolation: InterpMode::Linear,
474        };
475        assert!((curve_duration(&curve) - 2.0).abs() < 1e-6);
476    }
477
478    #[test]
479    fn test_add_pose_key_sorted() {
480        let mut curve = PoseCurve {
481            keys: Vec::new(),
482            interpolation: InterpMode::Linear,
483        };
484        add_pose_key(
485            &mut curve,
486            PoseKey {
487                time: 1.0,
488                pose: vec![1.0],
489                in_tangent: vec![],
490                out_tangent: vec![],
491            },
492        );
493        add_pose_key(
494            &mut curve,
495            PoseKey {
496                time: 0.0,
497                pose: vec![0.0],
498                in_tangent: vec![],
499                out_tangent: vec![],
500            },
501        );
502        assert!(curve.keys[0].time <= curve.keys[1].time);
503    }
504
505    #[test]
506    fn test_tcb_tangents_single_key() {
507        let key = PoseKey {
508            time: 0.0,
509            pose: vec![1.0, 2.0],
510            in_tangent: vec![],
511            out_tangent: vec![],
512        };
513        let params = TcbParams {
514            tension: 0.0,
515            continuity: 0.0,
516            bias: 0.0,
517        };
518        let (inn, out) = tcb_tangents(&[key], 0, &params);
519        assert_eq!(inn.len(), 0);
520        assert_eq!(out.len(), 0);
521    }
522
523    #[test]
524    fn test_compute_cubic_tangents() {
525        let mut keys = vec![
526            PoseKey {
527                time: 0.0,
528                pose: vec![0.0, 0.0],
529                in_tangent: vec![],
530                out_tangent: vec![],
531            },
532            PoseKey {
533                time: 1.0,
534                pose: vec![1.0, 2.0],
535                in_tangent: vec![],
536                out_tangent: vec![],
537            },
538            PoseKey {
539                time: 2.0,
540                pose: vec![0.0, 0.0],
541                in_tangent: vec![],
542                out_tangent: vec![],
543            },
544        ];
545        compute_cubic_tangents(&mut keys);
546        // Middle key should have non-zero tangents
547        assert_eq!(keys[1].in_tangent.len(), 2);
548        assert_eq!(keys[1].out_tangent.len(), 2);
549    }
550}