diskann_wide/arch/aarch64/
i32x4_.rs1use crate::{
7 Emulated, SIMDAbs, SIMDCast, SIMDDotProduct, SIMDMask, SIMDMulAdd, SIMDPartialEq,
8 SIMDPartialOrd, SIMDSelect, SIMDSumTree, SIMDVector, constant::Const, helpers,
9};
10
11use super::{
13 Neon, f32x4, i8x8, i8x16, i16x8, internal,
14 macros::{self, AArchLoadStore, AArchSplat},
15 masks::mask32x4,
16 u8x8, u8x16,
17};
18
19use std::arch::{aarch64::*, asm};
21
22macros::aarch64_define_register!(i32x4, int32x4_t, mask32x4, i32, 4, Neon);
27macros::aarch64_define_splat!(i32x4, vmovq_n_s32);
28macros::aarch64_define_loadstore!(i32x4, vld1q_s32, internal::load_first::i32x4, vst1q_s32, 4);
29
30helpers::unsafe_map_binary_op!(i32x4, std::ops::Add, add, vaddq_s32, "neon");
31helpers::unsafe_map_binary_op!(i32x4, std::ops::Sub, sub, vsubq_s32, "neon");
32helpers::unsafe_map_binary_op!(i32x4, std::ops::Mul, mul, vmulq_s32, "neon");
33helpers::unsafe_map_unary_op!(i32x4, SIMDAbs, abs_simd, vabsq_s32, "neon");
34macros::aarch64_define_fma!(i32x4, vmlaq_s32);
35
36macros::aarch64_define_cmp!(
37 i32x4,
38 vceqq_s32,
39 (vmvnq_u32),
40 vcltq_s32,
41 vcleq_s32,
42 vcgtq_s32,
43 vcgeq_s32
44);
45macros::aarch64_define_bitops!(
46 i32x4,
47 vmvnq_s32,
48 vandq_s32,
49 vorrq_s32,
50 veorq_s32,
51 (
52 vshlq_s32,
53 32,
54 vnegq_s32,
55 vminq_u32,
56 vreinterpretq_s32_u32,
57 vreinterpretq_u32_s32
58 ),
59 (u32, i32, vmovq_n_s32),
60);
61
62impl SIMDSumTree for i32x4 {
63 #[inline(always)]
64 fn sum_tree(self) -> i32 {
65 if cfg!(miri) {
66 self.emulated().sum_tree()
67 } else {
68 unsafe { vaddvq_s32(self.0) }
70 }
71 }
72}
73
74impl SIMDSelect<i32x4> for mask32x4 {
75 #[inline(always)]
76 fn select(self, x: i32x4, y: i32x4) -> i32x4 {
77 i32x4(unsafe { vbslq_s32(self.0, x.0, y.0) })
79 }
80}
81
82impl SIMDDotProduct<i16x8> for i32x4 {
83 #[inline(always)]
84 fn dot_simd(self, left: i16x8, right: i16x8) -> Self {
85 if cfg!(miri) {
86 use crate::AsSIMD;
87 self.emulated()
88 .dot_simd(left.emulated(), right.emulated())
89 .as_simd(self.arch())
90 } else {
91 let left = left.0;
92 let right = right.0;
93 unsafe {
95 let lo: int32x4_t = vmull_s16(vget_low_s16(left), vget_low_s16(right));
96 let hi: int32x4_t = vmull_high_s16(left, right);
97 Self(vaddq_s32(self.0, vpaddq_s32(lo, hi)))
98 }
99 }
100 }
101}
102
103impl SIMDDotProduct<u8x16, i8x16> for i32x4 {
104 #[inline(always)]
105 fn dot_simd(self, left: u8x16, right: i8x16) -> Self {
106 if cfg!(miri) {
107 use crate::AsSIMD;
108 self.emulated()
109 .dot_simd(left.emulated(), right.emulated())
110 .as_simd(self.arch())
111 } else {
112 use crate::SplitJoin;
113
114 unsafe {
116 let left = left.split();
117 let right = right.split();
118
119 let left_evens: i16x8 = u8x8(vuzp1_u8(left.lo.0, left.hi.0)).into();
120 let left_odds: i16x8 = u8x8(vuzp2_u8(left.lo.0, left.hi.0)).into();
121
122 let right_evens: i16x8 = i8x8(vuzp1_s8(right.lo.0, right.hi.0)).into();
123 let right_odds: i16x8 = i8x8(vuzp2_s8(right.lo.0, right.hi.0)).into();
124
125 self.dot_simd(left_evens, right_evens)
126 .dot_simd(left_odds, right_odds)
127 }
128 }
129 }
130}
131
132impl SIMDDotProduct<i8x16, u8x16> for i32x4 {
133 #[inline(always)]
134 fn dot_simd(self, left: i8x16, right: u8x16) -> Self {
135 self.dot_simd(right, left)
136 }
137}
138
139impl SIMDDotProduct<i8x16, i8x16> for i32x4 {
140 #[inline(always)]
141 fn dot_simd(self, left: i8x16, right: i8x16) -> Self {
142 if cfg!(miri) {
143 use crate::AsSIMD;
144 self.emulated()
145 .dot_simd(left.emulated(), right.emulated())
146 .as_simd(self.arch())
147 } else {
148 #[target_feature(enable = "dotprod")]
153 unsafe fn sdot(mut s: int32x4_t, x: int8x16_t, y: int8x16_t) -> int32x4_t {
154 unsafe {
157 asm!(
158 "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b",
159 inout(vreg) s,
160 in(vreg) x,
161 in(vreg) y,
162 options(pure, nomem, nostack)
163 );
164 }
165
166 s
167 }
168
169 Self::from_underlying(self.arch(), unsafe { sdot(self.0, left.0, right.0) })
171 }
172 }
173}
174
175helpers::unsafe_map_cast!(
180 i32x4 => (f32, f32x4),
181 vcvtq_f32_s32,
182 "neon"
183);
184
185#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils};
193
194 #[test]
195 fn miri_test_load() {
196 if let Some(arch) = test_neon() {
197 test_utils::test_load_simd::<i32, 4, i32x4>(arch);
198 }
199 }
200
201 #[test]
202 fn miri_test_store() {
203 if let Some(arch) = test_neon() {
204 test_utils::test_store_simd::<i32, 4, i32x4>(arch);
205 }
206 }
207
208 #[test]
210 fn test_constructors() {
211 if let Some(arch) = test_neon() {
212 test_utils::ops::test_splat::<i32, 4, i32x4>(arch);
213 }
214 }
215
216 test_utils::ops::test_add!(i32x4, 0x3017fd73c99cc633, test_neon());
218 test_utils::ops::test_sub!(i32x4, 0xfc627f10b5f8db8a, test_neon());
219 test_utils::ops::test_mul!(i32x4, 0x0f4caa80eceaa523, test_neon());
220 test_utils::ops::test_fma!(i32x4, 0xb8f702ba85375041, test_neon());
221 test_utils::ops::test_abs!(i32x4, 0xb8f702ba85375041, test_neon());
222
223 test_utils::ops::test_cmp!(i32x4, 0x941757bd5cc641a1, test_neon());
224
225 test_utils::ops::test_bitops!(i32x4, 0xd62d8de09f82ed4e, test_neon());
227 test_utils::ops::test_select!(i32x4, 0xd62d8de09f82ed4e, test_neon());
228
229 test_utils::dot_product::test_dot_product!(
231 (i16x8, i16x8) => i32x4,
232 0x145f89b446c03ff1,
233 test_neon()
234 );
235
236 test_utils::dot_product::test_dot_product!(
237 (u8x16, i8x16) => i32x4,
238 0x145f89b446c03ff1,
239 test_neon()
240 );
241
242 test_utils::dot_product::test_dot_product!(
243 (i8x16, u8x16) => i32x4,
244 0x145f89b446c03ff1,
245 test_neon()
246 );
247
248 test_utils::dot_product::test_dot_product!(
249 (i8x16, i8x16) => i32x4,
250 0x145f89b446c03ff1,
251 test_neon()
252 );
253
254 test_utils::ops::test_sumtree!(i32x4, 0xb9ac82ab23a855da, test_neon());
256
257 test_utils::ops::test_cast!(i32x4 => f32x4, 0xba8fe343fc9dbeff, test_neon());
259}