Skip to main content

deke_types/
fk.rs

1use glam::{Affine3A, Mat3A, Vec3A};
2
3use crate::{DekeError, SRobotQ};
4
5#[inline(always)]
6fn fast_sin_cos(x: f32) -> (f32, f32) {
7    const FRAC_2_PI: f32 = std::f32::consts::FRAC_2_PI;
8    const PI_2_HI: f32 = 1.570_796_4_f32;
9    const PI_2_LO: f32 = -4.371_139e-8_f32;
10
11    const S1: f32 = -0.166_666_67;
12    const S2: f32 = 0.008_333_294;
13    const S3: f32 = -0.000_198_074_14;
14
15    const C1: f32 = -0.5;
16    const C2: f32 = 0.041_666_52;
17    const C3: f32 = -0.001_388_523_4;
18
19    let q = (x * FRAC_2_PI).round();
20    let qi = q as i32;
21    let r = x - q * PI_2_HI - q * PI_2_LO;
22    let r2 = r * r;
23
24    let sin_r = r * (1.0 + r2 * (S1 + r2 * (S2 + r2 * S3)));
25    let cos_r = 1.0 + r2 * (C1 + r2 * (C2 + r2 * C3));
26
27    let (s, c) = match qi & 3 {
28        0 => (sin_r, cos_r),
29        1 => (cos_r, -sin_r),
30        2 => (-sin_r, -cos_r),
31        3 => (-cos_r, sin_r),
32        _ => unsafe { std::hint::unreachable_unchecked() },
33    };
34
35    (s, c)
36}
37
38pub trait FKChain<const N: usize>: Clone + Send + Sync {
39    type Error: Into<DekeError>;
40    fn dof(&self) -> usize {
41        N
42    }
43    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error>;
44    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error>;
45    /// Returns joint rotation axes and axis-origin positions in world frame at
46    /// configuration `q`, plus the end-effector position.
47    fn joint_axes_positions(
48        &self,
49        q: &SRobotQ<N>,
50    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error>;
51
52    /// Geometric Jacobian (6×N) at configuration `q`.
53    /// Rows 0–2: linear velocity, rows 3–5: angular velocity.
54    fn jacobian(&self, q: &SRobotQ<N>) -> Result<[[f32; N]; 6], Self::Error> {
55        let (z, p, p_ee) = self.joint_axes_positions(q)?;
56        let mut j = [[0.0f32; N]; 6];
57        for i in 0..N {
58            let dp = p_ee - p[i];
59            let c = z[i].cross(dp);
60            j[0][i] = c.x;
61            j[1][i] = c.y;
62            j[2][i] = c.z;
63            j[3][i] = z[i].x;
64            j[4][i] = z[i].y;
65            j[5][i] = z[i].z;
66        }
67        Ok(j)
68    }
69
70    /// First time-derivative of the geometric Jacobian.
71    fn jacobian_dot(
72        &self,
73        q: &SRobotQ<N>,
74        qdot: &SRobotQ<N>,
75    ) -> Result<[[f32; N]; 6], Self::Error> {
76        let (z, p, p_ee) = self.joint_axes_positions(q)?;
77
78        let mut omega = Vec3A::ZERO;
79        let mut z_dot = [Vec3A::ZERO; N];
80        let mut p_dot = [Vec3A::ZERO; N];
81        let mut pdot_acc = Vec3A::ZERO;
82
83        for i in 0..N {
84            p_dot[i] = pdot_acc;
85            z_dot[i] = omega.cross(z[i]);
86            omega += qdot.0[i] * z[i];
87            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
88            pdot_acc += omega.cross(next_p - p[i]);
89        }
90        let p_ee_dot = pdot_acc;
91
92        let mut jd = [[0.0f32; N]; 6];
93        for i in 0..N {
94            let dp = p_ee - p[i];
95            let dp_dot = p_ee_dot - p_dot[i];
96            let c1 = z_dot[i].cross(dp);
97            let c2 = z[i].cross(dp_dot);
98            jd[0][i] = c1.x + c2.x;
99            jd[1][i] = c1.y + c2.y;
100            jd[2][i] = c1.z + c2.z;
101            jd[3][i] = z_dot[i].x;
102            jd[4][i] = z_dot[i].y;
103            jd[5][i] = z_dot[i].z;
104        }
105        Ok(jd)
106    }
107
108    /// Second time-derivative of the geometric Jacobian.
109    fn jacobian_ddot(
110        &self,
111        q: &SRobotQ<N>,
112        qdot: &SRobotQ<N>,
113        qddot: &SRobotQ<N>,
114    ) -> Result<[[f32; N]; 6], Self::Error> {
115        let (z, p, p_ee) = self.joint_axes_positions(q)?;
116
117        let mut omega = Vec3A::ZERO;
118        let mut omega_dot = Vec3A::ZERO;
119        let mut z_dot = [Vec3A::ZERO; N];
120        let mut z_ddot = [Vec3A::ZERO; N];
121        let mut p_dot = [Vec3A::ZERO; N];
122        let mut p_ddot = [Vec3A::ZERO; N];
123        let mut pdot_acc = Vec3A::ZERO;
124        let mut pddot_acc = Vec3A::ZERO;
125
126        for i in 0..N {
127            p_dot[i] = pdot_acc;
128            p_ddot[i] = pddot_acc;
129            let zd = omega.cross(z[i]);
130            z_dot[i] = zd;
131            z_ddot[i] = omega_dot.cross(z[i]) + omega.cross(zd);
132            omega_dot += qddot.0[i] * z[i] + qdot.0[i] * zd;
133            omega += qdot.0[i] * z[i];
134            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
135            let delta = next_p - p[i];
136            let delta_dot = omega.cross(delta);
137            pdot_acc += delta_dot;
138            pddot_acc += omega_dot.cross(delta) + omega.cross(delta_dot);
139        }
140        let p_ee_dot = pdot_acc;
141        let p_ee_ddot = pddot_acc;
142
143        let mut jdd = [[0.0f32; N]; 6];
144        for i in 0..N {
145            let dp = p_ee - p[i];
146            let dp_dot = p_ee_dot - p_dot[i];
147            let dp_ddot = p_ee_ddot - p_ddot[i];
148            let c1 = z_ddot[i].cross(dp);
149            let c2 = z_dot[i].cross(dp_dot);
150            let c3 = z[i].cross(dp_ddot);
151            jdd[0][i] = c1.x + 2.0 * c2.x + c3.x;
152            jdd[1][i] = c1.y + 2.0 * c2.y + c3.y;
153            jdd[2][i] = c1.z + 2.0 * c2.z + c3.z;
154            jdd[3][i] = z_ddot[i].x;
155            jdd[4][i] = z_ddot[i].y;
156            jdd[5][i] = z_ddot[i].z;
157        }
158        Ok(jdd)
159    }
160}
161
162#[inline(always)]
163#[cfg(debug_assertions)]
164fn check_finite<const N: usize>(q: &SRobotQ<N>) -> Result<(), DekeError> {
165    if q.any_non_finite() {
166        return Err(DekeError::JointsNonFinite);
167    }
168    Ok(())
169}
170
171#[inline(always)]
172#[cfg(not(debug_assertions))]
173fn check_finite<const N: usize>(_: &SRobotQ<N>) -> Result<(), std::convert::Infallible> {
174    Ok(())
175}
176
177/// Accumulate a local rotation + translation into the running transform.
178/// Shared by both DH and HP — the only difference is how each convention
179/// builds `local_c0..c2` and `local_t`.
180#[inline(always)]
181fn accumulate(
182    acc_m: &mut Mat3A,
183    acc_t: &mut Vec3A,
184    local_c0: Vec3A,
185    local_c1: Vec3A,
186    local_c2: Vec3A,
187    local_t: Vec3A,
188) {
189    let new_c0 = *acc_m * local_c0;
190    let new_c1 = *acc_m * local_c1;
191    let new_c2 = *acc_m * local_c2;
192    *acc_t = *acc_m * local_t + *acc_t;
193    *acc_m = Mat3A::from_cols(new_c0, new_c1, new_c2);
194}
195
196#[derive(Debug, Clone, Copy)]
197pub struct DHJoint {
198    pub a: f32,
199    pub alpha: f32,
200    pub d: f32,
201    pub theta_offset: f32,
202}
203
204/// Precomputed standard-DH chain with SoA layout.
205///
206/// Convention: `T_i = Rz(θ) · Tz(d) · Tx(a) · Rx(α)`
207#[derive(Debug, Clone)]
208pub struct DHChain<const N: usize> {
209    a: [f32; N],
210    d: [f32; N],
211    sin_alpha: [f32; N],
212    cos_alpha: [f32; N],
213    theta_offset: [f32; N],
214}
215
216impl<const N: usize> DHChain<N> {
217    pub fn new(joints: [DHJoint; N]) -> Self {
218        let mut a = [0.0; N];
219        let mut d = [0.0; N];
220        let mut sin_alpha = [0.0; N];
221        let mut cos_alpha = [0.0; N];
222        let mut theta_offset = [0.0; N];
223
224        let mut i = 0;
225        while i < N {
226            a[i] = joints[i].a;
227            d[i] = joints[i].d;
228            let (sa, ca) = joints[i].alpha.sin_cos();
229            sin_alpha[i] = sa;
230            cos_alpha[i] = ca;
231            theta_offset[i] = joints[i].theta_offset;
232            i += 1;
233        }
234
235        Self {
236            a,
237            d,
238            sin_alpha,
239            cos_alpha,
240            theta_offset,
241        }
242    }
243}
244
245impl<const N: usize> FKChain<N> for DHChain<N> {
246    #[cfg(debug_assertions)]
247    type Error = DekeError;
248    #[cfg(not(debug_assertions))]
249    type Error = std::convert::Infallible;
250
251    /// DH forward kinematics exploiting the structure of `Rz(θ)·Rx(α)`.
252    ///
253    /// The per-joint accumulation decomposes into two 2D column rotations:
254    ///   1. Rotate `(c0, c1)` by θ  →  `(new_c0, perp)`
255    ///   2. Rotate `(perp, c2)` by α  →  `(new_c1, new_c2)`
256    /// Translation reuses `new_c0`:  `t += a·new_c0 + d·old_c2`
257    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
258        check_finite::<N>(q)?;
259        let mut out = [Affine3A::IDENTITY; N];
260        let mut c0 = Vec3A::X;
261        let mut c1 = Vec3A::Y;
262        let mut c2 = Vec3A::Z;
263        let mut t = Vec3A::ZERO;
264
265        let mut i = 0;
266        while i < N {
267            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
268            let sa = self.sin_alpha[i];
269            let ca = self.cos_alpha[i];
270
271            let new_c0 = ct * c0 + st * c1;
272            let perp = ct * c1 - st * c0;
273
274            let new_c1 = ca * perp + sa * c2;
275            let new_c2 = ca * c2 - sa * perp;
276
277            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
278
279            c0 = new_c0;
280            c1 = new_c1;
281            c2 = new_c2;
282
283            out[i] = Affine3A {
284                matrix3: Mat3A::from_cols(c0, c1, c2),
285                translation: t,
286            };
287            i += 1;
288        }
289        Ok(out)
290    }
291
292    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
293        check_finite::<N>(q)?;
294        let mut c0 = Vec3A::X;
295        let mut c1 = Vec3A::Y;
296        let mut c2 = Vec3A::Z;
297        let mut t = Vec3A::ZERO;
298
299        let mut i = 0;
300        while i < N {
301            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
302            let sa = self.sin_alpha[i];
303            let ca = self.cos_alpha[i];
304
305            let new_c0 = ct * c0 + st * c1;
306            let perp = ct * c1 - st * c0;
307
308            let new_c1 = ca * perp + sa * c2;
309            let new_c2 = ca * c2 - sa * perp;
310
311            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
312
313            c0 = new_c0;
314            c1 = new_c1;
315            c2 = new_c2;
316            i += 1;
317        }
318
319        Ok(Affine3A {
320            matrix3: Mat3A::from_cols(c0, c1, c2),
321            translation: t,
322        })
323    }
324
325    fn joint_axes_positions(
326        &self,
327        q: &SRobotQ<N>,
328    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
329        let frames = self.fk(q)?;
330        let mut axes = [Vec3A::Z; N];
331        let mut positions = [Vec3A::ZERO; N];
332
333        for i in 1..N {
334            axes[i] = frames[i - 1].matrix3.z_axis;
335            positions[i] = frames[i - 1].translation;
336        }
337
338        Ok((axes, positions, frames[N - 1].translation))
339    }
340}
341
342#[derive(Debug, Clone, Copy)]
343pub struct HPJoint {
344    pub a: f32,
345    pub alpha: f32,
346    pub beta: f32,
347    pub d: f32,
348    pub theta_offset: f32,
349}
350
351/// Precomputed Hayati-Paul chain with SoA layout.
352///
353/// Convention: `T_i = Rz(θ) · Rx(α) · Ry(β) · Tx(a) · Tz(d)`
354///
355/// HP adds a `β` rotation about Y, which makes it numerically stable for
356/// nearly-parallel consecutive joint axes where standard DH is singular.
357#[derive(Debug, Clone)]
358pub struct HPChain<const N: usize> {
359    a: [f32; N],
360    d: [f32; N],
361    sin_alpha: [f32; N],
362    cos_alpha: [f32; N],
363    sin_beta: [f32; N],
364    cos_beta: [f32; N],
365    theta_offset: [f32; N],
366}
367
368impl<const N: usize> HPChain<N> {
369    pub fn new(joints: [HPJoint; N]) -> Self {
370        let mut a = [0.0; N];
371        let mut d = [0.0; N];
372        let mut sin_alpha = [0.0; N];
373        let mut cos_alpha = [0.0; N];
374        let mut sin_beta = [0.0; N];
375        let mut cos_beta = [0.0; N];
376        let mut theta_offset = [0.0; N];
377
378        let mut i = 0;
379        while i < N {
380            a[i] = joints[i].a;
381            d[i] = joints[i].d;
382            let (sa, ca) = joints[i].alpha.sin_cos();
383            sin_alpha[i] = sa;
384            cos_alpha[i] = ca;
385            let (sb, cb) = joints[i].beta.sin_cos();
386            sin_beta[i] = sb;
387            cos_beta[i] = cb;
388            theta_offset[i] = joints[i].theta_offset;
389            i += 1;
390        }
391
392        Self {
393            a,
394            d,
395            sin_alpha,
396            cos_alpha,
397            sin_beta,
398            cos_beta,
399            theta_offset,
400        }
401    }
402
403    /// Build the local rotation columns and translation for joint `i`.
404    ///
405    /// R = Rz(θ) · Rx(α) · Ry(β), then t = R · [a, 0, d].
406    ///
407    /// Rx(α)·Ry(β) columns:
408    ///   col0 = [ cβ,       sα·sβ,     -cα·sβ     ]
409    ///   col1 = [ 0,        cα,          sα        ]
410    ///   col2 = [ sβ,      -sα·cβ,      cα·cβ     ]
411    ///
412    /// Then Rz(θ) rotates each column: [cθ·x - sθ·y, sθ·x + cθ·y, z]
413    ///
414    /// Translation = a·col0 + d·col2  (since R·[a,0,d] = a·col0 + d·col2).
415    #[inline(always)]
416    fn local_frame(&self, i: usize, st: f32, ct: f32) -> (Vec3A, Vec3A, Vec3A, Vec3A) {
417        let sa = self.sin_alpha[i];
418        let ca = self.cos_alpha[i];
419        let sb = self.sin_beta[i];
420        let cb = self.cos_beta[i];
421        let ai = self.a[i];
422        let di = self.d[i];
423
424        let sa_sb = sa * sb;
425        let sa_cb = sa * cb;
426        let ca_sb = ca * sb;
427        let ca_cb = ca * cb;
428
429        let c0 = Vec3A::new(ct * cb - st * sa_sb, st * cb + ct * sa_sb, -ca_sb);
430        let c1 = Vec3A::new(-st * ca, ct * ca, sa);
431        let c2 = Vec3A::new(ct * sb + st * sa_cb, st * sb - ct * sa_cb, ca_cb);
432        let t = Vec3A::new(
433            ai * c0.x + di * c2.x,
434            ai * c0.y + di * c2.y,
435            ai * c0.z + di * c2.z,
436        );
437
438        (c0, c1, c2, t)
439    }
440}
441
442impl<const N: usize> FKChain<N> for HPChain<N> {
443    #[cfg(debug_assertions)]
444    type Error = DekeError;
445    #[cfg(not(debug_assertions))]
446    type Error = std::convert::Infallible;
447
448    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
449        check_finite::<N>(q)?;
450        let mut out = [Affine3A::IDENTITY; N];
451        let mut acc_m = Mat3A::IDENTITY;
452        let mut acc_t = Vec3A::ZERO;
453
454        let mut i = 0;
455        while i < N {
456            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
457            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
458            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
459
460            out[i] = Affine3A {
461                matrix3: acc_m,
462                translation: acc_t,
463            };
464            i += 1;
465        }
466        Ok(out)
467    }
468
469    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
470        check_finite(q)?;
471        let mut acc_m = Mat3A::IDENTITY;
472        let mut acc_t = Vec3A::ZERO;
473
474        let mut i = 0;
475        while i < N {
476            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
477            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
478            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
479            i += 1;
480        }
481
482        Ok(Affine3A {
483            matrix3: acc_m,
484            translation: acc_t,
485        })
486    }
487
488    fn joint_axes_positions(
489        &self,
490        q: &SRobotQ<N>,
491    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
492        let frames = self.fk(q)?;
493        let mut axes = [Vec3A::Z; N];
494        let mut positions = [Vec3A::ZERO; N];
495
496        for i in 1..N {
497            axes[i] = frames[i - 1].matrix3.z_axis;
498            positions[i] = frames[i - 1].translation;
499        }
500
501        Ok((axes, positions, frames[N - 1].translation))
502    }
503}
504
505#[derive(Debug, Clone, Copy)]
506pub struct URDFJoint {
507    pub origin_xyz: [f64; 3],
508    pub origin_rpy: [f64; 3],
509    pub axis: [f64; 3],
510}
511
512/// Precomputed per-joint axis type for column-rotation FK.
513#[derive(Debug, Clone, Copy)]
514enum JointAxis {
515    Z,
516    Y(f32),
517    X(f32),
518}
519
520/// FK chain using exact URDF joint transforms.
521///
522/// Accumulation works directly on columns:
523///   1. Translation: `t += fx·c0 + fy·c1 + fz·c2`
524///   2. Fixed rotation: `(c0,c1,c2) = (c0,c1,c2) * fixed_rot`
525///   3. Joint rotation: 2D rotation on the appropriate column pair
526///
527/// When `fixed_rot` is identity (RPY = 0, the common case), step 2 is
528/// skipped entirely, making per-joint cost a single 2D column rotation
529/// plus translation — cheaper than DH.
530#[derive(Debug, Clone)]
531pub struct URDFChain<const N: usize> {
532    fr_c0: [Vec3A; N],
533    fr_c1: [Vec3A; N],
534    fr_c2: [Vec3A; N],
535    fr_identity: [bool; N],
536    fixed_trans: [Vec3A; N],
537    axis: [JointAxis; N],
538}
539
540impl<const N: usize> URDFChain<N> {
541    pub fn new(joints: [URDFJoint; N]) -> Self {
542        let mut fr_c0 = [Vec3A::X; N];
543        let mut fr_c1 = [Vec3A::Y; N];
544        let mut fr_c2 = [Vec3A::Z; N];
545        let mut fr_identity = [true; N];
546        let mut fixed_trans = [Vec3A::ZERO; N];
547        let mut axis = [JointAxis::Z; N];
548
549        for i in 0..N {
550            let [ox, oy, oz] = joints[i].origin_xyz;
551            let [roll, pitch, yaw] = joints[i].origin_rpy;
552
553            let is_identity = roll.abs() < 1e-10 && pitch.abs() < 1e-10 && yaw.abs() < 1e-10;
554            fr_identity[i] = is_identity;
555
556            if !is_identity {
557                let (sr, cr) = roll.sin_cos();
558                let (sp, cp) = pitch.sin_cos();
559                let (sy, cy) = yaw.sin_cos();
560                fr_c0[i] = Vec3A::new((cy * cp) as f32, (sy * cp) as f32, (-sp) as f32);
561                fr_c1[i] = Vec3A::new(
562                    (cy * sp * sr - sy * cr) as f32,
563                    (sy * sp * sr + cy * cr) as f32,
564                    (cp * sr) as f32,
565                );
566                fr_c2[i] = Vec3A::new(
567                    (cy * sp * cr + sy * sr) as f32,
568                    (sy * sp * cr - cy * sr) as f32,
569                    (cp * cr) as f32,
570                );
571            }
572
573            fixed_trans[i] = Vec3A::new(ox as f32, oy as f32, oz as f32);
574
575            let [ax, ay, az] = joints[i].axis;
576            if az.abs() > 0.5 {
577                axis[i] = JointAxis::Z;
578            } else if ay.abs() > 0.5 {
579                axis[i] = JointAxis::Y(ay.signum() as f32);
580            } else {
581                axis[i] = JointAxis::X(ax.signum() as f32);
582            }
583        }
584
585        Self {
586            fr_c0,
587            fr_c1,
588            fr_c2,
589            fr_identity,
590            fixed_trans,
591            axis,
592        }
593    }
594
595    /// Apply fixed rotation + joint rotation to accumulator columns.
596    #[inline(always)]
597    fn accumulate_joint(
598        &self,
599        i: usize,
600        st: f32,
601        ct: f32,
602        c0: &mut Vec3A,
603        c1: &mut Vec3A,
604        c2: &mut Vec3A,
605        t: &mut Vec3A,
606    ) {
607        let ft = self.fixed_trans[i];
608        *t = ft.x * *c0 + ft.y * *c1 + ft.z * *c2 + *t;
609
610        let (f0, f1, f2) = if self.fr_identity[i] {
611            (*c0, *c1, *c2)
612        } else {
613            let fc0 = self.fr_c0[i];
614            let fc1 = self.fr_c1[i];
615            let fc2 = self.fr_c2[i];
616            (
617                fc0.x * *c0 + fc0.y * *c1 + fc0.z * *c2,
618                fc1.x * *c0 + fc1.y * *c1 + fc1.z * *c2,
619                fc2.x * *c0 + fc2.y * *c1 + fc2.z * *c2,
620            )
621        };
622
623        match self.axis[i] {
624            JointAxis::Z => {
625                let new_c0 = ct * f0 + st * f1;
626                let new_c1 = ct * f1 - st * f0;
627                *c0 = new_c0;
628                *c1 = new_c1;
629                *c2 = f2;
630            }
631            JointAxis::Y(s) => {
632                let sst = s * st;
633                let new_c0 = ct * f0 - sst * f2;
634                let new_c2 = sst * f0 + ct * f2;
635                *c0 = new_c0;
636                *c1 = f1;
637                *c2 = new_c2;
638            }
639            JointAxis::X(s) => {
640                let sst = s * st;
641                let new_c1 = ct * f1 + sst * f2;
642                let new_c2 = ct * f2 - sst * f1;
643                *c0 = f0;
644                *c1 = new_c1;
645                *c2 = new_c2;
646            }
647        }
648    }
649}
650
651impl<const N: usize> FKChain<N> for URDFChain<N> {
652    #[cfg(debug_assertions)]
653    type Error = DekeError;
654    #[cfg(not(debug_assertions))]
655    type Error = std::convert::Infallible;
656
657    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
658        check_finite(q)?;
659        let mut out = [Affine3A::IDENTITY; N];
660        let mut c0 = Vec3A::X;
661        let mut c1 = Vec3A::Y;
662        let mut c2 = Vec3A::Z;
663        let mut t = Vec3A::ZERO;
664
665        let mut i = 0;
666        while i < N {
667            let (st, ct) = fast_sin_cos(q.0[i]);
668            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
669
670            out[i] = Affine3A {
671                matrix3: Mat3A::from_cols(c0, c1, c2),
672                translation: t,
673            };
674            i += 1;
675        }
676        Ok(out)
677    }
678
679    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
680        check_finite(q)?;
681        let mut c0 = Vec3A::X;
682        let mut c1 = Vec3A::Y;
683        let mut c2 = Vec3A::Z;
684        let mut t = Vec3A::ZERO;
685
686        let mut i = 0;
687        while i < N {
688            let (st, ct) = fast_sin_cos(q.0[i]);
689            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
690            i += 1;
691        }
692
693        Ok(Affine3A {
694            matrix3: Mat3A::from_cols(c0, c1, c2),
695            translation: t,
696        })
697    }
698
699    fn joint_axes_positions(
700        &self,
701        q: &SRobotQ<N>,
702    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
703        let frames = self.fk(q)?;
704        let mut axes = [Vec3A::ZERO; N];
705        let mut positions = [Vec3A::ZERO; N];
706
707        for i in 0..N {
708            axes[i] = match self.axis[i] {
709                JointAxis::Z => frames[i].matrix3.z_axis,
710                JointAxis::Y(s) => s * frames[i].matrix3.y_axis,
711                JointAxis::X(s) => s * frames[i].matrix3.x_axis,
712            };
713            positions[i] = frames[i].translation;
714        }
715
716        Ok((axes, positions, frames[N - 1].translation))
717    }
718}
719
720/// Wraps an `FKChain` with an optional prefix (base) and/or suffix (tool) transform.
721///
722/// - `fk` applies only the prefix — intermediate frames stay in world coordinates
723///   without the tool offset.
724/// - `fk_end` and `joint_axes_positions` apply both — the end-effector includes
725///   the tool tip.
726#[derive(Debug, Clone)]
727pub struct TransformedFK<const N: usize, FK: FKChain<N>> {
728    inner: FK,
729    prefix: Option<Affine3A>,
730    suffix: Option<Affine3A>,
731}
732
733impl<const N: usize, FK: FKChain<N>> TransformedFK<N, FK> {
734    pub fn new(inner: FK) -> Self {
735        Self {
736            inner,
737            prefix: None,
738            suffix: None,
739        }
740    }
741
742    pub fn with_prefix(mut self, prefix: Affine3A) -> Self {
743        self.prefix = Some(prefix);
744        self
745    }
746
747    pub fn with_suffix(mut self, suffix: Affine3A) -> Self {
748        self.suffix = Some(suffix);
749        self
750    }
751
752    pub fn set_prefix(&mut self, prefix: Option<Affine3A>) {
753        self.prefix = prefix;
754    }
755
756    pub fn set_suffix(&mut self, suffix: Option<Affine3A>) {
757        self.suffix = suffix;
758    }
759
760    pub fn prefix(&self) -> Option<&Affine3A> {
761        self.prefix.as_ref()
762    }
763
764    pub fn suffix(&self) -> Option<&Affine3A> {
765        self.suffix.as_ref()
766    }
767
768    pub fn inner(&self) -> &FK {
769        &self.inner
770    }
771}
772
773impl<const N: usize, FK: FKChain<N>> FKChain<N> for TransformedFK<N, FK> {
774    type Error = FK::Error;
775
776    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
777        let mut frames = self.inner.fk(q)?;
778        if let Some(pre) = &self.prefix {
779            for f in &mut frames {
780                *f = *pre * *f;
781            }
782        }
783        Ok(frames)
784    }
785
786    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
787        let mut end = self.inner.fk_end(q)?;
788        if let Some(pre) = &self.prefix {
789            end = *pre * end;
790        }
791        if let Some(suf) = &self.suffix {
792            end = end * *suf;
793        }
794        Ok(end)
795    }
796
797    fn joint_axes_positions(
798        &self,
799        q: &SRobotQ<N>,
800    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
801        let (mut axes, mut positions, inner_p_ee) = self.inner.joint_axes_positions(q)?;
802
803        if let Some(pre) = &self.prefix {
804            let rot = pre.matrix3;
805            let t = Vec3A::from(pre.translation);
806            for i in 0..N {
807                axes[i] = rot * axes[i];
808                positions[i] = rot * positions[i] + t;
809            }
810        }
811
812        let p_ee = if self.prefix.is_some() || self.suffix.is_some() {
813            self.fk_end(q)?.translation
814        } else {
815            inner_p_ee
816        };
817
818        Ok((axes, positions, p_ee))
819    }
820}
821
822/// Wraps an `FKChain<N>` and prepends a prismatic (linear) joint, producing
823/// an `FKChain<M>` where `M = N + 1`.
824///
825/// The prismatic joint always acts first in the kinematic chain — it
826/// translates the entire arm along `axis` (world frame).  The
827/// `q_index_first` flag only controls where the prismatic value is read
828/// from in `SRobotQ<M>`: when `true` it is `q[0]`, when `false` it is
829/// `q[M-1]`.
830///
831/// Jacobian columns for the prismatic joint are `[axis; 0]` (pure linear,
832/// no angular contribution).  Because the prismatic uniformly shifts all
833/// positions, the revolute Jacobian columns are identical to the inner
834/// chain's.
835#[derive(Debug, Clone)]
836pub struct PrismaticFK<const M: usize, const N: usize, FK: FKChain<N>> {
837    inner: FK,
838    axis: Vec3A,
839    q_index_first: bool,
840}
841
842impl<const M: usize, const N: usize, FK: FKChain<N>> PrismaticFK<M, N, FK> {
843    pub fn new(inner: FK, axis: Vec3A, q_index_first: bool) -> Self {
844        assert!(M == N + 1, "M must equal N + 1");
845        Self {
846            inner,
847            axis,
848            q_index_first,
849        }
850    }
851
852    pub fn inner(&self) -> &FK {
853        &self.inner
854    }
855
856    pub fn axis(&self) -> Vec3A {
857        self.axis
858    }
859
860    pub fn q_index_first(&self) -> bool {
861        self.q_index_first
862    }
863
864    fn split_q(&self, q: &SRobotQ<M>) -> (f32, SRobotQ<N>) {
865        let mut inner = [0.0f32; N];
866        if self.q_index_first {
867            inner.copy_from_slice(&q.0[1..M]);
868            (q.0[0], SRobotQ(inner))
869        } else {
870            inner.copy_from_slice(&q.0[..N]);
871            (q.0[M - 1], SRobotQ(inner))
872        }
873    }
874
875    fn prismatic_col(&self) -> usize {
876        if self.q_index_first { 0 } else { N }
877    }
878
879    fn revolute_offset(&self) -> usize {
880        if self.q_index_first { 1 } else { 0 }
881    }
882}
883
884impl<const M: usize, const N: usize, FK: FKChain<N>> FKChain<M> for PrismaticFK<M, N, FK> {
885    type Error = FK::Error;
886
887    fn fk(&self, q: &SRobotQ<M>) -> Result<[Affine3A; M], Self::Error> {
888        let (q_p, inner_q) = self.split_q(q);
889        let offset = q_p * self.axis;
890        let inner_frames = self.inner.fk(&inner_q)?;
891        let mut out = [Affine3A::IDENTITY; M];
892
893        out[0] = Affine3A::from_translation(offset.into());
894        for i in 0..N {
895            let mut f = inner_frames[i];
896            f.translation += offset;
897            out[i + 1] = f;
898        }
899
900        Ok(out)
901    }
902
903    fn fk_end(&self, q: &SRobotQ<M>) -> Result<Affine3A, Self::Error> {
904        let (q_p, inner_q) = self.split_q(q);
905        let mut end = self.inner.fk_end(&inner_q)?;
906        end.translation += q_p * self.axis;
907        Ok(end)
908    }
909
910    fn joint_axes_positions(
911        &self,
912        q: &SRobotQ<M>,
913    ) -> Result<([Vec3A; M], [Vec3A; M], Vec3A), Self::Error> {
914        let (q_p, inner_q) = self.split_q(q);
915        let offset = q_p * self.axis;
916        let (inner_axes, inner_pos, inner_p_ee) = self.inner.joint_axes_positions(&inner_q)?;
917
918        let mut axes = [Vec3A::ZERO; M];
919        let mut positions = [Vec3A::ZERO; M];
920
921        axes[0] = self.axis;
922        for i in 0..N {
923            axes[i + 1] = inner_axes[i];
924            positions[i + 1] = inner_pos[i] + offset;
925        }
926
927        Ok((axes, positions, inner_p_ee + offset))
928    }
929
930    fn jacobian(&self, q: &SRobotQ<M>) -> Result<[[f32; M]; 6], Self::Error> {
931        let (_q_p, inner_q) = self.split_q(q);
932        let inner_j = self.inner.jacobian(&inner_q)?;
933        let p_col = self.prismatic_col();
934        let r_off = self.revolute_offset();
935
936        let mut j = [[0.0f32; M]; 6];
937        j[0][p_col] = self.axis.x;
938        j[1][p_col] = self.axis.y;
939        j[2][p_col] = self.axis.z;
940
941        for row in 0..6 {
942            for col in 0..N {
943                j[row][col + r_off] = inner_j[row][col];
944            }
945        }
946
947        Ok(j)
948    }
949
950    fn jacobian_dot(
951        &self,
952        q: &SRobotQ<M>,
953        qdot: &SRobotQ<M>,
954    ) -> Result<[[f32; M]; 6], Self::Error> {
955        let (_q_p, inner_q) = self.split_q(q);
956        let (_qdot_p, inner_qdot) = self.split_q(qdot);
957        let inner_jd = self.inner.jacobian_dot(&inner_q, &inner_qdot)?;
958        let r_off = self.revolute_offset();
959
960        let mut jd = [[0.0f32; M]; 6];
961        for row in 0..6 {
962            for col in 0..N {
963                jd[row][col + r_off] = inner_jd[row][col];
964            }
965        }
966
967        Ok(jd)
968    }
969
970    fn jacobian_ddot(
971        &self,
972        q: &SRobotQ<M>,
973        qdot: &SRobotQ<M>,
974        qddot: &SRobotQ<M>,
975    ) -> Result<[[f32; M]; 6], Self::Error> {
976        let (_q_p, inner_q) = self.split_q(q);
977        let (_qdot_p, inner_qdot) = self.split_q(qdot);
978        let (_qddot_p, inner_qddot) = self.split_q(qddot);
979        let inner_jdd = self
980            .inner
981            .jacobian_ddot(&inner_q, &inner_qdot, &inner_qddot)?;
982        let r_off = self.revolute_offset();
983
984        let mut jdd = [[0.0f32; M]; 6];
985        for row in 0..6 {
986            for col in 0..N {
987                jdd[row][col + r_off] = inner_jdd[row][col];
988            }
989        }
990
991        Ok(jdd)
992    }
993}