Skip to main content

deke_types/
fk.rs

1use glam::{Affine3A, Mat3A, Vec3A};
2
3use crate::{DekeError, SRobotQ};
4
5#[inline(always)]
6const fn fast_sin_cos(x: f32) -> (f32, f32) {
7    const FRAC_2_PI: f32 = std::f32::consts::FRAC_2_PI;
8    const PI_2_HI: f32 = 1.570_796_4_f32;
9    const PI_2_LO: f32 = -4.371_139e-8_f32;
10
11    const S1: f32 = -0.166_666_67;
12    const S2: f32 = 0.008_333_294;
13    const S3: f32 = -0.000_198_074_14;
14
15    const C1: f32 = -0.5;
16    const C2: f32 = 0.041_666_52;
17    const C3: f32 = -0.001_388_523_4;
18
19    let q = (x * FRAC_2_PI).round();
20    let qi = q as i32;
21    let r = x - q * PI_2_HI - q * PI_2_LO;
22    let r2 = r * r;
23
24    let sin_r = r * (1.0 + r2 * (S1 + r2 * (S2 + r2 * S3)));
25    let cos_r = 1.0 + r2 * (C1 + r2 * (C2 + r2 * C3));
26
27    let (s, c) = match qi & 3 {
28        0 => (sin_r, cos_r),
29        1 => (cos_r, -sin_r),
30        2 => (-sin_r, -cos_r),
31        3 => (-cos_r, sin_r),
32        _ => unsafe { std::hint::unreachable_unchecked() },
33    };
34
35    (s, c)
36}
37
38#[inline]
39const fn const_sqrt(x: f64) -> f64 {
40    if x < 0.0 || x.is_nan() { return f64::NAN; }
41    if x == 0.0 || x == f64::INFINITY { return x; }
42
43    // Initial guess: halve the exponent. For x = m * 2^e,
44    // sqrt(x) ≈ sqrt(m) * 2^(e/2). Extract, halve, reassemble.
45    let bits = x.to_bits();
46    let exp = ((bits >> 52) & 0x7ff) as i64;
47    let new_exp = ((exp - 1023) / 2 + 1023) as u64;
48    let mut guess = f64::from_bits((new_exp << 52) | (bits & 0x000f_ffff_ffff_ffff));
49
50    let mut prev = 0.0;
51    while guess != prev {
52        prev = guess;
53        guess = (guess + x / guess) * 0.5;
54    }
55    guess
56}
57
58pub trait FKChain<const N: usize>: Clone + Send + Sync {
59    type Error: Into<DekeError>;
60    fn dof(&self) -> usize {
61        N
62    }
63    /// Theoretical maximum reach: sum of link lengths (upper bound, ignores joint limits).
64    fn max_reach(&self) -> Result<f32, Self::Error> {
65        let (_, p, p_ee) = self.joint_axes_positions(&SRobotQ::zeros())?;
66        let mut total = 0.0f32;
67        let mut prev = p[0];
68        for i in 1..N {
69            total += (p[i] - prev).length();
70            prev = p[i];
71        }
72        total += (p_ee - prev).length();
73        Ok(total)
74    }
75
76    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error>;
77    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error>;
78    /// Returns joint rotation axes and axis-origin positions in world frame at
79    /// configuration `q`, plus the end-effector position.
80    fn joint_axes_positions(
81        &self,
82        q: &SRobotQ<N>,
83    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error>;
84
85    /// Geometric Jacobian (6×N) at configuration `q`.
86    /// Rows 0–2: linear velocity, rows 3–5: angular velocity.
87    fn jacobian(&self, q: &SRobotQ<N>) -> Result<[[f32; N]; 6], Self::Error> {
88        let (z, p, p_ee) = self.joint_axes_positions(q)?;
89        let mut j = [[0.0f32; N]; 6];
90        for i in 0..N {
91            let dp = p_ee - p[i];
92            let c = z[i].cross(dp);
93            j[0][i] = c.x;
94            j[1][i] = c.y;
95            j[2][i] = c.z;
96            j[3][i] = z[i].x;
97            j[4][i] = z[i].y;
98            j[5][i] = z[i].z;
99        }
100        Ok(j)
101    }
102
103    /// First time-derivative of the geometric Jacobian.
104    fn jacobian_dot(
105        &self,
106        q: &SRobotQ<N>,
107        qdot: &SRobotQ<N>,
108    ) -> Result<[[f32; N]; 6], Self::Error> {
109        let (z, p, p_ee) = self.joint_axes_positions(q)?;
110
111        let mut omega = Vec3A::ZERO;
112        let mut z_dot = [Vec3A::ZERO; N];
113        let mut p_dot = [Vec3A::ZERO; N];
114        let mut pdot_acc = Vec3A::ZERO;
115
116        for i in 0..N {
117            p_dot[i] = pdot_acc;
118            z_dot[i] = omega.cross(z[i]);
119            omega += qdot.0[i] * z[i];
120            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
121            pdot_acc += omega.cross(next_p - p[i]);
122        }
123        let p_ee_dot = pdot_acc;
124
125        let mut jd = [[0.0f32; N]; 6];
126        for i in 0..N {
127            let dp = p_ee - p[i];
128            let dp_dot = p_ee_dot - p_dot[i];
129            let c1 = z_dot[i].cross(dp);
130            let c2 = z[i].cross(dp_dot);
131            jd[0][i] = c1.x + c2.x;
132            jd[1][i] = c1.y + c2.y;
133            jd[2][i] = c1.z + c2.z;
134            jd[3][i] = z_dot[i].x;
135            jd[4][i] = z_dot[i].y;
136            jd[5][i] = z_dot[i].z;
137        }
138        Ok(jd)
139    }
140
141    /// Second time-derivative of the geometric Jacobian.
142    fn jacobian_ddot(
143        &self,
144        q: &SRobotQ<N>,
145        qdot: &SRobotQ<N>,
146        qddot: &SRobotQ<N>,
147    ) -> Result<[[f32; N]; 6], Self::Error> {
148        let (z, p, p_ee) = self.joint_axes_positions(q)?;
149
150        let mut omega = Vec3A::ZERO;
151        let mut omega_dot = Vec3A::ZERO;
152        let mut z_dot = [Vec3A::ZERO; N];
153        let mut z_ddot = [Vec3A::ZERO; N];
154        let mut p_dot = [Vec3A::ZERO; N];
155        let mut p_ddot = [Vec3A::ZERO; N];
156        let mut pdot_acc = Vec3A::ZERO;
157        let mut pddot_acc = Vec3A::ZERO;
158
159        for i in 0..N {
160            p_dot[i] = pdot_acc;
161            p_ddot[i] = pddot_acc;
162            let zd = omega.cross(z[i]);
163            z_dot[i] = zd;
164            z_ddot[i] = omega_dot.cross(z[i]) + omega.cross(zd);
165            omega_dot += qddot.0[i] * z[i] + qdot.0[i] * zd;
166            omega += qdot.0[i] * z[i];
167            let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
168            let delta = next_p - p[i];
169            let delta_dot = omega.cross(delta);
170            pdot_acc += delta_dot;
171            pddot_acc += omega_dot.cross(delta) + omega.cross(delta_dot);
172        }
173        let p_ee_dot = pdot_acc;
174        let p_ee_ddot = pddot_acc;
175
176        let mut jdd = [[0.0f32; N]; 6];
177        for i in 0..N {
178            let dp = p_ee - p[i];
179            let dp_dot = p_ee_dot - p_dot[i];
180            let dp_ddot = p_ee_ddot - p_ddot[i];
181            let c1 = z_ddot[i].cross(dp);
182            let c2 = z_dot[i].cross(dp_dot);
183            let c3 = z[i].cross(dp_ddot);
184            jdd[0][i] = c1.x + 2.0 * c2.x + c3.x;
185            jdd[1][i] = c1.y + 2.0 * c2.y + c3.y;
186            jdd[2][i] = c1.z + 2.0 * c2.z + c3.z;
187            jdd[3][i] = z_ddot[i].x;
188            jdd[4][i] = z_ddot[i].y;
189            jdd[5][i] = z_ddot[i].z;
190        }
191        Ok(jdd)
192    }
193}
194
195#[inline(always)]
196#[cfg(debug_assertions)]
197fn check_finite<const N: usize>(q: &SRobotQ<N>) -> Result<(), DekeError> {
198    if q.any_non_finite() {
199        return Err(DekeError::JointsNonFinite);
200    }
201    Ok(())
202}
203
204#[inline(always)]
205#[cfg(not(debug_assertions))]
206fn check_finite<const N: usize>(_: &SRobotQ<N>) -> Result<(), std::convert::Infallible> {
207    Ok(())
208}
209
210/// Accumulate a local rotation + translation into the running transform.
211/// Shared by both DH and HP — the only difference is how each convention
212/// builds `local_c0..c2` and `local_t`.
213#[inline(always)]
214fn accumulate(
215    acc_m: &mut Mat3A,
216    acc_t: &mut Vec3A,
217    local_c0: Vec3A,
218    local_c1: Vec3A,
219    local_c2: Vec3A,
220    local_t: Vec3A,
221) {
222    let new_c0 = *acc_m * local_c0;
223    let new_c1 = *acc_m * local_c1;
224    let new_c2 = *acc_m * local_c2;
225    *acc_t = *acc_m * local_t + *acc_t;
226    *acc_m = Mat3A::from_cols(new_c0, new_c1, new_c2);
227}
228
229#[derive(Debug, Clone, Copy)]
230pub struct DHJoint {
231    pub a: f32,
232    pub alpha: f32,
233    pub d: f32,
234    pub theta_offset: f32,
235}
236
237/// Precomputed standard-DH chain with SoA layout.
238///
239/// Convention: `T_i = Rz(θ) · Tz(d) · Tx(a) · Rx(α)`
240#[derive(Debug, Clone)]
241pub struct DHChain<const N: usize> {
242    a: [f32; N],
243    d: [f32; N],
244    sin_alpha: [f32; N],
245    cos_alpha: [f32; N],
246    theta_offset: [f32; N],
247}
248
249impl<const N: usize> DHChain<N> {
250    pub fn new(joints: [DHJoint; N]) -> Self {
251        let mut a = [0.0; N];
252        let mut d = [0.0; N];
253        let mut sin_alpha = [0.0; N];
254        let mut cos_alpha = [0.0; N];
255        let mut theta_offset = [0.0; N];
256
257        let mut i = 0;
258        while i < N {
259            a[i] = joints[i].a;
260            d[i] = joints[i].d;
261            let (sa, ca) = joints[i].alpha.sin_cos();
262            sin_alpha[i] = sa;
263            cos_alpha[i] = ca;
264            theta_offset[i] = joints[i].theta_offset;
265            i += 1;
266        }
267
268        Self {
269            a,
270            d,
271            sin_alpha,
272            cos_alpha,
273            theta_offset,
274        }
275    }
276
277    /// Construct from the row-major `DH_PARAMS` const array emitted by the
278    /// workcell macro.
279    ///
280    /// `params`: `[[f64; N]; 4]` — rows are `(a, alpha, d, theta_offset)`
281    /// across joints.
282    pub const fn from_dh(params: &[[f64; N]; 4]) -> Self {
283        let mut a = [0.0f32; N];
284        let mut d = [0.0f32; N];
285        let mut sin_alpha = [0.0f32; N];
286        let mut cos_alpha = [0.0f32; N];
287        let mut theta_offset = [0.0f32; N];
288
289        let mut i = 0;
290        while i < N {
291            a[i] = params[0][i] as f32;
292            let (sa, ca) = fast_sin_cos(params[1][i] as f32);
293            sin_alpha[i] = sa;
294            cos_alpha[i] = ca;
295            d[i] = params[2][i] as f32;
296            theta_offset[i] = params[3][i] as f32;
297            i += 1;
298        }
299
300        Self {
301            a,
302            d,
303            sin_alpha,
304            cos_alpha,
305            theta_offset,
306        }
307    }
308}
309
310impl<const N: usize> FKChain<N> for DHChain<N> {
311    #[cfg(debug_assertions)]
312    type Error = DekeError;
313    #[cfg(not(debug_assertions))]
314    type Error = std::convert::Infallible;
315
316    /// DH forward kinematics exploiting the structure of `Rz(θ)·Rx(α)`.
317    ///
318    /// The per-joint accumulation decomposes into two 2D column rotations:
319    ///   1. Rotate `(c0, c1)` by θ  →  `(new_c0, perp)`
320    ///   2. Rotate `(perp, c2)` by α  →  `(new_c1, new_c2)`
321    /// Translation reuses `new_c0`:  `t += a·new_c0 + d·old_c2`
322    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
323        check_finite::<N>(q)?;
324        let mut out = [Affine3A::IDENTITY; N];
325        let mut c0 = Vec3A::X;
326        let mut c1 = Vec3A::Y;
327        let mut c2 = Vec3A::Z;
328        let mut t = Vec3A::ZERO;
329
330        let mut i = 0;
331        while i < N {
332            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
333            let sa = self.sin_alpha[i];
334            let ca = self.cos_alpha[i];
335
336            let new_c0 = ct * c0 + st * c1;
337            let perp = ct * c1 - st * c0;
338
339            let new_c1 = ca * perp + sa * c2;
340            let new_c2 = ca * c2 - sa * perp;
341
342            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
343
344            c0 = new_c0;
345            c1 = new_c1;
346            c2 = new_c2;
347
348            out[i] = Affine3A {
349                matrix3: Mat3A::from_cols(c0, c1, c2),
350                translation: t,
351            };
352            i += 1;
353        }
354        Ok(out)
355    }
356
357    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
358        check_finite::<N>(q)?;
359        let mut c0 = Vec3A::X;
360        let mut c1 = Vec3A::Y;
361        let mut c2 = Vec3A::Z;
362        let mut t = Vec3A::ZERO;
363
364        let mut i = 0;
365        while i < N {
366            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
367            let sa = self.sin_alpha[i];
368            let ca = self.cos_alpha[i];
369
370            let new_c0 = ct * c0 + st * c1;
371            let perp = ct * c1 - st * c0;
372
373            let new_c1 = ca * perp + sa * c2;
374            let new_c2 = ca * c2 - sa * perp;
375
376            t = self.a[i] * new_c0 + self.d[i] * c2 + t;
377
378            c0 = new_c0;
379            c1 = new_c1;
380            c2 = new_c2;
381            i += 1;
382        }
383
384        Ok(Affine3A {
385            matrix3: Mat3A::from_cols(c0, c1, c2),
386            translation: t,
387        })
388    }
389
390    fn joint_axes_positions(
391        &self,
392        q: &SRobotQ<N>,
393    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
394        let frames = self.fk(q)?;
395        let mut axes = [Vec3A::Z; N];
396        let mut positions = [Vec3A::ZERO; N];
397
398        for i in 1..N {
399            axes[i] = frames[i - 1].matrix3.z_axis;
400            positions[i] = frames[i - 1].translation;
401        }
402
403        Ok((axes, positions, frames[N - 1].translation))
404    }
405}
406
407#[derive(Debug, Clone, Copy)]
408pub struct HPJoint {
409    pub a: f32,
410    pub alpha: f32,
411    pub beta: f32,
412    pub d: f32,
413    pub theta_offset: f32,
414}
415
416/// Precomputed Hayati-Paul chain with SoA layout.
417///
418/// Convention: `T_i = Rz(θ) · Rx(α) · Ry(β) · Tx(a) · Tz(d)`
419///
420/// HP adds a `β` rotation about Y, which makes it numerically stable for
421/// nearly-parallel consecutive joint axes where standard DH is singular.
422#[derive(Debug, Clone)]
423pub struct HPChain<const N: usize> {
424    a: [f32; N],
425    d: [f32; N],
426    sin_alpha: [f32; N],
427    cos_alpha: [f32; N],
428    sin_beta: [f32; N],
429    cos_beta: [f32; N],
430    theta_offset: [f32; N],
431}
432
433impl<const N: usize> HPChain<N> {
434    pub const fn new(joints: [HPJoint; N]) -> Self {
435        let mut a = [0.0; N];
436        let mut d = [0.0; N];
437        let mut sin_alpha = [0.0; N];
438        let mut cos_alpha = [0.0; N];
439        let mut sin_beta = [0.0; N];
440        let mut cos_beta = [0.0; N];
441        let mut theta_offset = [0.0; N];
442
443        let mut i = 0;
444        while i < N {
445            a[i] = joints[i].a;
446            d[i] = joints[i].d;
447            let (sa, ca) = fast_sin_cos(joints[i].alpha);
448            sin_alpha[i] = sa;
449            cos_alpha[i] = ca;
450            let (sb, cb) = fast_sin_cos(joints[i].beta);
451            sin_beta[i] = sb;
452            cos_beta[i] = cb;
453            theta_offset[i] = joints[i].theta_offset;
454            i += 1;
455        }
456
457        Self {
458            a,
459            d,
460            sin_alpha,
461            cos_alpha,
462            sin_beta,
463            cos_beta,
464            theta_offset,
465        }
466    }
467
468    /// Construct from the row-major `HP_H` and `HP_P` const arrays emitted by
469    /// the workcell macro.
470    ///
471    /// `h`: `[[f64; N]; 3]` — rows are (x, y, z) components across joints.
472    /// `p`: `[[f64; N]; 3]` — rows are (x, y, z) components across points.
473    ///
474    /// Each `h[_][i]` is joint `i`'s axis in the base frame at zero config.
475    /// `p[_][0]` is the base-to-joint-0 offset; `p[_][i]` for `1..N` is the
476    /// offset from joint `i-1`'s origin to joint `i`'s origin. The tool
477    /// offset from joint `N-1` to the flange is not part of this input
478    /// because `HPChain` has no end-effector slot — wrap the result in a
479    /// [`TransformedFK`](crate::TransformedFK) if a tool offset is required.
480    ///
481    /// `theta_offset` is set to zero for every joint: at zero config each
482    /// local x-axis is pinned to `Rx(α) · Ry(β) · [1, 0, 0]` expressed in the
483    /// parent frame.
484    pub const fn from_hp(h: &[[f32; N]; 3], p: &[[f32; N]; 3]) -> Self {
485
486        let mut a = [0.0f32; N];
487        let mut d = [0.0f32; N];
488        let mut sin_alpha = [0.0f32; N];
489        let mut cos_alpha = [0.0f32; N];
490        let mut sin_beta = [0.0f32; N];
491        let mut cos_beta = [0.0f32; N];
492
493        let mut c0 = [1.0f32, 0.0, 0.0];
494        let mut c1 = [0.0f32, 1.0, 0.0];
495        let mut c2 = [0.0f32, 0.0, 1.0];
496
497        const EPS: f32 = 1e-12;
498
499        let mut i = 0;
500        while i < N {
501            let hx = h[0][i];
502            let hy = h[1][i];
503            let hz = h[2][i];
504            let px = p[0][i];
505            let py = p[1][i];
506            let pz = p[2][i];
507
508            let vx = c0[0] * hx + c0[1] * hy + c0[2] * hz;
509            let vy = c1[0] * hx + c1[1] * hy + c1[2] * hz;
510            let vz = c2[0] * hx + c2[1] * hy + c2[2] * hz;
511            let ux = c0[0] * px + c0[1] * py + c0[2] * pz;
512            let uy = c1[0] * px + c1[1] * py + c1[2] * pz;
513            let uz = c2[0] * px + c2[1] * py + c2[2] * pz;
514
515            let sb = vx;
516            let cb = const_sqrt((vy * vy + vz * vz) as f64) as f32;
517
518            let (sa, ca) = if cb > EPS {
519                (-vy / cb, vz / cb)
520            } else {
521                (0.0, 1.0)
522            };
523
524            let big_a = ux;
525            let big_b = sa * uy - ca * uz;
526            let ai = cb * big_a + sb * big_b;
527            let di = sb * big_a - cb * big_b;
528
529            a[i] = ai;
530            d[i] = di;
531            sin_alpha[i] = sa;
532            cos_alpha[i] = ca;
533            sin_beta[i] = sb;
534            cos_beta[i] = cb;
535
536            let sasb = sa * sb;
537            let casb = ca * sb;
538            let sacb = sa * cb;
539            let cacb = ca * cb;
540            let new_c0 = [
541                cb * c0[0] + sasb * c1[0] - casb * c2[0],
542                cb * c0[1] + sasb * c1[1] - casb * c2[1],
543                cb * c0[2] + sasb * c1[2] - casb * c2[2],
544            ];
545            let new_c1 = [
546                ca * c1[0] + sa * c2[0],
547                ca * c1[1] + sa * c2[1],
548                ca * c1[2] + sa * c2[2],
549            ];
550            let new_c2 = [
551                sb * c0[0] - sacb * c1[0] + cacb * c2[0],
552                sb * c0[1] - sacb * c1[1] + cacb * c2[1],
553                sb * c0[2] - sacb * c1[2] + cacb * c2[2],
554            ];
555            c0 = new_c0;
556            c1 = new_c1;
557            c2 = new_c2;
558
559            i += 1;
560        }
561
562        Self {
563            a,
564            d,
565            sin_alpha,
566            cos_alpha,
567            sin_beta,
568            cos_beta,
569            theta_offset: [0.0f32; N],
570        }
571    }
572
573    /// Build the local rotation columns and translation for joint `i`.
574    ///
575    /// R = Rz(θ) · Rx(α) · Ry(β), then t = R · [a, 0, d].
576    ///
577    /// Rx(α)·Ry(β) columns:
578    ///   col0 = [ cβ,       sα·sβ,     -cα·sβ     ]
579    ///   col1 = [ 0,        cα,          sα        ]
580    ///   col2 = [ sβ,      -sα·cβ,      cα·cβ     ]
581    ///
582    /// Then Rz(θ) rotates each column: [cθ·x - sθ·y, sθ·x + cθ·y, z]
583    ///
584    /// Translation = a·col0 + d·col2  (since R·[a,0,d] = a·col0 + d·col2).
585    #[inline(always)]
586    fn local_frame(&self, i: usize, st: f32, ct: f32) -> (Vec3A, Vec3A, Vec3A, Vec3A) {
587        let sa = self.sin_alpha[i];
588        let ca = self.cos_alpha[i];
589        let sb = self.sin_beta[i];
590        let cb = self.cos_beta[i];
591        let ai = self.a[i];
592        let di = self.d[i];
593
594        let sa_sb = sa * sb;
595        let sa_cb = sa * cb;
596        let ca_sb = ca * sb;
597        let ca_cb = ca * cb;
598
599        let c0 = Vec3A::new(ct * cb - st * sa_sb, st * cb + ct * sa_sb, -ca_sb);
600        let c1 = Vec3A::new(-st * ca, ct * ca, sa);
601        let c2 = Vec3A::new(ct * sb + st * sa_cb, st * sb - ct * sa_cb, ca_cb);
602        let t = Vec3A::new(
603            ai * c0.x + di * c2.x,
604            ai * c0.y + di * c2.y,
605            ai * c0.z + di * c2.z,
606        );
607
608        (c0, c1, c2, t)
609    }
610}
611
612impl<const N: usize> FKChain<N> for HPChain<N> {
613    #[cfg(debug_assertions)]
614    type Error = DekeError;
615    #[cfg(not(debug_assertions))]
616    type Error = std::convert::Infallible;
617
618    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
619        check_finite::<N>(q)?;
620        let mut out = [Affine3A::IDENTITY; N];
621        let mut acc_m = Mat3A::IDENTITY;
622        let mut acc_t = Vec3A::ZERO;
623
624        let mut i = 0;
625        while i < N {
626            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
627            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
628            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
629
630            out[i] = Affine3A {
631                matrix3: acc_m,
632                translation: acc_t,
633            };
634            i += 1;
635        }
636        Ok(out)
637    }
638
639    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
640        check_finite(q)?;
641        let mut acc_m = Mat3A::IDENTITY;
642        let mut acc_t = Vec3A::ZERO;
643
644        let mut i = 0;
645        while i < N {
646            let (st, ct) = fast_sin_cos(q.0[i] + self.theta_offset[i]);
647            let (c0, c1, c2, t) = self.local_frame(i, st, ct);
648            accumulate(&mut acc_m, &mut acc_t, c0, c1, c2, t);
649            i += 1;
650        }
651
652        Ok(Affine3A {
653            matrix3: acc_m,
654            translation: acc_t,
655        })
656    }
657
658    fn joint_axes_positions(
659        &self,
660        q: &SRobotQ<N>,
661    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
662        let frames = self.fk(q)?;
663        let mut axes = [Vec3A::Z; N];
664        let mut positions = [Vec3A::ZERO; N];
665
666        for i in 1..N {
667            axes[i] = frames[i - 1].matrix3.z_axis;
668            positions[i] = frames[i - 1].translation;
669        }
670
671        Ok((axes, positions, frames[N - 1].translation))
672    }
673}
674
675#[derive(Debug, Clone, Copy)]
676pub struct URDFJoint {
677    pub origin_xyz: [f64; 3],
678    pub origin_rpy: [f64; 3],
679    pub axis: [f64; 3],
680}
681
682/// Precomputed per-joint axis type for column-rotation FK.
683#[derive(Debug, Clone, Copy)]
684enum JointAxis {
685    Z,
686    Y(f32),
687    X(f32),
688}
689
690/// FK chain using exact URDF joint transforms.
691///
692/// Accumulation works directly on columns:
693///   1. Translation: `t += fx·c0 + fy·c1 + fz·c2`
694///   2. Fixed rotation: `(c0,c1,c2) = (c0,c1,c2) * fixed_rot`
695///   3. Joint rotation: 2D rotation on the appropriate column pair
696///
697/// When `fixed_rot` is identity (RPY = 0, the common case), step 2 is
698/// skipped entirely, making per-joint cost a single 2D column rotation
699/// plus translation — cheaper than DH.
700#[derive(Debug, Clone)]
701pub struct URDFChain<const N: usize> {
702    fr_c0: [Vec3A; N],
703    fr_c1: [Vec3A; N],
704    fr_c2: [Vec3A; N],
705    fr_identity: [bool; N],
706    fixed_trans: [Vec3A; N],
707    axis: [JointAxis; N],
708}
709
710impl<const N: usize> URDFChain<N> {
711    pub fn new(joints: [URDFJoint; N]) -> Self {
712        let mut fr_c0 = [Vec3A::X; N];
713        let mut fr_c1 = [Vec3A::Y; N];
714        let mut fr_c2 = [Vec3A::Z; N];
715        let mut fr_identity = [true; N];
716        let mut fixed_trans = [Vec3A::ZERO; N];
717        let mut axis = [JointAxis::Z; N];
718
719        for i in 0..N {
720            let [ox, oy, oz] = joints[i].origin_xyz;
721            let [roll, pitch, yaw] = joints[i].origin_rpy;
722
723            let is_identity = roll.abs() < 1e-10 && pitch.abs() < 1e-10 && yaw.abs() < 1e-10;
724            fr_identity[i] = is_identity;
725
726            if !is_identity {
727                let (sr, cr) = roll.sin_cos();
728                let (sp, cp) = pitch.sin_cos();
729                let (sy, cy) = yaw.sin_cos();
730                fr_c0[i] = Vec3A::new((cy * cp) as f32, (sy * cp) as f32, (-sp) as f32);
731                fr_c1[i] = Vec3A::new(
732                    (cy * sp * sr - sy * cr) as f32,
733                    (sy * sp * sr + cy * cr) as f32,
734                    (cp * sr) as f32,
735                );
736                fr_c2[i] = Vec3A::new(
737                    (cy * sp * cr + sy * sr) as f32,
738                    (sy * sp * cr - cy * sr) as f32,
739                    (cp * cr) as f32,
740                );
741            }
742
743            fixed_trans[i] = Vec3A::new(ox as f32, oy as f32, oz as f32);
744
745            let [ax, ay, az] = joints[i].axis;
746            if az.abs() > 0.5 {
747                axis[i] = JointAxis::Z;
748            } else if ay.abs() > 0.5 {
749                axis[i] = JointAxis::Y(ay.signum() as f32);
750            } else {
751                axis[i] = JointAxis::X(ax.signum() as f32);
752            }
753        }
754
755        Self {
756            fr_c0,
757            fr_c1,
758            fr_c2,
759            fr_identity,
760            fixed_trans,
761            axis,
762        }
763    }
764
765    /// Apply fixed rotation + joint rotation to accumulator columns.
766    #[inline(always)]
767    fn accumulate_joint(
768        &self,
769        i: usize,
770        st: f32,
771        ct: f32,
772        c0: &mut Vec3A,
773        c1: &mut Vec3A,
774        c2: &mut Vec3A,
775        t: &mut Vec3A,
776    ) {
777        let ft = self.fixed_trans[i];
778        *t = ft.x * *c0 + ft.y * *c1 + ft.z * *c2 + *t;
779
780        let (f0, f1, f2) = if self.fr_identity[i] {
781            (*c0, *c1, *c2)
782        } else {
783            let fc0 = self.fr_c0[i];
784            let fc1 = self.fr_c1[i];
785            let fc2 = self.fr_c2[i];
786            (
787                fc0.x * *c0 + fc0.y * *c1 + fc0.z * *c2,
788                fc1.x * *c0 + fc1.y * *c1 + fc1.z * *c2,
789                fc2.x * *c0 + fc2.y * *c1 + fc2.z * *c2,
790            )
791        };
792
793        match self.axis[i] {
794            JointAxis::Z => {
795                let new_c0 = ct * f0 + st * f1;
796                let new_c1 = ct * f1 - st * f0;
797                *c0 = new_c0;
798                *c1 = new_c1;
799                *c2 = f2;
800            }
801            JointAxis::Y(s) => {
802                let sst = s * st;
803                let new_c0 = ct * f0 - sst * f2;
804                let new_c2 = sst * f0 + ct * f2;
805                *c0 = new_c0;
806                *c1 = f1;
807                *c2 = new_c2;
808            }
809            JointAxis::X(s) => {
810                let sst = s * st;
811                let new_c1 = ct * f1 + sst * f2;
812                let new_c2 = ct * f2 - sst * f1;
813                *c0 = f0;
814                *c1 = new_c1;
815                *c2 = new_c2;
816            }
817        }
818    }
819}
820
821impl<const N: usize> FKChain<N> for URDFChain<N> {
822    #[cfg(debug_assertions)]
823    type Error = DekeError;
824    #[cfg(not(debug_assertions))]
825    type Error = std::convert::Infallible;
826
827    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
828        check_finite(q)?;
829        let mut out = [Affine3A::IDENTITY; N];
830        let mut c0 = Vec3A::X;
831        let mut c1 = Vec3A::Y;
832        let mut c2 = Vec3A::Z;
833        let mut t = Vec3A::ZERO;
834
835        let mut i = 0;
836        while i < N {
837            let (st, ct) = fast_sin_cos(q.0[i]);
838            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
839
840            out[i] = Affine3A {
841                matrix3: Mat3A::from_cols(c0, c1, c2),
842                translation: t,
843            };
844            i += 1;
845        }
846        Ok(out)
847    }
848
849    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
850        check_finite(q)?;
851        let mut c0 = Vec3A::X;
852        let mut c1 = Vec3A::Y;
853        let mut c2 = Vec3A::Z;
854        let mut t = Vec3A::ZERO;
855
856        let mut i = 0;
857        while i < N {
858            let (st, ct) = fast_sin_cos(q.0[i]);
859            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
860            i += 1;
861        }
862
863        Ok(Affine3A {
864            matrix3: Mat3A::from_cols(c0, c1, c2),
865            translation: t,
866        })
867    }
868
869    fn joint_axes_positions(
870        &self,
871        q: &SRobotQ<N>,
872    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
873        let frames = self.fk(q)?;
874        let mut axes = [Vec3A::ZERO; N];
875        let mut positions = [Vec3A::ZERO; N];
876
877        for i in 0..N {
878            axes[i] = match self.axis[i] {
879                JointAxis::Z => frames[i].matrix3.z_axis,
880                JointAxis::Y(s) => s * frames[i].matrix3.y_axis,
881                JointAxis::X(s) => s * frames[i].matrix3.x_axis,
882            };
883            positions[i] = frames[i].translation;
884        }
885
886        Ok((axes, positions, frames[N - 1].translation))
887    }
888}
889
890/// Wraps an `FKChain` with an optional prefix (base) and/or suffix (tool) transform.
891///
892/// - `fk` applies only the prefix — intermediate frames stay in world coordinates
893///   without the tool offset.
894/// - `fk_end` and `joint_axes_positions` apply both — the end-effector includes
895///   the tool tip.
896#[derive(Debug, Clone)]
897pub struct TransformedFK<const N: usize, FK: FKChain<N>> {
898    inner: FK,
899    prefix: Option<Affine3A>,
900    suffix: Option<Affine3A>,
901}
902
903impl<const N: usize, FK: FKChain<N>> TransformedFK<N, FK> {
904    pub fn new(inner: FK) -> Self {
905        Self {
906            inner,
907            prefix: None,
908            suffix: None,
909        }
910    }
911
912    pub fn with_prefix(mut self, prefix: Affine3A) -> Self {
913        self.prefix = Some(prefix);
914        self
915    }
916
917    pub fn with_suffix(mut self, suffix: Affine3A) -> Self {
918        self.suffix = Some(suffix);
919        self
920    }
921
922    pub fn set_prefix(&mut self, prefix: Option<Affine3A>) {
923        self.prefix = prefix;
924    }
925
926    pub fn set_suffix(&mut self, suffix: Option<Affine3A>) {
927        self.suffix = suffix;
928    }
929
930    pub fn prefix(&self) -> Option<&Affine3A> {
931        self.prefix.as_ref()
932    }
933
934    pub fn suffix(&self) -> Option<&Affine3A> {
935        self.suffix.as_ref()
936    }
937
938    pub fn inner(&self) -> &FK {
939        &self.inner
940    }
941}
942
943impl<const N: usize, FK: FKChain<N>> FKChain<N> for TransformedFK<N, FK> {
944    type Error = FK::Error;
945
946    fn max_reach(&self) -> Result<f32, Self::Error> {
947        let mut reach = self.inner.max_reach()?;
948        if let Some(suf) = &self.suffix {
949            reach += Vec3A::from(suf.translation).length();
950        }
951        Ok(reach)
952    }
953
954    fn fk(&self, q: &SRobotQ<N>) -> Result<[Affine3A; N], Self::Error> {
955        let mut frames = self.inner.fk(q)?;
956        if let Some(pre) = &self.prefix {
957            for f in &mut frames {
958                *f = *pre * *f;
959            }
960        }
961        Ok(frames)
962    }
963
964    fn fk_end(&self, q: &SRobotQ<N>) -> Result<Affine3A, Self::Error> {
965        let mut end = self.inner.fk_end(q)?;
966        if let Some(pre) = &self.prefix {
967            end = *pre * end;
968        }
969        if let Some(suf) = &self.suffix {
970            end = end * *suf;
971        }
972        Ok(end)
973    }
974
975    fn joint_axes_positions(
976        &self,
977        q: &SRobotQ<N>,
978    ) -> Result<([Vec3A; N], [Vec3A; N], Vec3A), Self::Error> {
979        let (mut axes, mut positions, inner_p_ee) = self.inner.joint_axes_positions(q)?;
980
981        if let Some(pre) = &self.prefix {
982            let rot = pre.matrix3;
983            let t = Vec3A::from(pre.translation);
984            for i in 0..N {
985                axes[i] = rot * axes[i];
986                positions[i] = rot * positions[i] + t;
987            }
988        }
989
990        let p_ee = if self.prefix.is_some() || self.suffix.is_some() {
991            self.fk_end(q)?.translation
992        } else {
993            inner_p_ee
994        };
995
996        Ok((axes, positions, p_ee))
997    }
998}
999
1000/// Wraps an `FKChain<N>` and prepends a prismatic (linear) joint, producing
1001/// an `FKChain<M>` where `M = N + 1`.
1002///
1003/// The prismatic joint always acts first in the kinematic chain — it
1004/// translates the entire arm along `axis` (world frame).  The
1005/// `q_index_first` flag only controls where the prismatic value is read
1006/// from in `SRobotQ<M>`: when `true` it is `q[0]`, when `false` it is
1007/// `q[M-1]`.
1008///
1009/// Jacobian columns for the prismatic joint are `[axis; 0]` (pure linear,
1010/// no angular contribution).  Because the prismatic uniformly shifts all
1011/// positions, the revolute Jacobian columns are identical to the inner
1012/// chain's.
1013#[derive(Debug, Clone)]
1014pub struct PrismaticFK<const M: usize, const N: usize, FK: FKChain<N>> {
1015    inner: FK,
1016    axis: Vec3A,
1017    q_index_first: bool,
1018}
1019
1020impl<const M: usize, const N: usize, FK: FKChain<N>> PrismaticFK<M, N, FK> {
1021    pub fn new(inner: FK, axis: Vec3A, q_index_first: bool) -> Self {
1022        const { assert!(M == N + 1, "M must equal N + 1") };
1023        Self {
1024            inner,
1025            axis,
1026            q_index_first,
1027        }
1028    }
1029
1030    pub fn inner(&self) -> &FK {
1031        &self.inner
1032    }
1033
1034    pub fn axis(&self) -> Vec3A {
1035        self.axis
1036    }
1037
1038    pub fn q_index_first(&self) -> bool {
1039        self.q_index_first
1040    }
1041
1042    fn split_q(&self, q: &SRobotQ<M>) -> (f32, SRobotQ<N>) {
1043        let mut inner = [0.0f32; N];
1044        if self.q_index_first {
1045            inner.copy_from_slice(&q.0[1..M]);
1046            (q.0[0], SRobotQ(inner))
1047        } else {
1048            inner.copy_from_slice(&q.0[..N]);
1049            (q.0[M - 1], SRobotQ(inner))
1050        }
1051    }
1052
1053    fn prismatic_col(&self) -> usize {
1054        if self.q_index_first { 0 } else { N }
1055    }
1056
1057    fn revolute_offset(&self) -> usize {
1058        if self.q_index_first { 1 } else { 0 }
1059    }
1060}
1061
1062impl<const M: usize, const N: usize, FK: FKChain<N>> FKChain<M> for PrismaticFK<M, N, FK> {
1063    type Error = FK::Error;
1064
1065    fn fk(&self, q: &SRobotQ<M>) -> Result<[Affine3A; M], Self::Error> {
1066        let (q_p, inner_q) = self.split_q(q);
1067        let offset = q_p * self.axis;
1068        let inner_frames = self.inner.fk(&inner_q)?;
1069        let mut out = [Affine3A::IDENTITY; M];
1070
1071        out[0] = Affine3A::from_translation(offset.into());
1072        for i in 0..N {
1073            let mut f = inner_frames[i];
1074            f.translation += offset;
1075            out[i + 1] = f;
1076        }
1077
1078        Ok(out)
1079    }
1080
1081    fn fk_end(&self, q: &SRobotQ<M>) -> Result<Affine3A, Self::Error> {
1082        let (q_p, inner_q) = self.split_q(q);
1083        let mut end = self.inner.fk_end(&inner_q)?;
1084        end.translation += q_p * self.axis;
1085        Ok(end)
1086    }
1087
1088    fn joint_axes_positions(
1089        &self,
1090        q: &SRobotQ<M>,
1091    ) -> Result<([Vec3A; M], [Vec3A; M], Vec3A), Self::Error> {
1092        let (q_p, inner_q) = self.split_q(q);
1093        let offset = q_p * self.axis;
1094        let (inner_axes, inner_pos, inner_p_ee) = self.inner.joint_axes_positions(&inner_q)?;
1095
1096        let mut axes = [Vec3A::ZERO; M];
1097        let mut positions = [Vec3A::ZERO; M];
1098
1099        axes[0] = self.axis;
1100        for i in 0..N {
1101            axes[i + 1] = inner_axes[i];
1102            positions[i + 1] = inner_pos[i] + offset;
1103        }
1104
1105        Ok((axes, positions, inner_p_ee + offset))
1106    }
1107
1108    fn jacobian(&self, q: &SRobotQ<M>) -> Result<[[f32; M]; 6], Self::Error> {
1109        let (_q_p, inner_q) = self.split_q(q);
1110        let inner_j = self.inner.jacobian(&inner_q)?;
1111        let p_col = self.prismatic_col();
1112        let r_off = self.revolute_offset();
1113
1114        let mut j = [[0.0f32; M]; 6];
1115        j[0][p_col] = self.axis.x;
1116        j[1][p_col] = self.axis.y;
1117        j[2][p_col] = self.axis.z;
1118
1119        for row in 0..6 {
1120            for col in 0..N {
1121                j[row][col + r_off] = inner_j[row][col];
1122            }
1123        }
1124
1125        Ok(j)
1126    }
1127
1128    fn jacobian_dot(
1129        &self,
1130        q: &SRobotQ<M>,
1131        qdot: &SRobotQ<M>,
1132    ) -> Result<[[f32; M]; 6], Self::Error> {
1133        let (_q_p, inner_q) = self.split_q(q);
1134        let (_qdot_p, inner_qdot) = self.split_q(qdot);
1135        let inner_jd = self.inner.jacobian_dot(&inner_q, &inner_qdot)?;
1136        let r_off = self.revolute_offset();
1137
1138        let mut jd = [[0.0f32; M]; 6];
1139        for row in 0..6 {
1140            for col in 0..N {
1141                jd[row][col + r_off] = inner_jd[row][col];
1142            }
1143        }
1144
1145        Ok(jd)
1146    }
1147
1148    fn jacobian_ddot(
1149        &self,
1150        q: &SRobotQ<M>,
1151        qdot: &SRobotQ<M>,
1152        qddot: &SRobotQ<M>,
1153    ) -> Result<[[f32; M]; 6], Self::Error> {
1154        let (_q_p, inner_q) = self.split_q(q);
1155        let (_qdot_p, inner_qdot) = self.split_q(qdot);
1156        let (_qddot_p, inner_qddot) = self.split_q(qddot);
1157        let inner_jdd = self
1158            .inner
1159            .jacobian_ddot(&inner_q, &inner_qdot, &inner_qddot)?;
1160        let r_off = self.revolute_offset();
1161
1162        let mut jdd = [[0.0f32; M]; 6];
1163        for row in 0..6 {
1164            for col in 0..N {
1165                jdd[row][col + r_off] = inner_jdd[row][col];
1166            }
1167        }
1168
1169        Ok(jdd)
1170    }
1171}