Skip to main content

diskann_wide/arch/aarch64/
f32x4_.rs

1/*
2 * Copyright (c) Microsoft Corporation. All rights reserved.
3 * Licensed under the MIT license.
4 */
5
6use half::f16;
7
8use crate::{
9    Emulated, SIMDAbs, SIMDMask, SIMDMinMax, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDSelect,
10    SIMDSumTree, SIMDVector, constant::Const, helpers,
11};
12
13// AArch64 masks
14use super::{
15    Neon, f16x4, f32x2, internal,
16    macros::{self, AArchLoadStore, AArchSplat},
17    masks::mask32x4,
18};
19
20// AArch64 intrinsics
21use std::arch::{aarch64::*, asm};
22
23/////////////////////
24// 32-bit floating //
25/////////////////////
26
27macros::aarch64_define_register!(f32x4, float32x4_t, mask32x4, f32, 4, Neon);
28macros::aarch64_define_splat!(f32x4, vmovq_n_f32);
29macros::aarch64_define_loadstore!(f32x4, vld1q_f32, internal::load_first::f32x4, vst1q_f32, 4);
30macros::aarch64_splitjoin!(f32x4, f32x2, vget_low_f32, vget_high_f32, vcombine_f32);
31
32helpers::unsafe_map_binary_op!(f32x4, std::ops::Add, add, vaddq_f32, "neon");
33helpers::unsafe_map_binary_op!(f32x4, std::ops::Sub, sub, vsubq_f32, "neon");
34helpers::unsafe_map_binary_op!(f32x4, std::ops::Mul, mul, vmulq_f32, "neon");
35helpers::unsafe_map_unary_op!(f32x4, SIMDAbs, abs_simd, vabsq_f32, "neon");
36macros::aarch64_define_fma!(f32x4, vfmaq_f32);
37
38impl SIMDMinMax for f32x4 {
39    #[inline(always)]
40    fn min_simd(self, rhs: Self) -> Self {
41        // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture.
42        Self(unsafe { vminnmq_f32(self.0, rhs.0) })
43    }
44
45    #[inline(always)]
46    fn min_simd_standard(self, rhs: Self) -> Self {
47        // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture.
48        Self(unsafe { vminnmq_f32(self.0, rhs.0) })
49    }
50
51    #[inline(always)]
52    fn max_simd(self, rhs: Self) -> Self {
53        // SAFETY: `vmaxnmq_f32` requires "neon", implied by the `Neon` architecture.
54        Self(unsafe { vmaxnmq_f32(self.0, rhs.0) })
55    }
56
57    #[inline(always)]
58    fn max_simd_standard(self, rhs: Self) -> Self {
59        // SAFETY: `vmaxnmq_f32` requires "neon", implied by the `Neon` architecture.
60        Self(unsafe { vmaxnmq_f32(self.0, rhs.0) })
61    }
62}
63
64macros::aarch64_define_cmp!(
65    f32x4,
66    vceqq_f32,
67    (vmvnq_u32),
68    vcltq_f32,
69    vcleq_f32,
70    vcgtq_f32,
71    vcgeq_f32
72);
73
74impl SIMDSumTree for f32x4 {
75    #[inline(always)]
76    fn sum_tree(self) -> f32 {
77        // Miri does not support `vaddv_f32`.
78        if cfg!(miri) {
79            self.emulated().sum_tree()
80        } else {
81            // NOTE: `vaddvq` does not do a tree reduction, so we need to do a bit of work
82            // manually.
83            let x = self.to_underlying();
84            // SAFETY: Allowed by the implicit `Neon` architecture.
85            unsafe {
86                let low = vget_low_f32(x);
87                let high = vget_high_f32(x);
88                vaddv_f32(vadd_f32(low, high))
89            }
90        }
91    }
92}
93
94impl SIMDSelect<f32x4> for mask32x4 {
95    #[inline(always)]
96    fn select(self, x: f32x4, y: f32x4) -> f32x4 {
97        // SAFETY: Allowed by the implicit `Neon` architecture.
98        f32x4(unsafe { vbslq_f32(self.0, x.0, y.0) })
99    }
100}
101
102//------------//
103// Conversion //
104//------------//
105
106// Rust does not expose any of the f16 style intrinsics, so we need to drop down straight
107// into inline assembly.
108impl From<f16x4> for f32x4 {
109    #[inline(always)]
110    fn from(value: f16x4) -> f32x4 {
111        if cfg!(miri) {
112            Self::from_array(value.arch(), value.to_array().map(crate::cast_f16_to_f32))
113        } else {
114            let raw = value.0;
115            let result: float32x4_t;
116            // SAFETY: The instruction we are running is available with the `neon` platform,
117            // just not exposed by Rust's intrinsics.
118            unsafe {
119                asm!(
120                    "fcvtl {0:v}.4s, {1:v}.4h",
121                    out(vreg) result,
122                    in(vreg) raw,
123                    options(pure, nomem, nostack)
124                );
125            }
126            Self(result)
127        }
128    }
129}
130
131impl crate::SIMDCast<f16> for f32x4 {
132    type Cast = f16x4;
133    #[inline(always)]
134    fn simd_cast(self) -> f16x4 {
135        if cfg!(miri) {
136            f16x4::from_array(self.arch(), self.to_array().map(crate::cast_f32_to_f16))
137        } else {
138            let raw = self.0;
139            let result: uint16x4_t;
140            // SAFETY: The instruction we are running is available with the `neon` platform,
141            // just not exposed by Rust's intrinsics.
142            unsafe {
143                asm!(
144                    "fcvtn {0:v}.4h, {1:v}.4s",
145                    out(vreg) result,
146                    in(vreg) raw,
147                    options(pure, nomem, nostack)
148                );
149            }
150            f16x4::from_underlying(self.arch(), result)
151        }
152    }
153}
154
155///////////
156// Tests //
157///////////
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils};
163
164    #[test]
165    fn miri_test_load() {
166        if let Some(arch) = test_neon() {
167            test_utils::test_load_simd::<f32, 4, f32x4>(arch);
168        }
169    }
170
171    #[test]
172    fn miri_test_store() {
173        if let Some(arch) = test_neon() {
174            test_utils::test_store_simd::<f32, 4, f32x4>(arch);
175        }
176    }
177
178    // constructors
179    #[test]
180    fn test_constructors() {
181        if let Some(arch) = test_neon() {
182            test_utils::ops::test_splat::<f32, 4, f32x4>(arch);
183        }
184    }
185
186    // Ops
187    test_utils::ops::test_add!(f32x4, 0xcd7a8fea9a3fb727, test_neon());
188    test_utils::ops::test_sub!(f32x4, 0x3f6562c94c923238, test_neon());
189    test_utils::ops::test_mul!(f32x4, 0x07e48666c0fc564c, test_neon());
190    test_utils::ops::test_fma!(f32x4, 0xcfde9d031302cf2c, test_neon());
191    test_utils::ops::test_abs!(f32x4, 0xb8f702ba85375041, test_neon());
192    test_utils::ops::test_minmax!(f32x4, 0x6d7fc8ed6d852187, test_neon());
193    test_utils::ops::test_splitjoin!(f32x4 => f32x2, 0xa4d00a4d04293967, test_neon());
194
195    test_utils::ops::test_cmp!(f32x4, 0xc4f468b224622326, test_neon());
196    test_utils::ops::test_select!(f32x4, 0xef24013b8578637c, test_neon());
197
198    test_utils::ops::test_sumtree!(f32x4, 0x828bd890a470dc4d, test_neon());
199
200    // Conversions
201    test_utils::ops::test_lossless_convert!(f16x4 => f32x4, 0xecba3008eae54ce7, test_neon());
202
203    test_utils::ops::test_cast!(f32x4 => f16x4, 0xba8fe343fc9dbeff, test_neon());
204}