1use std::ops::Mul;
2
3use glam::{Affine3A, DAffine3, DMat3, DVec3, Mat3A, Vec3A};
4use glam_traits_ext::{FloatAffine, FloatMat, FloatScalar, FloatVec, TAffine3, TMat3, TVec3};
5
6use crate::{DekeError, SRobotQ};
7
8mod sealed {
9 pub trait Sealed {}
10 impl Sealed for f32 {}
11 impl Sealed for f64 {}
12}
13
14pub trait KinScalar:
21 FloatScalar + Copy + std::fmt::Debug + Send + Sync + 'static + sealed::Sealed + crate::BatchLimits
22{
23 type AVec3: TVec3<Self, MaybeAligned = Self::AVec3>;
24 type AMat3: TMat3<Self, MaybeAligned = Self::AMat3>
25 + FloatMat<Self, Col = Self::AVec3>
26 + Mul<Self::AVec3, Output = Self::AVec3>;
27 type AAffine3: TAffine3<Self, MaybeAligned = Self::AAffine3>
28 + FloatAffine<Self, Vec = Self::AVec3, Mat = Self::AMat3>
29 + Mul<Self::AAffine3, Output = Self::AAffine3>;
30}
31
32impl KinScalar for f32 {
33 type AVec3 = Vec3A;
34 type AMat3 = Mat3A;
35 type AAffine3 = Affine3A;
36}
37
38impl KinScalar for f64 {
39 type AVec3 = DVec3;
40 type AMat3 = DMat3;
41 type AAffine3 = DAffine3;
42}
43
44#[allow(type_alias_bounds)]
45pub type AAffine3<F: KinScalar> = F::AAffine3;
46#[allow(type_alias_bounds)]
47pub type AVec3<F: KinScalar> = F::AVec3;
48
49#[allow(type_alias_bounds)]
50pub type AllFk<const N: usize, F: KinScalar> = (AAffine3<F>, [AAffine3<F>; N], AAffine3<F>);
51
52#[inline(always)]
53#[cfg(debug_assertions)]
54pub fn check_finite<const N: usize, F: FloatScalar>(q: &SRobotQ<N, F>) -> Result<(), DekeError> {
55 if q.any_non_finite() {
56 return Err(DekeError::JointsNonFinite);
57 }
58 Ok(())
59}
60
61#[inline(always)]
62#[cfg(not(debug_assertions))]
63pub fn check_finite<const N: usize, F: FloatScalar>(
64 _: &SRobotQ<N, F>,
65) -> Result<(), std::convert::Infallible> {
66 Ok(())
67}
68
69#[derive(Clone, Copy, Debug, PartialEq)]
70pub enum JointSpec<F: KinScalar> {
71 Revolute { axis_local: AVec3<F> },
72 Prismatic { axis_local: AVec3<F> },
73}
74
75#[derive(Clone, Copy, Debug, PartialEq)]
76pub struct KinSpec<F: KinScalar, const N: usize> {
77 pub base_to_first: AAffine3<F>,
78 pub joints: [(AAffine3<F>, JointSpec<F>); N],
80 pub end_to_ee: AAffine3<F>,
81}
82
83impl<F: KinScalar, const N: usize> KinSpec<F, N> {
84 pub fn new(
85 base_to_first: AAffine3<F>,
86 joints: [(AAffine3<F>, JointSpec<F>); N],
87 end_to_ee: AAffine3<F>,
88 ) -> Self {
89 Self {
90 base_to_first,
91 joints,
92 end_to_ee,
93 }
94 }
95}
96
97pub trait FKChain<const N: usize, F: KinScalar = f32>: Clone + Send + Sync {
98 type Error: Into<DekeError>;
99
100 fn dof(&self) -> usize {
101 N
102 }
103 fn base_tf(&self) -> AAffine3<F> {
106 AAffine3::<F>::IDENTITY
107 }
108
109 fn ee_tf(&self) -> AAffine3<F> {
110 AAffine3::<F>::IDENTITY
111 }
112
113 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], Self::Error>;
114
115 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, Self::Error> {
117 let frames = self.fk(q)?;
118 Ok(if N > 0 {
119 frames[N - 1] * self.ee_tf()
120 } else {
121 AAffine3::<F>::IDENTITY
122 })
123 }
124
125 fn all_fk(&self, q: &SRobotQ<N, F>) -> Result<AllFk<N, F>, Self::Error> {
128 let base = self.base_tf();
129 let frames = self.fk(q)?;
130 let end = self.fk_end(q)?;
131 Ok((base, frames, end))
132 }
133}
134
135pub trait ContinuousFKChain<const N: usize, F: KinScalar = f32>:
142 FKChain<N, F, Error = DekeError>
143{
144 fn structure(&self) -> KinSpec<F, N>;
145
146 fn max_reach(&self) -> Result<F, Self::Error> {
149 let spec = self.structure();
150 let (_, p, p_ee) = forward_pass(&spec, &SRobotQ::zeros());
151 let mut total = F::zero();
152 let mut prev = p[0];
153 for &point in p.iter().take(N).skip(1) {
154 total = total + (point - prev).length();
155 prev = point;
156 }
157 total = total + (p_ee - prev).length();
158 Ok(total)
159 }
160
161 fn jacobian(&self, q: &SRobotQ<N, F>) -> Result<[[F; N]; 6], Self::Error> {
164 #[cfg(debug_assertions)]
165 check_finite::<N, F>(q)?;
166 let spec = self.structure();
167 let (z, p, p_ee) = forward_pass(&spec, q);
168 let mut j = [[F::zero(); N]; 6];
169 for i in 0..N {
170 match spec.joints[i].1 {
171 JointSpec::Revolute { .. } => {
172 let c = z[i].cross(p_ee - p[i]);
173 j[0][i] = c.x();
174 j[1][i] = c.y();
175 j[2][i] = c.z();
176 j[3][i] = z[i].x();
177 j[4][i] = z[i].y();
178 j[5][i] = z[i].z();
179 }
180 JointSpec::Prismatic { .. } => {
181 j[0][i] = z[i].x();
182 j[1][i] = z[i].y();
183 j[2][i] = z[i].z();
184 }
185 }
186 }
187 Ok(j)
188 }
189
190 fn manipulability(&self, q: &SRobotQ<N, F>) -> Result<F, Self::Error> {
200 let j = self.jacobian(q)?;
201 if N == 6 {
205 let mut m = [[F::zero(); 6]; 6];
206 for (mr, jr) in m.iter_mut().zip(j.iter()) {
207 for (mv, &jv) in mr.iter_mut().zip(jr.iter()) {
208 *mv = jv;
209 }
210 }
211 return Ok(gram_determinant::<F>(m, 6).abs());
212 }
213 let k = if N >= 6 { 6 } else { N };
218 let mut g = [[F::zero(); 6]; 6];
219 if N >= 6 {
223 for r in 0..6 {
225 for c in 0..=r {
226 let dot = j[r]
227 .iter()
228 .zip(j[c].iter())
229 .fold(F::zero(), |acc, (&a, &b)| acc + a * b);
230 g[r][c] = dot;
231 g[c][r] = dot;
232 }
233 }
234 } else {
235 for r in 0..N {
237 for c in 0..=r {
238 let dot = j
239 .iter()
240 .fold(F::zero(), |acc, jrow| acc + jrow[r] * jrow[c]);
241 g[r][c] = dot;
242 g[c][r] = dot;
243 }
244 }
245 }
246 Ok(gram_determinant::<F>(g, k).max(F::zero()).sqrt())
249 }
250
251 fn jacobian_dot(
253 &self,
254 q: &SRobotQ<N, F>,
255 qdot: &SRobotQ<N, F>,
256 ) -> Result<[[F; N]; 6], Self::Error> {
257 #[cfg(debug_assertions)]
258 {
259 check_finite::<N, F>(q)?;
260 check_finite::<N, F>(qdot)?;
261 }
262 let spec = self.structure();
263 let (z, p, p_ee) = forward_pass(&spec, q);
264
265 let mut omega = AVec3::<F>::ZERO;
266 let mut z_dot = [AVec3::<F>::ZERO; N];
267 let mut p_dot = [AVec3::<F>::ZERO; N];
268 let mut pdot_acc = AVec3::<F>::ZERO;
269
270 for i in 0..N {
271 p_dot[i] = pdot_acc;
272 z_dot[i] = omega.cross(z[i]);
273 match spec.joints[i].1 {
274 JointSpec::Revolute { .. } => {
275 omega += z[i] * qdot.0[i];
276 }
277 JointSpec::Prismatic { .. } => {
278 pdot_acc += z[i] * qdot.0[i];
279 }
280 }
281 let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
282 pdot_acc += omega.cross(next_p - p[i]);
283 }
284 let p_ee_dot = pdot_acc;
285
286 let mut jd = [[F::zero(); N]; 6];
287 for i in 0..N {
288 match spec.joints[i].1 {
289 JointSpec::Revolute { .. } => {
290 let dp = p_ee - p[i];
291 let dp_dot = p_ee_dot - p_dot[i];
292 let c1 = z_dot[i].cross(dp);
293 let c2 = z[i].cross(dp_dot);
294 jd[0][i] = c1.x() + c2.x();
295 jd[1][i] = c1.y() + c2.y();
296 jd[2][i] = c1.z() + c2.z();
297 jd[3][i] = z_dot[i].x();
298 jd[4][i] = z_dot[i].y();
299 jd[5][i] = z_dot[i].z();
300 }
301 JointSpec::Prismatic { .. } => {
302 jd[0][i] = z_dot[i].x();
303 jd[1][i] = z_dot[i].y();
304 jd[2][i] = z_dot[i].z();
305 }
306 }
307 }
308 Ok(jd)
309 }
310
311 fn jacobian_ddot(
313 &self,
314 q: &SRobotQ<N, F>,
315 qdot: &SRobotQ<N, F>,
316 qddot: &SRobotQ<N, F>,
317 ) -> Result<[[F; N]; 6], Self::Error> {
318 #[cfg(debug_assertions)]
319 {
320 check_finite::<N, F>(q)?;
321 check_finite::<N, F>(qdot)?;
322 check_finite::<N, F>(qddot)?;
323 }
324 let spec = self.structure();
325 let (z, p, p_ee) = forward_pass(&spec, q);
326
327 let mut omega = AVec3::<F>::ZERO;
328 let mut omega_dot = AVec3::<F>::ZERO;
329 let mut z_dot = [AVec3::<F>::ZERO; N];
330 let mut z_ddot = [AVec3::<F>::ZERO; N];
331 let mut p_dot = [AVec3::<F>::ZERO; N];
332 let mut p_ddot = [AVec3::<F>::ZERO; N];
333 let mut pdot_acc = AVec3::<F>::ZERO;
334 let mut pddot_acc = AVec3::<F>::ZERO;
335
336 for i in 0..N {
337 p_dot[i] = pdot_acc;
338 p_ddot[i] = pddot_acc;
339 let zd = omega.cross(z[i]);
340 z_dot[i] = zd;
341 z_ddot[i] = omega_dot.cross(z[i]) + omega.cross(zd);
342 match spec.joints[i].1 {
343 JointSpec::Revolute { .. } => {
344 omega_dot += z[i] * qddot.0[i] + zd * qdot.0[i];
345 omega += z[i] * qdot.0[i];
346 }
347 JointSpec::Prismatic { .. } => {
348 pddot_acc += z[i] * qddot.0[i] + zd * qdot.0[i];
349 pdot_acc += z[i] * qdot.0[i];
350 }
351 }
352 let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
353 let delta = next_p - p[i];
354 let delta_dot = omega.cross(delta);
355 pdot_acc += delta_dot;
356 pddot_acc += omega_dot.cross(delta) + omega.cross(delta_dot);
357 }
358 let p_ee_dot = pdot_acc;
359 let p_ee_ddot = pddot_acc;
360
361 let mut jdd = [[F::zero(); N]; 6];
362 for i in 0..N {
363 match spec.joints[i].1 {
364 JointSpec::Revolute { .. } => {
365 let dp = p_ee - p[i];
366 let dp_dot = p_ee_dot - p_dot[i];
367 let dp_ddot = p_ee_ddot - p_ddot[i];
368 let c1 = z_ddot[i].cross(dp);
369 let c2 = z_dot[i].cross(dp_dot);
370 let c3 = z[i].cross(dp_ddot);
371 let two = F::one() + F::one();
372 jdd[0][i] = c1.x() + two * c2.x() + c3.x();
373 jdd[1][i] = c1.y() + two * c2.y() + c3.y();
374 jdd[2][i] = c1.z() + two * c2.z() + c3.z();
375 jdd[3][i] = z_ddot[i].x();
376 jdd[4][i] = z_ddot[i].y();
377 jdd[5][i] = z_ddot[i].z();
378 }
379 JointSpec::Prismatic { .. } => {
380 jdd[0][i] = z_ddot[i].x();
381 jdd[1][i] = z_ddot[i].y();
382 jdd[2][i] = z_ddot[i].z();
383 }
384 }
385 }
386 Ok(jdd)
387 }
388}
389
390fn gram_determinant<F: KinScalar>(mut m: [[F; 6]; 6], k: usize) -> F {
396 let mut det = F::one();
397 for col in 0..k {
398 let mut pivot = col;
399 let mut pivot_abs = m[col][col].abs();
400 for (r, row) in m.iter().enumerate().take(k).skip(col + 1) {
401 let v = row[col].abs();
402 if v > pivot_abs {
403 pivot_abs = v;
404 pivot = r;
405 }
406 }
407 if pivot_abs.partial_cmp(&F::zero()) != Some(core::cmp::Ordering::Greater) {
410 return F::zero();
411 }
412 if pivot != col {
413 m.swap(pivot, col);
414 det = -det;
415 }
416 let pivot_row = m[col];
417 let diag = pivot_row[col];
418 det = det * diag;
419 for row in m.iter_mut().take(k).skip(col + 1) {
420 let factor = row[col] / diag;
421 for (c, &pv) in pivot_row.iter().enumerate().take(k).skip(col) {
422 row[c] = row[c] - factor * pv;
423 }
424 }
425 }
426 det
427}
428
429fn forward_pass<F: KinScalar, const N: usize>(
435 spec: &KinSpec<F, N>,
436 q: &SRobotQ<N, F>,
437) -> ([AVec3<F>; N], [AVec3<F>; N], AVec3<F>) {
438 let mut z_out = [AVec3::<F>::ZERO; N];
439 let mut p_out = [AVec3::<F>::ZERO; N];
440 let mut current = spec.base_to_first;
441
442 for i in 0..N {
443 current *= spec.joints[i].0;
444 p_out[i] = current.translation();
445 match spec.joints[i].1 {
446 JointSpec::Revolute { axis_local } => {
447 let axis = axis_local.normalize();
448 z_out[i] = current.matrix3() * axis;
449 current *= AAffine3::<F>::from_axis_angle(axis, q.0[i]);
450 }
451 JointSpec::Prismatic { axis_local } => {
452 let axis = axis_local.normalize();
453 z_out[i] = current.matrix3() * axis;
454 current *= AAffine3::<F>::from_translation(axis * q.0[i]);
455 }
456 }
457 }
458
459 current *= spec.end_to_ee;
460 let p_ee = current.translation();
461 (z_out, p_out, p_ee)
462}
463
464pub type IkSolutions<const N: usize, F> = smallvec::SmallVec<[SRobotQ<N, F>; 8]>;
468
469pub enum IkOutcome<const N: usize, F: KinScalar> {
470 Solved(IkSolutions<N, F>),
471 Failed {
472 partial: Option<IkSolutions<N, F>>,
473 residual: F,
474 },
475}
476
477impl<const N: usize, F: KinScalar> IkOutcome<N, F> {
478 pub fn unwrap(self) -> IkSolutions<N, F> {
479 match self {
480 IkOutcome::Solved(solutions) => solutions,
481 _ => IkSolutions::new(),
482 }
483 }
484
485 pub fn as_result(self) -> Result<IkSolutions<N, F>, DekeError> {
486 match self {
487 IkOutcome::Solved(solutions) => Ok(solutions),
488 IkOutcome::Failed { residual, .. } => Err(DekeError::IkSolverFailed(
489 residual.to_f64().unwrap_or(f64::MAX),
490 )),
491 }
492 }
493
494 pub fn residual(&self) -> Option<F> {
495 match self {
496 IkOutcome::Solved(_) => Some(F::zero()),
497 IkOutcome::Failed { residual, .. } => Some(*residual),
498 }
499 }
500
501 pub fn is_solved(&self) -> bool {
502 matches!(self, IkOutcome::Solved(_))
503 }
504
505 pub fn is_failed(&self) -> bool {
506 matches!(self, IkOutcome::Failed { .. })
507 }
508}
509
510pub trait IkSolver<const N: usize, F: KinScalar = f32>: FKChain<N, F> {
511 type IkConfig: Default + Clone + Send + Sync + 'static;
512
513 fn ik_with_config(
514 &self,
515 target: AAffine3<F>,
516 config: &Self::IkConfig,
517 ) -> Result<IkOutcome<N, F>, Self::Error>;
518 fn ik(&self, target: AAffine3<F>) -> Result<IkOutcome<N, F>, Self::Error> {
519 self.ik_with_config(target, &Self::IkConfig::default())
520 }
521}
522
523trait ErasedFK<const N: usize, F: KinScalar>: Send + Sync {
524 fn base_tf(&self) -> AAffine3<F>;
525 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], DekeError>;
526 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, DekeError>;
527 fn all_fk(&self, q: &SRobotQ<N, F>) -> Result<AllFk<N, F>, DekeError>;
528 fn clone_box(&self) -> Box<dyn ErasedFK<N, F>>;
529}
530
531impl<const N: usize, F: KinScalar, FK: FKChain<N, F> + 'static> ErasedFK<N, F> for FK {
532 fn base_tf(&self) -> AAffine3<F> {
533 FKChain::base_tf(self)
534 }
535
536 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], DekeError> {
537 FKChain::fk(self, q).map_err(Into::into)
538 }
539
540 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, DekeError> {
541 FKChain::fk_end(self, q).map_err(Into::into)
542 }
543
544 fn all_fk(
545 &self,
546 q: &SRobotQ<N, F>,
547 ) -> Result<(AAffine3<F>, [AAffine3<F>; N], AAffine3<F>), DekeError> {
548 FKChain::all_fk(self, q).map_err(Into::into)
549 }
550
551 fn clone_box(&self) -> Box<dyn ErasedFK<N, F>> {
552 Box::new(self.clone())
553 }
554}
555
556pub struct BoxFK<const N: usize, F: KinScalar = f32>(Box<dyn ErasedFK<N, F>>);
557
558impl<const N: usize, F: KinScalar> BoxFK<N, F> {
559 pub fn new(fk: impl FKChain<N, F> + 'static) -> Self {
560 Self(Box::new(fk))
561 }
562}
563
564impl<const N: usize, F: KinScalar> Clone for BoxFK<N, F> {
565 fn clone(&self) -> Self {
566 Self(self.0.clone_box())
567 }
568}
569
570impl<const N: usize, F: KinScalar> FKChain<N, F> for BoxFK<N, F> {
571 type Error = DekeError;
572
573 fn base_tf(&self) -> AAffine3<F> {
574 self.0.base_tf()
575 }
576
577 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], DekeError> {
578 self.0.fk(q)
579 }
580
581 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, DekeError> {
582 self.0.fk_end(q)
583 }
584
585 fn all_fk(
586 &self,
587 q: &SRobotQ<N, F>,
588 ) -> Result<(AAffine3<F>, [AAffine3<F>; N], AAffine3<F>), DekeError> {
589 self.0.all_fk(q)
590 }
591}