1use std::ops::Mul;
2
3use glam::{Affine3A, DAffine3, DMat3, DVec3, Mat3A, Vec3A};
4use glam_traits_ext::{FloatAffine, FloatMat, FloatScalar, FloatVec, TAffine3, TMat3, TVec3};
5
6use crate::{DekeError, SRobotQ};
7
8mod sealed {
9 pub trait Sealed {}
10 impl Sealed for f32 {}
11 impl Sealed for f64 {}
12}
13
14pub trait KinScalar:
21 FloatScalar + Copy + std::fmt::Debug + Send + Sync + 'static + sealed::Sealed + crate::BatchLimits
22{
23 type AVec3: TVec3<Self, MaybeAligned = Self::AVec3>;
24 type AMat3: TMat3<Self, MaybeAligned = Self::AMat3>
25 + FloatMat<Self, Col = Self::AVec3>
26 + Mul<Self::AVec3, Output = Self::AVec3>;
27 type AAffine3: TAffine3<Self, MaybeAligned = Self::AAffine3>
28 + FloatAffine<Self, Vec = Self::AVec3, Mat = Self::AMat3>
29 + Mul<Self::AAffine3, Output = Self::AAffine3>;
30}
31
32impl KinScalar for f32 {
33 type AVec3 = Vec3A;
34 type AMat3 = Mat3A;
35 type AAffine3 = Affine3A;
36}
37
38impl KinScalar for f64 {
39 type AVec3 = DVec3;
40 type AMat3 = DMat3;
41 type AAffine3 = DAffine3;
42}
43
44#[allow(type_alias_bounds)]
45pub type AAffine3<F: KinScalar> = F::AAffine3;
46#[allow(type_alias_bounds)]
47pub type AVec3<F: KinScalar> = F::AVec3;
48
49#[allow(type_alias_bounds)]
50pub type AllFk<const N: usize, F: KinScalar> = (AAffine3<F>, [AAffine3<F>; N], AAffine3<F>);
51
52#[inline(always)]
53#[cfg(debug_assertions)]
54pub fn check_finite<const N: usize, F: FloatScalar>(q: &SRobotQ<N, F>) -> Result<(), DekeError> {
55 if q.any_non_finite() {
56 return Err(DekeError::JointsNonFinite);
57 }
58 Ok(())
59}
60
61#[inline(always)]
62#[cfg(not(debug_assertions))]
63pub fn check_finite<const N: usize, F: FloatScalar>(
64 _: &SRobotQ<N, F>,
65) -> Result<(), std::convert::Infallible> {
66 Ok(())
67}
68
69#[derive(Clone, Copy, Debug, PartialEq)]
70pub enum JointSpec<F: KinScalar> {
71 Revolute { axis_local: AVec3<F> },
72 Prismatic { axis_local: AVec3<F> },
73}
74
75#[derive(Clone, Copy, Debug, PartialEq)]
76pub struct KinSpec<F: KinScalar, const N: usize> {
77 pub base_to_first: AAffine3<F>,
78 pub joints: [(AAffine3<F>, JointSpec<F>); N],
80 pub end_to_ee: AAffine3<F>,
81}
82
83impl<F: KinScalar, const N: usize> KinSpec<F, N> {
84 pub fn new(
85 base_to_first: AAffine3<F>,
86 joints: [(AAffine3<F>, JointSpec<F>); N],
87 end_to_ee: AAffine3<F>,
88 ) -> Self {
89 Self {
90 base_to_first,
91 joints,
92 end_to_ee,
93 }
94 }
95}
96
97pub trait FKChain<const N: usize, F: KinScalar = f32>: Clone + Send + Sync {
98 type Error: Into<DekeError>;
99
100 fn dof(&self) -> usize {
101 N
102 }
103 fn base_tf(&self) -> AAffine3<F> {
106 AAffine3::<F>::IDENTITY
107 }
108
109 fn ee_tf(&self) -> AAffine3<F> {
110 AAffine3::<F>::IDENTITY
111 }
112
113 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], Self::Error>;
114
115 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, Self::Error> {
117 let frames = self.fk(q)?;
118 Ok(if N > 0 {
119 frames[N - 1] * self.ee_tf()
120 } else {
121 AAffine3::<F>::IDENTITY
122 })
123 }
124
125 fn all_fk(&self, q: &SRobotQ<N, F>) -> Result<AllFk<N, F>, Self::Error> {
128 let base = self.base_tf();
129 let frames = self.fk(q)?;
130 let end = self.fk_end(q)?;
131 Ok((base, frames, end))
132 }
133}
134
135pub trait ContinuousFKChain<const N: usize, F: KinScalar = f32>:
142 FKChain<N, F, Error = DekeError>
143{
144 fn structure(&self) -> KinSpec<F, N>;
145
146 fn max_reach(&self) -> Result<F, Self::Error> {
149 let spec = self.structure();
150 let (_, p, p_ee) = forward_pass(&spec, &SRobotQ::zeros());
151 let mut total = F::zero();
152 let mut prev = p[0];
153 for &point in p.iter().take(N).skip(1) {
154 total = total + (point - prev).length();
155 prev = point;
156 }
157 total = total + (p_ee - prev).length();
158 Ok(total)
159 }
160
161 fn jacobian(&self, q: &SRobotQ<N, F>) -> Result<[[F; N]; 6], Self::Error> {
164 #[cfg(debug_assertions)]
165 check_finite::<N, F>(q)?;
166 let spec = self.structure();
167 let (z, p, p_ee) = forward_pass(&spec, q);
168 let mut j = [[F::zero(); N]; 6];
169 for i in 0..N {
170 match spec.joints[i].1 {
171 JointSpec::Revolute { .. } => {
172 let c = z[i].cross(p_ee - p[i]);
173 j[0][i] = c.x();
174 j[1][i] = c.y();
175 j[2][i] = c.z();
176 j[3][i] = z[i].x();
177 j[4][i] = z[i].y();
178 j[5][i] = z[i].z();
179 }
180 JointSpec::Prismatic { .. } => {
181 j[0][i] = z[i].x();
182 j[1][i] = z[i].y();
183 j[2][i] = z[i].z();
184 }
185 }
186 }
187 Ok(j)
188 }
189
190 fn manipulability(&self, q: &SRobotQ<N, F>) -> Result<F, Self::Error> {
200 let j = self.jacobian(q)?;
201 let k = if N >= 6 { 6 } else { N };
206 let mut g = [[F::zero(); 6]; 6];
207 if N >= 6 {
208 for (r, grow) in g.iter_mut().enumerate() {
210 for (c, gval) in grow.iter_mut().enumerate() {
211 *gval = j[r]
212 .iter()
213 .zip(j[c].iter())
214 .fold(F::zero(), |acc, (&a, &b)| acc + a * b);
215 }
216 }
217 } else {
218 for (r, grow) in g.iter_mut().enumerate().take(N) {
220 for (c, gval) in grow.iter_mut().enumerate().take(N) {
221 *gval = j
222 .iter()
223 .fold(F::zero(), |acc, jrow| acc + jrow[r] * jrow[c]);
224 }
225 }
226 }
227 Ok(gram_determinant::<F>(g, k).max(F::zero()).sqrt())
230 }
231
232 fn jacobian_dot(
234 &self,
235 q: &SRobotQ<N, F>,
236 qdot: &SRobotQ<N, F>,
237 ) -> Result<[[F; N]; 6], Self::Error> {
238 #[cfg(debug_assertions)]
239 {
240 check_finite::<N, F>(q)?;
241 check_finite::<N, F>(qdot)?;
242 }
243 let spec = self.structure();
244 let (z, p, p_ee) = forward_pass(&spec, q);
245
246 let mut omega = AVec3::<F>::ZERO;
247 let mut z_dot = [AVec3::<F>::ZERO; N];
248 let mut p_dot = [AVec3::<F>::ZERO; N];
249 let mut pdot_acc = AVec3::<F>::ZERO;
250
251 for i in 0..N {
252 p_dot[i] = pdot_acc;
253 z_dot[i] = omega.cross(z[i]);
254 match spec.joints[i].1 {
255 JointSpec::Revolute { .. } => {
256 omega += z[i] * qdot.0[i];
257 }
258 JointSpec::Prismatic { .. } => {
259 pdot_acc += z[i] * qdot.0[i];
260 }
261 }
262 let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
263 pdot_acc += omega.cross(next_p - p[i]);
264 }
265 let p_ee_dot = pdot_acc;
266
267 let mut jd = [[F::zero(); N]; 6];
268 for i in 0..N {
269 match spec.joints[i].1 {
270 JointSpec::Revolute { .. } => {
271 let dp = p_ee - p[i];
272 let dp_dot = p_ee_dot - p_dot[i];
273 let c1 = z_dot[i].cross(dp);
274 let c2 = z[i].cross(dp_dot);
275 jd[0][i] = c1.x() + c2.x();
276 jd[1][i] = c1.y() + c2.y();
277 jd[2][i] = c1.z() + c2.z();
278 jd[3][i] = z_dot[i].x();
279 jd[4][i] = z_dot[i].y();
280 jd[5][i] = z_dot[i].z();
281 }
282 JointSpec::Prismatic { .. } => {
283 jd[0][i] = z_dot[i].x();
284 jd[1][i] = z_dot[i].y();
285 jd[2][i] = z_dot[i].z();
286 }
287 }
288 }
289 Ok(jd)
290 }
291
292 fn jacobian_ddot(
294 &self,
295 q: &SRobotQ<N, F>,
296 qdot: &SRobotQ<N, F>,
297 qddot: &SRobotQ<N, F>,
298 ) -> Result<[[F; N]; 6], Self::Error> {
299 #[cfg(debug_assertions)]
300 {
301 check_finite::<N, F>(q)?;
302 check_finite::<N, F>(qdot)?;
303 check_finite::<N, F>(qddot)?;
304 }
305 let spec = self.structure();
306 let (z, p, p_ee) = forward_pass(&spec, q);
307
308 let mut omega = AVec3::<F>::ZERO;
309 let mut omega_dot = AVec3::<F>::ZERO;
310 let mut z_dot = [AVec3::<F>::ZERO; N];
311 let mut z_ddot = [AVec3::<F>::ZERO; N];
312 let mut p_dot = [AVec3::<F>::ZERO; N];
313 let mut p_ddot = [AVec3::<F>::ZERO; N];
314 let mut pdot_acc = AVec3::<F>::ZERO;
315 let mut pddot_acc = AVec3::<F>::ZERO;
316
317 for i in 0..N {
318 p_dot[i] = pdot_acc;
319 p_ddot[i] = pddot_acc;
320 let zd = omega.cross(z[i]);
321 z_dot[i] = zd;
322 z_ddot[i] = omega_dot.cross(z[i]) + omega.cross(zd);
323 match spec.joints[i].1 {
324 JointSpec::Revolute { .. } => {
325 omega_dot += z[i] * qddot.0[i] + zd * qdot.0[i];
326 omega += z[i] * qdot.0[i];
327 }
328 JointSpec::Prismatic { .. } => {
329 pddot_acc += z[i] * qddot.0[i] + zd * qdot.0[i];
330 pdot_acc += z[i] * qdot.0[i];
331 }
332 }
333 let next_p = if i + 1 < N { p[i + 1] } else { p_ee };
334 let delta = next_p - p[i];
335 let delta_dot = omega.cross(delta);
336 pdot_acc += delta_dot;
337 pddot_acc += omega_dot.cross(delta) + omega.cross(delta_dot);
338 }
339 let p_ee_dot = pdot_acc;
340 let p_ee_ddot = pddot_acc;
341
342 let mut jdd = [[F::zero(); N]; 6];
343 for i in 0..N {
344 match spec.joints[i].1 {
345 JointSpec::Revolute { .. } => {
346 let dp = p_ee - p[i];
347 let dp_dot = p_ee_dot - p_dot[i];
348 let dp_ddot = p_ee_ddot - p_ddot[i];
349 let c1 = z_ddot[i].cross(dp);
350 let c2 = z_dot[i].cross(dp_dot);
351 let c3 = z[i].cross(dp_ddot);
352 let two = F::one() + F::one();
353 jdd[0][i] = c1.x() + two * c2.x() + c3.x();
354 jdd[1][i] = c1.y() + two * c2.y() + c3.y();
355 jdd[2][i] = c1.z() + two * c2.z() + c3.z();
356 jdd[3][i] = z_ddot[i].x();
357 jdd[4][i] = z_ddot[i].y();
358 jdd[5][i] = z_ddot[i].z();
359 }
360 JointSpec::Prismatic { .. } => {
361 jdd[0][i] = z_ddot[i].x();
362 jdd[1][i] = z_ddot[i].y();
363 jdd[2][i] = z_ddot[i].z();
364 }
365 }
366 }
367 Ok(jdd)
368 }
369}
370
371fn gram_determinant<F: KinScalar>(mut m: [[F; 6]; 6], k: usize) -> F {
377 let mut det = F::one();
378 for col in 0..k {
379 let mut pivot = col;
380 let mut pivot_abs = m[col][col].abs();
381 for (r, row) in m.iter().enumerate().take(k).skip(col + 1) {
382 let v = row[col].abs();
383 if v > pivot_abs {
384 pivot_abs = v;
385 pivot = r;
386 }
387 }
388 if pivot_abs.partial_cmp(&F::zero()) != Some(core::cmp::Ordering::Greater) {
391 return F::zero();
392 }
393 if pivot != col {
394 m.swap(pivot, col);
395 det = -det;
396 }
397 let pivot_row = m[col];
398 let diag = pivot_row[col];
399 det = det * diag;
400 for row in m.iter_mut().take(k).skip(col + 1) {
401 let factor = row[col] / diag;
402 for (c, &pv) in pivot_row.iter().enumerate().take(k).skip(col) {
403 row[c] = row[c] - factor * pv;
404 }
405 }
406 }
407 det
408}
409
410fn forward_pass<F: KinScalar, const N: usize>(
416 spec: &KinSpec<F, N>,
417 q: &SRobotQ<N, F>,
418) -> ([AVec3<F>; N], [AVec3<F>; N], AVec3<F>) {
419 let mut z_out = [AVec3::<F>::ZERO; N];
420 let mut p_out = [AVec3::<F>::ZERO; N];
421 let mut current = spec.base_to_first;
422
423 for i in 0..N {
424 current *= spec.joints[i].0;
425 p_out[i] = current.translation();
426 match spec.joints[i].1 {
427 JointSpec::Revolute { axis_local } => {
428 let axis = axis_local.normalize();
429 z_out[i] = current.matrix3() * axis;
430 current *= AAffine3::<F>::from_axis_angle(axis, q.0[i]);
431 }
432 JointSpec::Prismatic { axis_local } => {
433 let axis = axis_local.normalize();
434 z_out[i] = current.matrix3() * axis;
435 current *= AAffine3::<F>::from_translation(axis * q.0[i]);
436 }
437 }
438 }
439
440 current *= spec.end_to_ee;
441 let p_ee = current.translation();
442 (z_out, p_out, p_ee)
443}
444
445pub type IkSolutions<const N: usize, F> = smallvec::SmallVec<[SRobotQ<N, F>; 8]>;
449
450pub enum IkOutcome<const N: usize, F: KinScalar> {
451 Solved(IkSolutions<N, F>),
452 Failed {
453 partial: Option<IkSolutions<N, F>>,
454 residual: F,
455 },
456}
457
458impl<const N: usize, F: KinScalar> IkOutcome<N, F> {
459 pub fn unwrap(self) -> IkSolutions<N, F> {
460 match self {
461 IkOutcome::Solved(solutions) => solutions,
462 _ => IkSolutions::new(),
463 }
464 }
465
466 pub fn as_result(self) -> Result<IkSolutions<N, F>, DekeError> {
467 match self {
468 IkOutcome::Solved(solutions) => Ok(solutions),
469 IkOutcome::Failed { residual, .. } => Err(DekeError::IkSolverFailed(
470 residual.to_f64().unwrap_or(f64::MAX),
471 )),
472 }
473 }
474
475 pub fn residual(&self) -> Option<F> {
476 match self {
477 IkOutcome::Solved(_) => Some(F::zero()),
478 IkOutcome::Failed { residual, .. } => Some(*residual),
479 }
480 }
481
482 pub fn is_solved(&self) -> bool {
483 matches!(self, IkOutcome::Solved(_))
484 }
485
486 pub fn is_failed(&self) -> bool {
487 matches!(self, IkOutcome::Failed { .. })
488 }
489}
490
491pub trait IkSolver<const N: usize, F: KinScalar = f32>: FKChain<N, F> {
492 type IkConfig: Default + Clone + Send + Sync + 'static;
493
494 fn ik_with_config(
495 &self,
496 target: AAffine3<F>,
497 config: &Self::IkConfig,
498 ) -> Result<IkOutcome<N, F>, Self::Error>;
499 fn ik(&self, target: AAffine3<F>) -> Result<IkOutcome<N, F>, Self::Error> {
500 self.ik_with_config(target, &Self::IkConfig::default())
501 }
502}
503
504trait ErasedFK<const N: usize, F: KinScalar>: Send + Sync {
505 fn base_tf(&self) -> AAffine3<F>;
506 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], DekeError>;
507 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, DekeError>;
508 fn all_fk(&self, q: &SRobotQ<N, F>) -> Result<AllFk<N, F>, DekeError>;
509 fn clone_box(&self) -> Box<dyn ErasedFK<N, F>>;
510}
511
512impl<const N: usize, F: KinScalar, FK: FKChain<N, F> + 'static> ErasedFK<N, F> for FK {
513 fn base_tf(&self) -> AAffine3<F> {
514 FKChain::base_tf(self)
515 }
516
517 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], DekeError> {
518 FKChain::fk(self, q).map_err(Into::into)
519 }
520
521 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, DekeError> {
522 FKChain::fk_end(self, q).map_err(Into::into)
523 }
524
525 fn all_fk(
526 &self,
527 q: &SRobotQ<N, F>,
528 ) -> Result<(AAffine3<F>, [AAffine3<F>; N], AAffine3<F>), DekeError> {
529 FKChain::all_fk(self, q).map_err(Into::into)
530 }
531
532 fn clone_box(&self) -> Box<dyn ErasedFK<N, F>> {
533 Box::new(self.clone())
534 }
535}
536
537pub struct BoxFK<const N: usize, F: KinScalar = f32>(Box<dyn ErasedFK<N, F>>);
538
539impl<const N: usize, F: KinScalar> BoxFK<N, F> {
540 pub fn new(fk: impl FKChain<N, F> + 'static) -> Self {
541 Self(Box::new(fk))
542 }
543}
544
545impl<const N: usize, F: KinScalar> Clone for BoxFK<N, F> {
546 fn clone(&self) -> Self {
547 Self(self.0.clone_box())
548 }
549}
550
551impl<const N: usize, F: KinScalar> FKChain<N, F> for BoxFK<N, F> {
552 type Error = DekeError;
553
554 fn base_tf(&self) -> AAffine3<F> {
555 self.0.base_tf()
556 }
557
558 fn fk(&self, q: &SRobotQ<N, F>) -> Result<[AAffine3<F>; N], DekeError> {
559 self.0.fk(q)
560 }
561
562 fn fk_end(&self, q: &SRobotQ<N, F>) -> Result<AAffine3<F>, DekeError> {
563 self.0.fk_end(q)
564 }
565
566 fn all_fk(
567 &self,
568 q: &SRobotQ<N, F>,
569 ) -> Result<(AAffine3<F>, [AAffine3<F>; N], AAffine3<F>), DekeError> {
570 self.0.all_fk(q)
571 }
572}