rten_simd/arch/
generic.rs

1use std::array;
2use std::mem::transmute;
3
4use crate::ops::{
5    Concat, Extend, FloatOps, IntOps, Interleave, MaskOps, NarrowSaturate, NumOps, SignedIntOps,
6    ToFloat,
7};
8use crate::{Isa, Mask, Simd};
9
10// Size of SIMD vector in 32-bit lanes.
11const LEN_X32: usize = 4;
12
13macro_rules! simd_type {
14    ($simd:ident, $elem:ty, $len:expr) => {
15        #[repr(align(16))]
16        #[derive(Copy, Clone, Debug)]
17        pub struct $simd([$elem; $len]);
18
19        impl $simd {
20            /// Apply a unary operation to each lane of this vector.
21            #[allow(unused)]
22            #[inline]
23            fn map<U, F: Fn($elem) -> U, R>(self, op: F) -> R
24            where
25                R: From<[U; $len]>,
26            {
27                self.0.map(op).into()
28            }
29
30            /// Apply a binary operation to pairs of elements from `self` and `y`.
31            #[allow(unused)]
32            #[inline]
33            fn map_with<U, F: Fn($elem, $elem) -> U, R>(self, y: Self, op: F) -> R
34            where
35                R: From<[U; $len]>,
36            {
37                array::from_fn(|i| op(self.0[i], y.0[i])).into()
38            }
39        }
40
41        impl From<[$elem; $len]> for $simd {
42            fn from(val: [$elem; $len]) -> $simd {
43                $simd(val)
44            }
45        }
46    };
47}
48
49// Define SIMD vector types.
50simd_type!(F32x4, f32, LEN_X32);
51simd_type!(I32x4, i32, LEN_X32);
52simd_type!(I16x8, i16, LEN_X32 * 2);
53simd_type!(I8x16, i8, LEN_X32 * 4);
54simd_type!(U8x16, u8, LEN_X32 * 4);
55simd_type!(U16x8, u16, LEN_X32 * 2);
56simd_type!(U32x4, u32, LEN_X32);
57
58// Define mask vector types. `Mn` is a mask for a vector with n-bit lanes.
59simd_type!(M32, i32, LEN_X32);
60simd_type!(M16, i16, LEN_X32 * 2);
61simd_type!(M8, i8, LEN_X32 * 4);
62
63#[derive(Copy, Clone)]
64pub struct GenericIsa {
65    _private: (),
66}
67
68impl GenericIsa {
69    pub fn new() -> Self {
70        GenericIsa { _private: () }
71    }
72}
73
74impl Default for GenericIsa {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80// Safety: Instructions used by generic ISA are always supported.
81unsafe impl Isa for GenericIsa {
82    type M32 = M32;
83    type M16 = M16;
84    type M8 = M8;
85    type F32 = F32x4;
86    type I32 = I32x4;
87    type I16 = I16x8;
88    type I8 = I8x16;
89    type U8 = U8x16;
90    type U16 = U16x8;
91    type U32 = U32x4;
92    type Bits = I32x4;
93
94    fn f32(self) -> impl FloatOps<f32, Simd = Self::F32, Int = Self::I32> {
95        self
96    }
97
98    fn i32(
99        self,
100    ) -> impl SignedIntOps<i32, Simd = Self::I32>
101    + NarrowSaturate<i32, i16, Output = Self::I16>
102    + Concat<i32>
103    + ToFloat<i32, Output = Self::F32> {
104        self
105    }
106
107    fn i16(
108        self,
109    ) -> impl SignedIntOps<i16, Simd = Self::I16>
110    + NarrowSaturate<i16, u8, Output = Self::U8>
111    + Extend<i16, Output = Self::I32>
112    + Interleave<i16> {
113        self
114    }
115
116    fn i8(
117        self,
118    ) -> impl SignedIntOps<i8, Simd = Self::I8> + Extend<i8, Output = Self::I16> + Interleave<i8>
119    {
120        self
121    }
122
123    fn u8(
124        self,
125    ) -> impl IntOps<u8, Simd = Self::U8> + Extend<u8, Output = Self::U16> + Interleave<u8> {
126        self
127    }
128
129    fn u16(self) -> impl IntOps<u16, Simd = Self::U16> {
130        self
131    }
132
133    fn m32(self) -> impl MaskOps<Self::M32> {
134        self
135    }
136
137    fn m16(self) -> impl MaskOps<Self::M16> {
138        self
139    }
140
141    fn m8(self) -> impl MaskOps<Self::M8> {
142        self
143    }
144}
145
146macro_rules! simd_ops_common {
147    ($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
148        #[inline]
149        fn len(self) -> usize {
150            $len
151        }
152
153        #[inline]
154        fn first_n_mask(self, n: usize) -> $mask {
155            let mask = std::array::from_fn(|i| if i < n { !0 } else { 0 });
156            $mask(mask)
157        }
158
159        #[inline]
160        unsafe fn load_ptr_mask(
161            self,
162            ptr: *const <$simd as Simd>::Elem,
163            mask: <$simd as Simd>::Mask,
164        ) -> $simd {
165            let mask_array = mask.0;
166            let mut vec = <Self as NumOps<$elem>>::zero(self).0;
167            for i in 0..mask_array.len() {
168                if mask_array[i] != 0 {
169                    vec[i] = *ptr.add(i);
170                }
171            }
172            self.load_ptr(vec.as_ref().as_ptr())
173        }
174
175        #[inline]
176        unsafe fn store_ptr_mask(
177            self,
178            x: $simd,
179            ptr: *mut <$simd as Simd>::Elem,
180            mask: <$simd as Simd>::Mask,
181        ) {
182            let mask_array = mask.0;
183            let x_array = x.0;
184            for i in 0..<Self as NumOps<$elem>>::len(self) {
185                if mask_array[i] != 0 {
186                    *ptr.add(i) = x_array[i];
187                }
188            }
189        }
190
191        #[inline]
192        fn add(self, x: $simd, y: $simd) -> $simd {
193            x.map_with(y, |x, y| x + y)
194        }
195
196        #[inline]
197        fn sub(self, x: $simd, y: $simd) -> $simd {
198            x.map_with(y, |x, y| x - y)
199        }
200
201        #[inline]
202        fn mul(self, x: $simd, y: $simd) -> $simd {
203            x.map_with(y, |x, y| x * y)
204        }
205
206        #[inline]
207        fn mul_add(self, a: $simd, b: $simd, c: $simd) -> $simd {
208            let xs = array::from_fn(|i| a.0[i] * b.0[i] + c.0[i]);
209            $simd(xs)
210        }
211
212        #[inline]
213        fn eq(self, x: $simd, y: $simd) -> $mask {
214            x.map_with(y, |x, y| if x == y { !0 } else { 0 })
215        }
216
217        #[inline]
218        fn ge(self, x: $simd, y: $simd) -> $mask {
219            x.map_with(y, |x, y| if x >= y { !0 } else { 0 })
220        }
221
222        #[inline]
223        fn gt(self, x: $simd, y: $simd) -> $mask {
224            x.map_with(y, |x, y| if x > y { !0 } else { 0 })
225        }
226
227        #[inline]
228        fn min(self, x: $simd, y: $simd) -> $simd {
229            x.map_with(y, |x, y| x.min(y))
230        }
231
232        #[inline]
233        fn max(self, x: $simd, y: $simd) -> $simd {
234            x.map_with(y, |x, y| x.max(y))
235        }
236
237        #[inline]
238        fn splat(self, x: $elem) -> $simd {
239            $simd([x; $len])
240        }
241
242        #[inline]
243        unsafe fn load_ptr(self, ptr: *const $elem) -> $simd {
244            let xs = array::from_fn(|i| *ptr.add(i));
245            $simd(xs)
246        }
247
248        #[inline]
249        fn select(self, x: $simd, y: $simd, mask: <$simd as Simd>::Mask) -> $simd {
250            let xs = array::from_fn(|i| if mask.0[i] != 0 { x.0[i] } else { y.0[i] });
251            $simd(xs)
252        }
253
254        #[inline]
255        unsafe fn store_ptr(self, x: $simd, ptr: *mut $elem) {
256            for i in 0..$len {
257                *ptr.add(i) = x.0[i];
258            }
259        }
260    };
261}
262
263macro_rules! simd_int_ops_common {
264    ($simd:ty) => {
265        #[inline]
266        fn and(self, x: $simd, y: $simd) -> $simd {
267            x.map_with(y, |x, y| x & y)
268        }
269
270        #[inline]
271        fn or(self, x: $simd, y: $simd) -> $simd {
272            x.map_with(y, |x, y| x | y)
273        }
274
275        #[inline]
276        fn not(self, x: $simd) -> $simd {
277            x.map(|x| !x)
278        }
279
280        #[inline]
281        fn xor(self, x: $simd, y: $simd) -> $simd {
282            x.map_with(y, |x, y| x ^ y)
283        }
284    };
285}
286
287unsafe impl NumOps<f32> for GenericIsa {
288    type Simd = F32x4;
289
290    simd_ops_common!(F32x4, f32, 4, M32);
291
292    #[inline]
293    fn and(self, x: F32x4, y: F32x4) -> F32x4 {
294        x.map_with(y, |x, y| f32::from_bits(x.to_bits() & y.to_bits()))
295    }
296
297    #[inline]
298    fn not(self, x: F32x4) -> F32x4 {
299        x.map(|x| f32::from_bits(!x.to_bits()))
300    }
301
302    #[inline]
303    fn or(self, x: F32x4, y: F32x4) -> F32x4 {
304        x.map_with(y, |x, y| f32::from_bits(x.to_bits() | y.to_bits()))
305    }
306
307    #[inline]
308    fn xor(self, x: F32x4, y: F32x4) -> F32x4 {
309        x.map_with(y, |x, y| f32::from_bits(x.to_bits() ^ y.to_bits()))
310    }
311}
312
313impl FloatOps<f32> for GenericIsa {
314    type Int = <Self as Isa>::I32;
315
316    #[inline]
317    fn div(self, x: F32x4, y: F32x4) -> F32x4 {
318        x.map_with(y, |x, y| x / y)
319    }
320
321    #[inline]
322    fn round_ties_even(self, x: F32x4) -> F32x4 {
323        x.map(|x| x.round_ties_even())
324    }
325
326    #[inline]
327    fn neg(self, x: F32x4) -> F32x4 {
328        x.map(|x| -x)
329    }
330
331    #[inline]
332    fn abs(self, x: F32x4) -> F32x4 {
333        x.map(|x| x.abs())
334    }
335
336    #[inline]
337    fn to_int_trunc(self, x: F32x4) -> Self::Int {
338        x.map(|x| x as i32)
339    }
340
341    #[inline]
342    fn to_int_round(self, x: F32x4) -> Self::Int {
343        x.map(|x| x.round_ties_even() as i32)
344    }
345}
346
347macro_rules! impl_simd_int_ops {
348    ($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
349        unsafe impl NumOps<$elem> for GenericIsa {
350            type Simd = $simd;
351
352            simd_ops_common!($simd, $elem, $len, $mask);
353            simd_int_ops_common!($simd);
354        }
355
356        impl IntOps<$elem> for GenericIsa {
357            #[inline]
358            fn shift_left<const SHIFT: i32>(self, x: $simd) -> $simd {
359                x.map(|x| x << SHIFT)
360            }
361
362            #[inline]
363            fn shift_right<const SHIFT: i32>(self, x: $simd) -> $simd {
364                x.map(|x| x >> SHIFT)
365            }
366        }
367    };
368}
369
370macro_rules! impl_simd_signed_int_ops {
371    ($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
372        impl_simd_int_ops!($simd, $elem, $len, $mask);
373
374        impl SignedIntOps<$elem> for GenericIsa {
375            #[inline]
376            fn neg(self, x: $simd) -> $simd {
377                x.map(|x| -x)
378            }
379        }
380    };
381}
382
383impl_simd_signed_int_ops!(I32x4, i32, 4, M32);
384impl_simd_signed_int_ops!(I16x8, i16, 8, M16);
385impl_simd_signed_int_ops!(I8x16, i8, 16, M8);
386
387macro_rules! impl_extend {
388    ($src:ty, $elem:ty, $dst:ty) => {
389        impl Extend<$elem> for GenericIsa {
390            type Output = $dst;
391
392            fn extend(self, x: $src) -> ($dst, $dst) {
393                let extended = x.0.map(|x| x as <$dst as Simd>::Elem);
394                let low = array::from_fn(|i| extended[i]);
395                let high = array::from_fn(|i| extended[i + extended.len() / 2]);
396                (low.into(), high.into())
397            }
398        }
399    };
400}
401impl_extend!(I8x16, i8, I16x8);
402impl_extend!(I16x8, i16, I32x4);
403impl_extend!(U8x16, u8, U16x8);
404
405macro_rules! impl_concat {
406    ($elem:ty, $simd:ty) => {
407        impl Concat<$elem> for GenericIsa {
408            fn concat_low(self, a: $simd, b: $simd) -> $simd {
409                let half_len = a.0.len() / 2;
410                array::from_fn(|i| {
411                    if i < half_len {
412                        a.0[i]
413                    } else {
414                        b.0[i - half_len]
415                    }
416                })
417                .into()
418            }
419
420            fn concat_high(self, a: $simd, b: $simd) -> $simd {
421                let half_len = a.0.len() / 2;
422                array::from_fn(|i| {
423                    if i < half_len {
424                        a.0[half_len + i]
425                    } else {
426                        b.0[i]
427                    }
428                })
429                .into()
430            }
431        }
432    };
433}
434
435impl_concat!(i32, I32x4);
436
437macro_rules! impl_interleave {
438    ($elem:ty, $simd:ty) => {
439        impl Interleave<$elem> for GenericIsa {
440            fn interleave_low(self, a: $simd, b: $simd) -> $simd {
441                array::from_fn(|i| if i % 2 == 0 { a.0[i / 2] } else { b.0[i / 2] }).into()
442            }
443
444            fn interleave_high(self, a: $simd, b: $simd) -> $simd {
445                let start = a.0.len() / 2;
446                array::from_fn(|i| {
447                    if i % 2 == 0 {
448                        a.0[start + i / 2]
449                    } else {
450                        b.0[start + i / 2]
451                    }
452                })
453                .into()
454            }
455        }
456    };
457}
458impl_interleave!(i8, I8x16);
459impl_interleave!(i16, I16x8);
460impl_interleave!(u8, U8x16);
461
462impl_simd_int_ops!(U8x16, u8, 16, M8);
463impl_simd_int_ops!(U16x8, u16, 8, M16);
464
465impl ToFloat<i32> for GenericIsa {
466    type Output = F32x4;
467
468    fn to_float(self, x: I32x4) -> Self::Output {
469        F32x4(x.0.map(|x| x as f32))
470    }
471}
472
473trait NarrowSaturateElem<T> {
474    fn narrow_saturate(self) -> T;
475}
476
477impl NarrowSaturateElem<i16> for i32 {
478    fn narrow_saturate(self) -> i16 {
479        self.clamp(i16::MIN as i32, i16::MAX as i32) as i16
480    }
481}
482
483impl NarrowSaturateElem<u8> for i16 {
484    fn narrow_saturate(self) -> u8 {
485        self.clamp(u8::MIN as i16, u8::MAX as i16) as u8
486    }
487}
488
489macro_rules! impl_narrow {
490    ($from:ident, $from_elem:ty, $to:ident, $to_elem:ty) => {
491        impl NarrowSaturate<$from_elem, $to_elem> for GenericIsa {
492            type Output = $to;
493
494            fn narrow_saturate(self, lo: $from, hi: $from) -> $to {
495                let mid = lo.0.len() / 2;
496                let xs = array::from_fn(|i| {
497                    let x = if i < mid { lo.0[i] } else { hi.0[i] };
498                    x.narrow_saturate()
499                });
500                $to(xs)
501            }
502        }
503    };
504}
505impl_narrow!(I32x4, i32, I16x8, i16);
506impl_narrow!(I16x8, i16, U8x16, u8);
507
508macro_rules! impl_mask {
509    ($mask:ident, $len:expr) => {
510        impl Mask for $mask {
511            type Array = [bool; $len];
512
513            #[inline]
514            fn to_array(self) -> Self::Array {
515                let array = self.0;
516                array::from_fn(|i| array[i] != 0)
517            }
518        }
519
520        unsafe impl MaskOps<$mask> for GenericIsa {
521            #[inline]
522            fn and(self, x: $mask, y: $mask) -> $mask {
523                let xs = array::from_fn(|i| x.0[i] & y.0[i]);
524                $mask(xs)
525            }
526
527            #[inline]
528            fn any(self, x: $mask) -> bool {
529                x.0.iter().any(|x| *x != 0)
530            }
531
532            #[inline]
533            fn all(self, x: $mask) -> bool {
534                x.0.iter().all(|x| *x != 0)
535            }
536        }
537    };
538}
539
540impl_mask!(M32, LEN_X32);
541impl_mask!(M16, LEN_X32 * 2);
542impl_mask!(M8, LEN_X32 * 4);
543
544macro_rules! impl_simd {
545    ($simd:ty, $elem:ty, $mask:ty, $len:expr) => {
546        impl Simd for $simd {
547            type Mask = $mask;
548            type Elem = $elem;
549            type Array = [$elem; $len];
550            type Isa = GenericIsa;
551
552            #[inline]
553            fn to_bits(self) -> <Self::Isa as Isa>::Bits {
554                #[allow(clippy::useless_transmute)]
555                I32x4(unsafe { transmute::<[$elem; $len], [i32; LEN_X32]>(self.0) })
556            }
557
558            #[inline]
559            fn from_bits(bits: <Self::Isa as Isa>::Bits) -> Self {
560                #[allow(clippy::useless_transmute)]
561                Self(unsafe { transmute::<[i32; LEN_X32], [$elem; $len]>(bits.0) })
562            }
563
564            #[inline]
565            fn to_array(self) -> Self::Array {
566                self.0
567            }
568        }
569    };
570}
571
572impl_simd!(F32x4, f32, M32, 4);
573impl_simd!(I32x4, i32, M32, 4);
574impl_simd!(I16x8, i16, M16, 8);
575impl_simd!(I8x16, i8, M8, 16);
576impl_simd!(U8x16, u8, M8, 16);
577impl_simd!(U16x8, u16, M16, 8);
578impl_simd!(U32x4, u32, M32, 4);