Skip to main content

diskann_wide/arch/aarch64/
i32x4_.rs

1/*
2 * Copyright (c) Microsoft Corporation. All rights reserved.
3 * Licensed under the MIT license.
4 */
5
6use crate::{
7    Emulated, SIMDAbs, SIMDCast, SIMDDotProduct, SIMDMask, SIMDMulAdd, SIMDPartialEq,
8    SIMDPartialOrd, SIMDSelect, SIMDSumTree, SIMDVector, constant::Const, helpers,
9};
10
11// AArch64 masks
12use super::{
13    Neon, f32x4, i8x8, i8x16, i16x8, internal,
14    macros::{self, AArchLoadStore, AArchSplat},
15    masks::mask32x4,
16    u8x8, u8x16,
17};
18
19// AArch64 intrinsics
20use std::arch::{aarch64::*, asm};
21
22///////////////////
23// 32-bit signed //
24///////////////////
25
26macros::aarch64_define_register!(i32x4, int32x4_t, mask32x4, i32, 4, Neon);
27macros::aarch64_define_splat!(i32x4, vmovq_n_s32);
28macros::aarch64_define_loadstore!(i32x4, vld1q_s32, internal::load_first::i32x4, vst1q_s32, 4);
29
30helpers::unsafe_map_binary_op!(i32x4, std::ops::Add, add, vaddq_s32, "neon");
31helpers::unsafe_map_binary_op!(i32x4, std::ops::Sub, sub, vsubq_s32, "neon");
32helpers::unsafe_map_binary_op!(i32x4, std::ops::Mul, mul, vmulq_s32, "neon");
33helpers::unsafe_map_unary_op!(i32x4, SIMDAbs, abs_simd, vabsq_s32, "neon");
34macros::aarch64_define_fma!(i32x4, vmlaq_s32);
35
36macros::aarch64_define_cmp!(
37    i32x4,
38    vceqq_s32,
39    (vmvnq_u32),
40    vcltq_s32,
41    vcleq_s32,
42    vcgtq_s32,
43    vcgeq_s32
44);
45macros::aarch64_define_bitops!(
46    i32x4,
47    vmvnq_s32,
48    vandq_s32,
49    vorrq_s32,
50    veorq_s32,
51    (
52        vshlq_s32,
53        32,
54        vnegq_s32,
55        vminq_u32,
56        vreinterpretq_s32_u32,
57        vreinterpretq_u32_s32
58    ),
59    (u32, i32, vmovq_n_s32),
60);
61
62impl SIMDSumTree for i32x4 {
63    #[inline(always)]
64    fn sum_tree(self) -> i32 {
65        if cfg!(miri) {
66            self.emulated().sum_tree()
67        } else {
68            // SAFETY: Allowed by the `Neon` architecture.
69            unsafe { vaddvq_s32(self.0) }
70        }
71    }
72}
73
74impl SIMDSelect<i32x4> for mask32x4 {
75    #[inline(always)]
76    fn select(self, x: i32x4, y: i32x4) -> i32x4 {
77        // SAFETY: Allowed by the `Neon` architecture.
78        i32x4(unsafe { vbslq_s32(self.0, x.0, y.0) })
79    }
80}
81
82impl SIMDDotProduct<i16x8> for i32x4 {
83    #[inline(always)]
84    fn dot_simd(self, left: i16x8, right: i16x8) -> Self {
85        if cfg!(miri) {
86            use crate::AsSIMD;
87            self.emulated()
88                .dot_simd(left.emulated(), right.emulated())
89                .as_simd(self.arch())
90        } else {
91            let left = left.0;
92            let right = right.0;
93            // SAFETY: Allowed by the `Neon` architecture.
94            unsafe {
95                let lo: int32x4_t = vmull_s16(vget_low_s16(left), vget_low_s16(right));
96                let hi: int32x4_t = vmull_high_s16(left, right);
97                Self(vaddq_s32(self.0, vpaddq_s32(lo, hi)))
98            }
99        }
100    }
101}
102
103impl SIMDDotProduct<u8x16, i8x16> for i32x4 {
104    #[inline(always)]
105    fn dot_simd(self, left: u8x16, right: i8x16) -> Self {
106        if cfg!(miri) {
107            use crate::AsSIMD;
108            self.emulated()
109                .dot_simd(left.emulated(), right.emulated())
110                .as_simd(self.arch())
111        } else {
112            use crate::SplitJoin;
113
114            // SAFETY: The intrinsics used here are allowed by the implicit `Neon` architecture.
115            unsafe {
116                let left = left.split();
117                let right = right.split();
118
119                let left_evens: i16x8 = u8x8(vuzp1_u8(left.lo.0, left.hi.0)).into();
120                let left_odds: i16x8 = u8x8(vuzp2_u8(left.lo.0, left.hi.0)).into();
121
122                let right_evens: i16x8 = i8x8(vuzp1_s8(right.lo.0, right.hi.0)).into();
123                let right_odds: i16x8 = i8x8(vuzp2_s8(right.lo.0, right.hi.0)).into();
124
125                self.dot_simd(left_evens, right_evens)
126                    .dot_simd(left_odds, right_odds)
127            }
128        }
129    }
130}
131
132impl SIMDDotProduct<i8x16, u8x16> for i32x4 {
133    #[inline(always)]
134    fn dot_simd(self, left: i8x16, right: u8x16) -> Self {
135        self.dot_simd(right, left)
136    }
137}
138
139impl SIMDDotProduct<i8x16, i8x16> for i32x4 {
140    #[inline(always)]
141    fn dot_simd(self, left: i8x16, right: i8x16) -> Self {
142        if cfg!(miri) {
143            use crate::AsSIMD;
144            self.emulated()
145                .dot_simd(left.emulated(), right.emulated())
146                .as_simd(self.arch())
147        } else {
148            // SAFETY: Instantiating `Neon` implies `dotprod`.
149            //
150            // We need this wrapper to allow compilation of the underlying ASM when compiling
151            // without the `dotprod` feature globally enabled.
152            #[target_feature(enable = "dotprod")]
153            unsafe fn sdot(mut s: int32x4_t, x: int8x16_t, y: int8x16_t) -> int32x4_t {
154                // SAFETY: The `Neon` architecture implies `dotprod`, allowing us to use
155                // this intrinsic.
156                unsafe {
157                    asm!(
158                        "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b",
159                        inout(vreg) s,
160                        in(vreg) x,
161                        in(vreg) y,
162                        options(pure, nomem, nostack)
163                    );
164                }
165
166                s
167            }
168
169            // SAFETY: The `Neon` architecture guarantees the `dotprod` feature.
170            Self::from_underlying(self.arch(), unsafe { sdot(self.0, left.0, right.0) })
171        }
172    }
173}
174
175//-------------//
176// Conversions //
177//-------------//
178
179helpers::unsafe_map_cast!(
180    i32x4 => (f32, f32x4),
181    vcvtq_f32_s32,
182    "neon"
183);
184
185///////////
186// Tests //
187///////////
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils};
193
194    #[test]
195    fn miri_test_load() {
196        if let Some(arch) = test_neon() {
197            test_utils::test_load_simd::<i32, 4, i32x4>(arch);
198        }
199    }
200
201    #[test]
202    fn miri_test_store() {
203        if let Some(arch) = test_neon() {
204            test_utils::test_store_simd::<i32, 4, i32x4>(arch);
205        }
206    }
207
208    // constructors
209    #[test]
210    fn test_constructors() {
211        if let Some(arch) = test_neon() {
212            test_utils::ops::test_splat::<i32, 4, i32x4>(arch);
213        }
214    }
215
216    // Ops
217    test_utils::ops::test_add!(i32x4, 0x3017fd73c99cc633, test_neon());
218    test_utils::ops::test_sub!(i32x4, 0xfc627f10b5f8db8a, test_neon());
219    test_utils::ops::test_mul!(i32x4, 0x0f4caa80eceaa523, test_neon());
220    test_utils::ops::test_fma!(i32x4, 0xb8f702ba85375041, test_neon());
221    test_utils::ops::test_abs!(i32x4, 0xb8f702ba85375041, test_neon());
222
223    test_utils::ops::test_cmp!(i32x4, 0x941757bd5cc641a1, test_neon());
224
225    // Bit ops
226    test_utils::ops::test_bitops!(i32x4, 0xd62d8de09f82ed4e, test_neon());
227    test_utils::ops::test_select!(i32x4, 0xd62d8de09f82ed4e, test_neon());
228
229    // Dot Products
230    test_utils::dot_product::test_dot_product!(
231        (i16x8, i16x8) => i32x4,
232        0x145f89b446c03ff1,
233        test_neon()
234    );
235
236    test_utils::dot_product::test_dot_product!(
237        (u8x16, i8x16) => i32x4,
238        0x145f89b446c03ff1,
239        test_neon()
240    );
241
242    test_utils::dot_product::test_dot_product!(
243        (i8x16, u8x16) => i32x4,
244        0x145f89b446c03ff1,
245        test_neon()
246    );
247
248    test_utils::dot_product::test_dot_product!(
249        (i8x16, i8x16) => i32x4,
250        0x145f89b446c03ff1,
251        test_neon()
252    );
253
254    // Reductions
255    test_utils::ops::test_sumtree!(i32x4, 0xb9ac82ab23a855da, test_neon());
256
257    // Conversions
258    test_utils::ops::test_cast!(i32x4 => f32x4, 0xba8fe343fc9dbeff, test_neon());
259}