1use core::marker::PhantomData;
7use core::ops::{Add, Div, Mul, Sub};
8
9#[cfg(feature = "no-std")]
10use alloc::vec::Vec;
11#[cfg(not(feature = "no-std"))]
12use std::vec::Vec;
13
14pub struct SimdWidth<const WIDTH: usize>;
16
17#[derive(Debug, Clone)]
19pub struct SafeSimdVector<T, const WIDTH: usize> {
20 data: [T; WIDTH],
21 _phantom: PhantomData<SimdWidth<WIDTH>>,
22}
23
24impl<T, const WIDTH: usize> SafeSimdVector<T, WIDTH>
25where
26 T: Copy + Default,
27{
28 pub const fn new(data: [T; WIDTH]) -> Self {
30 Self {
31 data,
32 _phantom: PhantomData,
33 }
34 }
35
36 pub fn splat(value: T) -> Self {
38 Self {
39 data: [value; WIDTH],
40 _phantom: PhantomData,
41 }
42 }
43
44 pub fn width(&self) -> usize {
46 WIDTH
47 }
48
49 pub fn as_slice(&self) -> &[T] {
51 &self.data
52 }
53
54 pub fn as_mut_slice(&mut self) -> &mut [T] {
56 &mut self.data
57 }
58
59 pub fn into_array(self) -> [T; WIDTH] {
61 self.data
62 }
63
64 pub fn from_slice(slice: &[T]) -> Option<Self> {
66 if slice.len() >= WIDTH {
67 let mut data = [T::default(); WIDTH];
68 data.copy_from_slice(&slice[..WIDTH]);
69 Some(Self::new(data))
70 } else {
71 None
72 }
73 }
74
75 pub fn extract_lane(&self, lane: usize) -> Option<T> {
77 if lane < WIDTH {
78 Some(self.data[lane])
79 } else {
80 None
81 }
82 }
83
84 pub fn replace_lane(&mut self, lane: usize, value: T) -> bool {
86 if lane < WIDTH {
87 self.data[lane] = value;
88 true
89 } else {
90 false
91 }
92 }
93}
94
95pub trait SimdOperation<T, const WIDTH: usize> {
97 type Output;
98
99 fn apply(&self, input: &SafeSimdVector<T, WIDTH>) -> Self::Output;
100}
101
102pub struct ElementWiseOp<F, T, const WIDTH: usize> {
104 func: F,
105 _phantom: PhantomData<(T, SimdWidth<WIDTH>)>,
106}
107
108impl<F, T, const WIDTH: usize> ElementWiseOp<F, T, WIDTH>
109where
110 F: Fn(T) -> T,
111 T: Copy,
112{
113 pub const fn new(func: F) -> Self {
114 Self {
115 func,
116 _phantom: PhantomData,
117 }
118 }
119}
120
121impl<F, T, const WIDTH: usize> SimdOperation<T, WIDTH> for ElementWiseOp<F, T, WIDTH>
122where
123 F: Fn(T) -> T,
124 T: Copy + Default,
125{
126 type Output = SafeSimdVector<T, WIDTH>;
127
128 fn apply(&self, input: &SafeSimdVector<T, WIDTH>) -> Self::Output {
129 let mut result = [T::default(); WIDTH];
130 for (r, &val) in result.iter_mut().zip(input.data.iter()) {
131 *r = (self.func)(val);
132 }
133 SafeSimdVector::new(result)
134 }
135}
136
137impl<T, const WIDTH: usize> Add for SafeSimdVector<T, WIDTH>
139where
140 T: Add<Output = T> + Copy + Default,
141{
142 type Output = Self;
143
144 fn add(self, rhs: Self) -> Self::Output {
145 let mut result = [T::default(); WIDTH];
146 for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
147 *r = *a + *b;
148 }
149 Self::new(result)
150 }
151}
152
153impl<T, const WIDTH: usize> Sub for SafeSimdVector<T, WIDTH>
154where
155 T: Sub<Output = T> + Copy + Default,
156{
157 type Output = Self;
158
159 fn sub(self, rhs: Self) -> Self::Output {
160 let mut result = [T::default(); WIDTH];
161 for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
162 *r = *a - *b;
163 }
164 Self::new(result)
165 }
166}
167
168impl<T, const WIDTH: usize> Mul for SafeSimdVector<T, WIDTH>
169where
170 T: Mul<Output = T> + Copy + Default,
171{
172 type Output = Self;
173
174 fn mul(self, rhs: Self) -> Self::Output {
175 let mut result = [T::default(); WIDTH];
176 for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
177 *r = *a * *b;
178 }
179 Self::new(result)
180 }
181}
182
183impl<T, const WIDTH: usize> Div for SafeSimdVector<T, WIDTH>
184where
185 T: Div<Output = T> + Copy + Default,
186{
187 type Output = Self;
188
189 fn div(self, rhs: Self) -> Self::Output {
190 let mut result = [T::default(); WIDTH];
191 for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
192 *r = *a / *b;
193 }
194 Self::new(result)
195 }
196}
197
198pub mod widths {
200 pub type Scalar = super::SimdWidth<1>;
201 pub type Sse = super::SimdWidth<4>; pub type Avx = super::SimdWidth<8>; pub type Avx512 = super::SimdWidth<16>; pub type SseF64 = super::SimdWidth<2>; pub type AvxF64 = super::SimdWidth<4>; pub type Avx512F64 = super::SimdWidth<8>; }
209
210pub type SimdF32x4 = SafeSimdVector<f32, 4>;
212pub type SimdF32x8 = SafeSimdVector<f32, 8>;
213pub type SimdF32x16 = SafeSimdVector<f32, 16>;
214
215pub type SimdF64x2 = SafeSimdVector<f64, 2>;
216pub type SimdF64x4 = SafeSimdVector<f64, 4>;
217pub type SimdF64x8 = SafeSimdVector<f64, 8>;
218
219pub type SimdU32x4 = SafeSimdVector<u32, 4>;
220pub type SimdU32x8 = SafeSimdVector<u32, 8>;
221pub type SimdU32x16 = SafeSimdVector<u32, 16>;
222
223pub mod capabilities {
225
226 pub trait SimdCapable<const WIDTH: usize> {
228 fn is_supported() -> bool;
229 fn best_width() -> usize;
230 }
231
232 pub struct X86Simd;
234
235 impl SimdCapable<4> for X86Simd {
236 fn is_supported() -> bool {
237 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
238 {
239 crate::simd_feature_detected!("sse")
240 }
241 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
242 {
243 false
244 }
245 }
246
247 fn best_width() -> usize {
248 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249 {
250 if crate::simd_feature_detected!("avx512f") {
251 16
252 } else if crate::simd_feature_detected!("avx2") {
253 8
254 } else if crate::simd_feature_detected!("sse") {
255 4
256 } else {
257 1
258 }
259 }
260 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
261 {
262 1
263 }
264 }
265 }
266
267 impl SimdCapable<8> for X86Simd {
268 fn is_supported() -> bool {
269 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
270 {
271 crate::simd_feature_detected!("avx2")
272 }
273 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
274 {
275 false
276 }
277 }
278
279 fn best_width() -> usize {
280 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
281 {
282 if crate::simd_feature_detected!("avx512f") {
283 16
284 } else if crate::simd_feature_detected!("avx2") {
285 8
286 } else if crate::simd_feature_detected!("sse") {
287 4
288 } else {
289 1
290 }
291 }
292 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
293 {
294 1
295 }
296 }
297 }
298
299 impl SimdCapable<16> for X86Simd {
300 fn is_supported() -> bool {
301 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
302 {
303 crate::simd_feature_detected!("avx512f")
304 }
305 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
306 {
307 false
308 }
309 }
310
311 fn best_width() -> usize {
312 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
313 {
314 if crate::simd_feature_detected!("avx512f") {
315 16
316 } else if crate::simd_feature_detected!("avx2") {
317 8
318 } else if crate::simd_feature_detected!("sse") {
319 4
320 } else {
321 1
322 }
323 }
324 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
325 {
326 1
327 }
328 }
329 }
330
331 pub struct ArmSimd;
333
334 impl SimdCapable<4> for ArmSimd {
335 fn is_supported() -> bool {
336 #[cfg(target_arch = "aarch64")]
337 {
338 true }
340 #[cfg(not(target_arch = "aarch64"))]
341 {
342 false
343 }
344 }
345
346 fn best_width() -> usize {
347 #[cfg(target_arch = "aarch64")]
348 {
349 4
350 }
351 #[cfg(not(target_arch = "aarch64"))]
352 {
353 1
354 }
355 }
356 }
357}
358
359pub struct OptimizedSimdOp<T, const WIDTH: usize> {
361 _phantom: PhantomData<(T, SimdWidth<WIDTH>)>,
362}
363
364impl<T, const WIDTH: usize> Default for OptimizedSimdOp<T, WIDTH> {
365 fn default() -> Self {
366 Self::new()
367 }
368}
369
370impl<T, const WIDTH: usize> OptimizedSimdOp<T, WIDTH> {
371 pub const fn new() -> Self {
372 Self {
373 _phantom: PhantomData,
374 }
375 }
376
377 pub fn dot_product(a: &SafeSimdVector<T, WIDTH>, b: &SafeSimdVector<T, WIDTH>) -> T
379 where
380 T: Mul<Output = T> + Add<Output = T> + Default + Copy,
381 {
382 let mut result = T::default();
383 for i in 0..WIDTH {
384 result = result + (a.data[i] * b.data[i]);
385 }
386 result
387 }
388
389 pub fn element_wise_multiply(
391 a: &SafeSimdVector<T, WIDTH>,
392 b: &SafeSimdVector<T, WIDTH>,
393 ) -> SafeSimdVector<T, WIDTH>
394 where
395 T: Mul<Output = T> + Default + Copy,
396 {
397 let mut result = [T::default(); WIDTH];
398 for (r, (x, y)) in result.iter_mut().zip(a.data.iter().zip(b.data.iter())) {
399 *r = *x * *y;
400 }
401 SafeSimdVector::new(result)
402 }
403
404 pub fn horizontal_sum(vector: &SafeSimdVector<T, WIDTH>) -> T
406 where
407 T: Add<Output = T> + Default + Copy,
408 {
409 let mut sum = T::default();
410 for i in 0..WIDTH {
411 sum = sum + vector.data[i];
412 }
413 sum
414 }
415
416 pub fn horizontal_max(vector: &SafeSimdVector<T, WIDTH>) -> T
418 where
419 T: PartialOrd + Copy,
420 {
421 let mut max = vector.data[0];
422 for i in 1..WIDTH {
423 if vector.data[i] > max {
424 max = vector.data[i];
425 }
426 }
427 max
428 }
429}
430
431pub struct SafeSliceOps;
433
434impl SafeSliceOps {
435 pub fn process_slice_vectorized<T, F, const CHUNK_SIZE: usize>(
437 data: &[T],
438 mut func: F,
439 ) -> Vec<T>
440 where
441 T: Copy + Default,
442 F: FnMut(&SafeSimdVector<T, CHUNK_SIZE>) -> SafeSimdVector<T, CHUNK_SIZE>,
443 [(); CHUNK_SIZE]:,
444 {
445 let mut result = Vec::with_capacity(data.len());
446
447 for chunk in data.chunks_exact(CHUNK_SIZE) {
449 if let Some(simd_chunk) = SafeSimdVector::<T, CHUNK_SIZE>::from_slice(chunk) {
450 let processed = func(&simd_chunk);
451 result.extend_from_slice(processed.as_slice());
452 }
453 }
454
455 let remainder_start = data.len() - (data.len() % CHUNK_SIZE);
457 for &item in &data[remainder_start..] {
458 let mut single_data = [T::default(); CHUNK_SIZE];
460 single_data[0] = item;
461 if let Some(simd_single) = SafeSimdVector::<T, CHUNK_SIZE>::from_slice(&single_data) {
462 let processed = func(&simd_single);
463 result.push(processed.as_slice()[0]);
464 }
465 }
466
467 result
468 }
469
470 pub fn dot_product_safe<T, const WIDTH: usize>(a: &[T], b: &[T]) -> Option<T>
472 where
473 T: Mul<Output = T> + Add<Output = T> + Default + Copy,
474 [(); WIDTH]:,
475 {
476 if a.len() != b.len() || a.len() < WIDTH {
477 return None;
478 }
479
480 let mut result = T::default();
481 let chunks_a = a.chunks_exact(WIDTH);
482 let chunks_b = b.chunks_exact(WIDTH);
483
484 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
485 if let (Some(vec_a), Some(vec_b)) = (
486 SafeSimdVector::<T, WIDTH>::from_slice(chunk_a),
487 SafeSimdVector::<T, WIDTH>::from_slice(chunk_b),
488 ) {
489 result = result + OptimizedSimdOp::<T, WIDTH>::dot_product(&vec_a, &vec_b);
490 }
491 }
492
493 let remainder = a.len() % WIDTH;
495 for i in 0..remainder {
496 let idx = a.len() - remainder + i;
497 result = result + (a[idx] * b[idx]);
498 }
499
500 Some(result)
501 }
502}
503
504pub trait TypeSafeSimd<T> {
506 type Output;
507
508 fn apply_safe(&self, input: &[T]) -> Self::Output;
509}
510
511pub struct SafeMathOps;
513
514impl SafeMathOps {
515 pub fn sqrt_vectorized<const WIDTH: usize>(data: &[f32]) -> Vec<f32>
517 where
518 [(); WIDTH]:,
519 {
520 SafeSliceOps::process_slice_vectorized::<f32, _, WIDTH>(data, |chunk| {
521 let op = ElementWiseOp::new(|x: f32| x.sqrt());
522 op.apply(chunk)
523 })
524 }
525
526 pub fn exp_vectorized<const WIDTH: usize>(data: &[f32]) -> Vec<f32>
528 where
529 [(); WIDTH]:,
530 {
531 SafeSliceOps::process_slice_vectorized::<f32, _, WIDTH>(data, |chunk| {
532 let op = ElementWiseOp::new(|x: f32| x.exp());
533 op.apply(chunk)
534 })
535 }
536
537 pub fn polynomial_vectorized<const WIDTH: usize>(data: &[f32], coefficients: &[f32]) -> Vec<f32>
539 where
540 [(); WIDTH]:,
541 {
542 SafeSliceOps::process_slice_vectorized::<f32, _, WIDTH>(data, |chunk| {
543 let op = ElementWiseOp::new(|x: f32| {
544 coefficients
545 .iter()
546 .rev()
547 .fold(0.0, |acc, &coeff| acc * x + coeff)
548 });
549 op.apply(chunk)
550 })
551 }
552}
553
554#[allow(non_snake_case)]
557#[cfg(all(test, not(feature = "no-std")))]
558mod tests {
559 use super::*;
560
561 #[cfg(feature = "no-std")]
562 use alloc::{vec, vec::Vec};
563
564 #[test]
565 fn test_safe_simd_vector_creation() {
566 let vec = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
567 assert_eq!(vec.width(), 4);
568 assert_eq!(vec.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
569 }
570
571 #[test]
572 fn test_safe_simd_arithmetic() {
573 let a = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
574 let b = SimdF32x4::new([5.0, 6.0, 7.0, 8.0]);
575
576 let sum = a + b;
577 assert_eq!(sum.as_slice(), &[6.0, 8.0, 10.0, 12.0]);
578
579 let diff = SimdF32x4::new([10.0, 12.0, 14.0, 16.0]) - SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
580 assert_eq!(diff.as_slice(), &[9.0, 10.0, 11.0, 12.0]);
581 }
582
583 #[test]
584 fn test_lane_access() {
585 let vec = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
586 assert_eq!(vec.extract_lane(0), Some(1.0));
587 assert_eq!(vec.extract_lane(1), Some(2.0));
588 assert_eq!(vec.extract_lane(3), Some(4.0));
589 assert_eq!(vec.extract_lane(4), None); }
591
592 #[test]
593 fn test_dot_product_safe() {
594 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
595 let b = vec![8.0f32, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
596
597 let result = SafeSliceOps::dot_product_safe::<f32, 4>(&a, &b);
598 assert!(result.is_some());
599
600 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
601 assert_eq!(result.expect("operation should succeed"), expected);
602 }
603
604 #[test]
605 fn test_element_wise_operations() {
606 let vec = SimdF32x4::new([1.0, 4.0, 9.0, 16.0]);
607 let op = ElementWiseOp::new(|x: f32| x.sqrt());
608 let result = op.apply(&vec);
609
610 assert_eq!(result.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
611 }
612
613 #[test]
614 fn test_horizontal_operations() {
615 let vec = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
616
617 let sum = OptimizedSimdOp::<f32, 4>::horizontal_sum(&vec);
618 assert_eq!(sum, 10.0);
619
620 let max = OptimizedSimdOp::<f32, 4>::horizontal_max(&vec);
621 assert_eq!(max, 4.0);
622 }
623
624 #[test]
625 fn test_safe_math_operations() {
626 let data = vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0];
627 let result = SafeMathOps::sqrt_vectorized::<4>(&data);
628
629 let expected: Vec<f32> = data.iter().map(|x| x.sqrt()).collect();
630 for (a, b) in result.iter().zip(expected.iter()) {
631 assert!(
632 (a - b).abs() < 1e-4,
633 "sqrt({}) = {}, expected {}, diff = {}",
634 a * a,
635 a,
636 b,
637 (a - b).abs()
638 ); }
640 }
641
642 #[test]
643 fn test_from_slice_validation() {
644 let data = vec![1.0f32, 2.0, 3.0];
645 let vec = SimdF32x4::from_slice(&data);
646 assert!(vec.is_none()); let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
649 let vec = SimdF32x4::from_slice(&data);
650 assert!(vec.is_some());
651 assert_eq!(
652 vec.expect("slice operation should succeed").as_slice(),
653 &[1.0, 2.0, 3.0, 4.0]
654 );
655 }
656
657 #[test]
658 fn test_capability_detection() {
659 use super::capabilities::SimdCapable;
660
661 let _sse_supported = <capabilities::X86Simd as SimdCapable<4>>::is_supported();
663 let _best_width = <capabilities::X86Simd as SimdCapable<4>>::best_width();
664
665 let _neon_supported = <capabilities::ArmSimd as SimdCapable<4>>::is_supported();
667 }
668
669 #[test]
670 fn test_zero_cost_abstractions() {
671 let a = SimdF32x4::splat(2.0);
673 let b = SimdF32x4::splat(3.0);
674
675 let result = OptimizedSimdOp::<f32, 4>::element_wise_multiply(&a, &b);
676 assert_eq!(result.as_slice(), &[6.0, 6.0, 6.0, 6.0]);
677 }
678}