Skip to main content

deke_types/
q.rs

1use ndarray::Array1;
2use wide::{CmpGt, CmpLt, f32x8};
3
4use crate::DekeError;
5
6#[inline(always)]
7fn simd_load(slice: &[f32], off: usize) -> f32x8 {
8    let n = 8.min(slice.len().saturating_sub(off));
9    let mut buf = [0.0; 8];
10    buf[..n].copy_from_slice(&slice[off..off + n]);
11    f32x8::new(buf)
12}
13
14#[inline(always)]
15fn simd_store(v: f32x8, dst: &mut [f32], off: usize) {
16    let n = 8.min(dst.len().saturating_sub(off));
17    dst[off..off + n].copy_from_slice(&v.to_array()[..n]);
18}
19
20#[inline(always)]
21fn simd_binop<const N: usize>(
22    a: &[f32; N],
23    b: &[f32; N],
24    out: &mut [f32; N],
25    op: fn(f32x8, f32x8) -> f32x8,
26) {
27    let mut off = 0;
28    while off < N {
29        simd_store(op(simd_load(a, off), simd_load(b, off)), out, off);
30        off += 8;
31    }
32}
33
34#[inline(always)]
35fn simd_unaryop<const N: usize>(a: &[f32; N], out: &mut [f32; N], op: fn(f32x8) -> f32x8) {
36    let mut off = 0;
37    while off < N {
38        simd_store(op(simd_load(a, off)), out, off);
39        off += 8;
40    }
41}
42
43#[inline(always)]
44fn simd_scalarop<const N: usize>(
45    a: &[f32; N],
46    s: f32x8,
47    out: &mut [f32; N],
48    op: fn(f32x8, f32x8) -> f32x8,
49) {
50    let mut off = 0;
51    while off < N {
52        simd_store(op(simd_load(a, off), s), out, off);
53        off += 8;
54    }
55}
56
57#[inline(always)]
58fn simd_hsum<const N: usize>(a: &[f32; N]) -> f32 {
59    let mut acc = f32x8::ZERO;
60    let mut off = 0;
61    while off < N {
62        acc += simd_load(a, off);
63        off += 8;
64    }
65    acc.reduce_add()
66}
67
68#[inline(always)]
69fn simd_load_neg_inf(slice: &[f32], off: usize) -> f32x8 {
70    let n = 8.min(slice.len().saturating_sub(off));
71    let mut buf = [f32::NEG_INFINITY; 8];
72    buf[..n].copy_from_slice(&slice[off..off + n]);
73    f32x8::new(buf)
74}
75
76#[inline(always)]
77fn simd_load_inf(slice: &[f32], off: usize) -> f32x8 {
78    let n = 8.min(slice.len().saturating_sub(off));
79    let mut buf = [f32::INFINITY; 8];
80    buf[..n].copy_from_slice(&slice[off..off + n]);
81    f32x8::new(buf)
82}
83
84#[inline(always)]
85fn simd_dot<const N: usize>(a: &[f32; N], b: &[f32; N]) -> f32 {
86    let mut acc = f32x8::ZERO;
87    let mut off = 0;
88    while off < N {
89        acc = simd_load(a, off).mul_add(simd_load(b, off), acc);
90        off += 8;
91    }
92    acc.reduce_add()
93}
94
95pub type RobotQ = Array1<f32>;
96
97/// Statically-sized joint configuration backed by `[f32; N]`.
98#[derive(Debug, Clone, Copy, PartialEq)]
99pub struct SRobotQ<const N: usize>(pub [f32; N]);
100
101impl<const N: usize> SRobotQ<N> {
102    pub const fn zeros() -> Self {
103        Self([0.0; N])
104    }
105
106    pub const fn from_array(arr: [f32; N]) -> Self {
107        Self(arr)
108    }
109
110    pub const fn as_slice(&self) -> &[f32] {
111        &self.0
112    }
113
114    pub const fn as_mut_slice(&mut self) -> &mut [f32] {
115        &mut self.0
116    }
117
118    pub fn to_robotq(&self) -> RobotQ {
119        RobotQ::from(self.0.to_vec())
120    }
121
122    pub fn force_from_robotq(q: &RobotQ) -> Self {
123        if let Ok(sq) = Self::try_from(q) {
124            sq
125        } else {
126            let slice = q.as_slice().unwrap_or(&[]);
127            let mut arr = [0.0; N];
128            for i in 0..N {
129                arr[i] = *slice.get(i).unwrap_or(&0.0);
130            }
131            Self(arr)
132        }
133    }
134
135    pub fn norm(&self) -> f32 {
136        if N <= 16 {
137            self.dot(self).sqrt()
138        } else {
139            self.0.iter().map(|x| x * x).sum::<f32>().sqrt()
140        }
141    }
142
143    pub fn dot(&self, other: &Self) -> f32 {
144        if N <= 16 {
145            simd_dot(&self.0, &other.0)
146        } else {
147            self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum()
148        }
149    }
150
151    pub fn map(&self, f: impl Fn(f32) -> f32) -> Self {
152        let mut out = [0.0; N];
153        for i in 0..N {
154            out[i] = f(self.0[i]);
155        }
156        Self(out)
157    }
158
159    pub fn sum(&self) -> f32 {
160        if N <= 16 {
161            simd_hsum(&self.0)
162        } else {
163            self.0.iter().sum()
164        }
165    }
166
167    pub fn splat(val: f32) -> Self {
168        Self([val; N])
169    }
170
171    pub fn from_fn(f: impl Fn(usize) -> f32) -> Self {
172        let mut out = [0.0; N];
173        for i in 0..N {
174            out[i] = f(i);
175        }
176        Self(out)
177    }
178
179    pub fn norm_squared(&self) -> f32 {
180        self.dot(self)
181    }
182
183    pub fn normalize(&self) -> Self {
184        let n = self.norm();
185        debug_assert!(n > 0.0, "cannot normalize zero-length SRobotQ");
186        *self / n
187    }
188
189    pub fn distance(&self, other: &Self) -> f32 {
190        (*self - *other).norm()
191    }
192
193    pub fn distance_squared(&self, other: &Self) -> f32 {
194        (*self - *other).norm_squared()
195    }
196
197    pub fn abs(&self) -> Self {
198        if N <= 16 {
199            let mut out = [0.0; N];
200            simd_unaryop(&self.0, &mut out, |a| a.abs());
201            Self(out)
202        } else {
203            self.map(f32::abs)
204        }
205    }
206
207    pub fn clamp(&self, min: &Self, max: &Self) -> Self {
208        if N <= 16 {
209            let mut out = [0.0; N];
210            let mut off = 0;
211            while off < N {
212                let v = simd_load(&self.0, off);
213                let lo = simd_load(&min.0, off);
214                let hi = simd_load(&max.0, off);
215                simd_store(v.fast_max(lo).fast_min(hi), &mut out, off);
216                off += 8;
217            }
218            Self(out)
219        } else {
220            let mut out = [0.0; N];
221            for i in 0..N {
222                out[i] = self.0[i].clamp(min.0[i], max.0[i]);
223            }
224            Self(out)
225        }
226    }
227
228    pub fn clamp_scalar(&self, min: f32, max: f32) -> Self {
229        if N <= 16 {
230            let mut out = [0.0; N];
231            let lo = f32x8::splat(min);
232            let hi = f32x8::splat(max);
233            let mut off = 0;
234            while off < N {
235                let v = simd_load(&self.0, off);
236                simd_store(v.fast_max(lo).fast_min(hi), &mut out, off);
237                off += 8;
238            }
239            Self(out)
240        } else {
241            self.map(|x| x.clamp(min, max))
242        }
243    }
244
245    pub fn max_element(&self) -> f32 {
246        if N <= 16 {
247            let mut acc = f32x8::splat(f32::NEG_INFINITY);
248            let mut off = 0;
249            while off < N {
250                acc = acc.fast_max(simd_load_neg_inf(&self.0, off));
251                off += 8;
252            }
253            let a = acc.to_array();
254            a[0].max(a[1])
255                .max(a[2].max(a[3]))
256                .max(a[4].max(a[5]).max(a[6].max(a[7])))
257        } else {
258            self.0.iter().copied().fold(f32::NEG_INFINITY, f32::max)
259        }
260    }
261
262    pub fn min_element(&self) -> f32 {
263        if N <= 16 {
264            let mut acc = f32x8::splat(f32::INFINITY);
265            let mut off = 0;
266            while off < N {
267                acc = acc.fast_min(simd_load_inf(&self.0, off));
268                off += 8;
269            }
270            let a = acc.to_array();
271            a[0].min(a[1])
272                .min(a[2].min(a[3]))
273                .min(a[4].min(a[5]).min(a[6].min(a[7])))
274        } else {
275            self.0.iter().copied().fold(f32::INFINITY, f32::min)
276        }
277    }
278
279    pub fn linf_norm(&self) -> f32 {
280        self.abs().max_element()
281    }
282
283    pub fn elementwise_mul(&self, other: &Self) -> Self {
284        let mut out = [0.0; N];
285        if N <= 16 {
286            simd_binop(&self.0, &other.0, &mut out, |a, b| a * b);
287        } else {
288            for i in 0..N {
289                out[i] = self.0[i] * other.0[i];
290            }
291        }
292        Self(out)
293    }
294
295    pub fn elementwise_div(&self, other: &Self) -> Self {
296        let mut out = [0.0; N];
297        if N <= 16 {
298            simd_binop(&self.0, &other.0, &mut out, |a, b| a / b);
299        } else {
300            for i in 0..N {
301                out[i] = self.0[i] / other.0[i];
302            }
303        }
304        Self(out)
305    }
306
307    pub fn zip_map(&self, other: &Self, f: impl Fn(f32, f32) -> f32) -> Self {
308        let mut out = [0.0; N];
309        for i in 0..N {
310            out[i] = f(self.0[i], other.0[i]);
311        }
312        Self(out)
313    }
314
315    pub fn sqrt(&self) -> Self {
316        if N <= 16 {
317            let mut out = [0.0; N];
318            simd_unaryop(&self.0, &mut out, |a| a.sqrt());
319            Self(out)
320        } else {
321            self.map(f32::sqrt)
322        }
323    }
324
325    pub fn mul_add(&self, mul: &Self, add: &Self) -> Self {
326        if N <= 16 {
327            let mut out = [0.0; N];
328            let mut off = 0;
329            while off < N {
330                let a = simd_load(&self.0, off);
331                let m = simd_load(&mul.0, off);
332                let d = simd_load(&add.0, off);
333                simd_store(a.mul_add(m, d), &mut out, off);
334                off += 8;
335            }
336            Self(out)
337        } else {
338            let mut out = [0.0; N];
339            for i in 0..N {
340                out[i] = self.0[i].mul_add(mul.0[i], add.0[i]);
341            }
342            Self(out)
343        }
344    }
345
346    /// Returns `true` if any element of `self` is greater than the corresponding element of `other`.
347    pub fn any_non_finite(&self) -> bool {
348        let mut off = 0;
349        while off < N {
350            let v = simd_load(&self.0, off);
351            let bad = v.is_nan() | v.is_inf();
352            if (bad.to_bitmask() & Self::lane_mask(off)) != 0 {
353                return true;
354            }
355            off += 8;
356        }
357        false
358    }
359
360    pub fn any_gt(&self, other: &Self) -> bool {
361        let mut off = 0;
362        while off < N {
363            let a = simd_load(&self.0, off);
364            let b = simd_load(&other.0, off);
365            if (a.simd_gt(b).to_bitmask() & Self::lane_mask(off)) != 0 {
366                return true;
367            }
368            off += 8;
369        }
370        false
371    }
372
373    /// Returns `true` if any element of `self` is less than the corresponding element of `other`.
374    pub fn any_lt(&self, other: &Self) -> bool {
375        let mut off = 0;
376        while off < N {
377            let a = simd_load(&self.0, off);
378            let b = simd_load(&other.0, off);
379            if (a.simd_lt(b).to_bitmask() & Self::lane_mask(off)) != 0 {
380                return true;
381            }
382            off += 8;
383        }
384        false
385    }
386
387    #[inline(always)]
388    const fn lane_mask(off: usize) -> u32 {
389        let active = N.saturating_sub(off);
390        if active >= 8 {
391            0b11111111
392        } else {
393            (1 << active) - 1
394        }
395    }
396
397    pub fn is_close(&self, other: &Self, tol: f32) -> bool {
398        let diff = *self - *other;
399        diff.dot(&diff).sqrt() < tol
400    }
401
402    pub fn interpolate(&self, other: &Self, t: f32) -> Self {
403        *self + ((*other - *self) * t)
404    }
405}
406
407impl<const N: usize> std::ops::Index<usize> for SRobotQ<N> {
408    type Output = f32;
409    #[inline]
410    fn index(&self, i: usize) -> &f32 {
411        &self.0[i]
412    }
413}
414
415impl<const N: usize> std::ops::IndexMut<usize> for SRobotQ<N> {
416    #[inline]
417    fn index_mut(&mut self, i: usize) -> &mut f32 {
418        &mut self.0[i]
419    }
420}
421
422impl<const N: usize> std::ops::Add for SRobotQ<N> {
423    type Output = Self;
424    #[inline]
425    fn add(self, rhs: Self) -> Self {
426        let mut out = [0.0; N];
427        if N <= 16 {
428            simd_binop(&self.0, &rhs.0, &mut out, |a, b| a + b);
429        } else {
430            for i in 0..N {
431                out[i] = self.0[i] + rhs.0[i];
432            }
433        }
434        Self(out)
435    }
436}
437
438impl<const N: usize> std::ops::Sub for SRobotQ<N> {
439    type Output = Self;
440    #[inline]
441    fn sub(self, rhs: Self) -> Self {
442        let mut out = [0.0; N];
443        if N <= 16 {
444            simd_binop(&self.0, &rhs.0, &mut out, |a, b| a - b);
445        } else {
446            for i in 0..N {
447                out[i] = self.0[i] - rhs.0[i];
448            }
449        }
450        Self(out)
451    }
452}
453
454impl<const N: usize> std::ops::Neg for SRobotQ<N> {
455    type Output = Self;
456    #[inline]
457    fn neg(self) -> Self {
458        let mut out = [0.0; N];
459        if N <= 16 {
460            simd_unaryop(&self.0, &mut out, |a| f32x8::ZERO - a);
461        } else {
462            for i in 0..N {
463                out[i] = -self.0[i];
464            }
465        }
466        Self(out)
467    }
468}
469
470impl<const N: usize> std::ops::Mul<f32> for SRobotQ<N> {
471    type Output = Self;
472    #[inline]
473    fn mul(self, rhs: f32) -> Self {
474        let mut out = [0.0; N];
475        if N <= 16 {
476            simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a * s);
477        } else {
478            for i in 0..N {
479                out[i] = self.0[i] * rhs;
480            }
481        }
482        Self(out)
483    }
484}
485
486impl<const N: usize> std::ops::Mul<SRobotQ<N>> for f32 {
487    type Output = SRobotQ<N>;
488    #[inline]
489    fn mul(self, rhs: SRobotQ<N>) -> SRobotQ<N> {
490        rhs * self
491    }
492}
493
494impl<const N: usize> std::ops::Div<f32> for SRobotQ<N> {
495    type Output = Self;
496    #[inline]
497    fn div(self, rhs: f32) -> Self {
498        let mut out = [0.0; N];
499        if N <= 16 {
500            simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a / s);
501        } else {
502            for i in 0..N {
503                out[i] = self.0[i] / rhs;
504            }
505        }
506        Self(out)
507    }
508}
509
510impl<const N: usize> std::ops::AddAssign for SRobotQ<N> {
511    #[inline]
512    fn add_assign(&mut self, rhs: Self) {
513        if N <= 16 {
514            let mut out = [0.0; N];
515            simd_binop(&self.0, &rhs.0, &mut out, |a, b| a + b);
516            self.0 = out;
517        } else {
518            for i in 0..N {
519                self.0[i] += rhs.0[i];
520            }
521        }
522    }
523}
524
525impl<const N: usize> std::ops::SubAssign for SRobotQ<N> {
526    #[inline]
527    fn sub_assign(&mut self, rhs: Self) {
528        if N <= 16 {
529            let mut out = [0.0; N];
530            simd_binop(&self.0, &rhs.0, &mut out, |a, b| a - b);
531            self.0 = out;
532        } else {
533            for i in 0..N {
534                self.0[i] -= rhs.0[i];
535            }
536        }
537    }
538}
539
540impl<const N: usize> std::ops::MulAssign<f32> for SRobotQ<N> {
541    #[inline]
542    fn mul_assign(&mut self, rhs: f32) {
543        if N <= 16 {
544            let mut out = [0.0; N];
545            simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a * s);
546            self.0 = out;
547        } else {
548            for i in 0..N {
549                self.0[i] *= rhs;
550            }
551        }
552    }
553}
554
555impl<const N: usize> std::ops::DivAssign<f32> for SRobotQ<N> {
556    #[inline]
557    fn div_assign(&mut self, rhs: f32) {
558        if N <= 16 {
559            let mut out = [0.0; N];
560            simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a / s);
561            self.0 = out;
562        } else {
563            for i in 0..N {
564                self.0[i] /= rhs;
565            }
566        }
567    }
568}
569
570impl<const N: usize> std::ops::Add<SRobotQ<N>> for &RobotQ {
571    type Output = SRobotQ<N>;
572    #[inline]
573    fn add(self, rhs: SRobotQ<N>) -> SRobotQ<N> {
574        SRobotQ::<N>::force_from_robotq(self) + rhs
575    }
576}
577
578impl<const N: usize> std::ops::Sub<SRobotQ<N>> for &RobotQ {
579    type Output = SRobotQ<N>;
580    #[inline]
581    fn sub(self, rhs: SRobotQ<N>) -> SRobotQ<N> {
582        SRobotQ::<N>::force_from_robotq(self) - rhs
583    }
584}
585
586impl<const N: usize> Default for SRobotQ<N> {
587    #[inline]
588    fn default() -> Self {
589        Self::zeros()
590    }
591}
592
593impl<const N: usize> AsRef<[f32; N]> for SRobotQ<N> {
594    #[inline]
595    fn as_ref(&self) -> &[f32; N] {
596        &self.0
597    }
598}
599
600impl<const N: usize> AsMut<[f32; N]> for SRobotQ<N> {
601    #[inline]
602    fn as_mut(&mut self) -> &mut [f32; N] {
603        &mut self.0
604    }
605}
606
607impl<const N: usize> AsRef<[f32]> for SRobotQ<N> {
608    #[inline]
609    fn as_ref(&self) -> &[f32] {
610        &self.0
611    }
612}
613
614impl<const N: usize> AsMut<[f32]> for SRobotQ<N> {
615    #[inline]
616    fn as_mut(&mut self) -> &mut [f32] {
617        &mut self.0
618    }
619}
620
621impl<const N: usize> From<[f32; N]> for SRobotQ<N> {
622    #[inline]
623    fn from(arr: [f32; N]) -> Self {
624        Self(arr)
625    }
626}
627
628impl<const N: usize> From<&[f32; N]> for SRobotQ<N> {
629    #[inline]
630    fn from(arr: &[f32; N]) -> Self {
631        Self(*arr)
632    }
633}
634
635impl<const N: usize> From<[f64; N]> for SRobotQ<N> {
636    #[inline]
637    fn from(arr: [f64; N]) -> Self {
638        let mut out = [0.0f32; N];
639        let mut i = 0;
640        while i < N {
641            out[i] = arr[i] as f32;
642            i += 1;
643        }
644        Self(out)
645    }
646}
647
648impl<const N: usize> From<&[f64; N]> for SRobotQ<N> {
649    #[inline]
650    fn from(arr: &[f64; N]) -> Self {
651        Self::from(*arr)
652    }
653}
654
655impl<const N: usize> From<SRobotQ<N>> for [f32; N] {
656    #[inline]
657    fn from(q: SRobotQ<N>) -> [f32; N] {
658        q.0
659    }
660}
661
662impl<const N: usize> From<SRobotQ<N>> for Vec<f32> {
663    #[inline]
664    fn from(q: SRobotQ<N>) -> Vec<f32> {
665        q.0.to_vec()
666    }
667}
668
669impl<const N: usize> From<SRobotQ<N>> for RobotQ {
670    #[inline]
671    fn from(q: SRobotQ<N>) -> RobotQ {
672        q.to_robotq()
673    }
674}
675
676impl<const N: usize> TryFrom<&SRobotQ<N>> for SRobotQ<N> {
677    type Error = DekeError;
678
679    #[inline]
680    fn try_from(q: &SRobotQ<N>) -> Result<Self, Self::Error> {
681        Ok(*q)
682    }
683}
684
685impl<const N: usize> TryFrom<&[f32]> for SRobotQ<N> {
686    type Error = DekeError;
687
688    #[inline]
689    fn try_from(slice: &[f32]) -> Result<Self, Self::Error> {
690        if slice.len() != N {
691            return Err(DekeError::ShapeMismatch {
692                expected: N,
693                found: slice.len(),
694            });
695        }
696        let mut arr = [0.0; N];
697        arr.copy_from_slice(slice);
698        Ok(Self(arr))
699    }
700}
701
702impl<const N: usize> TryFrom<Vec<f32>> for SRobotQ<N> {
703    type Error = DekeError;
704
705    #[inline]
706    fn try_from(v: Vec<f32>) -> Result<Self, Self::Error> {
707        Self::try_from(v.as_slice())
708    }
709}
710
711impl<const N: usize> TryFrom<&Vec<f32>> for SRobotQ<N> {
712    type Error = DekeError;
713
714    #[inline]
715    fn try_from(v: &Vec<f32>) -> Result<Self, Self::Error> {
716        Self::try_from(v.as_slice())
717    }
718}
719
720impl<const N: usize> TryFrom<&[f64]> for SRobotQ<N> {
721    type Error = DekeError;
722
723    #[inline]
724    fn try_from(slice: &[f64]) -> Result<Self, Self::Error> {
725        if slice.len() != N {
726            return Err(DekeError::ShapeMismatch {
727                expected: N,
728                found: slice.len(),
729            });
730        }
731        let mut arr = [0.0f32; N];
732        let mut i = 0;
733        while i < N {
734            arr[i] = slice[i] as f32;
735            i += 1;
736        }
737        Ok(Self(arr))
738    }
739}
740
741impl<const N: usize> TryFrom<Vec<f64>> for SRobotQ<N> {
742    type Error = DekeError;
743
744    #[inline]
745    fn try_from(v: Vec<f64>) -> Result<Self, Self::Error> {
746        Self::try_from(v.as_slice())
747    }
748}
749
750impl<const N: usize> TryFrom<&Vec<f64>> for SRobotQ<N> {
751    type Error = DekeError;
752
753    #[inline]
754    fn try_from(v: &Vec<f64>) -> Result<Self, Self::Error> {
755        Self::try_from(v.as_slice())
756    }
757}
758
759impl<const N: usize> TryFrom<&RobotQ> for SRobotQ<N> {
760    type Error = DekeError;
761
762    #[inline]
763    fn try_from(q: &RobotQ) -> Result<Self, Self::Error> {
764        let slice = q.as_slice().unwrap_or(&[]);
765        if slice.len() != N {
766            return Err(DekeError::ShapeMismatch {
767                expected: N,
768                found: slice.len(),
769            });
770        }
771        let mut arr = [0.0; N];
772        arr.copy_from_slice(slice);
773        Ok(Self(arr))
774    }
775}
776
777impl<const N: usize> TryFrom<RobotQ> for SRobotQ<N> {
778    type Error = DekeError;
779
780    #[inline]
781    fn try_from(q: RobotQ) -> Result<Self, Self::Error> {
782        let slice = q.as_slice().unwrap_or(&[]);
783        if slice.len() != N {
784            return Err(DekeError::ShapeMismatch {
785                expected: N,
786                found: slice.len(),
787            });
788        }
789        let mut arr = [0.0; N];
790        arr.copy_from_slice(slice);
791        Ok(Self(arr))
792    }
793}