Skip to main content

deke_types/
fk.rs

1use glam::{Affine3A, Mat3A, Vec3A};
2
3use crate::{DekeError, SRobotQ};
4
5#[inline(always)]
6fn fast_sin_cos(x: f32) -> (f32, f32) {
7    const FRAC_2_PI: f32 = std::f32::consts::FRAC_2_PI;
8    const PI_2_HI: f32 = 1.570_796_4_f32;
9    const PI_2_LO: f32 = -4.371_139e-8_f32;
10
11    const S1: f32 = -0.166_666_67;
12    const S2: f32 = 0.008_333_294;
13    const S3: f32 = -0.000_198_074_14;
14
15    const C1: f32 = -0.5;
16    const C2: f32 = 0.041_666_52;
17    const C3: f32 = -0.001_388_523_4;
18
19    let q = (x * FRAC_2_PI).round();
20    let qi = q as i32;
21    let r = x - q * PI_2_HI - q * PI_2_LO;
22    let r2 = r * r;
23
24    let sin_r = r * (1.0 + r2 * (S1 + r2 * (S2 + r2 * S3)));
25    let cos_r = 1.0 + r2 * (C1 + r2 * (C2 + r2 * C3));
26
27    let (s, c) = match qi & 3 {
28        0 => (sin_r, cos_r),
29        1 => (cos_r, -sin_r),
30        2 => (-sin_r, -cos_r),
31        3 => (-cos_r, sin_r),
32        _ => unsafe { std::hint::unreachable_unchecked() },
33    };
34
35    (s, c)
36}
37
38pub trait FKChain<const N: usize>: Clone + Send + Sync {
39    type Error: Into<DekeError>;
40    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error>;
41    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error>;
42}
43
44#[inline(always)]
45#[cfg(debug_assertions)]
46fn check_finite<const N: usize>(q: &SRobotQ<N>) -> Result<(), DekeError> {
47    if q.any_non_finite() {
48        return Err(DekeError::JointsNonFinite);
49    }
50    Ok(())
51}
52
53#[inline(always)]
54#[cfg(not(debug_assertions))]
55fn check_finite<const N: usize>(_: &SRobotQ<N>) -> Result<(), std::convert::Infallible> {
56    Ok(())
57}
58
59/// Accumulate a local rotation + translation into the running transform.
60/// Shared by both DH and HP — the only difference is how each convention
61/// builds `local_c0..c2` and `local_t`.
62#[inline(always)]
63fn accumulate(
64    acc_m: &mut Mat3A,
65    acc_t: &mut Vec3A,
66    local_c0: Vec3A,
67    local_c1: Vec3A,
68    local_c2: Vec3A,
69    local_t: Vec3A,
70) {
71    let new_c0 = *acc_m * local_c0;
72    let new_c1 = *acc_m * local_c1;
73    let new_c2 = *acc_m * local_c2;
74    *acc_t = *acc_m * local_t + *acc_t;
75    *acc_m = Mat3A::from_cols(new_c0, new_c1, new_c2);
76}
77
78#[derive(Debug, Clone, Copy)]
79pub struct DHJoint {
80    pub a: f32,
81    pub alpha: f32,
82    pub d: f32,
83    pub theta_offset: f32,
84}
85
86/// Precomputed standard-DH chain with SoA layout.
87///
88/// Convention: `T_i = Rz(θ) · Tz(d) · Tx(a) · Rx(α)`
89#[derive(Debug, Clone)]
90pub struct DHChain<const N: usize> {
91    a: [f32; N],
92    d: [f32; N],
93    sin_alpha: [f32; N],
94    cos_alpha: [f32; N],
95    theta_offset: [f32; N],
96}
97
98impl<const N: usize> DHChain<N> {
99    pub fn new(joints: [DHJoint; N]) -> Self {
100        let mut a = [0.0; N];
101        let mut d = [0.0; N];
102        let mut sin_alpha = [0.0; N];
103        let mut cos_alpha = [0.0; N];
104        let mut theta_offset = [0.0; N];
105
106        let mut i = 0;
107        while i < N {
108            a[i] = joints[i].a;
109            d[i] = joints[i].d;
110            let (sa, ca) = joints[i].alpha.sin_cos();
111            sin_alpha[i] = sa;
112            cos_alpha[i] = ca;
113            theta_offset[i] = joints[i].theta_offset;
114            i += 1;
115        }
116
117        Self {
118            a,
119            d,
120            sin_alpha,
121            cos_alpha,
122            theta_offset,
123        }
124    }
125}
126
127impl<const N: usize> FKChain<N> for DHChain<N> {
128    #[cfg(debug_assertions)]
129    type Error = DekeError;
130    #[cfg(not(debug_assertions))]
131    type Error = std::convert::Infallible;
132
133    /// DH forward kinematics exploiting the structure of `Rz(θ)·Rx(α)`.
134    ///
135    /// The per-joint accumulation decomposes into two 2D column rotations:
136    ///   1. Rotate `(c0, c1)` by θ  →  `(new_c0, perp)`
137    ///   2. Rotate `(perp, c2)` by α  →  `(new_c1, new_c2)`
138    /// Translation reuses `new_c0`:  `t += a·new_c0 + d·old_c2`
139    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
140        check_finite::<N>(q)?;
141        let mut out = [Affine3A::IDENTITY; N];
142        let mut c0 = Vec3A::X;
143        let mut c1 = Vec3A::Y;
144        let mut c2 = Vec3A::Z;
145        let mut t = Vec3A::ZERO;
146
147        let mut i = 0;
148        while i < N {
149            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
150            let sa = self.sin_alpha[i];
151            let ca = self.cos_alpha[i];
152
153            let new_c0 = ct * c0 + st * c1;
154            let perp = ct * c1 - st * c0;
155
156            let new_c1 = ca * perp + sa * c2;
157            let new_c2 = ca * c2 - sa * perp;
158
159            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
160
161            c0 = new_c0;
162            c1 = new_c1;
163            c2 = new_c2;
164
165            out[i] = Affine3A {
166                matrix3: Mat3A::from_cols(c0, c1, c2),
167                translation: t,
168            };
169            i += 1;
170        }
171        Ok(out)
172    }
173
174    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
175        check_finite::<N>(q)?;
176        let mut c0 = Vec3A::X;
177        let mut c1 = Vec3A::Y;
178        let mut c2 = Vec3A::Z;
179        let mut t = Vec3A::ZERO;
180
181        let mut i = 0;
182        while i < N {
183            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
184            let sa = self.sin_alpha[i];
185            let ca = self.cos_alpha[i];
186
187            let new_c0 = ct * c0 + st * c1;
188            let perp = ct * c1 - st * c0;
189
190            let new_c1 = ca * perp + sa * c2;
191            let new_c2 = ca * c2 - sa * perp;
192
193            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
194
195            c0 = new_c0;
196            c1 = new_c1;
197            c2 = new_c2;
198            i += 1;
199        }
200
201        Ok(Affine3A {
202            matrix3: Mat3A::from_cols(c0, c1, c2),
203            translation: t,
204        })
205    }
206}
207
208#[derive(Debug, Clone, Copy)]
209pub struct HPJoint {
210    pub a: f32,
211    pub alpha: f32,
212    pub beta: f32,
213    pub d: f32,
214    pub theta_offset: f32,
215}
216
217/// Precomputed Hayati-Paul chain with SoA layout.
218///
219/// Convention: `T_i = Rz(θ) · Rx(α) · Ry(β) · Tx(a) · Tz(d)`
220///
221/// HP adds a `β` rotation about Y, which makes it numerically stable for
222/// nearly-parallel consecutive joint axes where standard DH is singular.
223#[derive(Debug, Clone)]
224pub struct HPChain<const N: usize> {
225    a: [f32; N],
226    d: [f32; N],
227    sin_alpha: [f32; N],
228    cos_alpha: [f32; N],
229    sin_beta: [f32; N],
230    cos_beta: [f32; N],
231    theta_offset: [f32; N],
232}
233
234impl<const N: usize> HPChain<N> {
235    pub fn new(joints: [HPJoint; N]) -> Self {
236        let mut a = [0.0; N];
237        let mut d = [0.0; N];
238        let mut sin_alpha = [0.0; N];
239        let mut cos_alpha = [0.0; N];
240        let mut sin_beta = [0.0; N];
241        let mut cos_beta = [0.0; N];
242        let mut theta_offset = [0.0; N];
243
244        let mut i = 0;
245        while i < N {
246            a[i] = joints[i].a;
247            d[i] = joints[i].d;
248            let (sa, ca) = joints[i].alpha.sin_cos();
249            sin_alpha[i] = sa;
250            cos_alpha[i] = ca;
251            let (sb, cb) = joints[i].beta.sin_cos();
252            sin_beta[i] = sb;
253            cos_beta[i] = cb;
254            theta_offset[i] = joints[i].theta_offset;
255            i += 1;
256        }
257
258        Self {
259            a,
260            d,
261            sin_alpha,
262            cos_alpha,
263            sin_beta,
264            cos_beta,
265            theta_offset,
266        }
267    }
268
269    /// Build the local rotation columns and translation for joint `i`.
270    ///
271    /// R = Rz(θ) · Rx(α) · Ry(β), then t = R · [a, 0, d].
272    ///
273    /// Rx(α)·Ry(β) columns:
274    ///   col0 = [ cβ,       sα·sβ,     -cα·sβ     ]
275    ///   col1 = [ 0,        cα,          sα        ]
276    ///   col2 = [ sβ,      -sα·cβ,      cα·cβ     ]
277    ///
278    /// Then Rz(θ) rotates each column: [cθ·x - sθ·y, sθ·x + cθ·y, z]
279    ///
280    /// Translation = a·col0 + d·col2  (since R·[a,0,d] = a·col0 + d·col2).
281    #[inline(always)]
282    fn local_frame(&self, i: usize, st: f32, ct: f32) -> (Vec3A, Vec3A, Vec3A, Vec3A) {
283        let sa = self.sin_alpha[i];
284        let ca = self.cos_alpha[i];
285        let sb = self.sin_beta[i];
286        let cb = self.cos_beta[i];
287        let ai = self.a[i];
288        let di = self.d[i];
289
290        let sa_sb = sa * sb;
291        let sa_cb = sa * cb;
292        let ca_sb = ca * sb;
293        let ca_cb = ca * cb;
294
295        let c0 = Vec3A::new(ct * cb - st * sa_sb, st * cb + ct * sa_sb, -ca_sb);
296        let c1 = Vec3A::new(-st * ca, ct * ca, sa);
297        let c2 = Vec3A::new(ct * sb + st * sa_cb, st * sb - ct * sa_cb, ca_cb);
298        let t = Vec3A::new(
299            ai * c0.x + di * c2.x,
300            ai * c0.y + di * c2.y,
301            ai * c0.z + di * c2.z,
302        );
303
304        (c0, c1, c2, t)
305    }
306}
307
308impl<const N: usize> FKChain<N> for HPChain<N> {
309    #[cfg(debug_assertions)]
310    type Error = DekeError;
311    #[cfg(not(debug_assertions))]
312    type Error = std::convert::Infallible;
313
314    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
315        check_finite::<N>(q)?;
316        let mut out = [Affine3A::IDENTITY; N];
317        let mut acc_m = Mat3A::IDENTITY;
318        let mut acc_t = Vec3A::ZERO;
319
320        let mut i = 0;
321        while i < N {
322            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
323            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
324            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
325
326            out[i] = Affine3A {
327                matrix3: acc_m,
328                translation: acc_t,
329            };
330            i += 1;
331        }
332        Ok(out)
333    }
334
335    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
336        check_finite(q)?;
337        let mut acc_m = Mat3A::IDENTITY;
338        let mut acc_t = Vec3A::ZERO;
339
340        let mut i = 0;
341        while i < N {
342            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
343            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
344            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
345            i += 1;
346        }
347
348        Ok(Affine3A {
349            matrix3: acc_m,
350            translation: acc_t,
351        })
352    }
353}
354
355#[derive(Debug, Clone, Copy)]
356pub struct URDFJoint {
357    pub origin_xyz: [f64; 3],
358    pub origin_rpy: [f64; 3],
359    pub axis: [f64; 3],
360}
361
362/// Precomputed per-joint axis type for column-rotation FK.
363#[derive(Debug, Clone, Copy)]
364enum JointAxis {
365    Z,
366    Y(f32),
367    X(f32),
368}
369
370/// FK chain using exact URDF joint transforms.
371///
372/// Accumulation works directly on columns:
373///   1. Translation: `t += fx·c0 + fy·c1 + fz·c2`
374///   2. Fixed rotation: `(c0,c1,c2) = (c0,c1,c2) * fixed_rot`
375///   3. Joint rotation: 2D rotation on the appropriate column pair
376///
377/// When `fixed_rot` is identity (RPY = 0, the common case), step 2 is
378/// skipped entirely, making per-joint cost a single 2D column rotation
379/// plus translation — cheaper than DH.
380#[derive(Debug, Clone)]
381pub struct URDFChain<const N: usize> {
382    fr_c0: [Vec3A; N],
383    fr_c1: [Vec3A; N],
384    fr_c2: [Vec3A; N],
385    fr_identity: [bool; N],
386    fixed_trans: [Vec3A; N],
387    axis: [JointAxis; N],
388}
389
390impl<const N: usize> URDFChain<N> {
391    pub fn new(joints: [URDFJoint; N]) -> Self {
392        let mut fr_c0 = [Vec3A::X; N];
393        let mut fr_c1 = [Vec3A::Y; N];
394        let mut fr_c2 = [Vec3A::Z; N];
395        let mut fr_identity = [true; N];
396        let mut fixed_trans = [Vec3A::ZERO; N];
397        let mut axis = [JointAxis::Z; N];
398
399        for i in 0..N {
400            let [ox, oy, oz] = joints[i].origin_xyz;
401            let [roll, pitch, yaw] = joints[i].origin_rpy;
402
403            let is_identity = roll.abs() < 1e-10 && pitch.abs() < 1e-10 && yaw.abs() < 1e-10;
404            fr_identity[i] = is_identity;
405
406            if !is_identity {
407                let (sr, cr) = roll.sin_cos();
408                let (sp, cp) = pitch.sin_cos();
409                let (sy, cy) = yaw.sin_cos();
410                fr_c0[i] = Vec3A::new((cy * cp) as f32, (sy * cp) as f32, (-sp) as f32);
411                fr_c1[i] = Vec3A::new(
412                    (cy * sp * sr - sy * cr) as f32,
413                    (sy * sp * sr + cy * cr) as f32,
414                    (cp * sr) as f32,
415                );
416                fr_c2[i] = Vec3A::new(
417                    (cy * sp * cr + sy * sr) as f32,
418                    (sy * sp * cr - cy * sr) as f32,
419                    (cp * cr) as f32,
420                );
421            }
422
423            fixed_trans[i] = Vec3A::new(ox as f32, oy as f32, oz as f32);
424
425            let [ax, ay, az] = joints[i].axis;
426            if az.abs() > 0.5 {
427                axis[i] = JointAxis::Z;
428            } else if ay.abs() > 0.5 {
429                axis[i] = JointAxis::Y(ay.signum() as f32);
430            } else {
431                axis[i] = JointAxis::X(ax.signum() as f32);
432            }
433        }
434
435        Self {
436            fr_c0,
437            fr_c1,
438            fr_c2,
439            fr_identity,
440            fixed_trans,
441            axis,
442        }
443    }
444
445    /// Apply fixed rotation + joint rotation to accumulator columns.
446    #[inline(always)]
447    fn accumulate_joint(
448        &self,
449        i: usize,
450        st: f32,
451        ct: f32,
452        c0: &mut Vec3A,
453        c1: &mut Vec3A,
454        c2: &mut Vec3A,
455        t: &mut Vec3A,
456    ) {
457        let ft = self.fixed_trans[i];
458        *t = ft.x * *c0 + ft.y * *c1 + ft.z * *c2 + *t;
459
460        let (f0, f1, f2) = if self.fr_identity[i] {
461            (*c0, *c1, *c2)
462        } else {
463            let fc0 = self.fr_c0[i];
464            let fc1 = self.fr_c1[i];
465            let fc2 = self.fr_c2[i];
466            (
467                fc0.x * *c0 + fc0.y * *c1 + fc0.z * *c2,
468                fc1.x * *c0 + fc1.y * *c1 + fc1.z * *c2,
469                fc2.x * *c0 + fc2.y * *c1 + fc2.z * *c2,
470            )
471        };
472
473        match self.axis[i] {
474            JointAxis::Z => {
475                let new_c0 = ct * f0 + st * f1;
476                let new_c1 = ct * f1 - st * f0;
477                *c0 = new_c0;
478                *c1 = new_c1;
479                *c2 = f2;
480            }
481            JointAxis::Y(s) => {
482                let sst = s * st;
483                let new_c0 = ct * f0 - sst * f2;
484                let new_c2 = sst * f0 + ct * f2;
485                *c0 = new_c0;
486                *c1 = f1;
487                *c2 = new_c2;
488            }
489            JointAxis::X(s) => {
490                let sst = s * st;
491                let new_c1 = ct * f1 + sst * f2;
492                let new_c2 = ct * f2 - sst * f1;
493                *c0 = f0;
494                *c1 = new_c1;
495                *c2 = new_c2;
496            }
497        }
498    }
499}
500
501impl<const N: usize> FKChain<N> for URDFChain<N> {
502    #[cfg(debug_assertions)]
503    type Error = DekeError;
504    #[cfg(not(debug_assertions))]
505    type Error = std::convert::Infallible;
506
507    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
508        check_finite(q)?;
509        let mut out = [Affine3A::IDENTITY; N];
510        let mut c0 = Vec3A::X;
511        let mut c1 = Vec3A::Y;
512        let mut c2 = Vec3A::Z;
513        let mut t = Vec3A::ZERO;
514
515        let mut i = 0;
516        while i < N {
517            let (st, ct) = fast_sin_cos(q.0[i]);
518            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
519
520            out[i] = Affine3A {
521                matrix3: Mat3A::from_cols(c0, c1, c2),
522                translation: t,
523            };
524            i += 1;
525        }
526        Ok(out)
527    }
528
529    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
530        check_finite(q)?;
531        let mut c0 = Vec3A::X;
532        let mut c1 = Vec3A::Y;
533        let mut c2 = Vec3A::Z;
534        let mut t = Vec3A::ZERO;
535
536        let mut i = 0;
537        while i < N {
538            let (st, ct) = fast_sin_cos(q.0[i]);
539            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
540            i += 1;
541        }
542
543        Ok(Affine3A {
544            matrix3: Mat3A::from_cols(c0, c1, c2),
545            translation: t,
546        })
547    }
548}