simdvec/
lib.rs

1#![feature(portable_simd)]
2
3use core::hash::{Hash, Hasher};
4use core::slice;
5use std::{
6  mem::MaybeUninit,
7  ops::*,
8  simd::{num::SimdInt, LaneCount, Simd, SimdElement, SupportedLaneCount},
9};
10
11// non-resizable SIMD processing vectors
12// TODO: rewrite impls to use SimdInt and SimdUint traits
13#[derive(Default, Clone, Debug)]
14pub struct SimdVec<T, const LANES: usize>
15where
16  LaneCount<LANES>: SupportedLaneCount,
17  T: SimdElement,
18{
19  buf: Vec<Simd<T, LANES>>,
20  // len: usize,
21}
22
23impl<T, const N: usize> SimdVec<T, N>
24where
25  LaneCount<N>: SupportedLaneCount,
26  T: SimdElement,
27{
28  pub const LANES: usize = N;
29
30  pub fn iter(&self) -> slice::Iter<'_, Simd<T, N>> {
31    self.buf.iter()
32  }
33
34  pub fn iter_mut(&mut self) -> slice::IterMut<'_, Simd<T, N>> {
35    self.buf.iter_mut()
36  }
37
38  /// Provides upper bound on the number of elements that can be stored in the vector
39  pub fn capacity(&self) -> usize {
40    self.buf.len() * Self::LANES
41  }
42}
43
44impl<T, const N: usize> SimdVec<T, N>
45where
46  LaneCount<N>: SupportedLaneCount,
47  T: SimdElement + Default,
48{
49  pub fn with_capacity(capacity: usize) -> Self {
50    let size = (capacity + Self::LANES - 1) / Self::LANES;
51
52    Self {
53      buf: vec![Simd::default(); size],
54      // len: capacity,
55    }
56  }
57
58  pub fn with_capacity_value(capacity: usize, value: T) -> Self {
59    let size = (capacity + Self::LANES - 1) / Self::LANES;
60
61    Self {
62      buf: vec![Simd::splat(value); size],
63    }
64  }
65
66  pub fn into_vec(self) -> Vec<T> {
67    self.into()
68  }
69}
70
71impl<T, const N: usize> From<&[T]> for SimdVec<T, N>
72where
73  LaneCount<N>: SupportedLaneCount,
74  T: SimdElement + Default,
75{
76  fn from(slice: &[T]) -> Self {
77    let capacity = (slice.len() + Self::LANES - 1) / Self::LANES;
78
79    let mut buf = vec![MaybeUninit::uninit(); capacity];
80    // let len = slice.len();
81
82    let mut slice_iter = buf.iter_mut().zip(slice.chunks(Self::LANES)).peekable();
83
84    while let Some((buf, slice)) = slice_iter.next() {
85      let el = if slice_iter.peek().is_some() {
86        Simd::from_slice(slice)
87      } else {
88        Simd::load_or_default(slice)
89      };
90
91      buf.write(el);
92    }
93
94    let buf = unsafe { core::mem::transmute(buf) };
95
96    Self { buf }
97  }
98}
99
100impl<T, const N: usize, const U: usize> From<[T; U]> for SimdVec<T, N>
101where
102  LaneCount<N>: SupportedLaneCount,
103  T: SimdElement + Default,
104{
105  fn from(slice: [T; U]) -> Self {
106    Self::from(&slice as &[T])
107  }
108}
109
110impl<T, const N: usize> From<Vec<T>> for SimdVec<T, N>
111where
112  LaneCount<N>: SupportedLaneCount,
113  T: SimdElement + Default,
114{
115  fn from(vec: Vec<T>) -> Self {
116    Self::from(vec.as_slice())
117  }
118}
119
120impl<T, const N: usize> From<SimdVec<T, N>> for Vec<T>
121where
122  LaneCount<N>: SupportedLaneCount,
123  T: SimdElement + Default,
124{
125  fn from(vec: SimdVec<T, N>) -> Self {
126    let capacity = vec.capacity();
127    // let slice: &[T] = unsafe { core::mem::transmute(vec.buf.as_slice()) };
128    // let mut vec = Vec::from(slice);
129    // unsafe { vec.set_len(capacity) };
130    // assert_eq!(capacity, vec.len());
131    // vec
132    let mut vec: Vec<T> = unsafe { core::mem::transmute(vec.buf) };
133    unsafe { vec.set_len(capacity) };
134    vec
135  }
136}
137
138impl<T, const N: usize> FromIterator<Simd<T, N>> for SimdVec<T, N>
139where
140  LaneCount<N>: SupportedLaneCount,
141  T: SimdElement,
142{
143  fn from_iter<I: IntoIterator<Item = Simd<T, N>>>(iter: I) -> Self {
144    Self {
145      buf: Vec::from_iter(iter),
146    }
147  }
148}
149
150impl<'a, T, const N: usize> FromIterator<&'a Simd<T, N>> for SimdVec<T, N>
151where
152  LaneCount<N>: SupportedLaneCount,
153  T: SimdElement,
154{
155  fn from_iter<I: IntoIterator<Item = &'a Simd<T, N>>>(iter: I) -> Self {
156    Self {
157      buf: Vec::from_iter(iter.into_iter().map(|m| m.clone())),
158    }
159  }
160}
161
162impl<T, const N: usize, I: slice::SliceIndex<[T]>> Index<I> for SimdVec<T, N>
163where
164  LaneCount<N>: SupportedLaneCount,
165  T: SimdElement,
166{
167  type Output = I::Output;
168
169  fn index(&self, index: I) -> &Self::Output {
170    let slice = unsafe {
171      slice::from_raw_parts::<T>(core::mem::transmute(self.buf.as_ptr()), self.capacity())
172    };
173    slice.index(index)
174  }
175}
176
177impl<T, const N: usize, I: slice::SliceIndex<[T]>> IndexMut<I> for SimdVec<T, N>
178where
179  LaneCount<N>: SupportedLaneCount,
180  T: SimdElement,
181{
182  fn index_mut(&mut self, index: I) -> &mut Self::Output {
183    let slice = unsafe {
184      slice::from_raw_parts_mut::<T>(core::mem::transmute(self.buf.as_mut_ptr()), self.capacity())
185    };
186    slice.index_mut(index)
187  }
188}
189
190impl<'a, T, const N: usize> PartialEq for SimdVec<T, N>
191where
192  LaneCount<N>: SupportedLaneCount,
193  T: SimdElement + PartialEq,
194{
195  fn eq(&self, other: &Self) -> bool {
196    self.iter().zip(other.iter()).all(|(a, b)| a == b)
197  }
198}
199
200impl<'a, T, const N: usize> Eq for SimdVec<T, N>
201where
202  LaneCount<N>: SupportedLaneCount,
203  T: SimdElement + Eq,
204{
205}
206
207impl<'a, T, const N: usize> Hash for SimdVec<T, N>
208where
209  LaneCount<N>: SupportedLaneCount,
210  T: SimdElement + PartialEq,
211  Vec<Simd<T, N>>: Hash,
212  [T]: Hash,
213{
214  #[inline]
215  fn hash<H: Hasher>(&self, state: &mut H) {
216    // self.buf.hash(state);
217    // let vec: &[T] = unsafe { core::mem::transmute(self.buf.as_slice()) };
218    let vec: &[T] =
219      unsafe { slice::from_raw_parts(core::mem::transmute(self.buf.as_ptr()), self.capacity()) };
220    // unsafe { vec.set_len(self.capacity()) };
221    vec.hash(state);
222  }
223}
224
225macro_rules! deref_ops {
226  ($($trait:ident: fn $call:ident),*) => {
227    $(
228      // deref left hand side
229      impl<T, const N: usize> $trait<SimdVec<T, N>> for &SimdVec<T, N>
230      where
231        LaneCount<N>: SupportedLaneCount,
232        T: SimdElement + $trait,
233        Simd<T, N>: $trait<Output = Simd<T, N>>,
234      {
235        type Output = SimdVec<T, N>;
236
237        fn $call(self, rhs: SimdVec<T, N>) -> Self::Output {
238          (*self)
239            .iter()
240            .zip(rhs.iter())
241            .map(|(a, b)| a.$call(b))
242            .collect()
243        }
244      }
245
246      // deref right hand ride
247      impl<T, const N: usize> $trait<&SimdVec<T, N>> for SimdVec<T, N>
248      where
249        LaneCount<N>: SupportedLaneCount,
250        T: SimdElement + $trait,
251        Simd<T, N>: $trait<Output = Simd<T, N>>,
252      {
253        type Output = SimdVec<T, N>;
254
255        fn $call(self, rhs: &SimdVec<T, N>) -> Self::Output {
256          self
257            .iter()
258            .zip(rhs.iter())
259            .map(|(a, b)| a.$call(b))
260            .collect()
261        }
262      }
263
264      // deref both sides
265      impl<'lhs, 'rhs, T, const N: usize> $trait<&'rhs SimdVec<T, N>> for &'lhs SimdVec<T, N>
266      where
267        LaneCount<N>: SupportedLaneCount,
268        T: SimdElement + $trait,
269        Simd<T, N>: $trait<Output = Simd<T, N>>,
270      {
271        type Output = SimdVec<T, N>;
272
273        fn $call(self, rhs: &'rhs SimdVec<T, N>) -> Self::Output {
274          (*self)
275            .iter()
276            .zip(rhs.iter())
277            .map(|(a, b)| a.$call(b))
278            .collect()
279        }
280      }
281
282      // both sides are owned
283      impl<T, const N: usize> $trait<SimdVec<T, N>> for SimdVec<T, N>
284      where
285        LaneCount<N>: SupportedLaneCount,
286        T: SimdElement + $trait,
287        Simd<T, N>: $trait<Output = Simd<T, N>>,
288      {
289        type Output = SimdVec<T, N>;
290
291        fn $call(self, rhs: SimdVec<T, N>) -> Self::Output {
292          self
293            .iter()
294            .zip(rhs.iter())
295            .map(|(a, b)| a.$call(b))
296            .collect()
297        }
298      }
299    )*
300  };
301}
302
303macro_rules! unary_ops {
304  ($($trait:ident: fn $call:ident),*) => {
305    $(
306      impl<T, const N: usize> $trait for SimdVec<T, N>
307      where
308        LaneCount<N>: SupportedLaneCount,
309        T: SimdElement + $trait,
310        Simd<T, N>: $trait<Output = Simd<T, N>>,
311      {
312        type Output = SimdVec<T, N>;
313
314        fn $call(self) -> Self::Output {
315          self
316            .iter()
317            .map(|a| a.$call())
318            .collect()
319        }
320      }
321    )*
322  };
323}
324
325macro_rules! propagate_ops {
326  ($(fn $call:ident),*) => {
327    $(
328      impl<T, const N: usize> SimdVec<T, N>
329      where
330        LaneCount<N>: SupportedLaneCount,
331        T: SimdElement,
332        Simd<T, N>: SimdInt
333      {
334        pub fn $call(&self) -> Self {
335          self
336            .iter()
337            .map(|a| a.$call())
338            .collect()
339        }
340      }
341    )*
342  };
343}
344
345macro_rules! scalar_ops {
346  ($($trait:ident: fn $call:ident => $op:tt),*) => {
347    $(
348      impl<T, const N: usize> SimdVec<T, N>
349      where
350        LaneCount<N>: SupportedLaneCount,
351        T: SimdElement + $trait<Output = T> + Default,
352      {
353        pub fn $call(&self) -> T
354        where
355          Simd<T, N>: SimdInt<Scalar = T>,
356        {
357          self
358            .buf
359            .iter()
360            .fold(Default::default(), |acc, x| acc $op x.$call())
361        }
362      }
363    )*
364  };
365}
366
367deref_ops! {
368  Add: fn add,
369  Mul: fn mul,
370  Sub: fn sub,
371  Div: fn div,
372  Rem: fn rem,
373  BitAnd: fn bitand,
374  BitOr: fn bitor,
375  BitXor: fn bitxor,
376  Shl: fn shl,
377  Shr: fn shr
378}
379
380unary_ops! {
381  Not: fn not,
382  Neg: fn neg
383}
384
385propagate_ops! {
386  fn abs,
387  fn saturating_abs,
388  fn saturating_neg,
389  // fn is_positive,
390  // fn is_negative,
391  fn signum,
392  // fn reduce_max,
393  // fn reduce_min,
394  fn swap_bytes,
395  fn reverse_bits
396  // fn leading_zeros,
397  // fn trailing_zeros
398  // fn leading_ones,
399  // fn trailing_ones
400}
401
402scalar_ops! {
403  Add: fn reduce_sum => +,
404  Mul: fn reduce_product => *,
405  // Max: fn reduce_max => max,
406  // Min: fn reduce_min => min,
407  BitAnd: fn reduce_and => &,
408  BitOr: fn reduce_or => |,
409  BitXor: fn reduce_xor => ^
410}
411
412#[cfg(test)]
413mod tests {
414  use super::*;
415
416  #[test]
417  fn simple_inst() {
418    let simd_vec: SimdVec<i8, 16> = SimdVec::from([1, 2, 3, 4]);
419    assert_eq!((&simd_vec).reduce_sum(), 10);
420
421    // alternative creation
422    let simd_vec: SimdVec<i8, 16> = [1, 2, 3, 4].into();
423    assert_eq!(simd_vec.reduce_sum(), 10);
424  }
425
426  #[test]
427  fn sum() {
428    let vec1: SimdVec<i8, 8> = SimdVec::from([1, 2, 3]);
429    let vec2: SimdVec<i8, 8> = SimdVec::from([
430      5, 10, 15, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4,
431      5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9,
432    ]);
433    assert_eq!(
434      vec1 + vec2,
435      SimdVec::from([
436        6, 12, 18, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4,
437        5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9
438      ])
439    );
440  }
441}