lance_linalg/simd/
i32.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::Formatter;
5use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
6
7#[cfg(target_arch = "aarch64")]
8use std::arch::aarch64::*;
9#[cfg(target_arch = "loongarch64")]
10use std::arch::loongarch64::*;
11#[cfg(target_arch = "x86_64")]
12use std::arch::x86_64::*;
13#[cfg(target_arch = "loongarch64")]
14use std::mem::transmute;
15
16use super::SIMD;
17
18#[allow(non_camel_case_types)]
19#[cfg(target_arch = "x86_64")]
20#[derive(Clone, Copy)]
21pub struct i32x8(pub(crate) __m256i);
22
23#[allow(non_camel_case_types)]
24#[cfg(target_arch = "aarch64")]
25#[derive(Clone, Copy)]
26pub struct i32x8(int32x4x2_t);
27
28#[allow(non_camel_case_types)]
29#[cfg(target_arch = "loongarch64")]
30#[derive(Clone, Copy)]
31pub struct i32x8(v8i32);
32
33impl std::fmt::Debug for i32x8 {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        let mut arr = [0; 8];
36        unsafe {
37            self.store_unaligned(arr.as_mut_ptr());
38        }
39        write!(f, "i32x8({:?})", arr)
40    }
41}
42
43impl From<&[i32]> for i32x8 {
44    fn from(value: &[i32]) -> Self {
45        unsafe { Self::load_unaligned(value.as_ptr()) }
46    }
47}
48
49impl From<&[i32; 8]> for i32x8 {
50    fn from(value: &[i32; 8]) -> Self {
51        unsafe { Self::load_unaligned(value.as_ptr()) }
52    }
53}
54
55impl SIMD<i32, 8> for i32x8 {
56    #[inline]
57    fn splat(val: i32) -> Self {
58        #[cfg(target_arch = "x86_64")]
59        unsafe {
60            Self(_mm256_set1_epi32(val))
61        }
62        #[cfg(target_arch = "aarch64")]
63        unsafe {
64            Self(int32x4x2_t(vdupq_n_s32(val), vdupq_n_s32(val)))
65        }
66        #[cfg(target_arch = "loongarch64")]
67        unsafe {
68            Self(lasx_xvreplgr2vr_w(val))
69        }
70    }
71
72    #[inline]
73    fn zeros() -> Self {
74        #[cfg(target_arch = "x86_64")]
75        unsafe {
76            Self(_mm256_setzero_si256())
77        }
78        #[cfg(target_arch = "aarch64")]
79        {
80            Self::splat(0)
81        }
82        #[cfg(target_arch = "loongarch64")]
83        {
84            Self::splat(0)
85        }
86    }
87
88    #[inline]
89    unsafe fn load(ptr: *const i32) -> Self {
90        #[cfg(target_arch = "x86_64")]
91        unsafe {
92            Self(_mm256_loadu_si256(ptr as *const __m256i))
93        }
94        #[cfg(target_arch = "aarch64")]
95        {
96            Self(vld1q_s32_x2(ptr))
97        }
98        #[cfg(target_arch = "loongarch64")]
99        {
100            Self(transmute(lasx_xvld::<0>(transmute(ptr))))
101        }
102    }
103
104    #[inline]
105    unsafe fn load_unaligned(ptr: *const i32) -> Self {
106        #[cfg(target_arch = "x86_64")]
107        unsafe {
108            Self(_mm256_loadu_si256(ptr as *const __m256i))
109        }
110        #[cfg(target_arch = "aarch64")]
111        {
112            Self(vld1q_s32_x2(ptr))
113        }
114        #[cfg(target_arch = "loongarch64")]
115        {
116            Self(transmute(lasx_xvld::<0>(transmute(ptr))))
117        }
118    }
119
120    #[inline]
121    unsafe fn store(&self, ptr: *mut i32) {
122        self.store_unaligned(ptr)
123    }
124
125    unsafe fn store_unaligned(&self, ptr: *mut i32) {
126        #[cfg(target_arch = "x86_64")]
127        unsafe {
128            _mm256_storeu_si256(ptr as *mut __m256i, self.0);
129        }
130        #[cfg(target_arch = "aarch64")]
131        unsafe {
132            vst1q_s32_x2(ptr, self.0)
133        }
134        #[cfg(target_arch = "loongarch64")]
135        unsafe {
136            lasx_xvst::<0>(transmute(self.0), transmute(ptr))
137        }
138    }
139
140    fn reduce_sum(&self) -> i32 {
141        #[cfg(target_arch = "x86_64")]
142        {
143            self.as_array().iter().sum()
144        }
145        #[cfg(target_arch = "aarch64")]
146        unsafe {
147            let sum = vaddq_s32(self.0 .0, self.0 .1);
148            vaddvq_s32(sum)
149        }
150        #[cfg(target_arch = "loongarch64")]
151        {
152            self.as_array().iter().sum()
153        }
154    }
155
156    fn reduce_min(&self) -> i32 {
157        todo!()
158    }
159
160    fn min(&self, rhs: &Self) -> Self {
161        #[cfg(target_arch = "x86_64")]
162        unsafe {
163            Self(_mm256_min_epi32(self.0, rhs.0))
164        }
165        #[cfg(target_arch = "aarch64")]
166        unsafe {
167            Self(int32x4x2_t(
168                vminq_s32(self.0 .0, rhs.0 .0),
169                vminq_s32(self.0 .1, rhs.0 .1),
170            ))
171        }
172        #[cfg(target_arch = "loongarch64")]
173        unsafe {
174            Self(lasx_xvmin_w(self.0, rhs.0))
175        }
176    }
177
178    fn find(&self, val: i32) -> Option<i32> {
179        #[cfg(target_arch = "x86_64")]
180        unsafe {
181            for i in 0..8 {
182                if self.as_array().get_unchecked(i) == &val {
183                    return Some(i as i32);
184                }
185            }
186        }
187        #[cfg(target_arch = "aarch64")]
188        unsafe {
189            let tgt = vdupq_n_s32(val);
190            let mut arr = [0; 8];
191            let mask1 = vceqq_s32(self.0 .0, tgt);
192            let mask2 = vceqq_s32(self.0 .1, tgt);
193            vst1q_u32(arr.as_mut_ptr(), mask1);
194            vst1q_u32(arr.as_mut_ptr().add(4), mask2);
195            for i in 0..8 {
196                if arr.get_unchecked(i) != &0 {
197                    return Some(i as i32);
198                }
199            }
200        }
201        #[cfg(target_arch = "loongarch64")]
202        unsafe {
203            for i in 0..8 {
204                if self.as_array().get_unchecked(i) == &val {
205                    return Some(i as i32);
206                }
207            }
208        }
209        None
210    }
211}
212
213impl Add for i32x8 {
214    type Output = Self;
215
216    #[inline]
217    fn add(self, rhs: Self) -> Self::Output {
218        #[cfg(target_arch = "x86_64")]
219        unsafe {
220            Self(_mm256_add_epi32(self.0, rhs.0))
221        }
222        #[cfg(target_arch = "aarch64")]
223        unsafe {
224            Self(int32x4x2_t(
225                vaddq_s32(self.0 .0, rhs.0 .0),
226                vaddq_s32(self.0 .1, rhs.0 .1),
227            ))
228        }
229        #[cfg(target_arch = "loongarch64")]
230        unsafe {
231            Self(lasx_xvadd_w(self.0, rhs.0))
232        }
233    }
234}
235
236impl AddAssign for i32x8 {
237    #[inline]
238    fn add_assign(&mut self, rhs: Self) {
239        #[cfg(target_arch = "x86_64")]
240        unsafe {
241            self.0 = _mm256_add_epi32(self.0, rhs.0);
242        }
243        #[cfg(target_arch = "aarch64")]
244        unsafe {
245            self.0 .0 = vaddq_s32(self.0 .0, rhs.0 .0);
246            self.0 .1 = vaddq_s32(self.0 .1, rhs.0 .1);
247        }
248        #[cfg(target_arch = "loongarch64")]
249        unsafe {
250            self.0 = lasx_xvadd_w(self.0, rhs.0);
251        }
252    }
253}
254
255impl Sub for i32x8 {
256    type Output = Self;
257
258    #[inline]
259    fn sub(self, rhs: Self) -> Self::Output {
260        #[cfg(target_arch = "x86_64")]
261        unsafe {
262            Self(_mm256_sub_epi32(self.0, rhs.0))
263        }
264        #[cfg(target_arch = "aarch64")]
265        unsafe {
266            Self(int32x4x2_t(
267                vsubq_s32(self.0 .0, rhs.0 .0),
268                vsubq_s32(self.0 .1, rhs.0 .1),
269            ))
270        }
271        #[cfg(target_arch = "loongarch64")]
272        unsafe {
273            Self(lasx_xvsub_w(self.0, rhs.0))
274        }
275    }
276}
277
278impl SubAssign for i32x8 {
279    #[inline]
280    fn sub_assign(&mut self, rhs: Self) {
281        #[cfg(target_arch = "x86_64")]
282        unsafe {
283            self.0 = _mm256_sub_epi32(self.0, rhs.0);
284        }
285        #[cfg(target_arch = "aarch64")]
286        unsafe {
287            self.0 .0 = vsubq_s32(self.0 .0, rhs.0 .0);
288            self.0 .1 = vsubq_s32(self.0 .1, rhs.0 .1);
289        }
290        #[cfg(target_arch = "loongarch64")]
291        unsafe {
292            self.0 = lasx_xvsub_w(self.0, rhs.0);
293        }
294    }
295}
296
297impl Mul for i32x8 {
298    type Output = Self;
299
300    #[inline]
301    fn mul(self, rhs: Self) -> Self::Output {
302        #[cfg(target_arch = "x86_64")]
303        unsafe {
304            Self(_mm256_mul_epi32(self.0, rhs.0))
305        }
306        #[cfg(target_arch = "aarch64")]
307        unsafe {
308            Self(int32x4x2_t(
309                vmulq_s32(self.0 .0, rhs.0 .0),
310                vmulq_s32(self.0 .1, rhs.0 .1),
311            ))
312        }
313        #[cfg(target_arch = "loongarch64")]
314        unsafe {
315            Self(lasx_xvmul_w(self.0, rhs.0))
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {}