1#![feature(portable_simd)]
2
3use core::hash::{Hash, Hasher};
4use core::slice;
5use std::{
6 mem::MaybeUninit,
7 ops::*,
8 simd::{num::SimdInt, LaneCount, Simd, SimdElement, SupportedLaneCount},
9};
10
11#[derive(Default, Clone, Debug)]
14pub struct SimdVec<T, const LANES: usize>
15where
16 LaneCount<LANES>: SupportedLaneCount,
17 T: SimdElement,
18{
19 buf: Vec<Simd<T, LANES>>,
20 }
22
23impl<T, const N: usize> SimdVec<T, N>
24where
25 LaneCount<N>: SupportedLaneCount,
26 T: SimdElement,
27{
28 pub const LANES: usize = N;
29
30 pub fn iter(&self) -> slice::Iter<'_, Simd<T, N>> {
31 self.buf.iter()
32 }
33
34 pub fn iter_mut(&mut self) -> slice::IterMut<'_, Simd<T, N>> {
35 self.buf.iter_mut()
36 }
37
38 pub fn capacity(&self) -> usize {
40 self.buf.len() * Self::LANES
41 }
42}
43
44impl<T, const N: usize> SimdVec<T, N>
45where
46 LaneCount<N>: SupportedLaneCount,
47 T: SimdElement + Default,
48{
49 pub fn with_capacity(capacity: usize) -> Self {
50 let size = (capacity + Self::LANES - 1) / Self::LANES;
51
52 Self {
53 buf: vec![Simd::default(); size],
54 }
56 }
57
58 pub fn with_capacity_value(capacity: usize, value: T) -> Self {
59 let size = (capacity + Self::LANES - 1) / Self::LANES;
60
61 Self {
62 buf: vec![Simd::splat(value); size],
63 }
64 }
65
66 pub fn into_vec(self) -> Vec<T> {
67 self.into()
68 }
69}
70
71impl<T, const N: usize> From<&[T]> for SimdVec<T, N>
72where
73 LaneCount<N>: SupportedLaneCount,
74 T: SimdElement + Default,
75{
76 fn from(slice: &[T]) -> Self {
77 let capacity = (slice.len() + Self::LANES - 1) / Self::LANES;
78
79 let mut buf = vec![MaybeUninit::uninit(); capacity];
80 let mut slice_iter = buf.iter_mut().zip(slice.chunks(Self::LANES)).peekable();
83
84 while let Some((buf, slice)) = slice_iter.next() {
85 let el = if slice_iter.peek().is_some() {
86 Simd::from_slice(slice)
87 } else {
88 Simd::load_or_default(slice)
89 };
90
91 buf.write(el);
92 }
93
94 let buf = unsafe { core::mem::transmute(buf) };
95
96 Self { buf }
97 }
98}
99
100impl<T, const N: usize, const U: usize> From<[T; U]> for SimdVec<T, N>
101where
102 LaneCount<N>: SupportedLaneCount,
103 T: SimdElement + Default,
104{
105 fn from(slice: [T; U]) -> Self {
106 Self::from(&slice as &[T])
107 }
108}
109
110impl<T, const N: usize> From<Vec<T>> for SimdVec<T, N>
111where
112 LaneCount<N>: SupportedLaneCount,
113 T: SimdElement + Default,
114{
115 fn from(vec: Vec<T>) -> Self {
116 Self::from(vec.as_slice())
117 }
118}
119
120impl<T, const N: usize> From<SimdVec<T, N>> for Vec<T>
121where
122 LaneCount<N>: SupportedLaneCount,
123 T: SimdElement + Default,
124{
125 fn from(vec: SimdVec<T, N>) -> Self {
126 let capacity = vec.capacity();
127 let mut vec: Vec<T> = unsafe { core::mem::transmute(vec.buf) };
133 unsafe { vec.set_len(capacity) };
134 vec
135 }
136}
137
138impl<T, const N: usize> FromIterator<Simd<T, N>> for SimdVec<T, N>
139where
140 LaneCount<N>: SupportedLaneCount,
141 T: SimdElement,
142{
143 fn from_iter<I: IntoIterator<Item = Simd<T, N>>>(iter: I) -> Self {
144 Self {
145 buf: Vec::from_iter(iter),
146 }
147 }
148}
149
150impl<'a, T, const N: usize> FromIterator<&'a Simd<T, N>> for SimdVec<T, N>
151where
152 LaneCount<N>: SupportedLaneCount,
153 T: SimdElement,
154{
155 fn from_iter<I: IntoIterator<Item = &'a Simd<T, N>>>(iter: I) -> Self {
156 Self {
157 buf: Vec::from_iter(iter.into_iter().map(|m| m.clone())),
158 }
159 }
160}
161
162impl<T, const N: usize, I: slice::SliceIndex<[T]>> Index<I> for SimdVec<T, N>
163where
164 LaneCount<N>: SupportedLaneCount,
165 T: SimdElement,
166{
167 type Output = I::Output;
168
169 fn index(&self, index: I) -> &Self::Output {
170 let slice = unsafe {
171 slice::from_raw_parts::<T>(core::mem::transmute(self.buf.as_ptr()), self.capacity())
172 };
173 slice.index(index)
174 }
175}
176
177impl<T, const N: usize, I: slice::SliceIndex<[T]>> IndexMut<I> for SimdVec<T, N>
178where
179 LaneCount<N>: SupportedLaneCount,
180 T: SimdElement,
181{
182 fn index_mut(&mut self, index: I) -> &mut Self::Output {
183 let slice = unsafe {
184 slice::from_raw_parts_mut::<T>(core::mem::transmute(self.buf.as_mut_ptr()), self.capacity())
185 };
186 slice.index_mut(index)
187 }
188}
189
190impl<'a, T, const N: usize> PartialEq for SimdVec<T, N>
191where
192 LaneCount<N>: SupportedLaneCount,
193 T: SimdElement + PartialEq,
194{
195 fn eq(&self, other: &Self) -> bool {
196 self.iter().zip(other.iter()).all(|(a, b)| a == b)
197 }
198}
199
200impl<'a, T, const N: usize> Eq for SimdVec<T, N>
201where
202 LaneCount<N>: SupportedLaneCount,
203 T: SimdElement + Eq,
204{
205}
206
207impl<'a, T, const N: usize> Hash for SimdVec<T, N>
208where
209 LaneCount<N>: SupportedLaneCount,
210 T: SimdElement + PartialEq,
211 Vec<Simd<T, N>>: Hash,
212 [T]: Hash,
213{
214 #[inline]
215 fn hash<H: Hasher>(&self, state: &mut H) {
216 let vec: &[T] =
219 unsafe { slice::from_raw_parts(core::mem::transmute(self.buf.as_ptr()), self.capacity()) };
220 vec.hash(state);
222 }
223}
224
225macro_rules! deref_ops {
226 ($($trait:ident: fn $call:ident),*) => {
227 $(
228 impl<T, const N: usize> $trait<SimdVec<T, N>> for &SimdVec<T, N>
230 where
231 LaneCount<N>: SupportedLaneCount,
232 T: SimdElement + $trait,
233 Simd<T, N>: $trait<Output = Simd<T, N>>,
234 {
235 type Output = SimdVec<T, N>;
236
237 fn $call(self, rhs: SimdVec<T, N>) -> Self::Output {
238 (*self)
239 .iter()
240 .zip(rhs.iter())
241 .map(|(a, b)| a.$call(b))
242 .collect()
243 }
244 }
245
246 impl<T, const N: usize> $trait<&SimdVec<T, N>> for SimdVec<T, N>
248 where
249 LaneCount<N>: SupportedLaneCount,
250 T: SimdElement + $trait,
251 Simd<T, N>: $trait<Output = Simd<T, N>>,
252 {
253 type Output = SimdVec<T, N>;
254
255 fn $call(self, rhs: &SimdVec<T, N>) -> Self::Output {
256 self
257 .iter()
258 .zip(rhs.iter())
259 .map(|(a, b)| a.$call(b))
260 .collect()
261 }
262 }
263
264 impl<'lhs, 'rhs, T, const N: usize> $trait<&'rhs SimdVec<T, N>> for &'lhs SimdVec<T, N>
266 where
267 LaneCount<N>: SupportedLaneCount,
268 T: SimdElement + $trait,
269 Simd<T, N>: $trait<Output = Simd<T, N>>,
270 {
271 type Output = SimdVec<T, N>;
272
273 fn $call(self, rhs: &'rhs SimdVec<T, N>) -> Self::Output {
274 (*self)
275 .iter()
276 .zip(rhs.iter())
277 .map(|(a, b)| a.$call(b))
278 .collect()
279 }
280 }
281
282 impl<T, const N: usize> $trait<SimdVec<T, N>> for SimdVec<T, N>
284 where
285 LaneCount<N>: SupportedLaneCount,
286 T: SimdElement + $trait,
287 Simd<T, N>: $trait<Output = Simd<T, N>>,
288 {
289 type Output = SimdVec<T, N>;
290
291 fn $call(self, rhs: SimdVec<T, N>) -> Self::Output {
292 self
293 .iter()
294 .zip(rhs.iter())
295 .map(|(a, b)| a.$call(b))
296 .collect()
297 }
298 }
299 )*
300 };
301}
302
303macro_rules! unary_ops {
304 ($($trait:ident: fn $call:ident),*) => {
305 $(
306 impl<T, const N: usize> $trait for SimdVec<T, N>
307 where
308 LaneCount<N>: SupportedLaneCount,
309 T: SimdElement + $trait,
310 Simd<T, N>: $trait<Output = Simd<T, N>>,
311 {
312 type Output = SimdVec<T, N>;
313
314 fn $call(self) -> Self::Output {
315 self
316 .iter()
317 .map(|a| a.$call())
318 .collect()
319 }
320 }
321 )*
322 };
323}
324
325macro_rules! propagate_ops {
326 ($(fn $call:ident),*) => {
327 $(
328 impl<T, const N: usize> SimdVec<T, N>
329 where
330 LaneCount<N>: SupportedLaneCount,
331 T: SimdElement,
332 Simd<T, N>: SimdInt
333 {
334 pub fn $call(&self) -> Self {
335 self
336 .iter()
337 .map(|a| a.$call())
338 .collect()
339 }
340 }
341 )*
342 };
343}
344
345macro_rules! scalar_ops {
346 ($($trait:ident: fn $call:ident => $op:tt),*) => {
347 $(
348 impl<T, const N: usize> SimdVec<T, N>
349 where
350 LaneCount<N>: SupportedLaneCount,
351 T: SimdElement + $trait<Output = T> + Default,
352 {
353 pub fn $call(&self) -> T
354 where
355 Simd<T, N>: SimdInt<Scalar = T>,
356 {
357 self
358 .buf
359 .iter()
360 .fold(Default::default(), |acc, x| acc $op x.$call())
361 }
362 }
363 )*
364 };
365}
366
367deref_ops! {
368 Add: fn add,
369 Mul: fn mul,
370 Sub: fn sub,
371 Div: fn div,
372 Rem: fn rem,
373 BitAnd: fn bitand,
374 BitOr: fn bitor,
375 BitXor: fn bitxor,
376 Shl: fn shl,
377 Shr: fn shr
378}
379
380unary_ops! {
381 Not: fn not,
382 Neg: fn neg
383}
384
385propagate_ops! {
386 fn abs,
387 fn saturating_abs,
388 fn saturating_neg,
389 fn signum,
392 fn swap_bytes,
395 fn reverse_bits
396 }
401
402scalar_ops! {
403 Add: fn reduce_sum => +,
404 Mul: fn reduce_product => *,
405 BitAnd: fn reduce_and => &,
408 BitOr: fn reduce_or => |,
409 BitXor: fn reduce_xor => ^
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn simple_inst() {
418 let simd_vec: SimdVec<i8, 16> = SimdVec::from([1, 2, 3, 4]);
419 assert_eq!((&simd_vec).reduce_sum(), 10);
420
421 let simd_vec: SimdVec<i8, 16> = [1, 2, 3, 4].into();
423 assert_eq!(simd_vec.reduce_sum(), 10);
424 }
425
426 #[test]
427 fn sum() {
428 let vec1: SimdVec<i8, 8> = SimdVec::from([1, 2, 3]);
429 let vec2: SimdVec<i8, 8> = SimdVec::from([
430 5, 10, 15, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4,
431 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9,
432 ]);
433 assert_eq!(
434 vec1 + vec2,
435 SimdVec::from([
436 6, 12, 18, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4,
437 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9
438 ])
439 );
440 }
441}