Skip to main content

deke_types/
fk.rs

1use glam::{Affine3A, Mat3A, Vec3A};
2
3use crate::{DekeError, SRobotQ};
4
5#[inline(always)]
6const fn fast_sin_cos(x: f32) -> (f32, f32) {
7    const FRAC_2_PI: f32 = std::f32::consts::FRAC_2_PI;
8    const PI_2_HI: f32 = 1.570_796_4_f32;
9    const PI_2_LO: f32 = -4.371_139e-8_f32;
10
11    const S1: f32 = -0.166_666_67;
12    const S2: f32 = 0.008_333_294;
13    const S3: f32 = -0.000_198_074_14;
14
15    const C1: f32 = -0.5;
16    const C2: f32 = 0.041_666_52;
17    const C3: f32 = -0.001_388_523_4;
18
19    let q = (x * FRAC_2_PI).round();
20    let qi = q as i32;
21    let r = x - q * PI_2_HI - q * PI_2_LO;
22    let r2 = r * r;
23
24    let sin_r = r * (1.0 + r2 * (S1 + r2 * (S2 + r2 * S3)));
25    let cos_r = 1.0 + r2 * (C1 + r2 * (C2 + r2 * C3));
26
27    let (s, c) = match qi & 3 {
28        0 => (sin_r, cos_r),
29        1 => (cos_r, -sin_r),
30        2 => (-sin_r, -cos_r),
31        3 => (-cos_r, sin_r),
32        _ => unsafe { std::hint::unreachable_unchecked() },
33    };
34
35    (s, c)
36}
37
38#[inline]
39const fn const_sqrt(x: f64) -> f64 {
40    if x < 0.0 || x.is_nan() { return f64::NAN; }
41    if x == 0.0 || x == f64::INFINITY { return x; }
42
43    // Initial guess: halve the exponent. For x = m * 2^e,
44    // sqrt(x) ≈ sqrt(m) * 2^(e/2). Extract, halve, reassemble.
45    let bits = x.to_bits();
46    let exp = ((bits >> 52) & 0x7ff) as i64;
47    let new_exp = ((exp - 1023) / 2 + 1023) as u64;
48    let mut guess = f64::from_bits((new_exp << 52) | (bits & 0x000f_ffff_ffff_ffff));
49
50    let mut prev = 0.0;
51    while guess != prev {
52        prev = guess;
53        guess = (guess + x / guess) * 0.5;
54    }
55    guess
56}
57
58pub trait FKChain<const N: usize>: Clone + Send + Sync {
59    type Error: Into<DekeError>;
60    fn dof(&self) -> usize {
61        N
62    }
63    /// Configuration-independent transform from the robot's base frame to the
64    /// world frame. Defaults to identity; wrappers that install a static
65    /// prefix (e.g. [`TransformedFK`] with a prefix set, or [`URDFChain`]
66    /// with fixed leading joints baked in) override this so downstream
67    /// consumers (collision validators, visualizers) can place the static
68    /// base body at the correct pose.
69    fn base_tf(&self) -> Affine3A {
70        Affine3A::IDENTITY
71    }
72    /// Theoretical maximum reach: sum of link lengths (upper bound, ignores joint limits).
73    fn max_reach(&self) -> Result<f32, Self::Error> {
74        let (_, p, p_ee) = self.joint_axes_positions(&SRobotQ::zeros())?;
75        let mut total = 0.0f32;
76        let mut prev = p[0];
77        for i in 1..N {
78            total += (p[i] - prev).length();
79            prev = p[i];
80        }
81        total += (p_ee - prev).length();
82        Ok(total)
83    }
84
85    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error>;
86    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error>;
87    /// Returns joint rotation axes and axis-origin positions in world frame at
88    /// configuration `q`, plus the end-effector position.
89    fn joint_axes_positions(
90        &self,
91        q: &SRobotQ<N>,
92    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error>;
93
94    /// Geometric Jacobian (6×N) at configuration `q`.
95    /// Rows 0–2: linear velocity, rows 3–5: angular velocity.
96    fn jacobian(&self, q: &SRobotQ<N>) -> Result<[[f32; N]; 6], Self::Error> {
97        let (z, p, p_ee) = self.joint_axes_positions(q)?;
98        let mut j = [[0.0f32; N]; 6];
99        for i in 0..N {
100            let dp = p_ee - p[i];
101            let c = z[i].cross(dp);
102            j[0][i] = c.x;
103            j[1][i] = c.y;
104            j[2][i] = c.z;
105            j[3][i] = z[i].x;
106            j[4][i] = z[i].y;
107            j[5][i] = z[i].z;
108        }
109        Ok(j)
110    }
111
112    /// First time-derivative of the geometric Jacobian.
113    fn jacobian_dot(
114        &self,
115        q: &SRobotQ<N>,
116        qdot: &SRobotQ<N>,
117    ) -> Result<[[f32; N]; 6], Self::Error> {
118        let (z, p, p_ee) = self.joint_axes_positions(q)?;
119
120        let mut omega = Vec3A::ZERO;
121        let mut z_dot = [Vec3A::ZERO; N];
122        let mut p_dot = [Vec3A::ZERO; N];
123        let mut pdot_acc = Vec3A::ZERO;
124
125        for i in 0..N {
126            p_dot[i] = pdot_acc;
127            z_dot[i] = omega.cross(z[i]);
128            omega += qdot.0[i] * z[i];
129            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
130            pdot_acc += omega.cross(next_p - p[i]);
131        }
132        let p_ee_dot = pdot_acc;
133
134        let mut jd = [[0.0f32; N]; 6];
135        for i in 0..N {
136            let dp = p_ee - p[i];
137            let dp_dot = p_ee_dot - p_dot[i];
138            let c1 = z_dot[i].cross(dp);
139            let c2 = z[i].cross(dp_dot);
140            jd[0][i] = c1.x + c2.x;
141            jd[1][i] = c1.y + c2.y;
142            jd[2][i] = c1.z + c2.z;
143            jd[3][i] = z_dot[i].x;
144            jd[4][i] = z_dot[i].y;
145            jd[5][i] = z_dot[i].z;
146        }
147        Ok(jd)
148    }
149
150    /// Second time-derivative of the geometric Jacobian.
151    fn jacobian_ddot(
152        &self,
153        q: &SRobotQ<N>,
154        qdot: &SRobotQ<N>,
155        qddot: &SRobotQ<N>,
156    ) -> Result<[[f32; N]; 6], Self::Error> {
157        let (z, p, p_ee) = self.joint_axes_positions(q)?;
158
159        let mut omega = Vec3A::ZERO;
160        let mut omega_dot = Vec3A::ZERO;
161        let mut z_dot = [Vec3A::ZERO; N];
162        let mut z_ddot = [Vec3A::ZERO; N];
163        let mut p_dot = [Vec3A::ZERO; N];
164        let mut p_ddot = [Vec3A::ZERO; N];
165        let mut pdot_acc = Vec3A::ZERO;
166        let mut pddot_acc = Vec3A::ZERO;
167
168        for i in 0..N {
169            p_dot[i] = pdot_acc;
170            p_ddot[i] = pddot_acc;
171            let zd = omega.cross(z[i]);
172            z_dot[i] = zd;
173            z_ddot[i] = omega_dot.cross(z[i]) + omega.cross(zd);
174            omega_dot += qddot.0[i] * z[i] + qdot.0[i] * zd;
175            omega += qdot.0[i] * z[i];
176            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
177            let delta = next_p - p[i];
178            let delta_dot = omega.cross(delta);
179            pdot_acc += delta_dot;
180            pddot_acc += omega_dot.cross(delta) + omega.cross(delta_dot);
181        }
182        let p_ee_dot = pdot_acc;
183        let p_ee_ddot = pddot_acc;
184
185        let mut jdd = [[0.0f32; N]; 6];
186        for i in 0..N {
187            let dp = p_ee - p[i];
188            let dp_dot = p_ee_dot - p_dot[i];
189            let dp_ddot = p_ee_ddot - p_ddot[i];
190            let c1 = z_ddot[i].cross(dp);
191            let c2 = z_dot[i].cross(dp_dot);
192            let c3 = z[i].cross(dp_ddot);
193            jdd[0][i] = c1.x + 2.0 * c2.x + c3.x;
194            jdd[1][i] = c1.y + 2.0 * c2.y + c3.y;
195            jdd[2][i] = c1.z + 2.0 * c2.z + c3.z;
196            jdd[3][i] = z_ddot[i].x;
197            jdd[4][i] = z_ddot[i].y;
198            jdd[5][i] = z_ddot[i].z;
199        }
200        Ok(jdd)
201    }
202}
203
204#[inline(always)]
205#[cfg(debug_assertions)]
206fn check_finite<const N: usize>(q: &SRobotQ<N>) -> Result<(), DekeError> {
207    if q.any_non_finite() {
208        return Err(DekeError::JointsNonFinite);
209    }
210    Ok(())
211}
212
213#[inline(always)]
214#[cfg(not(debug_assertions))]
215fn check_finite<const N: usize>(_: &SRobotQ<N>) -> Result<(), std::convert::Infallible> {
216    Ok(())
217}
218
219#[inline(always)]
220const fn abs_f32(x: f32) -> f32 {
221    if x < 0.0 { -x } else { x }
222}
223
224/// Const-friendly affine transform backed by plain f32 arrays. `glam`'s
225/// `Vec3A`/`Mat3A` types use SIMD and expose components via a non-const
226/// `Deref`, so `const fn` code that needs per-component arithmetic (compose,
227/// identity check) routes through this type and only converts to
228/// `Affine3A` at the boundary.
229#[derive(Debug, Clone, Copy)]
230struct AffineRaw {
231    c0: [f32; 3],
232    c1: [f32; 3],
233    c2: [f32; 3],
234    t: [f32; 3],
235}
236
237impl AffineRaw {
238    const IDENTITY: Self = Self {
239        c0: [1.0, 0.0, 0.0],
240        c1: [0.0, 1.0, 0.0],
241        c2: [0.0, 0.0, 1.0],
242        t: [0.0, 0.0, 0.0],
243    };
244
245    /// `self * other` — applies `other` first, then `self`.
246    #[inline(always)]
247    const fn mul(self, other: Self) -> Self {
248        let nc0 = [
249            self.c0[0] * other.c0[0] + self.c1[0] * other.c0[1] + self.c2[0] * other.c0[2],
250            self.c0[1] * other.c0[0] + self.c1[1] * other.c0[1] + self.c2[1] * other.c0[2],
251            self.c0[2] * other.c0[0] + self.c1[2] * other.c0[1] + self.c2[2] * other.c0[2],
252        ];
253        let nc1 = [
254            self.c0[0] * other.c1[0] + self.c1[0] * other.c1[1] + self.c2[0] * other.c1[2],
255            self.c0[1] * other.c1[0] + self.c1[1] * other.c1[1] + self.c2[1] * other.c1[2],
256            self.c0[2] * other.c1[0] + self.c1[2] * other.c1[1] + self.c2[2] * other.c1[2],
257        ];
258        let nc2 = [
259            self.c0[0] * other.c2[0] + self.c1[0] * other.c2[1] + self.c2[0] * other.c2[2],
260            self.c0[1] * other.c2[0] + self.c1[1] * other.c2[1] + self.c2[1] * other.c2[2],
261            self.c0[2] * other.c2[0] + self.c1[2] * other.c2[1] + self.c2[2] * other.c2[2],
262        ];
263        let nt = [
264            self.c0[0] * other.t[0]
265                + self.c1[0] * other.t[1]
266                + self.c2[0] * other.t[2]
267                + self.t[0],
268            self.c0[1] * other.t[0]
269                + self.c1[1] * other.t[1]
270                + self.c2[1] * other.t[2]
271                + self.t[1],
272            self.c0[2] * other.t[0]
273                + self.c1[2] * other.t[1]
274                + self.c2[2] * other.t[2]
275                + self.t[2],
276        ];
277        Self {
278            c0: nc0,
279            c1: nc1,
280            c2: nc2,
281            t: nt,
282        }
283    }
284
285    #[inline(always)]
286    const fn is_identity(&self) -> bool {
287        const EPS: f32 = 1e-6;
288        abs_f32(self.c0[0] - 1.0) <= EPS
289            && abs_f32(self.c0[1]) <= EPS
290            && abs_f32(self.c0[2]) <= EPS
291            && abs_f32(self.c1[0]) <= EPS
292            && abs_f32(self.c1[1] - 1.0) <= EPS
293            && abs_f32(self.c1[2]) <= EPS
294            && abs_f32(self.c2[0]) <= EPS
295            && abs_f32(self.c2[1]) <= EPS
296            && abs_f32(self.c2[2] - 1.0) <= EPS
297            && abs_f32(self.t[0]) <= EPS
298            && abs_f32(self.t[1]) <= EPS
299            && abs_f32(self.t[2]) <= EPS
300    }
301
302    /// Build the URDF RPY-convention rotation (`Rz(yaw)·Ry(pitch)·Rx(roll)`)
303    /// and translate by `xyz`, using [`fast_sin_cos`] for const evaluation.
304    #[inline(always)]
305    const fn from_xyz_rpy(xyz: (f64, f64, f64), rpy: (f64, f64, f64)) -> Self {
306        let (ox, oy, oz) = xyz;
307        let (roll, pitch, yaw) = rpy;
308        let (sr, cr) = fast_sin_cos(roll as f32);
309        let (sp, cp) = fast_sin_cos(pitch as f32);
310        let (sy, cy) = fast_sin_cos(yaw as f32);
311        Self {
312            c0: [cy * cp, sy * cp, -sp],
313            c1: [cy * sp * sr - sy * cr, sy * sp * sr + cy * cr, cp * sr],
314            c2: [cy * sp * cr + sy * sr, sy * sp * cr - cy * sr, cp * cr],
315            t: [ox as f32, oy as f32, oz as f32],
316        }
317    }
318
319    #[inline(always)]
320    const fn to_affine3a(self) -> Affine3A {
321        Affine3A {
322            matrix3: Mat3A::from_cols(
323                Vec3A::new(self.c0[0], self.c0[1], self.c0[2]),
324                Vec3A::new(self.c1[0], self.c1[1], self.c1[2]),
325                Vec3A::new(self.c2[0], self.c2[1], self.c2[2]),
326            ),
327            translation: Vec3A::new(self.t[0], self.t[1], self.t[2]),
328        }
329    }
330
331    #[inline(always)]
332    const fn c0_vec3a(&self) -> Vec3A {
333        Vec3A::new(self.c0[0], self.c0[1], self.c0[2])
334    }
335
336    #[inline(always)]
337    const fn c1_vec3a(&self) -> Vec3A {
338        Vec3A::new(self.c1[0], self.c1[1], self.c1[2])
339    }
340
341    #[inline(always)]
342    const fn c2_vec3a(&self) -> Vec3A {
343        Vec3A::new(self.c2[0], self.c2[1], self.c2[2])
344    }
345
346    #[inline(always)]
347    const fn t_vec3a(&self) -> Vec3A {
348        Vec3A::new(self.t[0], self.t[1], self.t[2])
349    }
350}
351
352/// Accumulate a local rotation + translation into the running transform.
353/// Shared by both DH and HP — the only difference is how each convention
354/// builds `local_c0..c2` and `local_t`.
355#[inline(always)]
356fn accumulate(
357    acc_m: &mut Mat3A,
358    acc_t: &mut Vec3A,
359    local_c0: Vec3A,
360    local_c1: Vec3A,
361    local_c2: Vec3A,
362    local_t: Vec3A,
363) {
364    let new_c0 = *acc_m * local_c0;
365    let new_c1 = *acc_m * local_c1;
366    let new_c2 = *acc_m * local_c2;
367    *acc_t = *acc_m * local_t + *acc_t;
368    *acc_m = Mat3A::from_cols(new_c0, new_c1, new_c2);
369}
370
371#[derive(Debug, Clone, Copy)]
372pub struct DHJoint {
373    pub a: f32,
374    pub alpha: f32,
375    pub d: f32,
376    pub theta_offset: f32,
377}
378
379/// Precomputed standard-DH chain with SoA layout.
380///
381/// Convention: `T_i = Rz(θ) · Tz(d) · Tx(a) · Rx(α)`
382#[derive(Debug, Clone)]
383pub struct DHChain<const N: usize> {
384    a: [f32; N],
385    d: [f32; N],
386    sin_alpha: [f32; N],
387    cos_alpha: [f32; N],
388    theta_offset: [f32; N],
389}
390
391impl<const N: usize> DHChain<N> {
392    pub const fn new(joints: [DHJoint; N]) -> Self {
393        let mut a = [0.0; N];
394        let mut d = [0.0; N];
395        let mut sin_alpha = [0.0; N];
396        let mut cos_alpha = [0.0; N];
397        let mut theta_offset = [0.0; N];
398
399        let mut i = 0;
400        while i < N {
401            a[i] = joints[i].a;
402            d[i] = joints[i].d;
403            let (sa, ca) = fast_sin_cos(joints[i].alpha);
404            sin_alpha[i] = sa;
405            cos_alpha[i] = ca;
406            theta_offset[i] = joints[i].theta_offset;
407            i += 1;
408        }
409
410        Self {
411            a,
412            d,
413            sin_alpha,
414            cos_alpha,
415            theta_offset,
416        }
417    }
418
419    /// Construct from the row-major `DH_PARAMS` const array emitted by the
420    /// workcell macro.
421    ///
422    /// `params`: `[[f64; N]; 4]` — rows are `(a, alpha, d, theta_offset)`
423    /// across joints.
424    pub const fn from_dh(params: &[[f64; N]; 4]) -> Self {
425        let mut a = [0.0f32; N];
426        let mut d = [0.0f32; N];
427        let mut sin_alpha = [0.0f32; N];
428        let mut cos_alpha = [0.0f32; N];
429        let mut theta_offset = [0.0f32; N];
430
431        let mut i = 0;
432        while i < N {
433            a[i] = params[0][i] as f32;
434            let (sa, ca) = fast_sin_cos(params[1][i] as f32);
435            sin_alpha[i] = sa;
436            cos_alpha[i] = ca;
437            d[i] = params[2][i] as f32;
438            theta_offset[i] = params[3][i] as f32;
439            i += 1;
440        }
441
442        Self {
443            a,
444            d,
445            sin_alpha,
446            cos_alpha,
447            theta_offset,
448        }
449    }
450}
451
452impl<const N: usize> FKChain<N> for DHChain<N> {
453    #[cfg(debug_assertions)]
454    type Error = DekeError;
455    #[cfg(not(debug_assertions))]
456    type Error = std::convert::Infallible;
457
458    /// DH forward kinematics exploiting the structure of `Rz(θ)·Rx(α)`.
459    ///
460    /// The per-joint accumulation decomposes into two 2D column rotations:
461    ///   1. Rotate `(c0, c1)` by θ  →  `(new_c0, perp)`
462    ///   2. Rotate `(perp, c2)` by α  →  `(new_c1, new_c2)`
463    /// Translation reuses `new_c0`:  `t += a·new_c0 + d·old_c2`
464    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
465        check_finite::<N>(q)?;
466        let mut out = [Affine3A::IDENTITY; N];
467        let mut c0 = Vec3A::X;
468        let mut c1 = Vec3A::Y;
469        let mut c2 = Vec3A::Z;
470        let mut t = Vec3A::ZERO;
471
472        let mut i = 0;
473        while i < N {
474            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
475            let sa = self.sin_alpha[i];
476            let ca = self.cos_alpha[i];
477
478            let new_c0 = ct * c0 + st * c1;
479            let perp = ct * c1 - st * c0;
480
481            let new_c1 = ca * perp + sa * c2;
482            let new_c2 = ca * c2 - sa * perp;
483
484            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
485
486            c0 = new_c0;
487            c1 = new_c1;
488            c2 = new_c2;
489
490            out[i] = Affine3A {
491                matrix3: Mat3A::from_cols(c0, c1, c2),
492                translation: t,
493            };
494            i += 1;
495        }
496        Ok(out)
497    }
498
499    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
500        check_finite::<N>(q)?;
501        let mut c0 = Vec3A::X;
502        let mut c1 = Vec3A::Y;
503        let mut c2 = Vec3A::Z;
504        let mut t = Vec3A::ZERO;
505
506        let mut i = 0;
507        while i < N {
508            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
509            let sa = self.sin_alpha[i];
510            let ca = self.cos_alpha[i];
511
512            let new_c0 = ct * c0 + st * c1;
513            let perp = ct * c1 - st * c0;
514
515            let new_c1 = ca * perp + sa * c2;
516            let new_c2 = ca * c2 - sa * perp;
517
518            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
519
520            c0 = new_c0;
521            c1 = new_c1;
522            c2 = new_c2;
523            i += 1;
524        }
525
526        Ok(Affine3A {
527            matrix3: Mat3A::from_cols(c0, c1, c2),
528            translation: t,
529        })
530    }
531
532    fn joint_axes_positions(
533        &self,
534        q: &SRobotQ<N>,
535    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
536        let frames = self.fk(q)?;
537        let mut axes = [Vec3A::Z; N];
538        let mut positions = [Vec3A::ZERO; N];
539
540        for i in 1..N {
541            axes[i] = frames[i - 1].matrix3.z_axis;
542            positions[i] = frames[i - 1].translation;
543        }
544
545        Ok((axes, positions, frames[N - 1].translation))
546    }
547}
548
549#[derive(Debug, Clone, Copy)]
550pub struct HPJoint {
551    pub a: f32,
552    pub alpha: f32,
553    pub beta: f32,
554    pub d: f32,
555    pub theta_offset: f32,
556}
557
558/// Precomputed Hayati-Paul chain with SoA layout.
559///
560/// Convention: `T_i = Rz(θ) · Rx(α) · Ry(β) · Tx(a) · Tz(d)`
561///
562/// HP adds a `β` rotation about Y, which makes it numerically stable for
563/// nearly-parallel consecutive joint axes where standard DH is singular.
564#[derive(Debug, Clone)]
565pub struct HPChain<const N: usize> {
566    a: [f32; N],
567    d: [f32; N],
568    sin_alpha: [f32; N],
569    cos_alpha: [f32; N],
570    sin_beta: [f32; N],
571    cos_beta: [f32; N],
572    theta_offset: [f32; N],
573}
574
575impl<const N: usize> HPChain<N> {
576    pub const fn new(joints: [HPJoint; N]) -> Self {
577        let mut a = [0.0; N];
578        let mut d = [0.0; N];
579        let mut sin_alpha = [0.0; N];
580        let mut cos_alpha = [0.0; N];
581        let mut sin_beta = [0.0; N];
582        let mut cos_beta = [0.0; N];
583        let mut theta_offset = [0.0; N];
584
585        let mut i = 0;
586        while i < N {
587            a[i] = joints[i].a;
588            d[i] = joints[i].d;
589            let (sa, ca) = fast_sin_cos(joints[i].alpha);
590            sin_alpha[i] = sa;
591            cos_alpha[i] = ca;
592            let (sb, cb) = fast_sin_cos(joints[i].beta);
593            sin_beta[i] = sb;
594            cos_beta[i] = cb;
595            theta_offset[i] = joints[i].theta_offset;
596            i += 1;
597        }
598
599        Self {
600            a,
601            d,
602            sin_alpha,
603            cos_alpha,
604            sin_beta,
605            cos_beta,
606            theta_offset,
607        }
608    }
609
610    /// Construct from the row-major `HP_H` and `HP_P` const arrays emitted by
611    /// the workcell macro.
612    ///
613    /// `h`: `[[f64; N]; 3]` — rows are (x, y, z) components across joints.
614    /// `p`: `[[f64; N]; 3]` — rows are (x, y, z) components across points.
615    ///
616    /// Each `h[_][i]` is joint `i`'s axis in the base frame at zero config.
617    /// `p[_][0]` is the base-to-joint-0 offset; `p[_][i]` for `1..N` is the
618    /// offset from joint `i-1`'s origin to joint `i`'s origin. The tool
619    /// offset from joint `N-1` to the flange is not part of this input
620    /// because `HPChain` has no end-effector slot — wrap the result in a
621    /// [`TransformedFK`](crate::TransformedFK) if a tool offset is required.
622    ///
623    /// `theta_offset` is set to zero for every joint: at zero config each
624    /// local x-axis is pinned to `Rx(α) · Ry(β) · [1, 0, 0]` expressed in the
625    /// parent frame.
626    pub const fn from_hp(h: &[[f32; N]; 3], p: &[[f32; N]; 3]) -> Self {
627
628        let mut a = [0.0f32; N];
629        let mut d = [0.0f32; N];
630        let mut sin_alpha = [0.0f32; N];
631        let mut cos_alpha = [0.0f32; N];
632        let mut sin_beta = [0.0f32; N];
633        let mut cos_beta = [0.0f32; N];
634
635        let mut c0 = [1.0f32, 0.0, 0.0];
636        let mut c1 = [0.0f32, 1.0, 0.0];
637        let mut c2 = [0.0f32, 0.0, 1.0];
638
639        const EPS: f32 = 1e-12;
640
641        let mut i = 0;
642        while i < N {
643            let hx = h[0][i];
644            let hy = h[1][i];
645            let hz = h[2][i];
646            let px = p[0][i];
647            let py = p[1][i];
648            let pz = p[2][i];
649
650            let vx = c0[0] * hx + c0[1] * hy + c0[2] * hz;
651            let vy = c1[0] * hx + c1[1] * hy + c1[2] * hz;
652            let vz = c2[0] * hx + c2[1] * hy + c2[2] * hz;
653            let ux = c0[0] * px + c0[1] * py + c0[2] * pz;
654            let uy = c1[0] * px + c1[1] * py + c1[2] * pz;
655            let uz = c2[0] * px + c2[1] * py + c2[2] * pz;
656
657            let sb = vx;
658            let cb = const_sqrt((vy * vy + vz * vz) as f64) as f32;
659
660            let (sa, ca) = if cb > EPS {
661                (-vy / cb, vz / cb)
662            } else {
663                (0.0, 1.0)
664            };
665
666            let big_a = ux;
667            let big_b = sa * uy - ca * uz;
668            let ai = cb * big_a + sb * big_b;
669            let di = sb * big_a - cb * big_b;
670
671            a[i] = ai;
672            d[i] = di;
673            sin_alpha[i] = sa;
674            cos_alpha[i] = ca;
675            sin_beta[i] = sb;
676            cos_beta[i] = cb;
677
678            let sasb = sa * sb;
679            let casb = ca * sb;
680            let sacb = sa * cb;
681            let cacb = ca * cb;
682            let new_c0 = [
683                cb * c0[0] + sasb * c1[0] - casb * c2[0],
684                cb * c0[1] + sasb * c1[1] - casb * c2[1],
685                cb * c0[2] + sasb * c1[2] - casb * c2[2],
686            ];
687            let new_c1 = [
688                ca * c1[0] + sa * c2[0],
689                ca * c1[1] + sa * c2[1],
690                ca * c1[2] + sa * c2[2],
691            ];
692            let new_c2 = [
693                sb * c0[0] - sacb * c1[0] + cacb * c2[0],
694                sb * c0[1] - sacb * c1[1] + cacb * c2[1],
695                sb * c0[2] - sacb * c1[2] + cacb * c2[2],
696            ];
697            c0 = new_c0;
698            c1 = new_c1;
699            c2 = new_c2;
700
701            i += 1;
702        }
703
704        Self {
705            a,
706            d,
707            sin_alpha,
708            cos_alpha,
709            sin_beta,
710            cos_beta,
711            theta_offset: [0.0f32; N],
712        }
713    }
714
715    /// Build the local rotation columns and translation for joint `i`.
716    ///
717    /// R = Rz(θ) · Rx(α) · Ry(β), then t = R · [a, 0, d].
718    ///
719    /// Rx(α)·Ry(β) columns:
720    ///   col0 = [ cβ,       sα·sβ,     -cα·sβ     ]
721    ///   col1 = [ 0,        cα,          sα        ]
722    ///   col2 = [ sβ,      -sα·cβ,      cα·cβ     ]
723    ///
724    /// Then Rz(θ) rotates each column: [cθ·x - sθ·y, sθ·x + cθ·y, z]
725    ///
726    /// Translation = a·col0 + d·col2  (since R·[a,0,d] = a·col0 + d·col2).
727    #[inline(always)]
728    fn local_frame(&self, i: usize, st: f32, ct: f32) -> (Vec3A, Vec3A, Vec3A, Vec3A) {
729        let sa = self.sin_alpha[i];
730        let ca = self.cos_alpha[i];
731        let sb = self.sin_beta[i];
732        let cb = self.cos_beta[i];
733        let ai = self.a[i];
734        let di = self.d[i];
735
736        let sa_sb = sa * sb;
737        let sa_cb = sa * cb;
738        let ca_sb = ca * sb;
739        let ca_cb = ca * cb;
740
741        let c0 = Vec3A::new(ct * cb - st * sa_sb, st * cb + ct * sa_sb, -ca_sb);
742        let c1 = Vec3A::new(-st * ca, ct * ca, sa);
743        let c2 = Vec3A::new(ct * sb + st * sa_cb, st * sb - ct * sa_cb, ca_cb);
744        let t = Vec3A::new(
745            ai * c0.x + di * c2.x,
746            ai * c0.y + di * c2.y,
747            ai * c0.z + di * c2.z,
748        );
749
750        (c0, c1, c2, t)
751    }
752}
753
754impl<const N: usize> FKChain<N> for HPChain<N> {
755    #[cfg(debug_assertions)]
756    type Error = DekeError;
757    #[cfg(not(debug_assertions))]
758    type Error = std::convert::Infallible;
759
760    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
761        check_finite::<N>(q)?;
762        let mut out = [Affine3A::IDENTITY; N];
763        let mut acc_m = Mat3A::IDENTITY;
764        let mut acc_t = Vec3A::ZERO;
765
766        let mut i = 0;
767        while i < N {
768            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
769            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
770            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
771
772            out[i] = Affine3A {
773                matrix3: acc_m,
774                translation: acc_t,
775            };
776            i += 1;
777        }
778        Ok(out)
779    }
780
781    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
782        check_finite(q)?;
783        let mut acc_m = Mat3A::IDENTITY;
784        let mut acc_t = Vec3A::ZERO;
785
786        let mut i = 0;
787        while i < N {
788            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
789            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
790            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
791            i += 1;
792        }
793
794        Ok(Affine3A {
795            matrix3: acc_m,
796            translation: acc_t,
797        })
798    }
799
800    fn joint_axes_positions(
801        &self,
802        q: &SRobotQ<N>,
803    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
804        let frames = self.fk(q)?;
805        let mut axes = [Vec3A::Z; N];
806        let mut positions = [Vec3A::ZERO; N];
807
808        for i in 1..N {
809            axes[i] = frames[i - 1].matrix3.z_axis;
810            positions[i] = frames[i - 1].translation;
811        }
812
813        Ok((axes, positions, frames[N - 1].translation))
814    }
815}
816
817/// Kind of URDF joint. Fixed joints have no motion; revolute and prismatic
818/// joints move along `axis` (expressed in the joint's own frame, as per the
819/// URDF spec).
820#[derive(Debug, Clone, Copy, PartialEq)]
821pub enum URDFJointType {
822    Fixed,
823    Revolute { axis: (f64, f64, f64) },
824    Prismatic { axis: (f64, f64, f64) },
825}
826
827/// A URDF joint: its type plus the `<origin>` transform (xyz translation and
828/// rpy Euler rotation) from the parent link's frame to the joint's own frame.
829#[derive(Debug, Clone, Copy)]
830pub struct URDFJoint {
831    pub r#type: URDFJointType,
832    pub xyz: (f64, f64, f64),
833    pub rpy: (f64, f64, f64),
834}
835
836impl URDFJoint {
837    pub const fn fixed(xyz: (f64, f64, f64), rpy: (f64, f64, f64)) -> Self {
838        Self {
839            r#type: URDFJointType::Fixed,
840            xyz,
841            rpy,
842        }
843    }
844
845    pub const fn revolute(
846        xyz: (f64, f64, f64),
847        rpy: (f64, f64, f64),
848        axis: (f64, f64, f64),
849    ) -> Self {
850        Self {
851            r#type: URDFJointType::Revolute { axis },
852            xyz,
853            rpy,
854        }
855    }
856
857    pub const fn prismatic(
858        xyz: (f64, f64, f64),
859        rpy: (f64, f64, f64),
860        axis: (f64, f64, f64),
861    ) -> Self {
862        Self {
863            r#type: URDFJointType::Prismatic { axis },
864            xyz,
865            rpy,
866        }
867    }
868
869    /// Build the `Affine3A` corresponding to this joint's `<origin>`, using
870    /// the URDF RPY convention `R = Rz(yaw) · Ry(pitch) · Rx(roll)`.
871    pub const fn origin_affine(&self) -> Affine3A {
872        AffineRaw::from_xyz_rpy(self.xyz, self.rpy).to_affine3a()
873    }
874}
875
876/// Const-friendly error type for the `URDFChain` / `URDFJoint` const
877/// constructors. Trivially `Copy`, so values can be matched/returned inside
878/// `const fn`s (unlike [`DekeError`], whose `RetimerFailed(String)` variant
879/// carries a non-const destructor). Converts into [`DekeError`] via `From`.
880#[derive(Debug, Clone, Copy, PartialEq, thiserror::Error)]
881pub enum URDFBuildError {
882    #[error(
883        "URDF joint at index {index} has an unexpected type: expected {expected}, found {found}"
884    )]
885    JointTypeMismatch {
886        index: usize,
887        expected: &'static str,
888        found: &'static str,
889    },
890    #[error("URDFChain<{expected}> requires {expected} revolute joints, found {found}")]
891    RevoluteCountMismatch { expected: usize, found: usize },
892}
893
894impl From<URDFBuildError> for DekeError {
895    fn from(e: URDFBuildError) -> Self {
896        match e {
897            URDFBuildError::JointTypeMismatch {
898                index,
899                expected,
900                found,
901            } => DekeError::URDFJointTypeMismatch {
902                index,
903                expected,
904                found,
905            },
906            URDFBuildError::RevoluteCountMismatch { expected, found } => {
907                DekeError::URDFRevoluteCountMismatch { expected, found }
908            }
909        }
910    }
911}
912
913const fn joint_kind_name(k: URDFJointType) -> &'static str {
914    match k {
915        URDFJointType::Fixed => "Fixed",
916        URDFJointType::Revolute { .. } => "Revolute",
917        URDFJointType::Prismatic { .. } => "Prismatic",
918    }
919}
920
921const fn compose_fixed_joints_raw(joints: &[URDFJoint]) -> Result<AffineRaw, URDFBuildError> {
922    let mut acc = AffineRaw::IDENTITY;
923    let n = joints.len();
924    let mut i = 0;
925    while i < n {
926        let j = &joints[i];
927        if !matches!(j.r#type, URDFJointType::Fixed) {
928            return Err(URDFBuildError::JointTypeMismatch {
929                index: i,
930                expected: "Fixed",
931                found: joint_kind_name(j.r#type),
932            });
933        }
934        acc = acc.mul(AffineRaw::from_xyz_rpy(j.xyz, j.rpy));
935        i += 1;
936    }
937    Ok(acc)
938}
939
940/// Compose the `<origin>` transforms of a sequence of fixed joints
941/// (parent→child order) into a single `Affine3A`. Returns an error if any
942/// joint in `joints` is not `Fixed`.
943pub const fn compose_fixed_joints(joints: &[URDFJoint]) -> Result<Affine3A, URDFBuildError> {
944    match compose_fixed_joints_raw(joints) {
945        Ok(a) => Ok(a.to_affine3a()),
946        Err(e) => Err(e),
947    }
948}
949
950/// Precomputed per-joint axis type for column-rotation FK.
951#[derive(Debug, Clone, Copy)]
952enum JointAxis {
953    Z,
954    Y(f32),
955    X(f32),
956}
957
958/// FK chain using exact URDF joint transforms.
959///
960/// Accumulation works directly on columns:
961///   1. Translation: `t += fx·c0 + fy·c1 + fz·c2`
962///   2. Fixed rotation: `(c0,c1,c2) = (c0,c1,c2) * fixed_rot`
963///   3. Joint rotation: 2D rotation on the appropriate column pair
964///
965/// When `fixed_rot` is identity (RPY = 0, the common case), step 2 is
966/// skipped entirely, making per-joint cost a single 2D column rotation
967/// plus translation — cheaper than DH.
968#[derive(Debug, Clone)]
969pub struct URDFChain<const N: usize> {
970    fr_c0: [Vec3A; N],
971    fr_c1: [Vec3A; N],
972    fr_c2: [Vec3A; N],
973    fr_identity: [bool; N],
974    fixed_trans: [Vec3A; N],
975    axis: [JointAxis; N],
976    prefix_c0: Vec3A,
977    prefix_c1: Vec3A,
978    prefix_c2: Vec3A,
979    prefix_t: Vec3A,
980    prefix_identity: bool,
981    suffix_c0: Vec3A,
982    suffix_c1: Vec3A,
983    suffix_c2: Vec3A,
984    suffix_t: Vec3A,
985    suffix_identity: bool,
986}
987
988impl<const N: usize> URDFChain<N> {
989    /// Build a chain from exactly `N` actuated (revolute) joints. Returns
990    /// [`URDFBuildError::JointTypeMismatch`] if any entry is `Fixed` or
991    /// `Prismatic`. For a slice that mixes fixed joints in, use
992    /// [`URDFChain::from_urdf`] instead.
993    pub const fn new(joints: [URDFJoint; N]) -> Result<Self, URDFBuildError> {
994        let mut fr_c0 = [Vec3A::X; N];
995        let mut fr_c1 = [Vec3A::Y; N];
996        let mut fr_c2 = [Vec3A::Z; N];
997        let mut fr_identity = [true; N];
998        let mut fixed_trans = [Vec3A::ZERO; N];
999        let mut axis = [JointAxis::Z; N];
1000
1001        let mut i = 0;
1002        while i < N {
1003            let (ox, oy, oz) = joints[i].xyz;
1004            let (roll, pitch, yaw) = joints[i].rpy;
1005
1006            let is_identity = roll.abs() < 1e-10 && pitch.abs() < 1e-10 && yaw.abs() < 1e-10;
1007            fr_identity[i] = is_identity;
1008
1009            if !is_identity {
1010                let (sr, cr) = fast_sin_cos(roll as f32);
1011                let (sp, cp) = fast_sin_cos(pitch as f32);
1012                let (sy, cy) = fast_sin_cos(yaw as f32);
1013                fr_c0[i] = Vec3A::new(cy * cp, sy * cp, -sp);
1014                fr_c1[i] = Vec3A::new(cy * sp * sr - sy * cr, sy * sp * sr + cy * cr, cp * sr);
1015                fr_c2[i] = Vec3A::new(cy * sp * cr + sy * sr, sy * sp * cr - cy * sr, cp * cr);
1016            }
1017
1018            fixed_trans[i] = Vec3A::new(ox as f32, oy as f32, oz as f32);
1019
1020            let (ax, ay, az) = match joints[i].r#type {
1021                URDFJointType::Revolute { axis } => axis,
1022                _ => {
1023                    return Err(URDFBuildError::JointTypeMismatch {
1024                        index: i,
1025                        expected: "Revolute",
1026                        found: joint_kind_name(joints[i].r#type),
1027                    });
1028                }
1029            };
1030            if az.abs() > 0.5 {
1031                axis[i] = JointAxis::Z;
1032            } else if ay.abs() > 0.5 {
1033                axis[i] = JointAxis::Y(ay.signum() as f32);
1034            } else {
1035                axis[i] = JointAxis::X(ax.signum() as f32);
1036            }
1037            i += 1;
1038        }
1039
1040        Ok(Self {
1041            fr_c0,
1042            fr_c1,
1043            fr_c2,
1044            fr_identity,
1045            fixed_trans,
1046            axis,
1047            prefix_c0: Vec3A::X,
1048            prefix_c1: Vec3A::Y,
1049            prefix_c2: Vec3A::Z,
1050            prefix_t: Vec3A::ZERO,
1051            prefix_identity: true,
1052            suffix_c0: Vec3A::X,
1053            suffix_c1: Vec3A::Y,
1054            suffix_c2: Vec3A::Z,
1055            suffix_t: Vec3A::ZERO,
1056            suffix_identity: true,
1057        })
1058    }
1059
1060    /// Build a chain from a flat URDF joint list (any mix of `Fixed`,
1061    /// `Revolute`, and/or `Prismatic`). The list must describe a single
1062    /// branch in parent→child order.
1063    ///
1064    /// - Leading `Fixed` joints become the prefix (applied before joint 0).
1065    /// - Trailing `Fixed` joints become the suffix (applied after the last
1066    ///   actuated joint).
1067    /// - `Fixed` joints sandwiched between actuated joints are folded into
1068    ///   the origin of the next actuated joint so the kinematics are
1069    ///   preserved exactly.
1070    /// - The number of `Revolute` joints must equal `N`.
1071    ///
1072    /// Returns [`URDFBuildError::JointTypeMismatch`] if a `Prismatic` joint
1073    /// appears (not handled by `URDFChain` itself — wrap the result in
1074    /// [`PrismaticFK`] for a prismatic joint at the start or end), or
1075    /// [`URDFBuildError::RevoluteCountMismatch`] if the revolute count
1076    /// doesn't match `N`.
1077    pub const fn from_urdf(joints: &[URDFJoint]) -> Result<Self, URDFBuildError> {
1078        let mut fr_c0 = [Vec3A::X; N];
1079        let mut fr_c1 = [Vec3A::Y; N];
1080        let mut fr_c2 = [Vec3A::Z; N];
1081        let mut fr_identity = [true; N];
1082        let mut fixed_trans = [Vec3A::ZERO; N];
1083        let mut axis_out = [JointAxis::Z; N];
1084
1085        let mut pending = AffineRaw::IDENTITY;
1086        let mut prefix = AffineRaw::IDENTITY;
1087        let mut prefix_set = false;
1088        let mut r_count = 0usize;
1089
1090        let n = joints.len();
1091        let mut i = 0;
1092        while i < n {
1093            let joint = &joints[i];
1094            match joint.r#type {
1095                URDFJointType::Fixed => {
1096                    pending = pending.mul(AffineRaw::from_xyz_rpy(joint.xyz, joint.rpy));
1097                }
1098                URDFJointType::Revolute { axis } => {
1099                    if r_count >= N {
1100                        return Err(URDFBuildError::RevoluteCountMismatch {
1101                            expected: N,
1102                            found: r_count + 1,
1103                        });
1104                    }
1105                    let local = AffineRaw::from_xyz_rpy(joint.xyz, joint.rpy);
1106                    let effective = if !prefix_set {
1107                        prefix = pending;
1108                        prefix_set = true;
1109                        local
1110                    } else {
1111                        pending.mul(local)
1112                    };
1113
1114                    fr_identity[r_count] = effective.is_identity();
1115                    fr_c0[r_count] = effective.c0_vec3a();
1116                    fr_c1[r_count] = effective.c1_vec3a();
1117                    fr_c2[r_count] = effective.c2_vec3a();
1118                    fixed_trans[r_count] = effective.t_vec3a();
1119
1120                    let (ax, ay, az) = axis;
1121                    axis_out[r_count] = if az.abs() > 0.5 {
1122                        JointAxis::Z
1123                    } else if ay.abs() > 0.5 {
1124                        JointAxis::Y(ay.signum() as f32)
1125                    } else {
1126                        JointAxis::X(ax.signum() as f32)
1127                    };
1128
1129                    pending = AffineRaw::IDENTITY;
1130                    r_count += 1;
1131                }
1132                URDFJointType::Prismatic { .. } => {
1133                    return Err(URDFBuildError::JointTypeMismatch {
1134                        index: i,
1135                        expected: "Fixed or Revolute",
1136                        found: "Prismatic",
1137                    });
1138                }
1139            }
1140            i += 1;
1141        }
1142        if r_count != N {
1143            return Err(URDFBuildError::RevoluteCountMismatch {
1144                expected: N,
1145                found: r_count,
1146            });
1147        }
1148
1149        let prefix_identity = !prefix_set || prefix.is_identity();
1150        let suffix_identity = pending.is_identity();
1151
1152        Ok(Self {
1153            fr_c0,
1154            fr_c1,
1155            fr_c2,
1156            fr_identity,
1157            fixed_trans,
1158            axis: axis_out,
1159            prefix_c0: prefix.c0_vec3a(),
1160            prefix_c1: prefix.c1_vec3a(),
1161            prefix_c2: prefix.c2_vec3a(),
1162            prefix_t: prefix.t_vec3a(),
1163            prefix_identity,
1164            suffix_c0: pending.c0_vec3a(),
1165            suffix_c1: pending.c1_vec3a(),
1166            suffix_c2: pending.c2_vec3a(),
1167            suffix_t: pending.t_vec3a(),
1168            suffix_identity,
1169        })
1170    }
1171
1172    /// Bake a sequence of URDF fixed-joint origins (parent→child order) into
1173    /// the base side of the chain. The composed transform is applied before
1174    /// joint 0, so every joint frame returned by [`FKChain::fk`] and every
1175    /// position returned by [`FKChain::joint_axes_positions`] reflects the
1176    /// fixed prefix.
1177    ///
1178    /// Each joint in `joints` must be `URDFJointType::Fixed`. An empty slice
1179    /// clears any previously set prefix. Returns
1180    /// [`DekeError::URDFJointTypeMismatch`] if any joint is non-Fixed.
1181    pub const fn with_fixed_prefix(
1182        mut self,
1183        joints: &[URDFJoint],
1184    ) -> Result<Self, URDFBuildError> {
1185        if joints.is_empty() {
1186            self.prefix_c0 = Vec3A::X;
1187            self.prefix_c1 = Vec3A::Y;
1188            self.prefix_c2 = Vec3A::Z;
1189            self.prefix_t = Vec3A::ZERO;
1190            self.prefix_identity = true;
1191        } else {
1192            let a = match compose_fixed_joints_raw(joints) {
1193                Ok(a) => a,
1194                Err(e) => return Err(e),
1195            };
1196            self.prefix_identity = a.is_identity();
1197            self.prefix_c0 = a.c0_vec3a();
1198            self.prefix_c1 = a.c1_vec3a();
1199            self.prefix_c2 = a.c2_vec3a();
1200            self.prefix_t = a.t_vec3a();
1201        }
1202        Ok(self)
1203    }
1204
1205    /// Bake a sequence of URDF fixed-joint origins (parent→child order) into
1206    /// the tool side of the chain. The composed transform is applied after
1207    /// the last actuated joint, so the final frame of [`FKChain::fk`], the
1208    /// result of [`FKChain::fk_end`], and the `p_ee` returned by
1209    /// [`FKChain::joint_axes_positions`] all include the fixed suffix.
1210    ///
1211    /// Joint pivot positions and axes (`positions[0..N]`, `axes[0..N]`)
1212    /// remain at the actuated joint origins — they are not shifted by the
1213    /// suffix.
1214    ///
1215    /// Each joint in `joints` must be `URDFJointType::Fixed`. An empty slice
1216    /// clears any previously set suffix. Returns
1217    /// [`DekeError::URDFJointTypeMismatch`] if any joint is non-Fixed.
1218    pub const fn with_fixed_suffix(
1219        mut self,
1220        joints: &[URDFJoint],
1221    ) -> Result<Self, URDFBuildError> {
1222        if joints.is_empty() {
1223            self.suffix_c0 = Vec3A::X;
1224            self.suffix_c1 = Vec3A::Y;
1225            self.suffix_c2 = Vec3A::Z;
1226            self.suffix_t = Vec3A::ZERO;
1227            self.suffix_identity = true;
1228        } else {
1229            let a = match compose_fixed_joints_raw(joints) {
1230                Ok(a) => a,
1231                Err(e) => return Err(e),
1232            };
1233            self.suffix_identity = a.is_identity();
1234            self.suffix_c0 = a.c0_vec3a();
1235            self.suffix_c1 = a.c1_vec3a();
1236            self.suffix_c2 = a.c2_vec3a();
1237            self.suffix_t = a.t_vec3a();
1238        }
1239        Ok(self)
1240    }
1241
1242    /// Convenience: set both a fixed-joint prefix and suffix in one call.
1243    pub const fn with_fixed_joints(
1244        self,
1245        prefix: &[URDFJoint],
1246        suffix: &[URDFJoint],
1247    ) -> Result<Self, URDFBuildError> {
1248        match self.with_fixed_prefix(prefix) {
1249            Ok(s) => s.with_fixed_suffix(suffix),
1250            Err(e) => Err(e),
1251        }
1252    }
1253
1254    #[inline(always)]
1255    fn initial_frame(&self) -> (Vec3A, Vec3A, Vec3A, Vec3A) {
1256        if self.prefix_identity {
1257            (Vec3A::X, Vec3A::Y, Vec3A::Z, Vec3A::ZERO)
1258        } else {
1259            (self.prefix_c0, self.prefix_c1, self.prefix_c2, self.prefix_t)
1260        }
1261    }
1262
1263    #[inline(always)]
1264    fn apply_suffix(
1265        &self,
1266        c0: &mut Vec3A,
1267        c1: &mut Vec3A,
1268        c2: &mut Vec3A,
1269        t: &mut Vec3A,
1270    ) {
1271        let st = self.suffix_t;
1272        *t = st.x * *c0 + st.y * *c1 + st.z * *c2 + *t;
1273
1274        let fc0 = self.suffix_c0;
1275        let fc1 = self.suffix_c1;
1276        let fc2 = self.suffix_c2;
1277        let new_c0 = fc0.x * *c0 + fc0.y * *c1 + fc0.z * *c2;
1278        let new_c1 = fc1.x * *c0 + fc1.y * *c1 + fc1.z * *c2;
1279        let new_c2 = fc2.x * *c0 + fc2.y * *c1 + fc2.z * *c2;
1280        *c0 = new_c0;
1281        *c1 = new_c1;
1282        *c2 = new_c2;
1283    }
1284
1285    /// Apply fixed rotation + joint rotation to accumulator columns.
1286    #[inline(always)]
1287    fn accumulate_joint(
1288        &self,
1289        i: usize,
1290        st: f32,
1291        ct: f32,
1292        c0: &mut Vec3A,
1293        c1: &mut Vec3A,
1294        c2: &mut Vec3A,
1295        t: &mut Vec3A,
1296    ) {
1297        let ft = self.fixed_trans[i];
1298        *t = ft.x * *c0 + ft.y * *c1 + ft.z * *c2 + *t;
1299
1300        let (f0, f1, f2) = if self.fr_identity[i] {
1301            (*c0, *c1, *c2)
1302        } else {
1303            let fc0 = self.fr_c0[i];
1304            let fc1 = self.fr_c1[i];
1305            let fc2 = self.fr_c2[i];
1306            (
1307                fc0.x * *c0 + fc0.y * *c1 + fc0.z * *c2,
1308                fc1.x * *c0 + fc1.y * *c1 + fc1.z * *c2,
1309                fc2.x * *c0 + fc2.y * *c1 + fc2.z * *c2,
1310            )
1311        };
1312
1313        match self.axis[i] {
1314            JointAxis::Z => {
1315                let new_c0 = ct * f0 + st * f1;
1316                let new_c1 = ct * f1 - st * f0;
1317                *c0 = new_c0;
1318                *c1 = new_c1;
1319                *c2 = f2;
1320            }
1321            JointAxis::Y(s) => {
1322                let sst = s * st;
1323                let new_c0 = ct * f0 - sst * f2;
1324                let new_c2 = sst * f0 + ct * f2;
1325                *c0 = new_c0;
1326                *c1 = f1;
1327                *c2 = new_c2;
1328            }
1329            JointAxis::X(s) => {
1330                let sst = s * st;
1331                let new_c1 = ct * f1 + sst * f2;
1332                let new_c2 = ct * f2 - sst * f1;
1333                *c0 = f0;
1334                *c1 = new_c1;
1335                *c2 = new_c2;
1336            }
1337        }
1338    }
1339}
1340
1341impl<const N: usize> FKChain<N> for URDFChain<N> {
1342    #[cfg(debug_assertions)]
1343    type Error = DekeError;
1344    #[cfg(not(debug_assertions))]
1345    type Error = std::convert::Infallible;
1346
1347    fn base_tf(&self) -> Affine3A {
1348        if self.prefix_identity {
1349            Affine3A::IDENTITY
1350        } else {
1351            Affine3A {
1352                matrix3: Mat3A::from_cols(self.prefix_c0, self.prefix_c1, self.prefix_c2),
1353                translation: self.prefix_t,
1354            }
1355        }
1356    }
1357
1358    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
1359        check_finite(q)?;
1360        let mut out = [Affine3A::IDENTITY; N];
1361        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
1362
1363        let mut i = 0;
1364        while i < N {
1365            let (st, ct) = fast_sin_cos(q.0[i]);
1366            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
1367
1368            out[i] = Affine3A {
1369                matrix3: Mat3A::from_cols(c0, c1, c2),
1370                translation: t,
1371            };
1372            i += 1;
1373        }
1374
1375        if N > 0 && !self.suffix_identity {
1376            self.apply_suffix(&mut c0, &mut c1, &mut c2, &mut t);
1377            out[N - 1] = Affine3A {
1378                matrix3: Mat3A::from_cols(c0, c1, c2),
1379                translation: t,
1380            };
1381        }
1382        Ok(out)
1383    }
1384
1385    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
1386        check_finite(q)?;
1387        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
1388
1389        let mut i = 0;
1390        while i < N {
1391            let (st, ct) = fast_sin_cos(q.0[i]);
1392            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
1393            i += 1;
1394        }
1395
1396        if !self.suffix_identity {
1397            self.apply_suffix(&mut c0, &mut c1, &mut c2, &mut t);
1398        }
1399
1400        Ok(Affine3A {
1401            matrix3: Mat3A::from_cols(c0, c1, c2),
1402            translation: t,
1403        })
1404    }
1405
1406    fn joint_axes_positions(
1407        &self,
1408        q: &SRobotQ<N>,
1409    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
1410        check_finite(q)?;
1411        let mut frames = [Affine3A::IDENTITY; N];
1412        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
1413
1414        let mut i = 0;
1415        while i < N {
1416            let (st, ct) = fast_sin_cos(q.0[i]);
1417            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
1418            frames[i] = Affine3A {
1419                matrix3: Mat3A::from_cols(c0, c1, c2),
1420                translation: t,
1421            };
1422            i += 1;
1423        }
1424
1425        let mut axes = [Vec3A::ZERO; N];
1426        let mut positions = [Vec3A::ZERO; N];
1427
1428        for i in 0..N {
1429            axes[i] = match self.axis[i] {
1430                JointAxis::Z => frames[i].matrix3.z_axis,
1431                JointAxis::Y(s) => s * frames[i].matrix3.y_axis,
1432                JointAxis::X(s) => s * frames[i].matrix3.x_axis,
1433            };
1434            positions[i] = frames[i].translation;
1435        }
1436
1437        let p_ee = if N == 0 {
1438            Vec3A::ZERO
1439        } else if !self.suffix_identity {
1440            self.apply_suffix(&mut c0, &mut c1, &mut c2, &mut t);
1441            t
1442        } else {
1443            frames[N - 1].translation
1444        };
1445
1446        Ok((axes, positions, p_ee))
1447    }
1448}
1449
1450/// Wraps an `FKChain` with an optional prefix (base) and/or suffix (tool) transform.
1451///
1452/// - `fk` applies only the prefix — intermediate frames stay in world coordinates
1453///   without the tool offset.
1454/// - `fk_end` and `joint_axes_positions` apply both — the end-effector includes
1455///   the tool tip.
1456#[derive(Debug, Clone)]
1457pub struct TransformedFK<const N: usize, FK: FKChain<N>> {
1458    inner: FK,
1459    prefix: Option<Affine3A>,
1460    suffix: Option<Affine3A>,
1461}
1462
1463impl<const N: usize, FK: FKChain<N>> TransformedFK<N, FK> {
1464    pub const fn new(inner: FK) -> Self {
1465        Self {
1466            inner,
1467            prefix: None,
1468            suffix: None,
1469        }
1470    }
1471
1472    pub const fn with_prefix(mut self, prefix: Affine3A) -> Self {
1473        self.prefix = Some(prefix);
1474        self
1475    }
1476
1477    pub const fn with_suffix(mut self, suffix: Affine3A) -> Self {
1478        self.suffix = Some(suffix);
1479        self
1480    }
1481
1482    /// Infallible `const`-usable setter for the prefix. `None` clears any
1483    /// previously set prefix. Pair with [`compose_fixed_joints`] (const) to
1484    /// build the prefix from a slice of `Fixed` joints in a `const` context.
1485    pub const fn with_prefix_opt(mut self, prefix: Option<Affine3A>) -> Self {
1486        self.prefix = prefix;
1487        self
1488    }
1489
1490    /// Infallible `const`-usable setter for the suffix. `None` clears any
1491    /// previously set suffix. Pair with [`compose_fixed_joints`] (const) to
1492    /// build the suffix from a slice of `Fixed` joints in a `const` context.
1493    pub const fn with_suffix_opt(mut self, suffix: Option<Affine3A>) -> Self {
1494        self.suffix = suffix;
1495        self
1496    }
1497
1498    /// Compose a slice of fixed URDF joints (parent→child order) and set the
1499    /// result as the base-side prefix transform, replacing any existing
1500    /// prefix. Every joint in `joints` must be `URDFJointType::Fixed`. An
1501    /// empty slice clears the prefix.
1502    pub fn with_prefix_joints(mut self, joints: &[URDFJoint]) -> Result<Self, URDFBuildError> {
1503        if joints.is_empty() {
1504            self.prefix = None;
1505            Ok(self)
1506        } else {
1507            self.prefix = Some(compose_fixed_joints(joints)?);
1508            Ok(self)
1509        }
1510    }
1511
1512    /// Compose a slice of fixed URDF joints (parent→child order) and set the
1513    /// result as the tool-side suffix transform, replacing any existing
1514    /// suffix. Every joint in `joints` must be `URDFJointType::Fixed`. An
1515    /// empty slice clears the suffix.
1516    pub fn with_suffix_joints(mut self, joints: &[URDFJoint]) -> Result<Self, URDFBuildError> {
1517        if joints.is_empty() {
1518            self.suffix = None;
1519            Ok(self)
1520        } else {
1521            self.suffix = Some(compose_fixed_joints(joints)?);
1522            Ok(self)
1523        }
1524    }
1525
1526    pub fn set_prefix(&mut self, prefix: Option<Affine3A>) {
1527        self.prefix = prefix;
1528    }
1529
1530    pub fn set_suffix(&mut self, suffix: Option<Affine3A>) {
1531        self.suffix = suffix;
1532    }
1533
1534    pub fn prefix(&self) -> Option<&Affine3A> {
1535        self.prefix.as_ref()
1536    }
1537
1538    pub fn suffix(&self) -> Option<&Affine3A> {
1539        self.suffix.as_ref()
1540    }
1541
1542    pub fn inner(&self) -> &FK {
1543        &self.inner
1544    }
1545}
1546
1547impl<const N: usize, FK: FKChain<N>> FKChain<N> for TransformedFK<N, FK> {
1548    type Error = FK::Error;
1549
1550    fn base_tf(&self) -> Affine3A {
1551        match &self.prefix {
1552            Some(p) => *p * self.inner.base_tf(),
1553            None => self.inner.base_tf(),
1554        }
1555    }
1556
1557    fn max_reach(&self) -> Result<f32, Self::Error> {
1558        let mut reach = self.inner.max_reach()?;
1559        if let Some(suf) = &self.suffix {
1560            reach += Vec3A::from(suf.translation).length();
1561        }
1562        Ok(reach)
1563    }
1564
1565    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
1566        let mut frames = self.inner.fk(q)?;
1567        if let Some(pre) = &self.prefix {
1568            for f in &mut frames {
1569                *f = *pre * *f;
1570            }
1571        }
1572        Ok(frames)
1573    }
1574
1575    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
1576        let mut end = self.inner.fk_end(q)?;
1577        if let Some(pre) = &self.prefix {
1578            end = *pre * end;
1579        }
1580        if let Some(suf) = &self.suffix {
1581            end = end * *suf;
1582        }
1583        Ok(end)
1584    }
1585
1586    fn joint_axes_positions(
1587        &self,
1588        q: &SRobotQ<N>,
1589    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
1590        let (mut axes, mut positions, inner_p_ee) = self.inner.joint_axes_positions(q)?;
1591
1592        if let Some(pre) = &self.prefix {
1593            let rot = pre.matrix3;
1594            let t = Vec3A::from(pre.translation);
1595            for i in 0..N {
1596                axes[i] = rot * axes[i];
1597                positions[i] = rot * positions[i] + t;
1598            }
1599        }
1600
1601        let p_ee = if self.prefix.is_some() || self.suffix.is_some() {
1602            self.fk_end(q)?.translation
1603        } else {
1604            inner_p_ee
1605        };
1606
1607        Ok((axes, positions, p_ee))
1608    }
1609}
1610
1611/// Wraps an `FKChain<N>` and prepends a prismatic (linear) joint, producing
1612/// an `FKChain<M>` where `M = N + 1`.
1613///
1614/// The prismatic joint always acts first in the kinematic chain — it
1615/// translates the entire arm along `axis` (world frame).  The
1616/// `q_index_first` flag only controls where the prismatic value is read
1617/// from in `SRobotQ<M>`: when `true` it is `q[0]`, when `false` it is
1618/// `q[M-1]`.
1619///
1620/// Jacobian columns for the prismatic joint are `[axis; 0]` (pure linear,
1621/// no angular contribution).  Because the prismatic uniformly shifts all
1622/// positions, the revolute Jacobian columns are identical to the inner
1623/// chain's.
1624#[derive(Debug, Clone)]
1625pub struct PrismaticFK<const M: usize, const N: usize, FK: FKChain<N>> {
1626    inner: FK,
1627    axis: Vec3A,
1628    q_index_first: bool,
1629}
1630
1631impl<const M: usize, const N: usize, FK: FKChain<N>> PrismaticFK<M, N, FK> {
1632    pub const fn new(inner: FK, axis: Vec3A, q_index_first: bool) -> Self {
1633        const { assert!(M == N + 1, "M must equal N + 1") };
1634        Self {
1635            inner,
1636            axis,
1637            q_index_first,
1638        }
1639    }
1640
1641    pub fn inner(&self) -> &FK {
1642        &self.inner
1643    }
1644
1645    pub fn axis(&self) -> Vec3A {
1646        self.axis
1647    }
1648
1649    pub fn q_index_first(&self) -> bool {
1650        self.q_index_first
1651    }
1652
1653    fn split_q(&self, q: &SRobotQ<M>) -> (f32, SRobotQ<N>) {
1654        let mut inner = [0.0f32; N];
1655        if self.q_index_first {
1656            inner.copy_from_slice(&q.0[1..M]);
1657            (q.0[0], SRobotQ(inner))
1658        } else {
1659            inner.copy_from_slice(&q.0[..N]);
1660            (q.0[M - 1], SRobotQ(inner))
1661        }
1662    }
1663
1664    fn prismatic_col(&self) -> usize {
1665        if self.q_index_first { 0 } else { N }
1666    }
1667
1668    fn revolute_offset(&self) -> usize {
1669        if self.q_index_first { 1 } else { 0 }
1670    }
1671}
1672
1673impl<const M: usize, const N: usize, FK: FKChain<N>> FKChain<M> for PrismaticFK<M, N, FK> {
1674    type Error = FK::Error;
1675
1676    fn base_tf(&self) -> Affine3A {
1677        self.inner.base_tf()
1678    }
1679
1680    fn fk(&self, q: &SRobotQ<M>) -> Result<[Affine3A; M], Self::Error> {
1681        let (q_p, inner_q) = self.split_q(q);
1682        let offset = q_p * self.axis;
1683        let inner_frames = self.inner.fk(&inner_q)?;
1684        let mut out = [Affine3A::IDENTITY; M];
1685
1686        out[0] = Affine3A::from_translation(offset.into());
1687        for i in 0..N {
1688            let mut f = inner_frames[i];
1689            f.translation += offset;
1690            out[i + 1] = f;
1691        }
1692
1693        Ok(out)
1694    }
1695
1696    fn fk_end(&self, q: &SRobotQ<M>) -> Result<Affine3A, Self::Error> {
1697        let (q_p, inner_q) = self.split_q(q);
1698        let mut end = self.inner.fk_end(&inner_q)?;
1699        end.translation += q_p * self.axis;
1700        Ok(end)
1701    }
1702
1703    fn joint_axes_positions(
1704        &self,
1705        q: &SRobotQ<M>,
1706    ) -> Result<([Vec3A; M], [Vec3A; M], Vec3A), Self::Error> {
1707        let (q_p, inner_q) = self.split_q(q);
1708        let offset = q_p * self.axis;
1709        let (inner_axes, inner_pos, inner_p_ee) = self.inner.joint_axes_positions(&inner_q)?;
1710
1711        let mut axes = [Vec3A::ZERO; M];
1712        let mut positions = [Vec3A::ZERO; M];
1713
1714        axes[0] = self.axis;
1715        for i in 0..N {
1716            axes[i + 1] = inner_axes[i];
1717            positions[i + 1] = inner_pos[i] + offset;
1718        }
1719
1720        Ok((axes, positions, inner_p_ee + offset))
1721    }
1722
1723    fn jacobian(&self, q: &SRobotQ<M>) -> Result<[[f32; M]; 6], Self::Error> {
1724        let (_q_p, inner_q) = self.split_q(q);
1725        let inner_j = self.inner.jacobian(&inner_q)?;
1726        let p_col = self.prismatic_col();
1727        let r_off = self.revolute_offset();
1728
1729        let mut j = [[0.0f32; M]; 6];
1730        j[0][p_col] = self.axis.x;
1731        j[1][p_col] = self.axis.y;
1732        j[2][p_col] = self.axis.z;
1733
1734        for row in 0..6 {
1735            for col in 0..N {
1736                j[row][col + r_off] = inner_j[row][col];
1737            }
1738        }
1739
1740        Ok(j)
1741    }
1742
1743    fn jacobian_dot(
1744        &self,
1745        q: &SRobotQ<M>,
1746        qdot: &SRobotQ<M>,
1747    ) -> Result<[[f32; M]; 6], Self::Error> {
1748        let (_q_p, inner_q) = self.split_q(q);
1749        let (_qdot_p, inner_qdot) = self.split_q(qdot);
1750        let inner_jd = self.inner.jacobian_dot(&inner_q, &inner_qdot)?;
1751        let r_off = self.revolute_offset();
1752
1753        let mut jd = [[0.0f32; M]; 6];
1754        for row in 0..6 {
1755            for col in 0..N {
1756                jd[row][col + r_off] = inner_jd[row][col];
1757            }
1758        }
1759
1760        Ok(jd)
1761    }
1762
1763    fn jacobian_ddot(
1764        &self,
1765        q: &SRobotQ<M>,
1766        qdot: &SRobotQ<M>,
1767        qddot: &SRobotQ<M>,
1768    ) -> Result<[[f32; M]; 6], Self::Error> {
1769        let (_q_p, inner_q) = self.split_q(q);
1770        let (_qdot_p, inner_qdot) = self.split_q(qdot);
1771        let (_qddot_p, inner_qddot) = self.split_q(qddot);
1772        let inner_jdd = self
1773            .inner
1774            .jacobian_ddot(&inner_q, &inner_qdot, &inner_qddot)?;
1775        let r_off = self.revolute_offset();
1776
1777        let mut jdd = [[0.0f32; M]; 6];
1778        for row in 0..6 {
1779            for col in 0..N {
1780                jdd[row][col + r_off] = inner_jdd[row][col];
1781            }
1782        }
1783
1784        Ok(jdd)
1785    }
1786}