Skip to main content

carbon_simd/
lib.rs

1//! Traits and type for SIMD.
2//!
3//! # Example
4//! ```
5//! use carbon_simd::*;
6//!
7//! let mut left = [1, 2, 3, 4, 5, 6, 7, 8];
8//! let right = [7, 6, 5, 4, 3, 2, 1, 0];
9//!
10//! let mut left_simd = SimdMut::new(&mut left);
11//! let right_simd = SimdRef::new(&right);
12//!
13//! left_simd += &right_simd;
14//!
15//! assert_eq!(left, [8, 8, 8, 8, 8, 8, 8, 8]);
16//!
17//! ```
18//!
19
20#[cfg(target_arch = "x86_64")]
21mod x86_64;
22
23use core::ops::*;
24use num_traits::Float;
25use num_traits::Num;
26
27/// SIMD vector trait
28pub trait Simd<T: SimdElement>: Deref<Target = [T]> {}
29
30/// A trait for types that may be used as SIMD vector elements.
31pub unsafe trait SimdElement: Sized {
32    /// raw SIMD vector type like `__m256i`
33    type Vector: Sized + Copy;
34    /// Capacity of `Vector`
35    const VECTOR_LEN: usize;
36
37    /// Returns these SIMD functions are available.
38    fn is_available() -> bool;
39
40    /// Loads values to raw SIMD vector.
41    /// # Safety
42    /// Make sure `Self::is_available()` returns true and length of `src` is not less than `Self::VECTOR_LEN`.
43    unsafe fn load(src: *const Self) -> Self::Vector;
44
45    /// Loads values to raw SIMD vector.
46    /// # Safety
47    /// Make sure `Self::is_available()` returns true and length of `src` is not less than `len`.
48    unsafe fn load_partial(src: *const Self, len: usize) -> Self::Vector;
49
50    /// Stores raw SIMD vector to `dst`.
51    /// # Safety
52    /// Make sure `Self::is_available()` returns true and capacity of `dst` is not less than `Self::VECTOR_LEN`.
53    unsafe fn store(dst: *mut Self, src: Self::Vector);
54
55    /// Stores raw SIMD vector to `dst`.
56    /// # Safety
57    /// Make sure `Self::is_available()` returns true and capacity of `dst` is not less than `len`.
58    unsafe fn store_partial(dst: *mut Self, src: Self::Vector, len: usize);
59
60    /// Creates a raw SIMD vector filled with `value`.
61    /// # Safety
62    /// Make sure `Self::is_available()` returns true.
63    unsafe fn set(value: Self) -> Self::Vector;
64}
65
66/// A trait for type that may be used as numeric SIMD vector elements.
67pub unsafe trait SimdNumElement: SimdElement + Num {
68    /// Adds `left` and `right`.
69    /// # Safety
70    /// Make sure `Self::is_available()` returns true.
71    unsafe fn add(left: Self::Vector, right: Self::Vector) -> Self::Vector;
72
73    /// Substructs `left` and `right`.
74    /// # Safety
75    /// Make sure `Self::is_available()` returns true.
76    unsafe fn sub(left: Self::Vector, right: Self::Vector) -> Self::Vector;
77
78    /// Multiples `left` and `right`.
79    /// # Safety
80    /// Make sure `Self::is_available()` returns true.
81    unsafe fn mul(left: Self::Vector, right: Self::Vector) -> Self::Vector;
82
83    /// Divides `left` by `right`, then return the result.
84    /// # Safety
85    /// Make sure `Self::is_available()` returns true.
86    unsafe fn div(left: Self::Vector, right: Self::Vector) -> Self::Vector;
87
88    /// Calculates `a * b + c`. You should overload this when the architecture has dedicated operator.
89    /// # Safety
90    /// Make sure `Self::is_available()` returns true.
91    #[inline(always)]
92    unsafe fn fma(a: Self::Vector, b: Self::Vector, c: Self::Vector) -> Self::Vector {
93        unsafe { <Self as SimdNumElement>::add(<Self as SimdNumElement>::mul(a, b), c) }
94    }
95
96    /// Raises a number to an integer power.
97    /// # Safety
98    /// Make sure `Self::is_available()` returns true.
99    #[inline(always)]
100    unsafe fn powi(x: Self::Vector, n: i32) -> Self::Vector {
101        let mut result = unsafe { Self::set(Self::one()) };
102        if n < 0 {
103            for _ in 0..-n {
104                result = unsafe { <Self as SimdNumElement>::div(result, x) };
105            }
106        }
107        if 0 < n {
108            for _ in 0..n {
109                result = unsafe { <Self as SimdNumElement>::mul(result, x) };
110            }
111        }
112
113        result
114    }
115}
116
117/// A trait for types that may be used as floating point SIMD vector elements.
118pub unsafe trait SimdFloatingElement: SimdNumElement + Float {
119    /// Returns sqrt of `x`.
120    /// # Safety
121    /// Make sure `Self::is_available()` returns true.
122    unsafe fn sqrt(x: Self::Vector) -> Self::Vector;
123
124    /// Returns exp(x).
125    /// # Safety
126    /// Make sure `Self::is_available()` returns true.
127    unsafe fn exp(x: Self::Vector) -> Self::Vector;
128
129    /// Returns tanh of `x`.
130    /// # Safety
131    /// Make sure `Self::is_available()` returns true.
132    unsafe fn tanh(x: Self::Vector) -> Self::Vector;
133}
134
135/// A trait for types that may be used as integer SIMD vector elements.
136pub unsafe trait SimdIntegerElement: SimdNumElement + Num {
137    /// Calculate and of `left` and `right`.
138    /// # Safety
139    /// Make sure `Self::is_available()` returns true.
140    unsafe fn and(left: Self::Vector, right: Self::Vector) -> Self::Vector;
141
142    /// Calculates or of `left` and `right`.
143    /// # Safety
144    /// Make sure `Self::is_available()` returns true.
145    unsafe fn or(left: Self::Vector, right: Self::Vector) -> Self::Vector;
146
147    /// Calculates not of `left` and `right`.
148    /// # Safety
149    /// Make sure `Self::is_available()` returns true.
150    unsafe fn not(left: Self::Vector) -> Self::Vector;
151
152    /// Calculates xor of `left` and `right`.
153    /// # Safety
154    /// Make sure `Self::is_available()` returns true.
155    unsafe fn xor(left: Self::Vector, right: Self::Vector) -> Self::Vector;
156}
157
158/// Mutable SIMD wrapper structure
159#[repr(transparent)]
160#[derive(Debug, PartialEq, Clone, Copy)]
161pub struct SimdRef<'a, T: SimdElement>(&'a [T]);
162
163impl<'a, T: SimdElement> SimdRef<'a, T> {
164    /// Creates new `SimdRef<T>`.
165    pub fn new(slice: &'a [T]) -> Self {
166        Self(slice)
167    }
168}
169
170impl<'a, T: SimdElement> Deref for SimdRef<'a, T> {
171    type Target = [T];
172
173    fn deref(&self) -> &[T] {
174        &self.0
175    }
176}
177
178impl<'a, T: SimdElement> Simd<T> for SimdRef<'a, T> {}
179
180/// Mutable SIMD wrapper structure
181#[derive(Debug, PartialEq)]
182#[repr(transparent)]
183pub struct SimdMut<'a, T: SimdElement>(&'a mut [T]);
184
185impl<'a, T: SimdElement> SimdMut<'a, T> {
186    /// Creates new `SimdMut<T>`.
187    pub fn new(slice: &'a mut [T]) -> Self {
188        Self(slice)
189    }
190}
191
192impl<'a, T: SimdNumElement> SimdMut<'a, T> {
193    /// Raises a number to an integer power.
194    pub fn powi(&mut self, n: i32) {
195        if !T::is_available() {
196            panic!("simd is not available");
197        }
198
199        let len = self.len();
200        let x = self.as_mut_ptr();
201
202        unsafe {
203            for i in 0..len / T::VECTOR_LEN {
204                let offset = i * T::VECTOR_LEN;
205                let x_vector = T::load(x.add(offset));
206                let result_vector = T::powi(x_vector, n);
207                T::store(x.add(offset), result_vector);
208            }
209
210            let remaining = len % T::VECTOR_LEN;
211            if remaining != 0 {
212                let offset = len - remaining;
213                let x_vector = T::load_partial(x.add(offset), remaining);
214                let result_vector = T::powi(x_vector, n);
215                T::store(x.add(offset), result_vector);
216            }
217        }
218    }
219}
220
221impl<'a, T: SimdFloatingElement> SimdMut<'a, T> {
222    /// Calculates square root.
223    pub fn sqrt(&mut self) {
224        if !T::is_available() {
225            panic!("simd is not available");
226        }
227
228        let len = self.len();
229        let x = self.as_mut_ptr();
230
231        unsafe {
232            for i in 0..len / T::VECTOR_LEN {
233                let offset = i * T::VECTOR_LEN;
234                let x_vector = T::load(x.add(offset));
235                let result_vector = <T as SimdFloatingElement>::sqrt(x_vector);
236                T::store(x.add(offset), result_vector);
237            }
238
239            let remaining = len % T::VECTOR_LEN;
240            if remaining != 0 {
241                let offset = len - remaining;
242                let x_vector = T::load_partial(x.add(offset), remaining);
243                let result_vector = <T as SimdFloatingElement>::sqrt(x_vector);
244                T::store(x.add(offset), result_vector);
245            }
246        }
247    }
248
249    /// Calculates `e^self`
250    pub fn exp(&mut self) {
251        if !T::is_available() {
252            panic!("simd is not available");
253        }
254
255        let len = self.len();
256        let x = self.as_mut_ptr();
257
258        unsafe {
259            for i in 0..len / T::VECTOR_LEN {
260                let offset = i * T::VECTOR_LEN;
261                let x_vector = T::load(x.add(offset));
262                let result_vector = <T as SimdFloatingElement>::exp(x_vector);
263                T::store(x.add(offset), result_vector);
264            }
265
266            let remaining = len % T::VECTOR_LEN;
267            if remaining != 0 {
268                let offset = len - remaining;
269                let x_vector = T::load_partial(x.add(offset), remaining);
270                let result_vector = <T as SimdFloatingElement>::exp(x_vector);
271                T::store(x.add(offset), result_vector);
272            }
273        }
274    }
275
276    /// Calculates `tanh`
277    pub fn tanh(&mut self) {
278        if !T::is_available() {
279            panic!("simd is not available");
280        }
281
282        let len = self.len();
283        let x = self.as_mut_ptr();
284
285        unsafe {
286            for i in 0..len / T::VECTOR_LEN {
287                let offset = i * T::VECTOR_LEN;
288                let x_vector = T::load(x.add(offset));
289                let result_vector = <T as SimdFloatingElement>::tanh(x_vector);
290                T::store(x.add(offset), result_vector);
291            }
292
293            let remaining = len % T::VECTOR_LEN;
294            if remaining != 0 {
295                let offset = len - remaining;
296                let x_vector = T::load_partial(x.add(offset), remaining);
297                let result_vector = <T as SimdFloatingElement>::tanh(x_vector);
298                T::store(x.add(offset), result_vector);
299            }
300        }
301    }
302}
303
304impl<'a, T: SimdElement + SimdNumElement, R: Simd<T>> AddAssign<&R> for SimdMut<'a, T> {
305    fn add_assign(&mut self, rhs: &R) {
306        if !T::is_available() {
307            panic!("simd is not available");
308        }
309
310        let len = self.len().min(rhs.len());
311        let left = self.as_mut_ptr();
312        let right = rhs.as_ptr();
313
314        unsafe {
315            for i in 0..len / T::VECTOR_LEN {
316                let offset = i * T::VECTOR_LEN;
317                let left_vector = T::load(left.add(offset));
318                let right_vector = T::load(right.add(offset));
319                let result_vector = <T as SimdNumElement>::add(left_vector, right_vector);
320                T::store(left.add(offset), result_vector);
321            }
322
323            let remaining = len % T::VECTOR_LEN;
324            if remaining != 0 {
325                let offset = len - remaining;
326                let left_vector = T::load_partial(left.add(offset), remaining);
327                let right_vector = T::load_partial(right.add(offset), remaining);
328                let result_vector = <T as SimdNumElement>::add(left_vector, right_vector);
329                T::store_partial(left.add(offset), result_vector, remaining);
330            }
331        }
332    }
333}
334
335impl<'a, T: SimdElement + SimdNumElement, R: Simd<T>> SubAssign<&R> for SimdMut<'a, T> {
336    fn sub_assign(&mut self, rhs: &R) {
337        let len = self.len().min(rhs.len());
338        let left = self.as_mut_ptr();
339        let right = rhs.as_ptr();
340
341        if !T::is_available() {
342            panic!("simd is not available");
343        }
344
345        unsafe {
346            for i in 0..len / T::VECTOR_LEN {
347                let offset = i * T::VECTOR_LEN;
348                let left_vector = T::load(left.add(offset));
349                let right_vector = T::load(right.add(offset));
350                let result_vector = <T as SimdNumElement>::sub(left_vector, right_vector);
351                T::store(left.add(offset), result_vector);
352            }
353
354            let remaining = len % T::VECTOR_LEN;
355            if remaining != 0 {
356                let offset = len - remaining;
357                let left_vector = T::load_partial(left.add(offset), remaining);
358                let right_vector = T::load_partial(right.add(offset), remaining);
359                let result_vector = <T as SimdNumElement>::sub(left_vector, right_vector);
360                T::store_partial(left.add(offset), result_vector, remaining);
361            }
362        }
363    }
364}
365
366impl<'a, T: SimdElement + SimdNumElement, R: Simd<T>> MulAssign<&R> for SimdMut<'a, T> {
367    fn mul_assign(&mut self, rhs: &R) {
368        let len = self.len().min(rhs.len());
369        let left = self.as_mut_ptr();
370        let right = rhs.as_ptr();
371
372        if !T::is_available() {
373            panic!("simd is not available");
374        }
375
376        unsafe {
377            for i in 0..len / T::VECTOR_LEN {
378                let offset = i * T::VECTOR_LEN;
379                let left_vector = T::load(left.add(offset));
380                let right_vector = T::load(right.add(offset));
381                let result_vector = <T as SimdNumElement>::mul(left_vector, right_vector);
382                T::store(left.add(offset), result_vector);
383            }
384
385            let remaining = len % T::VECTOR_LEN;
386            if remaining != 0 {
387                let offset = len - remaining;
388                let left_vector = T::load_partial(left.add(offset), remaining);
389                let right_vector = T::load_partial(right.add(offset), remaining);
390                let result_vector = <T as SimdNumElement>::mul(left_vector, right_vector);
391                T::store_partial(left.add(offset), result_vector, remaining);
392            }
393        }
394    }
395}
396
397impl<'a, T: SimdElement + SimdNumElement, R: Simd<T>> DivAssign<&R> for SimdMut<'a, T> {
398    fn div_assign(&mut self, rhs: &R) {
399        let len = self.len().min(rhs.len());
400        let left = self.as_mut_ptr();
401        let right = rhs.as_ptr();
402
403        if !T::is_available() {
404            panic!("simd is not available");
405        }
406
407        unsafe {
408            for i in 0..len / T::VECTOR_LEN {
409                let offset = i * T::VECTOR_LEN;
410                let left_vector = T::load(left.add(offset));
411                let right_vector = T::load(right.add(offset));
412                let result_vector = <T as SimdNumElement>::div(left_vector, right_vector);
413                T::store(left.add(offset), result_vector);
414            }
415
416            let remaining = len % T::VECTOR_LEN;
417            if remaining != 0 {
418                let offset = len - remaining;
419                let left_vector = T::load_partial(left.add(offset), remaining);
420                let right_vector = T::load_partial(right.add(offset), remaining);
421                let result_vector = <T as SimdNumElement>::div(left_vector, right_vector);
422                T::store_partial(left.add(offset), result_vector, remaining);
423            }
424        }
425    }
426}
427
428impl<'a, T: SimdElement> Deref for SimdMut<'a, T> {
429    type Target = [T];
430
431    fn deref(&self) -> &[T] {
432        &self.0
433    }
434}
435
436impl<'a, T: SimdElement> DerefMut for SimdMut<'a, T> {
437    fn deref_mut(&mut self) -> &mut [T] {
438        &mut self.0
439    }
440}
441
442impl<'a, T: SimdElement> Simd<T> for SimdMut<'a, T> {}