Skip to main content

deke_types/fk/
urdf.rs

1use glam_traits_ext::{FloatAffine, FloatVec, TAffine3, TMat3, TVec3};
2
3use crate::{DekeError, SRobotQ};
4
5use super::{
6    AAffine3, AMat3, AVec3, AffineRaw, AffineRaw64, FKChain, FKScalar, check_finite, const_sin_cos,
7    const_sin_cos_f64,
8};
9
10/// Kind of URDF joint. Fixed joints have no motion; revolute and prismatic
11/// joints move along `axis` (expressed in the joint's own frame, as per the
12/// URDF spec).
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum URDFJointType {
15    Fixed,
16    Revolute { axis: (f64, f64, f64) },
17    Prismatic { axis: (f64, f64, f64) },
18}
19
20/// A URDF joint: its type plus the `<origin>` transform (xyz translation and
21/// rpy Euler rotation) from the parent link's frame to the joint's own frame.
22#[derive(Debug, Clone, Copy)]
23pub struct URDFJoint {
24    pub r#type: URDFJointType,
25    pub xyz: (f64, f64, f64),
26    pub rpy: (f64, f64, f64),
27}
28
29impl URDFJoint {
30    pub const fn fixed(xyz: (f64, f64, f64), rpy: (f64, f64, f64)) -> Self {
31        Self {
32            r#type: URDFJointType::Fixed,
33            xyz,
34            rpy,
35        }
36    }
37
38    pub const fn revolute(
39        xyz: (f64, f64, f64),
40        rpy: (f64, f64, f64),
41        axis: (f64, f64, f64),
42    ) -> Self {
43        Self {
44            r#type: URDFJointType::Revolute { axis },
45            xyz,
46            rpy,
47        }
48    }
49
50    pub const fn prismatic(
51        xyz: (f64, f64, f64),
52        rpy: (f64, f64, f64),
53        axis: (f64, f64, f64),
54    ) -> Self {
55        Self {
56            r#type: URDFJointType::Prismatic { axis },
57            xyz,
58            rpy,
59        }
60    }
61
62    /// Build the `Affine3A` corresponding to this joint's `<origin>`, using
63    /// the URDF RPY convention `R = Rz(yaw) · Ry(pitch) · Rx(roll)`.
64    pub const fn origin_affine(&self) -> glam::Affine3A {
65        AffineRaw::from_xyz_rpy(self.xyz, self.rpy).to_affine3a()
66    }
67}
68
69/// Const-friendly error type for the `URDFChain` / `URDFJoint` const
70/// constructors. Trivially `Copy`, so values can be matched/returned inside
71/// `const fn`s (unlike [`DekeError`], whose `RetimerFailed(String)` variant
72/// carries a non-const destructor). Converts into [`DekeError`] via `From`.
73#[derive(Debug, Clone, Copy, PartialEq, thiserror::Error)]
74pub enum URDFBuildError {
75    #[error(
76        "URDF joint at index {index} has an unexpected type: expected {expected}, found {found}"
77    )]
78    JointTypeMismatch {
79        index: usize,
80        expected: &'static str,
81        found: &'static str,
82    },
83    #[error("URDFChain<{expected}> requires {expected} revolute joints, found {found}")]
84    RevoluteCountMismatch { expected: usize, found: usize },
85}
86
87impl From<URDFBuildError> for DekeError {
88    fn from(e: URDFBuildError) -> Self {
89        match e {
90            URDFBuildError::JointTypeMismatch {
91                index,
92                expected,
93                found,
94            } => DekeError::URDFJointTypeMismatch {
95                index,
96                expected,
97                found,
98            },
99            URDFBuildError::RevoluteCountMismatch { expected, found } => {
100                DekeError::URDFRevoluteCountMismatch { expected, found }
101            }
102        }
103    }
104}
105
106const fn joint_kind_name(k: URDFJointType) -> &'static str {
107    match k {
108        URDFJointType::Fixed => "Fixed",
109        URDFJointType::Revolute { .. } => "Revolute",
110        URDFJointType::Prismatic { .. } => "Prismatic",
111    }
112}
113
114const fn compose_fixed_joints_raw(joints: &[URDFJoint]) -> Result<AffineRaw, URDFBuildError> {
115    let mut acc = AffineRaw::IDENTITY;
116    let n = joints.len();
117    let mut i = 0;
118    while i < n {
119        let j = &joints[i];
120        if !matches!(j.r#type, URDFJointType::Fixed) {
121            return Err(URDFBuildError::JointTypeMismatch {
122                index: i,
123                expected: "Fixed",
124                found: joint_kind_name(j.r#type),
125            });
126        }
127        acc = acc.mul(AffineRaw::from_xyz_rpy(j.xyz, j.rpy));
128        i += 1;
129    }
130    Ok(acc)
131}
132
133/// Compose the `<origin>` transforms of a sequence of fixed joints
134/// (parent→child order) into a single `Affine3A`. Returns an error if any
135/// joint in `joints` is not `Fixed`.
136pub const fn compose_fixed_joints(joints: &[URDFJoint]) -> Result<glam::Affine3A, URDFBuildError> {
137    match compose_fixed_joints_raw(joints) {
138        Ok(a) => Ok(a.to_affine3a()),
139        Err(e) => Err(e),
140    }
141}
142
143const fn compose_fixed_joints_raw_f64(
144    joints: &[URDFJoint],
145) -> Result<AffineRaw64, URDFBuildError> {
146    let mut acc = AffineRaw64::IDENTITY;
147    let n = joints.len();
148    let mut i = 0;
149    while i < n {
150        let j = &joints[i];
151        if !matches!(j.r#type, URDFJointType::Fixed) {
152            return Err(URDFBuildError::JointTypeMismatch {
153                index: i,
154                expected: "Fixed",
155                found: joint_kind_name(j.r#type),
156            });
157        }
158        acc = acc.mul(AffineRaw64::from_xyz_rpy(j.xyz, j.rpy));
159        i += 1;
160    }
161    Ok(acc)
162}
163
164/// `f64` analogue of [`compose_fixed_joints`], producing a `DAffine3`.
165pub const fn compose_fixed_joints_f64(
166    joints: &[URDFJoint],
167) -> Result<glam::DAffine3, URDFBuildError> {
168    match compose_fixed_joints_raw_f64(joints) {
169        Ok(a) => Ok(a.to_daffine3()),
170        Err(e) => Err(e),
171    }
172}
173
174/// Precomputed per-joint axis type for column-rotation FK.
175#[derive(Debug, Clone, Copy)]
176enum JointAxis<F: FKScalar> {
177    Z,
178    Y(F),
179    X(F),
180}
181
182/// FK chain using exact URDF joint transforms.
183///
184/// Accumulation works directly on columns:
185///   1. Translation: `t += fx·c0 + fy·c1 + fz·c2`
186///   2. Fixed rotation: `(c0,c1,c2) = (c0,c1,c2) * fixed_rot`
187///   3. Joint rotation: 2D rotation on the appropriate column pair
188///
189/// When `fixed_rot` is identity (RPY = 0, the common case), step 2 is
190/// skipped entirely, making per-joint cost a single 2D column rotation
191/// plus translation — cheaper than DH.
192#[derive(Debug, Clone)]
193pub struct URDFChain<const N: usize, F: FKScalar = f32> {
194    fr_c0: [AVec3<F>; N],
195    fr_c1: [AVec3<F>; N],
196    fr_c2: [AVec3<F>; N],
197    fr_identity: [bool; N],
198    fixed_trans: [AVec3<F>; N],
199    axis: [JointAxis<F>; N],
200    prefix_c0: AVec3<F>,
201    prefix_c1: AVec3<F>,
202    prefix_c2: AVec3<F>,
203    prefix_t: AVec3<F>,
204    prefix_identity: bool,
205    suffix_c0: AVec3<F>,
206    suffix_c1: AVec3<F>,
207    suffix_c2: AVec3<F>,
208    suffix_t: AVec3<F>,
209    suffix_identity: bool,
210}
211
212impl<const N: usize> URDFChain<N, f32> {
213    /// Build a chain from exactly `N` actuated (revolute) joints. Returns
214    /// [`URDFBuildError::JointTypeMismatch`] if any entry is `Fixed` or
215    /// `Prismatic`. For a slice that mixes fixed joints in, use
216    /// [`URDFChain::from_urdf`] instead.
217    pub const fn new(joints: [URDFJoint; N]) -> Result<Self, URDFBuildError> {
218        let mut fr_c0 = [glam::Vec3A::X; N];
219        let mut fr_c1 = [glam::Vec3A::Y; N];
220        let mut fr_c2 = [glam::Vec3A::Z; N];
221        let mut fr_identity = [true; N];
222        let mut fixed_trans = [glam::Vec3A::ZERO; N];
223        let mut axis = [JointAxis::Z; N];
224
225        let mut i = 0;
226        while i < N {
227            let (ox, oy, oz) = joints[i].xyz;
228            let (roll, pitch, yaw) = joints[i].rpy;
229
230            let is_identity = roll.abs() < 1e-10 && pitch.abs() < 1e-10 && yaw.abs() < 1e-10;
231            fr_identity[i] = is_identity;
232
233            if !is_identity {
234                let (sr, cr) = const_sin_cos(roll as f32);
235                let (sp, cp) = const_sin_cos(pitch as f32);
236                let (sy, cy) = const_sin_cos(yaw as f32);
237                fr_c0[i] = glam::Vec3A::new(cy * cp, sy * cp, -sp);
238                fr_c1[i] = glam::Vec3A::new(cy * sp * sr - sy * cr, sy * sp * sr + cy * cr, cp * sr);
239                fr_c2[i] = glam::Vec3A::new(cy * sp * cr + sy * sr, sy * sp * cr - cy * sr, cp * cr);
240            }
241
242            fixed_trans[i] = glam::Vec3A::new(ox as f32, oy as f32, oz as f32);
243
244            let (ax, ay, az) = match joints[i].r#type {
245                URDFJointType::Revolute { axis } => axis,
246                _ => {
247                    return Err(URDFBuildError::JointTypeMismatch {
248                        index: i,
249                        expected: "Revolute",
250                        found: joint_kind_name(joints[i].r#type),
251                    });
252                }
253            };
254            if az.abs() > 0.5 {
255                axis[i] = JointAxis::Z;
256            } else if ay.abs() > 0.5 {
257                axis[i] = JointAxis::Y(ay.signum() as f32);
258            } else {
259                axis[i] = JointAxis::X(ax.signum() as f32);
260            }
261            i += 1;
262        }
263
264        Ok(Self {
265            fr_c0,
266            fr_c1,
267            fr_c2,
268            fr_identity,
269            fixed_trans,
270            axis,
271            prefix_c0: glam::Vec3A::X,
272            prefix_c1: glam::Vec3A::Y,
273            prefix_c2: glam::Vec3A::Z,
274            prefix_t: glam::Vec3A::ZERO,
275            prefix_identity: true,
276            suffix_c0: glam::Vec3A::X,
277            suffix_c1: glam::Vec3A::Y,
278            suffix_c2: glam::Vec3A::Z,
279            suffix_t: glam::Vec3A::ZERO,
280            suffix_identity: true,
281        })
282    }
283
284    /// Build a chain from a flat URDF joint list (any mix of `Fixed`,
285    /// `Revolute`, and/or `Prismatic`). The list must describe a single
286    /// branch in parent→child order.
287    ///
288    /// - Leading `Fixed` joints become the prefix (applied before joint 0).
289    /// - Trailing `Fixed` joints become the suffix (applied after the last
290    ///   actuated joint).
291    /// - `Fixed` joints sandwiched between actuated joints are folded into
292    ///   the origin of the next actuated joint so the kinematics are
293    ///   preserved exactly.
294    /// - The number of `Revolute` joints must equal `N`.
295    ///
296    /// Returns [`URDFBuildError::JointTypeMismatch`] if a `Prismatic` joint
297    /// appears (not handled by `URDFChain` itself — wrap the result in
298    /// [`PrismaticFK`](crate::PrismaticFK) for a prismatic joint at the start
299    /// or end), or [`URDFBuildError::RevoluteCountMismatch`] if the revolute
300    /// count doesn't match `N`.
301    pub const fn from_urdf(joints: &[URDFJoint]) -> Result<Self, URDFBuildError> {
302        let mut fr_c0 = [glam::Vec3A::X; N];
303        let mut fr_c1 = [glam::Vec3A::Y; N];
304        let mut fr_c2 = [glam::Vec3A::Z; N];
305        let mut fr_identity = [true; N];
306        let mut fixed_trans = [glam::Vec3A::ZERO; N];
307        let mut axis_out = [JointAxis::Z; N];
308
309        let mut pending = AffineRaw::IDENTITY;
310        let mut prefix = AffineRaw::IDENTITY;
311        let mut prefix_set = false;
312        let mut r_count = 0usize;
313
314        let n = joints.len();
315        let mut i = 0;
316        while i < n {
317            let joint = &joints[i];
318            match joint.r#type {
319                URDFJointType::Fixed => {
320                    pending = pending.mul(AffineRaw::from_xyz_rpy(joint.xyz, joint.rpy));
321                }
322                URDFJointType::Revolute { axis } => {
323                    if r_count >= N {
324                        return Err(URDFBuildError::RevoluteCountMismatch {
325                            expected: N,
326                            found: r_count + 1,
327                        });
328                    }
329                    let local = AffineRaw::from_xyz_rpy(joint.xyz, joint.rpy);
330                    let effective = if !prefix_set {
331                        prefix = pending;
332                        prefix_set = true;
333                        local
334                    } else {
335                        pending.mul(local)
336                    };
337
338                    fr_identity[r_count] = effective.is_identity();
339                    fr_c0[r_count] = effective.c0_vec3a();
340                    fr_c1[r_count] = effective.c1_vec3a();
341                    fr_c2[r_count] = effective.c2_vec3a();
342                    fixed_trans[r_count] = effective.t_vec3a();
343
344                    let (ax, ay, az) = axis;
345                    axis_out[r_count] = if az.abs() > 0.5 {
346                        JointAxis::Z
347                    } else if ay.abs() > 0.5 {
348                        JointAxis::Y(ay.signum() as f32)
349                    } else {
350                        JointAxis::X(ax.signum() as f32)
351                    };
352
353                    pending = AffineRaw::IDENTITY;
354                    r_count += 1;
355                }
356                URDFJointType::Prismatic { .. } => {
357                    return Err(URDFBuildError::JointTypeMismatch {
358                        index: i,
359                        expected: "Fixed or Revolute",
360                        found: "Prismatic",
361                    });
362                }
363            }
364            i += 1;
365        }
366        if r_count != N {
367            return Err(URDFBuildError::RevoluteCountMismatch {
368                expected: N,
369                found: r_count,
370            });
371        }
372
373        let prefix_identity = !prefix_set || prefix.is_identity();
374        let suffix_identity = pending.is_identity();
375
376        Ok(Self {
377            fr_c0,
378            fr_c1,
379            fr_c2,
380            fr_identity,
381            fixed_trans,
382            axis: axis_out,
383            prefix_c0: prefix.c0_vec3a(),
384            prefix_c1: prefix.c1_vec3a(),
385            prefix_c2: prefix.c2_vec3a(),
386            prefix_t: prefix.t_vec3a(),
387            prefix_identity,
388            suffix_c0: pending.c0_vec3a(),
389            suffix_c1: pending.c1_vec3a(),
390            suffix_c2: pending.c2_vec3a(),
391            suffix_t: pending.t_vec3a(),
392            suffix_identity,
393        })
394    }
395
396    /// Bake a sequence of URDF fixed-joint origins (parent→child order) into
397    /// the base side of the chain.
398    pub const fn with_fixed_prefix(
399        mut self,
400        joints: &[URDFJoint],
401    ) -> Result<Self, URDFBuildError> {
402        if joints.is_empty() {
403            self.prefix_c0 = glam::Vec3A::X;
404            self.prefix_c1 = glam::Vec3A::Y;
405            self.prefix_c2 = glam::Vec3A::Z;
406            self.prefix_t = glam::Vec3A::ZERO;
407            self.prefix_identity = true;
408        } else {
409            let a = match compose_fixed_joints_raw(joints) {
410                Ok(a) => a,
411                Err(e) => return Err(e),
412            };
413            self.prefix_identity = a.is_identity();
414            self.prefix_c0 = a.c0_vec3a();
415            self.prefix_c1 = a.c1_vec3a();
416            self.prefix_c2 = a.c2_vec3a();
417            self.prefix_t = a.t_vec3a();
418        }
419        Ok(self)
420    }
421
422    /// Bake a sequence of URDF fixed-joint origins (parent→child order) into
423    /// the tool side of the chain.
424    pub const fn with_fixed_suffix(
425        mut self,
426        joints: &[URDFJoint],
427    ) -> Result<Self, URDFBuildError> {
428        if joints.is_empty() {
429            self.suffix_c0 = glam::Vec3A::X;
430            self.suffix_c1 = glam::Vec3A::Y;
431            self.suffix_c2 = glam::Vec3A::Z;
432            self.suffix_t = glam::Vec3A::ZERO;
433            self.suffix_identity = true;
434        } else {
435            let a = match compose_fixed_joints_raw(joints) {
436                Ok(a) => a,
437                Err(e) => return Err(e),
438            };
439            self.suffix_identity = a.is_identity();
440            self.suffix_c0 = a.c0_vec3a();
441            self.suffix_c1 = a.c1_vec3a();
442            self.suffix_c2 = a.c2_vec3a();
443            self.suffix_t = a.t_vec3a();
444        }
445        Ok(self)
446    }
447
448    /// Convenience: set both a fixed-joint prefix and suffix in one call.
449    pub const fn with_fixed_joints(
450        self,
451        prefix: &[URDFJoint],
452        suffix: &[URDFJoint],
453    ) -> Result<Self, URDFBuildError> {
454        match self.with_fixed_prefix(prefix) {
455            Ok(s) => s.with_fixed_suffix(suffix),
456            Err(e) => Err(e),
457        }
458    }
459}
460
461impl<const N: usize> URDFChain<N, f64> {
462    /// `const`-evaluable f64 analogue of [`URDFChain::<N, f32>::new`].
463    pub const fn new_f64(joints: [URDFJoint; N]) -> Result<Self, URDFBuildError> {
464        let mut fr_c0 = [glam::DVec3::X; N];
465        let mut fr_c1 = [glam::DVec3::Y; N];
466        let mut fr_c2 = [glam::DVec3::Z; N];
467        let mut fr_identity = [true; N];
468        let mut fixed_trans = [glam::DVec3::ZERO; N];
469        let mut axis = [JointAxis::Z; N];
470
471        let mut i = 0;
472        while i < N {
473            let (ox, oy, oz) = joints[i].xyz;
474            let (roll, pitch, yaw) = joints[i].rpy;
475
476            let is_identity = roll.abs() < 1e-12 && pitch.abs() < 1e-12 && yaw.abs() < 1e-12;
477            fr_identity[i] = is_identity;
478
479            if !is_identity {
480                let (sr, cr) = const_sin_cos_f64(roll);
481                let (sp, cp) = const_sin_cos_f64(pitch);
482                let (sy, cy) = const_sin_cos_f64(yaw);
483                fr_c0[i] = glam::DVec3::new(cy * cp, sy * cp, -sp);
484                fr_c1[i] = glam::DVec3::new(
485                    cy * sp * sr - sy * cr,
486                    sy * sp * sr + cy * cr,
487                    cp * sr,
488                );
489                fr_c2[i] = glam::DVec3::new(
490                    cy * sp * cr + sy * sr,
491                    sy * sp * cr - cy * sr,
492                    cp * cr,
493                );
494            }
495
496            fixed_trans[i] = glam::DVec3::new(ox, oy, oz);
497
498            let (ax, ay, az) = match joints[i].r#type {
499                URDFJointType::Revolute { axis } => axis,
500                _ => {
501                    return Err(URDFBuildError::JointTypeMismatch {
502                        index: i,
503                        expected: "Revolute",
504                        found: joint_kind_name(joints[i].r#type),
505                    });
506                }
507            };
508            if az.abs() > 0.5 {
509                axis[i] = JointAxis::Z;
510            } else if ay.abs() > 0.5 {
511                axis[i] = JointAxis::Y(ay.signum());
512            } else {
513                axis[i] = JointAxis::X(ax.signum());
514            }
515            i += 1;
516        }
517
518        Ok(Self {
519            fr_c0,
520            fr_c1,
521            fr_c2,
522            fr_identity,
523            fixed_trans,
524            axis,
525            prefix_c0: glam::DVec3::X,
526            prefix_c1: glam::DVec3::Y,
527            prefix_c2: glam::DVec3::Z,
528            prefix_t: glam::DVec3::ZERO,
529            prefix_identity: true,
530            suffix_c0: glam::DVec3::X,
531            suffix_c1: glam::DVec3::Y,
532            suffix_c2: glam::DVec3::Z,
533            suffix_t: glam::DVec3::ZERO,
534            suffix_identity: true,
535        })
536    }
537
538    /// `const`-evaluable f64 analogue of [`URDFChain::<N, f32>::from_urdf`].
539    pub const fn from_urdf_f64(joints: &[URDFJoint]) -> Result<Self, URDFBuildError> {
540        let mut fr_c0 = [glam::DVec3::X; N];
541        let mut fr_c1 = [glam::DVec3::Y; N];
542        let mut fr_c2 = [glam::DVec3::Z; N];
543        let mut fr_identity = [true; N];
544        let mut fixed_trans = [glam::DVec3::ZERO; N];
545        let mut axis_out = [JointAxis::Z; N];
546
547        let mut pending = AffineRaw64::IDENTITY;
548        let mut prefix = AffineRaw64::IDENTITY;
549        let mut prefix_set = false;
550        let mut r_count = 0usize;
551
552        let n = joints.len();
553        let mut i = 0;
554        while i < n {
555            let joint = &joints[i];
556            match joint.r#type {
557                URDFJointType::Fixed => {
558                    pending = pending.mul(AffineRaw64::from_xyz_rpy(joint.xyz, joint.rpy));
559                }
560                URDFJointType::Revolute { axis } => {
561                    if r_count >= N {
562                        return Err(URDFBuildError::RevoluteCountMismatch {
563                            expected: N,
564                            found: r_count + 1,
565                        });
566                    }
567                    let local = AffineRaw64::from_xyz_rpy(joint.xyz, joint.rpy);
568                    let effective = if !prefix_set {
569                        prefix = pending;
570                        prefix_set = true;
571                        local
572                    } else {
573                        pending.mul(local)
574                    };
575
576                    fr_identity[r_count] = effective.is_identity();
577                    fr_c0[r_count] = effective.c0_dvec3();
578                    fr_c1[r_count] = effective.c1_dvec3();
579                    fr_c2[r_count] = effective.c2_dvec3();
580                    fixed_trans[r_count] = effective.t_dvec3();
581
582                    let (ax, ay, az) = axis;
583                    axis_out[r_count] = if az.abs() > 0.5 {
584                        JointAxis::Z
585                    } else if ay.abs() > 0.5 {
586                        JointAxis::Y(ay.signum())
587                    } else {
588                        JointAxis::X(ax.signum())
589                    };
590
591                    pending = AffineRaw64::IDENTITY;
592                    r_count += 1;
593                }
594                URDFJointType::Prismatic { .. } => {
595                    return Err(URDFBuildError::JointTypeMismatch {
596                        index: i,
597                        expected: "Fixed or Revolute",
598                        found: "Prismatic",
599                    });
600                }
601            }
602            i += 1;
603        }
604        if r_count != N {
605            return Err(URDFBuildError::RevoluteCountMismatch {
606                expected: N,
607                found: r_count,
608            });
609        }
610
611        let prefix_identity = !prefix_set || prefix.is_identity();
612        let suffix_identity = pending.is_identity();
613
614        Ok(Self {
615            fr_c0,
616            fr_c1,
617            fr_c2,
618            fr_identity,
619            fixed_trans,
620            axis: axis_out,
621            prefix_c0: prefix.c0_dvec3(),
622            prefix_c1: prefix.c1_dvec3(),
623            prefix_c2: prefix.c2_dvec3(),
624            prefix_t: prefix.t_dvec3(),
625            prefix_identity,
626            suffix_c0: pending.c0_dvec3(),
627            suffix_c1: pending.c1_dvec3(),
628            suffix_c2: pending.c2_dvec3(),
629            suffix_t: pending.t_dvec3(),
630            suffix_identity,
631        })
632    }
633
634    /// `const`-evaluable f64 analogue of [`URDFChain::<N, f32>::with_fixed_prefix`].
635    pub const fn with_fixed_prefix_f64(
636        mut self,
637        joints: &[URDFJoint],
638    ) -> Result<Self, URDFBuildError> {
639        if joints.is_empty() {
640            self.prefix_c0 = glam::DVec3::X;
641            self.prefix_c1 = glam::DVec3::Y;
642            self.prefix_c2 = glam::DVec3::Z;
643            self.prefix_t = glam::DVec3::ZERO;
644            self.prefix_identity = true;
645        } else {
646            let a = match compose_fixed_joints_raw_f64(joints) {
647                Ok(a) => a,
648                Err(e) => return Err(e),
649            };
650            self.prefix_identity = a.is_identity();
651            self.prefix_c0 = a.c0_dvec3();
652            self.prefix_c1 = a.c1_dvec3();
653            self.prefix_c2 = a.c2_dvec3();
654            self.prefix_t = a.t_dvec3();
655        }
656        Ok(self)
657    }
658
659    /// `const`-evaluable f64 analogue of [`URDFChain::<N, f32>::with_fixed_suffix`].
660    pub const fn with_fixed_suffix_f64(
661        mut self,
662        joints: &[URDFJoint],
663    ) -> Result<Self, URDFBuildError> {
664        if joints.is_empty() {
665            self.suffix_c0 = glam::DVec3::X;
666            self.suffix_c1 = glam::DVec3::Y;
667            self.suffix_c2 = glam::DVec3::Z;
668            self.suffix_t = glam::DVec3::ZERO;
669            self.suffix_identity = true;
670        } else {
671            let a = match compose_fixed_joints_raw_f64(joints) {
672                Ok(a) => a,
673                Err(e) => return Err(e),
674            };
675            self.suffix_identity = a.is_identity();
676            self.suffix_c0 = a.c0_dvec3();
677            self.suffix_c1 = a.c1_dvec3();
678            self.suffix_c2 = a.c2_dvec3();
679            self.suffix_t = a.t_dvec3();
680        }
681        Ok(self)
682    }
683
684    /// Convenience: set both a fixed-joint prefix and suffix in one call (f64).
685    pub const fn with_fixed_joints_f64(
686        self,
687        prefix: &[URDFJoint],
688        suffix: &[URDFJoint],
689    ) -> Result<Self, URDFBuildError> {
690        match self.with_fixed_prefix_f64(prefix) {
691            Ok(s) => s.with_fixed_suffix_f64(suffix),
692            Err(e) => Err(e),
693        }
694    }
695}
696
697impl<const N: usize, F: FKScalar> URDFChain<N, F> {
698    #[inline(always)]
699    fn initial_frame(&self) -> (AVec3<F>, AVec3<F>, AVec3<F>, AVec3<F>) {
700        if self.prefix_identity {
701            (AVec3::<F>::X, AVec3::<F>::Y, AVec3::<F>::Z, AVec3::<F>::ZERO)
702        } else {
703            (self.prefix_c0, self.prefix_c1, self.prefix_c2, self.prefix_t)
704        }
705    }
706
707    #[inline(always)]
708    fn apply_suffix(
709        &self,
710        c0: &mut AVec3<F>,
711        c1: &mut AVec3<F>,
712        c2: &mut AVec3<F>,
713        t: &mut AVec3<F>,
714    ) {
715        let st = self.suffix_t;
716        *t = *c0 * st.x() + *c1 * st.y() + *c2 * st.z() + *t;
717
718        let fc0 = self.suffix_c0;
719        let fc1 = self.suffix_c1;
720        let fc2 = self.suffix_c2;
721        let new_c0 = *c0 * fc0.x() + *c1 * fc0.y() + *c2 * fc0.z();
722        let new_c1 = *c0 * fc1.x() + *c1 * fc1.y() + *c2 * fc1.z();
723        let new_c2 = *c0 * fc2.x() + *c1 * fc2.y() + *c2 * fc2.z();
724        *c0 = new_c0;
725        *c1 = new_c1;
726        *c2 = new_c2;
727    }
728
729    /// Apply fixed rotation + joint rotation to accumulator columns.
730    #[inline(always)]
731    fn accumulate_joint(
732        &self,
733        i: usize,
734        st: F,
735        ct: F,
736        c0: &mut AVec3<F>,
737        c1: &mut AVec3<F>,
738        c2: &mut AVec3<F>,
739        t: &mut AVec3<F>,
740    ) {
741        let ft = self.fixed_trans[i];
742        *t = *c0 * ft.x() + *c1 * ft.y() + *c2 * ft.z() + *t;
743
744        let (f0, f1, f2) = if self.fr_identity[i] {
745            (*c0, *c1, *c2)
746        } else {
747            let fc0 = self.fr_c0[i];
748            let fc1 = self.fr_c1[i];
749            let fc2 = self.fr_c2[i];
750            (
751                *c0 * fc0.x() + *c1 * fc0.y() + *c2 * fc0.z(),
752                *c0 * fc1.x() + *c1 * fc1.y() + *c2 * fc1.z(),
753                *c0 * fc2.x() + *c1 * fc2.y() + *c2 * fc2.z(),
754            )
755        };
756
757        match self.axis[i] {
758            JointAxis::Z => {
759                let new_c0 = f0 * ct + f1 * st;
760                let new_c1 = f1 * ct - f0 * st;
761                *c0 = new_c0;
762                *c1 = new_c1;
763                *c2 = f2;
764            }
765            JointAxis::Y(s) => {
766                let sst = s * st;
767                let new_c0 = f0 * ct - f2 * sst;
768                let new_c2 = f0 * sst + f2 * ct;
769                *c0 = new_c0;
770                *c1 = f1;
771                *c2 = new_c2;
772            }
773            JointAxis::X(s) => {
774                let sst = s * st;
775                let new_c1 = f1 * ct + f2 * sst;
776                let new_c2 = f2 * ct - f1 * sst;
777                *c0 = f0;
778                *c1 = new_c1;
779                *c2 = new_c2;
780            }
781        }
782    }
783}
784
785impl<const N: usize, F: FKScalar> FKChain<N, F> for URDFChain<N, F> {
786    #[cfg(debug_assertions)]
787    type Error = DekeError;
788    #[cfg(not(debug_assertions))]
789    type Error = std::convert::Infallible;
790
791    fn base_tf(&self) -> AAffine3<F> {
792        if self.prefix_identity {
793            AAffine3::<F>::IDENTITY
794        } else {
795            AAffine3::<F>::from_mat3_translation(
796                AMat3::<F>::from_cols(self.prefix_c0, self.prefix_c1, self.prefix_c2),
797                self.prefix_t,
798            )
799        }
800    }
801
802    fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], Self::Error> {
803        check_finite::<N, F>(q)?;
804        let mut out = [AAffine3::<F>::IDENTITY; N];
805        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
806
807        let mut i = 0;
808        while i < N {
809            let (st, ct) = q.0[i].sin_cos();
810            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
811
812            out[i] = AAffine3::<F>::from_mat3_translation(
813                AMat3::<F>::from_cols(c0, c1, c2),
814                t,
815            );
816            i += 1;
817        }
818        // The trailing fixed-joint suffix is part of the EE frame, not the
819        // last revolute link's frame; it is applied only in `fk_end` /
820        // `all_fk`.
821        Ok(out)
822    }
823
824    fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, Self::Error> {
825        check_finite::<N, F>(q)?;
826        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
827
828        let mut i = 0;
829        while i < N {
830            let (st, ct) = q.0[i].sin_cos();
831            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
832            i += 1;
833        }
834
835        if !self.suffix_identity {
836            self.apply_suffix(&mut c0, &mut c1, &mut c2, &mut t);
837        }
838
839        Ok(AAffine3::<F>::from_mat3_translation(
840            AMat3::<F>::from_cols(c0, c1, c2),
841            t,
842        ))
843    }
844
845    fn all_fk(
846        &self,
847        q: &SRobotQ<N, F>,
848    ) -> Result<(AAffine3<F>, [AAffine3<F>; N], AAffine3<F>), Self::Error> {
849        check_finite::<N, F>(q)?;
850        let mut frames = [AAffine3::<F>::IDENTITY; N];
851        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
852
853        let mut i = 0;
854        while i < N {
855            let (st, ct) = q.0[i].sin_cos();
856            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
857            frames[i] = AAffine3::<F>::from_mat3_translation(
858                AMat3::<F>::from_cols(c0, c1, c2),
859                t,
860            );
861            i += 1;
862        }
863
864        // The suffix only contributes to the EE frame; apply it to a
865        // separate copy of the post-loop accumulator so per-link frames
866        // remain untouched.
867        if !self.suffix_identity {
868            self.apply_suffix(&mut c0, &mut c1, &mut c2, &mut t);
869        }
870        let end = AAffine3::<F>::from_mat3_translation(
871            AMat3::<F>::from_cols(c0, c1, c2),
872            t,
873        );
874
875        Ok((self.base_tf(), frames, end))
876    }
877
878    fn joint_axes_positions(
879        &self,
880        q: &SRobotQ<N, F>,
881    ) -> Result<([AVec3<F>; N], [AVec3<F>; N], AVec3<F>), Self::Error> {
882        check_finite::<N, F>(q)?;
883        let mut frames = [AAffine3::<F>::IDENTITY; N];
884        let (mut c0, mut c1, mut c2, mut t) = self.initial_frame();
885
886        let mut i = 0;
887        while i < N {
888            let (st, ct) = q.0[i].sin_cos();
889            self.accumulate_joint(i, st, ct, &mut c0, &mut c1, &mut c2, &mut t);
890            frames[i] = AAffine3::<F>::from_mat3_translation(
891                AMat3::<F>::from_cols(c0, c1, c2),
892                t,
893            );
894            i += 1;
895        }
896
897        let mut axes = [AVec3::<F>::ZERO; N];
898        let mut positions = [AVec3::<F>::ZERO; N];
899
900        for i in 0..N {
901            axes[i] = match self.axis[i] {
902                JointAxis::Z => frames[i].matrix3().z_axis(),
903                JointAxis::Y(s) => frames[i].matrix3().y_axis() * s,
904                JointAxis::X(s) => frames[i].matrix3().x_axis() * s,
905            };
906            positions[i] = frames[i].translation();
907        }
908
909        let p_ee = if N == 0 {
910            AVec3::<F>::ZERO
911        } else if !self.suffix_identity {
912            self.apply_suffix(&mut c0, &mut c1, &mut c2, &mut t);
913            t
914        } else {
915            frames[N - 1].translation()
916        };
917
918        Ok((axes, positions, p_ee))
919    }
920}
921
922
923impl From<JointAxis<f32>> for JointAxis<f64> {
924    #[inline]
925    fn from(j: JointAxis<f32>) -> Self {
926        match j {
927            JointAxis::Z => JointAxis::Z,
928            JointAxis::Y(s) => JointAxis::Y(s as f64),
929            JointAxis::X(s) => JointAxis::X(s as f64),
930        }
931    }
932}
933
934impl From<JointAxis<f64>> for JointAxis<f32> {
935    #[inline]
936    fn from(j: JointAxis<f64>) -> Self {
937        match j {
938            JointAxis::Z => JointAxis::Z,
939            JointAxis::Y(s) => JointAxis::Y(s as f32),
940            JointAxis::X(s) => JointAxis::X(s as f32),
941        }
942    }
943}
944
945#[inline]
946fn cast_arr<const N: usize, A: Copy, B: Copy>(src: [A; N], cast: impl Fn(A) -> B) -> [B; N] {
947    std::array::from_fn(|i| cast(src[i]))
948}
949
950impl<const N: usize> From<URDFChain<N, f32>> for URDFChain<N, f64> {
951    #[inline]
952    fn from(c: URDFChain<N, f32>) -> Self {
953        URDFChain::<N, f64> {
954            fr_c0: cast_arr(c.fr_c0, |v| v.as_dvec3()),
955            fr_c1: cast_arr(c.fr_c1, |v| v.as_dvec3()),
956            fr_c2: cast_arr(c.fr_c2, |v| v.as_dvec3()),
957            fr_identity: c.fr_identity,
958            fixed_trans: cast_arr(c.fixed_trans, |v| v.as_dvec3()),
959            axis: cast_arr(c.axis, JointAxis::<f64>::from),
960            prefix_c0: c.prefix_c0.as_dvec3(),
961            prefix_c1: c.prefix_c1.as_dvec3(),
962            prefix_c2: c.prefix_c2.as_dvec3(),
963            prefix_t: c.prefix_t.as_dvec3(),
964            prefix_identity: c.prefix_identity,
965            suffix_c0: c.suffix_c0.as_dvec3(),
966            suffix_c1: c.suffix_c1.as_dvec3(),
967            suffix_c2: c.suffix_c2.as_dvec3(),
968            suffix_t: c.suffix_t.as_dvec3(),
969            suffix_identity: c.suffix_identity,
970        }
971    }
972}
973
974impl<const N: usize> From<URDFChain<N, f64>> for URDFChain<N, f32> {
975    #[inline]
976    fn from(c: URDFChain<N, f64>) -> Self {
977        URDFChain::<N, f32> {
978            fr_c0: cast_arr(c.fr_c0, |v| v.as_vec3a()),
979            fr_c1: cast_arr(c.fr_c1, |v| v.as_vec3a()),
980            fr_c2: cast_arr(c.fr_c2, |v| v.as_vec3a()),
981            fr_identity: c.fr_identity,
982            fixed_trans: cast_arr(c.fixed_trans, |v| v.as_vec3a()),
983            axis: cast_arr(c.axis, JointAxis::<f32>::from),
984            prefix_c0: c.prefix_c0.as_vec3a(),
985            prefix_c1: c.prefix_c1.as_vec3a(),
986            prefix_c2: c.prefix_c2.as_vec3a(),
987            prefix_t: c.prefix_t.as_vec3a(),
988            prefix_identity: c.prefix_identity,
989            suffix_c0: c.suffix_c0.as_vec3a(),
990            suffix_c1: c.suffix_c1.as_vec3a(),
991            suffix_c2: c.suffix_c2.as_vec3a(),
992            suffix_t: c.suffix_t.as_vec3a(),
993            suffix_identity: c.suffix_identity,
994        }
995    }
996}