1use ndarray::Array1;
2use wide::{CmpGt, CmpLt, f32x8};
3
4use crate::DekeError;
5
6#[inline(always)]
7fn simd_load(slice: &[f32], off: usize) -> f32x8 {
8 let n = 8.min(slice.len().saturating_sub(off));
9 let mut buf = [0.0; 8];
10 buf[..n].copy_from_slice(&slice[off..off + n]);
11 f32x8::new(buf)
12}
13
14#[inline(always)]
15fn simd_store(v: f32x8, dst: &mut [f32], off: usize) {
16 let n = 8.min(dst.len().saturating_sub(off));
17 dst[off..off + n].copy_from_slice(&v.to_array()[..n]);
18}
19
20#[inline(always)]
21fn simd_binop<const N: usize>(
22 a: &[f32; N],
23 b: &[f32; N],
24 out: &mut [f32; N],
25 op: fn(f32x8, f32x8) -> f32x8,
26) {
27 let mut off = 0;
28 while off < N {
29 simd_store(op(simd_load(a, off), simd_load(b, off)), out, off);
30 off += 8;
31 }
32}
33
34#[inline(always)]
35fn simd_unaryop<const N: usize>(a: &[f32; N], out: &mut [f32; N], op: fn(f32x8) -> f32x8) {
36 let mut off = 0;
37 while off < N {
38 simd_store(op(simd_load(a, off)), out, off);
39 off += 8;
40 }
41}
42
43#[inline(always)]
44fn simd_scalarop<const N: usize>(
45 a: &[f32; N],
46 s: f32x8,
47 out: &mut [f32; N],
48 op: fn(f32x8, f32x8) -> f32x8,
49) {
50 let mut off = 0;
51 while off < N {
52 simd_store(op(simd_load(a, off), s), out, off);
53 off += 8;
54 }
55}
56
57#[inline(always)]
58fn simd_hsum<const N: usize>(a: &[f32; N]) -> f32 {
59 let mut acc = f32x8::ZERO;
60 let mut off = 0;
61 while off < N {
62 acc += simd_load(a, off);
63 off += 8;
64 }
65 acc.reduce_add()
66}
67
68#[inline(always)]
69fn simd_load_neg_inf(slice: &[f32], off: usize) -> f32x8 {
70 let n = 8.min(slice.len().saturating_sub(off));
71 let mut buf = [f32::NEG_INFINITY; 8];
72 buf[..n].copy_from_slice(&slice[off..off + n]);
73 f32x8::new(buf)
74}
75
76#[inline(always)]
77fn simd_load_inf(slice: &[f32], off: usize) -> f32x8 {
78 let n = 8.min(slice.len().saturating_sub(off));
79 let mut buf = [f32::INFINITY; 8];
80 buf[..n].copy_from_slice(&slice[off..off + n]);
81 f32x8::new(buf)
82}
83
84#[inline(always)]
85fn simd_dot<const N: usize>(a: &[f32; N], b: &[f32; N]) -> f32 {
86 let mut acc = f32x8::ZERO;
87 let mut off = 0;
88 while off < N {
89 acc = simd_load(a, off).mul_add(simd_load(b, off), acc);
90 off += 8;
91 }
92 acc.reduce_add()
93}
94
95pub type RobotQ = Array1<f32>;
96
97#[derive(Debug, Clone, Copy, PartialEq)]
99pub struct SRobotQ<const N: usize>(pub [f32; N]);
100
101impl<const N: usize> SRobotQ<N> {
102 pub const fn zeros() -> Self {
103 Self([0.0; N])
104 }
105
106 pub const fn from_array(arr: [f32; N]) -> Self {
107 Self(arr)
108 }
109
110 pub const fn as_slice(&self) -> &[f32] {
111 &self.0
112 }
113
114 pub const fn as_mut_slice(&mut self) -> &mut [f32] {
115 &mut self.0
116 }
117
118 pub fn to_robotq(&self) -> RobotQ {
119 RobotQ::from(self.0.to_vec())
120 }
121
122 pub fn force_from_robotq(q: &RobotQ) -> Self {
123 if let Ok(sq) = Self::try_from(q) {
124 sq
125 } else {
126 let slice = q.as_slice().unwrap_or(&[]);
127 let mut arr = [0.0; N];
128 for i in 0..N {
129 arr[i] = *slice.get(i).unwrap_or(&0.0);
130 }
131 Self(arr)
132 }
133 }
134
135 pub fn norm(&self) -> f32 {
136 if N <= 16 {
137 self.dot(self).sqrt()
138 } else {
139 self.0.iter().map(|x| x * x).sum::<f32>().sqrt()
140 }
141 }
142
143 pub fn dot(&self, other: &Self) -> f32 {
144 if N <= 16 {
145 simd_dot(&self.0, &other.0)
146 } else {
147 self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum()
148 }
149 }
150
151 pub fn map(&self, f: impl Fn(f32) -> f32) -> Self {
152 let mut out = [0.0; N];
153 for i in 0..N {
154 out[i] = f(self.0[i]);
155 }
156 Self(out)
157 }
158
159 pub fn sum(&self) -> f32 {
160 if N <= 16 {
161 simd_hsum(&self.0)
162 } else {
163 self.0.iter().sum()
164 }
165 }
166
167 pub fn splat(val: f32) -> Self {
168 Self([val; N])
169 }
170
171 pub fn from_fn(f: impl Fn(usize) -> f32) -> Self {
172 let mut out = [0.0; N];
173 for i in 0..N {
174 out[i] = f(i);
175 }
176 Self(out)
177 }
178
179 pub fn norm_squared(&self) -> f32 {
180 self.dot(self)
181 }
182
183 pub fn normalize(&self) -> Self {
184 let n = self.norm();
185 debug_assert!(n > 0.0, "cannot normalize zero-length SRobotQ");
186 *self / n
187 }
188
189 pub fn distance(&self, other: &Self) -> f32 {
190 (*self - *other).norm()
191 }
192
193 pub fn distance_squared(&self, other: &Self) -> f32 {
194 (*self - *other).norm_squared()
195 }
196
197 pub fn abs(&self) -> Self {
198 if N <= 16 {
199 let mut out = [0.0; N];
200 simd_unaryop(&self.0, &mut out, |a| a.abs());
201 Self(out)
202 } else {
203 self.map(f32::abs)
204 }
205 }
206
207 pub fn clamp(&self, min: &Self, max: &Self) -> Self {
208 if N <= 16 {
209 let mut out = [0.0; N];
210 let mut off = 0;
211 while off < N {
212 let v = simd_load(&self.0, off);
213 let lo = simd_load(&min.0, off);
214 let hi = simd_load(&max.0, off);
215 simd_store(v.fast_max(lo).fast_min(hi), &mut out, off);
216 off += 8;
217 }
218 Self(out)
219 } else {
220 let mut out = [0.0; N];
221 for i in 0..N {
222 out[i] = self.0[i].clamp(min.0[i], max.0[i]);
223 }
224 Self(out)
225 }
226 }
227
228 pub fn clamp_scalar(&self, min: f32, max: f32) -> Self {
229 if N <= 16 {
230 let mut out = [0.0; N];
231 let lo = f32x8::splat(min);
232 let hi = f32x8::splat(max);
233 let mut off = 0;
234 while off < N {
235 let v = simd_load(&self.0, off);
236 simd_store(v.fast_max(lo).fast_min(hi), &mut out, off);
237 off += 8;
238 }
239 Self(out)
240 } else {
241 self.map(|x| x.clamp(min, max))
242 }
243 }
244
245 pub fn max_element(&self) -> f32 {
246 if N <= 16 {
247 let mut acc = f32x8::splat(f32::NEG_INFINITY);
248 let mut off = 0;
249 while off < N {
250 acc = acc.fast_max(simd_load_neg_inf(&self.0, off));
251 off += 8;
252 }
253 let a = acc.to_array();
254 a[0].max(a[1])
255 .max(a[2].max(a[3]))
256 .max(a[4].max(a[5]).max(a[6].max(a[7])))
257 } else {
258 self.0.iter().copied().fold(f32::NEG_INFINITY, f32::max)
259 }
260 }
261
262 pub fn min_element(&self) -> f32 {
263 if N <= 16 {
264 let mut acc = f32x8::splat(f32::INFINITY);
265 let mut off = 0;
266 while off < N {
267 acc = acc.fast_min(simd_load_inf(&self.0, off));
268 off += 8;
269 }
270 let a = acc.to_array();
271 a[0].min(a[1])
272 .min(a[2].min(a[3]))
273 .min(a[4].min(a[5]).min(a[6].min(a[7])))
274 } else {
275 self.0.iter().copied().fold(f32::INFINITY, f32::min)
276 }
277 }
278
279 pub fn linf_norm(&self) -> f32 {
280 self.abs().max_element()
281 }
282
283 pub fn elementwise_mul(&self, other: &Self) -> Self {
284 let mut out = [0.0; N];
285 if N <= 16 {
286 simd_binop(&self.0, &other.0, &mut out, |a, b| a * b);
287 } else {
288 for i in 0..N {
289 out[i] = self.0[i] * other.0[i];
290 }
291 }
292 Self(out)
293 }
294
295 pub fn elementwise_div(&self, other: &Self) -> Self {
296 let mut out = [0.0; N];
297 if N <= 16 {
298 simd_binop(&self.0, &other.0, &mut out, |a, b| a / b);
299 } else {
300 for i in 0..N {
301 out[i] = self.0[i] / other.0[i];
302 }
303 }
304 Self(out)
305 }
306
307 pub fn zip_map(&self, other: &Self, f: impl Fn(f32, f32) -> f32) -> Self {
308 let mut out = [0.0; N];
309 for i in 0..N {
310 out[i] = f(self.0[i], other.0[i]);
311 }
312 Self(out)
313 }
314
315 pub fn sqrt(&self) -> Self {
316 if N <= 16 {
317 let mut out = [0.0; N];
318 simd_unaryop(&self.0, &mut out, |a| a.sqrt());
319 Self(out)
320 } else {
321 self.map(f32::sqrt)
322 }
323 }
324
325 pub fn mul_add(&self, mul: &Self, add: &Self) -> Self {
326 if N <= 16 {
327 let mut out = [0.0; N];
328 let mut off = 0;
329 while off < N {
330 let a = simd_load(&self.0, off);
331 let m = simd_load(&mul.0, off);
332 let d = simd_load(&add.0, off);
333 simd_store(a.mul_add(m, d), &mut out, off);
334 off += 8;
335 }
336 Self(out)
337 } else {
338 let mut out = [0.0; N];
339 for i in 0..N {
340 out[i] = self.0[i].mul_add(mul.0[i], add.0[i]);
341 }
342 Self(out)
343 }
344 }
345
346 pub fn any_non_finite(&self) -> bool {
348 let mut off = 0;
349 while off < N {
350 let v = simd_load(&self.0, off);
351 let bad = v.is_nan() | v.is_inf();
352 if (bad.to_bitmask() & Self::lane_mask(off)) != 0 {
353 return true;
354 }
355 off += 8;
356 }
357 false
358 }
359
360 pub fn any_gt(&self, other: &Self) -> bool {
361 let mut off = 0;
362 while off < N {
363 let a = simd_load(&self.0, off);
364 let b = simd_load(&other.0, off);
365 if (a.simd_gt(b).to_bitmask() & Self::lane_mask(off)) != 0 {
366 return true;
367 }
368 off += 8;
369 }
370 false
371 }
372
373 pub fn any_lt(&self, other: &Self) -> bool {
375 let mut off = 0;
376 while off < N {
377 let a = simd_load(&self.0, off);
378 let b = simd_load(&other.0, off);
379 if (a.simd_lt(b).to_bitmask() & Self::lane_mask(off)) != 0 {
380 return true;
381 }
382 off += 8;
383 }
384 false
385 }
386
387 #[inline(always)]
388 const fn lane_mask(off: usize) -> u32 {
389 let active = N.saturating_sub(off);
390 if active >= 8 {
391 0b11111111
392 } else {
393 (1 << active) - 1
394 }
395 }
396
397 pub fn is_close(&self, other: &Self, tol: f32) -> bool {
398 let diff = *self - *other;
399 diff.dot(&diff).sqrt() < tol
400 }
401
402 pub fn interpolate(&self, other: &Self, t: f32) -> Self {
403 *self + ((*other - *self) * t)
404 }
405}
406
407impl<const N: usize> std::ops::Index<usize> for SRobotQ<N> {
408 type Output = f32;
409 #[inline]
410 fn index(&self, i: usize) -> &f32 {
411 &self.0[i]
412 }
413}
414
415impl<const N: usize> std::ops::IndexMut<usize> for SRobotQ<N> {
416 #[inline]
417 fn index_mut(&mut self, i: usize) -> &mut f32 {
418 &mut self.0[i]
419 }
420}
421
422impl<const N: usize> std::ops::Add for SRobotQ<N> {
423 type Output = Self;
424 #[inline]
425 fn add(self, rhs: Self) -> Self {
426 let mut out = [0.0; N];
427 if N <= 16 {
428 simd_binop(&self.0, &rhs.0, &mut out, |a, b| a + b);
429 } else {
430 for i in 0..N {
431 out[i] = self.0[i] + rhs.0[i];
432 }
433 }
434 Self(out)
435 }
436}
437
438impl<const N: usize> std::ops::Sub for SRobotQ<N> {
439 type Output = Self;
440 #[inline]
441 fn sub(self, rhs: Self) -> Self {
442 let mut out = [0.0; N];
443 if N <= 16 {
444 simd_binop(&self.0, &rhs.0, &mut out, |a, b| a - b);
445 } else {
446 for i in 0..N {
447 out[i] = self.0[i] - rhs.0[i];
448 }
449 }
450 Self(out)
451 }
452}
453
454impl<const N: usize> std::ops::Neg for SRobotQ<N> {
455 type Output = Self;
456 #[inline]
457 fn neg(self) -> Self {
458 let mut out = [0.0; N];
459 if N <= 16 {
460 simd_unaryop(&self.0, &mut out, |a| f32x8::ZERO - a);
461 } else {
462 for i in 0..N {
463 out[i] = -self.0[i];
464 }
465 }
466 Self(out)
467 }
468}
469
470impl<const N: usize> std::ops::Mul<f32> for SRobotQ<N> {
471 type Output = Self;
472 #[inline]
473 fn mul(self, rhs: f32) -> Self {
474 let mut out = [0.0; N];
475 if N <= 16 {
476 simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a * s);
477 } else {
478 for i in 0..N {
479 out[i] = self.0[i] * rhs;
480 }
481 }
482 Self(out)
483 }
484}
485
486impl<const N: usize> std::ops::Mul<SRobotQ<N>> for f32 {
487 type Output = SRobotQ<N>;
488 #[inline]
489 fn mul(self, rhs: SRobotQ<N>) -> SRobotQ<N> {
490 rhs * self
491 }
492}
493
494impl<const N: usize> std::ops::Div<f32> for SRobotQ<N> {
495 type Output = Self;
496 #[inline]
497 fn div(self, rhs: f32) -> Self {
498 let mut out = [0.0; N];
499 if N <= 16 {
500 simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a / s);
501 } else {
502 for i in 0..N {
503 out[i] = self.0[i] / rhs;
504 }
505 }
506 Self(out)
507 }
508}
509
510impl<const N: usize> std::ops::AddAssign for SRobotQ<N> {
511 #[inline]
512 fn add_assign(&mut self, rhs: Self) {
513 if N <= 16 {
514 let mut out = [0.0; N];
515 simd_binop(&self.0, &rhs.0, &mut out, |a, b| a + b);
516 self.0 = out;
517 } else {
518 for i in 0..N {
519 self.0[i] += rhs.0[i];
520 }
521 }
522 }
523}
524
525impl<const N: usize> std::ops::SubAssign for SRobotQ<N> {
526 #[inline]
527 fn sub_assign(&mut self, rhs: Self) {
528 if N <= 16 {
529 let mut out = [0.0; N];
530 simd_binop(&self.0, &rhs.0, &mut out, |a, b| a - b);
531 self.0 = out;
532 } else {
533 for i in 0..N {
534 self.0[i] -= rhs.0[i];
535 }
536 }
537 }
538}
539
540impl<const N: usize> std::ops::MulAssign<f32> for SRobotQ<N> {
541 #[inline]
542 fn mul_assign(&mut self, rhs: f32) {
543 if N <= 16 {
544 let mut out = [0.0; N];
545 simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a * s);
546 self.0 = out;
547 } else {
548 for i in 0..N {
549 self.0[i] *= rhs;
550 }
551 }
552 }
553}
554
555impl<const N: usize> std::ops::DivAssign<f32> for SRobotQ<N> {
556 #[inline]
557 fn div_assign(&mut self, rhs: f32) {
558 if N <= 16 {
559 let mut out = [0.0; N];
560 simd_scalarop(&self.0, f32x8::splat(rhs), &mut out, |a, s| a / s);
561 self.0 = out;
562 } else {
563 for i in 0..N {
564 self.0[i] /= rhs;
565 }
566 }
567 }
568}
569
570impl<const N: usize> std::ops::Add<SRobotQ<N>> for &RobotQ {
571 type Output = SRobotQ<N>;
572 #[inline]
573 fn add(self, rhs: SRobotQ<N>) -> SRobotQ<N> {
574 SRobotQ::<N>::force_from_robotq(self) + rhs
575 }
576}
577
578impl<const N: usize> std::ops::Sub<SRobotQ<N>> for &RobotQ {
579 type Output = SRobotQ<N>;
580 #[inline]
581 fn sub(self, rhs: SRobotQ<N>) -> SRobotQ<N> {
582 SRobotQ::<N>::force_from_robotq(self) - rhs
583 }
584}
585
586impl<const N: usize> Default for SRobotQ<N> {
587 #[inline]
588 fn default() -> Self {
589 Self::zeros()
590 }
591}
592
593impl<const N: usize> AsRef<[f32; N]> for SRobotQ<N> {
594 #[inline]
595 fn as_ref(&self) -> &[f32; N] {
596 &self.0
597 }
598}
599
600impl<const N: usize> AsMut<[f32; N]> for SRobotQ<N> {
601 #[inline]
602 fn as_mut(&mut self) -> &mut [f32; N] {
603 &mut self.0
604 }
605}
606
607impl<const N: usize> AsRef<[f32]> for SRobotQ<N> {
608 #[inline]
609 fn as_ref(&self) -> &[f32] {
610 &self.0
611 }
612}
613
614impl<const N: usize> AsMut<[f32]> for SRobotQ<N> {
615 #[inline]
616 fn as_mut(&mut self) -> &mut [f32] {
617 &mut self.0
618 }
619}
620
621impl<const N: usize> From<[f32; N]> for SRobotQ<N> {
622 #[inline]
623 fn from(arr: [f32; N]) -> Self {
624 Self(arr)
625 }
626}
627
628impl<const N: usize> From<&[f32; N]> for SRobotQ<N> {
629 #[inline]
630 fn from(arr: &[f32; N]) -> Self {
631 Self(*arr)
632 }
633}
634
635impl<const N: usize> From<[f64; N]> for SRobotQ<N> {
636 #[inline]
637 fn from(arr: [f64; N]) -> Self {
638 let mut out = [0.0f32; N];
639 let mut i = 0;
640 while i < N {
641 out[i] = arr[i] as f32;
642 i += 1;
643 }
644 Self(out)
645 }
646}
647
648impl<const N: usize> From<&[f64; N]> for SRobotQ<N> {
649 #[inline]
650 fn from(arr: &[f64; N]) -> Self {
651 Self::from(*arr)
652 }
653}
654
655impl<const N: usize> From<SRobotQ<N>> for [f32; N] {
656 #[inline]
657 fn from(q: SRobotQ<N>) -> [f32; N] {
658 q.0
659 }
660}
661
662impl<const N: usize> From<SRobotQ<N>> for Vec<f32> {
663 #[inline]
664 fn from(q: SRobotQ<N>) -> Vec<f32> {
665 q.0.to_vec()
666 }
667}
668
669impl<const N: usize> From<SRobotQ<N>> for RobotQ {
670 #[inline]
671 fn from(q: SRobotQ<N>) -> RobotQ {
672 q.to_robotq()
673 }
674}
675
676impl<const N: usize> TryFrom<&SRobotQ<N>> for SRobotQ<N> {
677 type Error = DekeError;
678
679 #[inline]
680 fn try_from(q: &SRobotQ<N>) -> Result<Self, Self::Error> {
681 Ok(*q)
682 }
683}
684
685impl<const N: usize> TryFrom<&[f32]> for SRobotQ<N> {
686 type Error = DekeError;
687
688 #[inline]
689 fn try_from(slice: &[f32]) -> Result<Self, Self::Error> {
690 if slice.len() != N {
691 return Err(DekeError::ShapeMismatch {
692 expected: N,
693 found: slice.len(),
694 });
695 }
696 let mut arr = [0.0; N];
697 arr.copy_from_slice(slice);
698 Ok(Self(arr))
699 }
700}
701
702impl<const N: usize> TryFrom<Vec<f32>> for SRobotQ<N> {
703 type Error = DekeError;
704
705 #[inline]
706 fn try_from(v: Vec<f32>) -> Result<Self, Self::Error> {
707 Self::try_from(v.as_slice())
708 }
709}
710
711impl<const N: usize> TryFrom<&Vec<f32>> for SRobotQ<N> {
712 type Error = DekeError;
713
714 #[inline]
715 fn try_from(v: &Vec<f32>) -> Result<Self, Self::Error> {
716 Self::try_from(v.as_slice())
717 }
718}
719
720impl<const N: usize> TryFrom<&[f64]> for SRobotQ<N> {
721 type Error = DekeError;
722
723 #[inline]
724 fn try_from(slice: &[f64]) -> Result<Self, Self::Error> {
725 if slice.len() != N {
726 return Err(DekeError::ShapeMismatch {
727 expected: N,
728 found: slice.len(),
729 });
730 }
731 let mut arr = [0.0f32; N];
732 let mut i = 0;
733 while i < N {
734 arr[i] = slice[i] as f32;
735 i += 1;
736 }
737 Ok(Self(arr))
738 }
739}
740
741impl<const N: usize> TryFrom<Vec<f64>> for SRobotQ<N> {
742 type Error = DekeError;
743
744 #[inline]
745 fn try_from(v: Vec<f64>) -> Result<Self, Self::Error> {
746 Self::try_from(v.as_slice())
747 }
748}
749
750impl<const N: usize> TryFrom<&Vec<f64>> for SRobotQ<N> {
751 type Error = DekeError;
752
753 #[inline]
754 fn try_from(v: &Vec<f64>) -> Result<Self, Self::Error> {
755 Self::try_from(v.as_slice())
756 }
757}
758
759impl<const N: usize> TryFrom<&RobotQ> for SRobotQ<N> {
760 type Error = DekeError;
761
762 #[inline]
763 fn try_from(q: &RobotQ) -> Result<Self, Self::Error> {
764 let slice = q.as_slice().unwrap_or(&[]);
765 if slice.len() != N {
766 return Err(DekeError::ShapeMismatch {
767 expected: N,
768 found: slice.len(),
769 });
770 }
771 let mut arr = [0.0; N];
772 arr.copy_from_slice(slice);
773 Ok(Self(arr))
774 }
775}
776
777impl<const N: usize> TryFrom<RobotQ> for SRobotQ<N> {
778 type Error = DekeError;
779
780 #[inline]
781 fn try_from(q: RobotQ) -> Result<Self, Self::Error> {
782 let slice = q.as_slice().unwrap_or(&[]);
783 if slice.len() != N {
784 return Err(DekeError::ShapeMismatch {
785 expected: N,
786 found: slice.len(),
787 });
788 }
789 let mut arr = [0.0; N];
790 arr.copy_from_slice(slice);
791 Ok(Self(arr))
792 }
793}