Skip to main content

deke_types/
fk.rs

1use glam::{Affine3A, Mat3A, Vec3A};
2
3use crate::{DekeError, SRobotQ};
4
5#[inline(always)]
6fn fast_sin_cos(x: f32) -> (f32, f32) {
7    const FRAC_2_PI: f32 = std::f32::consts::FRAC_2_PI;
8    const PI_2_HI: f32 = 1.570_796_4_f32;
9    const PI_2_LO: f32 = -4.371_139e-8_f32;
10
11    const S1: f32 = -0.166_666_67;
12    const S2: f32 = 0.008_333_294;
13    const S3: f32 = -0.000_198_074_14;
14
15    const C1: f32 = -0.5;
16    const C2: f32 = 0.041_666_52;
17    const C3: f32 = -0.001_388_523_4;
18
19    let q = (x * FRAC_2_PI).round();
20    let qi = q as i32;
21    let r = x - q * PI_2_HI - q * PI_2_LO;
22    let r2 = r * r;
23
24    let sin_r = r * (1.0 + r2 * (S1 + r2 * (S2 + r2 * S3)));
25    let cos_r = 1.0 + r2 * (C1 + r2 * (C2 + r2 * C3));
26
27    let (s, c) = match qi & 3 {
28        0 => (sin_r, cos_r),
29        1 => (cos_r, -sin_r),
30        2 => (-sin_r, -cos_r),
31        3 => (-cos_r, sin_r),
32        _ => unsafe { std::hint::unreachable_unchecked() },
33    };
34
35    (s, c)
36}
37
38pub trait FKChain<const N: usize>: Clone + Send + Sync {
39    type Error: Into<DekeError>;
40    fn dof(&self) -> usize {
41        N
42    }
43    /// Theoretical maximum reach: sum of link lengths (upper bound, ignores joint limits).
44    fn max_reach(&self) -> Result<f32, Self::Error> {
45        let (_, p, p_ee) = self.joint_axes_positions(&SRobotQ::zeros())?;
46        let mut total = 0.0f32;
47        let mut prev = p[0];
48        for i in 1..N {
49            total += (p[i] - prev).length();
50            prev = p[i];
51        }
52        total += (p_ee - prev).length();
53        Ok(total)
54    }
55
56    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error>;
57    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error>;
58    /// Returns joint rotation axes and axis-origin positions in world frame at
59    /// configuration `q`, plus the end-effector position.
60    fn joint_axes_positions(
61        &self,
62        q: &SRobotQ<N>,
63    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error>;
64
65    /// Geometric Jacobian (6×N) at configuration `q`.
66    /// Rows 0–2: linear velocity, rows 3–5: angular velocity.
67    fn jacobian(&self, q: &SRobotQ<N>) -> Result<[[f32; N]; 6], Self::Error> {
68        let (z, p, p_ee) = self.joint_axes_positions(q)?;
69        let mut j = [[0.0f32; N]; 6];
70        for i in 0..N {
71            let dp = p_ee - p[i];
72            let c = z[i].cross(dp);
73            j[0][i] = c.x;
74            j[1][i] = c.y;
75            j[2][i] = c.z;
76            j[3][i] = z[i].x;
77            j[4][i] = z[i].y;
78            j[5][i] = z[i].z;
79        }
80        Ok(j)
81    }
82
83    /// First time-derivative of the geometric Jacobian.
84    fn jacobian_dot(
85        &self,
86        q: &SRobotQ<N>,
87        qdot: &SRobotQ<N>,
88    ) -> Result<[[f32; N]; 6], Self::Error> {
89        let (z, p, p_ee) = self.joint_axes_positions(q)?;
90
91        let mut omega = Vec3A::ZERO;
92        let mut z_dot = [Vec3A::ZERO; N];
93        let mut p_dot = [Vec3A::ZERO; N];
94        let mut pdot_acc = Vec3A::ZERO;
95
96        for i in 0..N {
97            p_dot[i] = pdot_acc;
98            z_dot[i] = omega.cross(z[i]);
99            omega += qdot.0[i] * z[i];
100            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
101            pdot_acc += omega.cross(next_p - p[i]);
102        }
103        let p_ee_dot = pdot_acc;
104
105        let mut jd = [[0.0f32; N]; 6];
106        for i in 0..N {
107            let dp = p_ee - p[i];
108            let dp_dot = p_ee_dot - p_dot[i];
109            let c1 = z_dot[i].cross(dp);
110            let c2 = z[i].cross(dp_dot);
111            jd[0][i] = c1.x + c2.x;
112            jd[1][i] = c1.y + c2.y;
113            jd[2][i] = c1.z + c2.z;
114            jd[3][i] = z_dot[i].x;
115            jd[4][i] = z_dot[i].y;
116            jd[5][i] = z_dot[i].z;
117        }
118        Ok(jd)
119    }
120
121    /// Second time-derivative of the geometric Jacobian.
122    fn jacobian_ddot(
123        &self,
124        q: &SRobotQ<N>,
125        qdot: &SRobotQ<N>,
126        qddot: &SRobotQ<N>,
127    ) -> Result<[[f32; N]; 6], Self::Error> {
128        let (z, p, p_ee) = self.joint_axes_positions(q)?;
129
130        let mut omega = Vec3A::ZERO;
131        let mut omega_dot = Vec3A::ZERO;
132        let mut z_dot = [Vec3A::ZERO; N];
133        let mut z_ddot = [Vec3A::ZERO; N];
134        let mut p_dot = [Vec3A::ZERO; N];
135        let mut p_ddot = [Vec3A::ZERO; N];
136        let mut pdot_acc = Vec3A::ZERO;
137        let mut pddot_acc = Vec3A::ZERO;
138
139        for i in 0..N {
140            p_dot[i] = pdot_acc;
141            p_ddot[i] = pddot_acc;
142            let zd = omega.cross(z[i]);
143            z_dot[i] = zd;
144            z_ddot[i] = omega_dot.cross(z[i]) + omega.cross(zd);
145            omega_dot += qddot.0[i] * z[i] + qdot.0[i] * zd;
146            omega += qdot.0[i] * z[i];
147            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
148            let delta = next_p - p[i];
149            let delta_dot = omega.cross(delta);
150            pdot_acc += delta_dot;
151            pddot_acc += omega_dot.cross(delta) + omega.cross(delta_dot);
152        }
153        let p_ee_dot = pdot_acc;
154        let p_ee_ddot = pddot_acc;
155
156        let mut jdd = [[0.0f32; N]; 6];
157        for i in 0..N {
158            let dp = p_ee - p[i];
159            let dp_dot = p_ee_dot - p_dot[i];
160            let dp_ddot = p_ee_ddot - p_ddot[i];
161            let c1 = z_ddot[i].cross(dp);
162            let c2 = z_dot[i].cross(dp_dot);
163            let c3 = z[i].cross(dp_ddot);
164            jdd[0][i] = c1.x + 2.0 * c2.x + c3.x;
165            jdd[1][i] = c1.y + 2.0 * c2.y + c3.y;
166            jdd[2][i] = c1.z + 2.0 * c2.z + c3.z;
167            jdd[3][i] = z_ddot[i].x;
168            jdd[4][i] = z_ddot[i].y;
169            jdd[5][i] = z_ddot[i].z;
170        }
171        Ok(jdd)
172    }
173}
174
175#[inline(always)]
176#[cfg(debug_assertions)]
177fn check_finite<const N: usize>(q: &SRobotQ<N>) -> Result<(), DekeError> {
178    if q.any_non_finite() {
179        return Err(DekeError::JointsNonFinite);
180    }
181    Ok(())
182}
183
184#[inline(always)]
185#[cfg(not(debug_assertions))]
186fn check_finite<const N: usize>(_: &SRobotQ<N>) -> Result<(), std::convert::Infallible> {
187    Ok(())
188}
189
190/// Accumulate a local rotation + translation into the running transform.
191/// Shared by both DH and HP — the only difference is how each convention
192/// builds `local_c0..c2` and `local_t`.
193#[inline(always)]
194fn accumulate(
195    acc_m: &mut Mat3A,
196    acc_t: &mut Vec3A,
197    local_c0: Vec3A,
198    local_c1: Vec3A,
199    local_c2: Vec3A,
200    local_t: Vec3A,
201) {
202    let new_c0 = *acc_m * local_c0;
203    let new_c1 = *acc_m * local_c1;
204    let new_c2 = *acc_m * local_c2;
205    *acc_t = *acc_m * local_t + *acc_t;
206    *acc_m = Mat3A::from_cols(new_c0, new_c1, new_c2);
207}
208
209#[derive(Debug, Clone, Copy)]
210pub struct DHJoint {
211    pub a: f32,
212    pub alpha: f32,
213    pub d: f32,
214    pub theta_offset: f32,
215}
216
217/// Precomputed standard-DH chain with SoA layout.
218///
219/// Convention: `T_i = Rz(θ) · Tz(d) · Tx(a) · Rx(α)`
220#[derive(Debug, Clone)]
221pub struct DHChain<const N: usize> {
222    a: [f32; N],
223    d: [f32; N],
224    sin_alpha: [f32; N],
225    cos_alpha: [f32; N],
226    theta_offset: [f32; N],
227}
228
229impl<const N: usize> DHChain<N> {
230    pub fn new(joints: [DHJoint; N]) -> Self {
231        let mut a = [0.0; N];
232        let mut d = [0.0; N];
233        let mut sin_alpha = [0.0; N];
234        let mut cos_alpha = [0.0; N];
235        let mut theta_offset = [0.0; N];
236
237        let mut i = 0;
238        while i < N {
239            a[i] = joints[i].a;
240            d[i] = joints[i].d;
241            let (sa, ca) = joints[i].alpha.sin_cos();
242            sin_alpha[i] = sa;
243            cos_alpha[i] = ca;
244            theta_offset[i] = joints[i].theta_offset;
245            i += 1;
246        }
247
248        Self {
249            a,
250            d,
251            sin_alpha,
252            cos_alpha,
253            theta_offset,
254        }
255    }
256}
257
258impl<const N: usize> FKChain<N> for DHChain<N> {
259    #[cfg(debug_assertions)]
260    type Error = DekeError;
261    #[cfg(not(debug_assertions))]
262    type Error = std::convert::Infallible;
263
264    /// DH forward kinematics exploiting the structure of `Rz(θ)·Rx(α)`.
265    ///
266    /// The per-joint accumulation decomposes into two 2D column rotations:
267    ///   1. Rotate `(c0, c1)` by θ  →  `(new_c0, perp)`
268    ///   2. Rotate `(perp, c2)` by α  →  `(new_c1, new_c2)`
269    /// Translation reuses `new_c0`:  `t += a·new_c0 + d·old_c2`
270    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
271        check_finite::<N>(q)?;
272        let mut out = [Affine3A::IDENTITY; N];
273        let mut c0 = Vec3A::X;
274        let mut c1 = Vec3A::Y;
275        let mut c2 = Vec3A::Z;
276        let mut t = Vec3A::ZERO;
277
278        let mut i = 0;
279        while i < N {
280            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
281            let sa = self.sin_alpha[i];
282            let ca = self.cos_alpha[i];
283
284            let new_c0 = ct * c0 + st * c1;
285            let perp = ct * c1 - st * c0;
286
287            let new_c1 = ca * perp + sa * c2;
288            let new_c2 = ca * c2 - sa * perp;
289
290            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
291
292            c0 = new_c0;
293            c1 = new_c1;
294            c2 = new_c2;
295
296            out[i] = Affine3A {
297                matrix3: Mat3A::from_cols(c0, c1, c2),
298                translation: t,
299            };
300            i += 1;
301        }
302        Ok(out)
303    }
304
305    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
306        check_finite::<N>(q)?;
307        let mut c0 = Vec3A::X;
308        let mut c1 = Vec3A::Y;
309        let mut c2 = Vec3A::Z;
310        let mut t = Vec3A::ZERO;
311
312        let mut i = 0;
313        while i < N {
314            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
315            let sa = self.sin_alpha[i];
316            let ca = self.cos_alpha[i];
317
318            let new_c0 = ct * c0 + st * c1;
319            let perp = ct * c1 - st * c0;
320
321            let new_c1 = ca * perp + sa * c2;
322            let new_c2 = ca * c2 - sa * perp;
323
324            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
325
326            c0 = new_c0;
327            c1 = new_c1;
328            c2 = new_c2;
329            i += 1;
330        }
331
332        Ok(Affine3A {
333            matrix3: Mat3A::from_cols(c0, c1, c2),
334            translation: t,
335        })
336    }
337
338    fn joint_axes_positions(
339        &self,
340        q: &SRobotQ<N>,
341    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
342        let frames = self.fk(q)?;
343        let mut axes = [Vec3A::Z; N];
344        let mut positions = [Vec3A::ZERO; N];
345
346        for i in 1..N {
347            axes[i] = frames[i - 1].matrix3.z_axis;
348            positions[i] = frames[i - 1].translation;
349        }
350
351        Ok((axes, positions, frames[N - 1].translation))
352    }
353}
354
355#[derive(Debug, Clone, Copy)]
356pub struct HPJoint {
357    pub a: f32,
358    pub alpha: f32,
359    pub beta: f32,
360    pub d: f32,
361    pub theta_offset: f32,
362}
363
364/// Precomputed Hayati-Paul chain with SoA layout.
365///
366/// Convention: `T_i = Rz(θ) · Rx(α) · Ry(β) · Tx(a) · Tz(d)`
367///
368/// HP adds a `β` rotation about Y, which makes it numerically stable for
369/// nearly-parallel consecutive joint axes where standard DH is singular.
370#[derive(Debug, Clone)]
371pub struct HPChain<const N: usize> {
372    a: [f32; N],
373    d: [f32; N],
374    sin_alpha: [f32; N],
375    cos_alpha: [f32; N],
376    sin_beta: [f32; N],
377    cos_beta: [f32; N],
378    theta_offset: [f32; N],
379}
380
381impl<const N: usize> HPChain<N> {
382    pub fn new(joints: [HPJoint; N]) -> Self {
383        let mut a = [0.0; N];
384        let mut d = [0.0; N];
385        let mut sin_alpha = [0.0; N];
386        let mut cos_alpha = [0.0; N];
387        let mut sin_beta = [0.0; N];
388        let mut cos_beta = [0.0; N];
389        let mut theta_offset = [0.0; N];
390
391        let mut i = 0;
392        while i < N {
393            a[i] = joints[i].a;
394            d[i] = joints[i].d;
395            let (sa, ca) = joints[i].alpha.sin_cos();
396            sin_alpha[i] = sa;
397            cos_alpha[i] = ca;
398            let (sb, cb) = joints[i].beta.sin_cos();
399            sin_beta[i] = sb;
400            cos_beta[i] = cb;
401            theta_offset[i] = joints[i].theta_offset;
402            i += 1;
403        }
404
405        Self {
406            a,
407            d,
408            sin_alpha,
409            cos_alpha,
410            sin_beta,
411            cos_beta,
412            theta_offset,
413        }
414    }
415
416    /// Build the local rotation columns and translation for joint `i`.
417    ///
418    /// R = Rz(θ) · Rx(α) · Ry(β), then t = R · [a, 0, d].
419    ///
420    /// Rx(α)·Ry(β) columns:
421    ///   col0 = [ cβ,       sα·sβ,     -cα·sβ     ]
422    ///   col1 = [ 0,        cα,          sα        ]
423    ///   col2 = [ sβ,      -sα·cβ,      cα·cβ     ]
424    ///
425    /// Then Rz(θ) rotates each column: [cθ·x - sθ·y, sθ·x + cθ·y, z]
426    ///
427    /// Translation = a·col0 + d·col2  (since R·[a,0,d] = a·col0 + d·col2).
428    #[inline(always)]
429    fn local_frame(&self, i: usize, st: f32, ct: f32) -> (Vec3A, Vec3A, Vec3A, Vec3A) {
430        let sa = self.sin_alpha[i];
431        let ca = self.cos_alpha[i];
432        let sb = self.sin_beta[i];
433        let cb = self.cos_beta[i];
434        let ai = self.a[i];
435        let di = self.d[i];
436
437        let sa_sb = sa * sb;
438        let sa_cb = sa * cb;
439        let ca_sb = ca * sb;
440        let ca_cb = ca * cb;
441
442        let c0 = Vec3A::new(ct * cb - st * sa_sb, st * cb + ct * sa_sb, -ca_sb);
443        let c1 = Vec3A::new(-st * ca, ct * ca, sa);
444        let c2 = Vec3A::new(ct * sb + st * sa_cb, st * sb - ct * sa_cb, ca_cb);
445        let t = Vec3A::new(
446            ai * c0.x + di * c2.x,
447            ai * c0.y + di * c2.y,
448            ai * c0.z + di * c2.z,
449        );
450
451        (c0, c1, c2, t)
452    }
453}
454
455impl<const N: usize> FKChain<N> for HPChain<N> {
456    #[cfg(debug_assertions)]
457    type Error = DekeError;
458    #[cfg(not(debug_assertions))]
459    type Error = std::convert::Infallible;
460
461    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
462        check_finite::<N>(q)?;
463        let mut out = [Affine3A::IDENTITY; N];
464        let mut acc_m = Mat3A::IDENTITY;
465        let mut acc_t = Vec3A::ZERO;
466
467        let mut i = 0;
468        while i < N {
469            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
470            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
471            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
472
473            out[i] = Affine3A {
474                matrix3: acc_m,
475                translation: acc_t,
476            };
477            i += 1;
478        }
479        Ok(out)
480    }
481
482    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
483        check_finite(q)?;
484        let mut acc_m = Mat3A::IDENTITY;
485        let mut acc_t = Vec3A::ZERO;
486
487        let mut i = 0;
488        while i < N {
489            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
490            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
491            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
492            i += 1;
493        }
494
495        Ok(Affine3A {
496            matrix3: acc_m,
497            translation: acc_t,
498        })
499    }
500
501    fn joint_axes_positions(
502        &self,
503        q: &SRobotQ<N>,
504    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
505        let frames = self.fk(q)?;
506        let mut axes = [Vec3A::Z; N];
507        let mut positions = [Vec3A::ZERO; N];
508
509        for i in 1..N {
510            axes[i] = frames[i - 1].matrix3.z_axis;
511            positions[i] = frames[i - 1].translation;
512        }
513
514        Ok((axes, positions, frames[N - 1].translation))
515    }
516}
517
518#[derive(Debug, Clone, Copy)]
519pub struct URDFJoint {
520    pub origin_xyz: [f64; 3],
521    pub origin_rpy: [f64; 3],
522    pub axis: [f64; 3],
523}
524
525/// Precomputed per-joint axis type for column-rotation FK.
526#[derive(Debug, Clone, Copy)]
527enum JointAxis {
528    Z,
529    Y(f32),
530    X(f32),
531}
532
533/// FK chain using exact URDF joint transforms.
534///
535/// Accumulation works directly on columns:
536///   1. Translation: `t += fx·c0 + fy·c1 + fz·c2`
537///   2. Fixed rotation: `(c0,c1,c2) = (c0,c1,c2) * fixed_rot`
538///   3. Joint rotation: 2D rotation on the appropriate column pair
539///
540/// When `fixed_rot` is identity (RPY = 0, the common case), step 2 is
541/// skipped entirely, making per-joint cost a single 2D column rotation
542/// plus translation — cheaper than DH.
543#[derive(Debug, Clone)]
544pub struct URDFChain<const N: usize> {
545    fr_c0: [Vec3A; N],
546    fr_c1: [Vec3A; N],
547    fr_c2: [Vec3A; N],
548    fr_identity: [bool; N],
549    fixed_trans: [Vec3A; N],
550    axis: [JointAxis; N],
551}
552
553impl<const N: usize> URDFChain<N> {
554    pub fn new(joints: [URDFJoint; N]) -> Self {
555        let mut fr_c0 = [Vec3A::X; N];
556        let mut fr_c1 = [Vec3A::Y; N];
557        let mut fr_c2 = [Vec3A::Z; N];
558        let mut fr_identity = [true; N];
559        let mut fixed_trans = [Vec3A::ZERO; N];
560        let mut axis = [JointAxis::Z; N];
561
562        for i in 0..N {
563            let [ox, oy, oz] = joints[i].origin_xyz;
564            let [roll, pitch, yaw] = joints[i].origin_rpy;
565
566            let is_identity = roll.abs() < 1e-10 && pitch.abs() < 1e-10 && yaw.abs() < 1e-10;
567            fr_identity[i] = is_identity;
568
569            if !is_identity {
570                let (sr, cr) = roll.sin_cos();
571                let (sp, cp) = pitch.sin_cos();
572                let (sy, cy) = yaw.sin_cos();
573                fr_c0[i] = Vec3A::new((cy * cp) as f32, (sy * cp) as f32, (-sp) as f32);
574                fr_c1[i] = Vec3A::new(
575                    (cy * sp * sr - sy * cr) as f32,
576                    (sy * sp * sr + cy * cr) as f32,
577                    (cp * sr) as f32,
578                );
579                fr_c2[i] = Vec3A::new(
580                    (cy * sp * cr + sy * sr) as f32,
581                    (sy * sp * cr - cy * sr) as f32,
582                    (cp * cr) as f32,
583                );
584            }
585
586            fixed_trans[i] = Vec3A::new(ox as f32, oy as f32, oz as f32);
587
588            let [ax, ay, az] = joints[i].axis;
589            if az.abs() > 0.5 {
590                axis[i] = JointAxis::Z;
591            } else if ay.abs() > 0.5 {
592                axis[i] = JointAxis::Y(ay.signum() as f32);
593            } else {
594                axis[i] = JointAxis::X(ax.signum() as f32);
595            }
596        }
597
598        Self {
599            fr_c0,
600            fr_c1,
601            fr_c2,
602            fr_identity,
603            fixed_trans,
604            axis,
605        }
606    }
607
608    /// Apply fixed rotation + joint rotation to accumulator columns.
609    #[inline(always)]
610    fn accumulate_joint(
611        &self,
612        i: usize,
613        st: f32,
614        ct: f32,
615        c0: &mut Vec3A,
616        c1: &mut Vec3A,
617        c2: &mut Vec3A,
618        t: &mut Vec3A,
619    ) {
620        let ft = self.fixed_trans[i];
621        *t = ft.x * *c0 + ft.y * *c1 + ft.z * *c2 + *t;
622
623        let (f0, f1, f2) = if self.fr_identity[i] {
624            (*c0, *c1, *c2)
625        } else {
626            let fc0 = self.fr_c0[i];
627            let fc1 = self.fr_c1[i];
628            let fc2 = self.fr_c2[i];
629            (
630                fc0.x * *c0 + fc0.y * *c1 + fc0.z * *c2,
631                fc1.x * *c0 + fc1.y * *c1 + fc1.z * *c2,
632                fc2.x * *c0 + fc2.y * *c1 + fc2.z * *c2,
633            )
634        };
635
636        match self.axis[i] {
637            JointAxis::Z => {
638                let new_c0 = ct * f0 + st * f1;
639                let new_c1 = ct * f1 - st * f0;
640                *c0 = new_c0;
641                *c1 = new_c1;
642                *c2 = f2;
643            }
644            JointAxis::Y(s) => {
645                let sst = s * st;
646                let new_c0 = ct * f0 - sst * f2;
647                let new_c2 = sst * f0 + ct * f2;
648                *c0 = new_c0;
649                *c1 = f1;
650                *c2 = new_c2;
651            }
652            JointAxis::X(s) => {
653                let sst = s * st;
654                let new_c1 = ct * f1 + sst * f2;
655                let new_c2 = ct * f2 - sst * f1;
656                *c0 = f0;
657                *c1 = new_c1;
658                *c2 = new_c2;
659            }
660        }
661    }
662}
663
664impl<const N: usize> FKChain<N> for URDFChain<N> {
665    #[cfg(debug_assertions)]
666    type Error = DekeError;
667    #[cfg(not(debug_assertions))]
668    type Error = std::convert::Infallible;
669
670    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
671        check_finite(q)?;
672        let mut out = [Affine3A::IDENTITY; N];
673        let mut c0 = Vec3A::X;
674        let mut c1 = Vec3A::Y;
675        let mut c2 = Vec3A::Z;
676        let mut t = Vec3A::ZERO;
677
678        let mut i = 0;
679        while i < N {
680            let (st, ct) = fast_sin_cos(q.0[i]);
681            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
682
683            out[i] = Affine3A {
684                matrix3: Mat3A::from_cols(c0, c1, c2),
685                translation: t,
686            };
687            i += 1;
688        }
689        Ok(out)
690    }
691
692    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
693        check_finite(q)?;
694        let mut c0 = Vec3A::X;
695        let mut c1 = Vec3A::Y;
696        let mut c2 = Vec3A::Z;
697        let mut t = Vec3A::ZERO;
698
699        let mut i = 0;
700        while i < N {
701            let (st, ct) = fast_sin_cos(q.0[i]);
702            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
703            i += 1;
704        }
705
706        Ok(Affine3A {
707            matrix3: Mat3A::from_cols(c0, c1, c2),
708            translation: t,
709        })
710    }
711
712    fn joint_axes_positions(
713        &self,
714        q: &SRobotQ<N>,
715    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
716        let frames = self.fk(q)?;
717        let mut axes = [Vec3A::ZERO; N];
718        let mut positions = [Vec3A::ZERO; N];
719
720        for i in 0..N {
721            axes[i] = match self.axis[i] {
722                JointAxis::Z => frames[i].matrix3.z_axis,
723                JointAxis::Y(s) => s * frames[i].matrix3.y_axis,
724                JointAxis::X(s) => s * frames[i].matrix3.x_axis,
725            };
726            positions[i] = frames[i].translation;
727        }
728
729        Ok((axes, positions, frames[N - 1].translation))
730    }
731}
732
733/// Wraps an `FKChain` with an optional prefix (base) and/or suffix (tool) transform.
734///
735/// - `fk` applies only the prefix — intermediate frames stay in world coordinates
736///   without the tool offset.
737/// - `fk_end` and `joint_axes_positions` apply both — the end-effector includes
738///   the tool tip.
739#[derive(Debug, Clone)]
740pub struct TransformedFK<const N: usize, FK: FKChain<N>> {
741    inner: FK,
742    prefix: Option<Affine3A>,
743    suffix: Option<Affine3A>,
744}
745
746impl<const N: usize, FK: FKChain<N>> TransformedFK<N, FK> {
747    pub fn new(inner: FK) -> Self {
748        Self {
749            inner,
750            prefix: None,
751            suffix: None,
752        }
753    }
754
755    pub fn with_prefix(mut self, prefix: Affine3A) -> Self {
756        self.prefix = Some(prefix);
757        self
758    }
759
760    pub fn with_suffix(mut self, suffix: Affine3A) -> Self {
761        self.suffix = Some(suffix);
762        self
763    }
764
765    pub fn set_prefix(&mut self, prefix: Option<Affine3A>) {
766        self.prefix = prefix;
767    }
768
769    pub fn set_suffix(&mut self, suffix: Option<Affine3A>) {
770        self.suffix = suffix;
771    }
772
773    pub fn prefix(&self) -> Option<&Affine3A> {
774        self.prefix.as_ref()
775    }
776
777    pub fn suffix(&self) -> Option<&Affine3A> {
778        self.suffix.as_ref()
779    }
780
781    pub fn inner(&self) -> &FK {
782        &self.inner
783    }
784}
785
786impl<const N: usize, FK: FKChain<N>> FKChain<N> for TransformedFK<N, FK> {
787    type Error = FK::Error;
788
789    fn max_reach(&self) -> Result<f32, Self::Error> {
790        let mut reach = self.inner.max_reach()?;
791        if let Some(suf) = &self.suffix {
792            reach += Vec3A::from(suf.translation).length();
793        }
794        Ok(reach)
795    }
796
797    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
798        let mut frames = self.inner.fk(q)?;
799        if let Some(pre) = &self.prefix {
800            for f in &mut frames {
801                *f = *pre * *f;
802            }
803        }
804        Ok(frames)
805    }
806
807    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
808        let mut end = self.inner.fk_end(q)?;
809        if let Some(pre) = &self.prefix {
810            end = *pre * end;
811        }
812        if let Some(suf) = &self.suffix {
813            end = end * *suf;
814        }
815        Ok(end)
816    }
817
818    fn joint_axes_positions(
819        &self,
820        q: &SRobotQ<N>,
821    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
822        let (mut axes, mut positions, inner_p_ee) = self.inner.joint_axes_positions(q)?;
823
824        if let Some(pre) = &self.prefix {
825            let rot = pre.matrix3;
826            let t = Vec3A::from(pre.translation);
827            for i in 0..N {
828                axes[i] = rot * axes[i];
829                positions[i] = rot * positions[i] + t;
830            }
831        }
832
833        let p_ee = if self.prefix.is_some() || self.suffix.is_some() {
834            self.fk_end(q)?.translation
835        } else {
836            inner_p_ee
837        };
838
839        Ok((axes, positions, p_ee))
840    }
841}
842
843/// Wraps an `FKChain<N>` and prepends a prismatic (linear) joint, producing
844/// an `FKChain<M>` where `M = N + 1`.
845///
846/// The prismatic joint always acts first in the kinematic chain — it
847/// translates the entire arm along `axis` (world frame).  The
848/// `q_index_first` flag only controls where the prismatic value is read
849/// from in `SRobotQ<M>`: when `true` it is `q[0]`, when `false` it is
850/// `q[M-1]`.
851///
852/// Jacobian columns for the prismatic joint are `[axis; 0]` (pure linear,
853/// no angular contribution).  Because the prismatic uniformly shifts all
854/// positions, the revolute Jacobian columns are identical to the inner
855/// chain's.
856#[derive(Debug, Clone)]
857pub struct PrismaticFK<const M: usize, const N: usize, FK: FKChain<N>> {
858    inner: FK,
859    axis: Vec3A,
860    q_index_first: bool,
861}
862
863impl<const M: usize, const N: usize, FK: FKChain<N>> PrismaticFK<M, N, FK> {
864    pub fn new(inner: FK, axis: Vec3A, q_index_first: bool) -> Self {
865        const { assert!(M == N + 1, "M must equal N + 1") };
866        Self {
867            inner,
868            axis,
869            q_index_first,
870        }
871    }
872
873    pub fn inner(&self) -> &FK {
874        &self.inner
875    }
876
877    pub fn axis(&self) -> Vec3A {
878        self.axis
879    }
880
881    pub fn q_index_first(&self) -> bool {
882        self.q_index_first
883    }
884
885    fn split_q(&self, q: &SRobotQ<M>) -> (f32, SRobotQ<N>) {
886        let mut inner = [0.0f32; N];
887        if self.q_index_first {
888            inner.copy_from_slice(&q.0[1..M]);
889            (q.0[0], SRobotQ(inner))
890        } else {
891            inner.copy_from_slice(&q.0[..N]);
892            (q.0[M - 1], SRobotQ(inner))
893        }
894    }
895
896    fn prismatic_col(&self) -> usize {
897        if self.q_index_first { 0 } else { N }
898    }
899
900    fn revolute_offset(&self) -> usize {
901        if self.q_index_first { 1 } else { 0 }
902    }
903}
904
905impl<const M: usize, const N: usize, FK: FKChain<N>> FKChain<M> for PrismaticFK<M, N, FK> {
906    type Error = FK::Error;
907
908    fn fk(&self, q: &SRobotQ<M>) -> Result<[Affine3A; M], Self::Error> {
909        let (q_p, inner_q) = self.split_q(q);
910        let offset = q_p * self.axis;
911        let inner_frames = self.inner.fk(&inner_q)?;
912        let mut out = [Affine3A::IDENTITY; M];
913
914        out[0] = Affine3A::from_translation(offset.into());
915        for i in 0..N {
916            let mut f = inner_frames[i];
917            f.translation += offset;
918            out[i + 1] = f;
919        }
920
921        Ok(out)
922    }
923
924    fn fk_end(&self, q: &SRobotQ<M>) -> Result<Affine3A, Self::Error> {
925        let (q_p, inner_q) = self.split_q(q);
926        let mut end = self.inner.fk_end(&inner_q)?;
927        end.translation += q_p * self.axis;
928        Ok(end)
929    }
930
931    fn joint_axes_positions(
932        &self,
933        q: &SRobotQ<M>,
934    ) -> Result<([Vec3A; M], [Vec3A; M], Vec3A), Self::Error> {
935        let (q_p, inner_q) = self.split_q(q);
936        let offset = q_p * self.axis;
937        let (inner_axes, inner_pos, inner_p_ee) = self.inner.joint_axes_positions(&inner_q)?;
938
939        let mut axes = [Vec3A::ZERO; M];
940        let mut positions = [Vec3A::ZERO; M];
941
942        axes[0] = self.axis;
943        for i in 0..N {
944            axes[i + 1] = inner_axes[i];
945            positions[i + 1] = inner_pos[i] + offset;
946        }
947
948        Ok((axes, positions, inner_p_ee + offset))
949    }
950
951    fn jacobian(&self, q: &SRobotQ<M>) -> Result<[[f32; M]; 6], Self::Error> {
952        let (_q_p, inner_q) = self.split_q(q);
953        let inner_j = self.inner.jacobian(&inner_q)?;
954        let p_col = self.prismatic_col();
955        let r_off = self.revolute_offset();
956
957        let mut j = [[0.0f32; M]; 6];
958        j[0][p_col] = self.axis.x;
959        j[1][p_col] = self.axis.y;
960        j[2][p_col] = self.axis.z;
961
962        for row in 0..6 {
963            for col in 0..N {
964                j[row][col + r_off] = inner_j[row][col];
965            }
966        }
967
968        Ok(j)
969    }
970
971    fn jacobian_dot(
972        &self,
973        q: &SRobotQ<M>,
974        qdot: &SRobotQ<M>,
975    ) -> Result<[[f32; M]; 6], Self::Error> {
976        let (_q_p, inner_q) = self.split_q(q);
977        let (_qdot_p, inner_qdot) = self.split_q(qdot);
978        let inner_jd = self.inner.jacobian_dot(&inner_q, &inner_qdot)?;
979        let r_off = self.revolute_offset();
980
981        let mut jd = [[0.0f32; M]; 6];
982        for row in 0..6 {
983            for col in 0..N {
984                jd[row][col + r_off] = inner_jd[row][col];
985            }
986        }
987
988        Ok(jd)
989    }
990
991    fn jacobian_ddot(
992        &self,
993        q: &SRobotQ<M>,
994        qdot: &SRobotQ<M>,
995        qddot: &SRobotQ<M>,
996    ) -> Result<[[f32; M]; 6], Self::Error> {
997        let (_q_p, inner_q) = self.split_q(q);
998        let (_qdot_p, inner_qdot) = self.split_q(qdot);
999        let (_qddot_p, inner_qddot) = self.split_q(qddot);
1000        let inner_jdd = self
1001            .inner
1002            .jacobian_ddot(&inner_q, &inner_qdot, &inner_qddot)?;
1003        let r_off = self.revolute_offset();
1004
1005        let mut jdd = [[0.0f32; M]; 6];
1006        for row in 0..6 {
1007            for col in 0..N {
1008                jdd[row][col + r_off] = inner_jdd[row][col];
1009            }
1010        }
1011
1012        Ok(jdd)
1013    }
1014}