Skip to main content

diskann_wide/arch/aarch64/
f16x4_.rs

1/*
2 * Copyright (c) Microsoft Corporation. All rights reserved.
3 * Licensed under the MIT license.
4 */
5
6use crate::{
7    Emulated,
8    arch::Scalar,
9    constant::Const,
10    traits::{SIMDMask, SIMDVector},
11};
12
13use half::f16;
14
15// AArch64 masks
16use super::{
17    Neon, internal,
18    macros::{self, AArchLoadStore, AArchSplat},
19    masks::mask16x4,
20};
21
22// AArch64 intrinsics
23use std::arch::aarch64::*;
24
25/////////////////////
26// 16-bit floating //
27/////////////////////
28
29macros::aarch64_define_register!(f16x4, uint16x4_t, mask16x4, f16, 4, Neon);
30
31impl AArchSplat for f16x4 {
32    #[inline(always)]
33    fn aarch_splat(_: Neon, value: f16) -> Self {
34        // SAFETY: Allowed by the `Neon` architecture.
35        Self(unsafe { vmov_n_u16(value.to_bits()) })
36    }
37
38    #[inline(always)]
39    fn aarch_default(arch: Neon) -> Self {
40        Self::aarch_splat(arch, f16::default())
41    }
42}
43
44impl AArchLoadStore for f16x4 {
45    #[inline(always)]
46    unsafe fn load_simd(_: Neon, ptr: *const f16) -> Self {
47        // SAFETY: Allowed by the `Neon` architecture.
48        Self(unsafe { vld1_u16(ptr.cast::<u16>()) })
49    }
50
51    #[inline(always)]
52    unsafe fn load_simd_masked_logical(arch: Neon, ptr: *const f16, mask: Self::Mask) -> Self {
53        // SAFETY: Pointer access safety inherited from the caller.
54        let e = unsafe {
55            Emulated::<f16, 4>::load_simd_masked_logical(Scalar, ptr, mask.bitmask().as_scalar())
56        };
57        Self::from_array(arch, e.to_array())
58    }
59
60    #[inline(always)]
61    unsafe fn load_simd_first(arch: Neon, ptr: *const f16, first: usize) -> Self {
62        // SAFETY: f16 and u16 share the same 2-byte representation. Pointer access
63        // inherited from caller.
64        Self(unsafe { internal::load_first::u16x4(arch, ptr.cast::<u16>(), first) })
65    }
66
67    #[inline(always)]
68    unsafe fn store_simd(self, ptr: *mut <Self as SIMDVector>::Scalar) {
69        // SAFETY: Pointer access safety inherited from the caller. Use of the instruction
70        // is allowed by the `Neon` architecture.
71        unsafe { vst1_u16(ptr.cast::<u16>(), self.0) }
72    }
73
74    #[inline(always)]
75    unsafe fn store_simd_masked_logical(self, ptr: *mut f16, mask: Self::Mask) {
76        let e = Emulated::<f16, 4>::from_array(Scalar, self.to_array());
77        // SAFETY: Pointer access safety inherited from the caller.
78        unsafe { e.store_simd_masked_logical(ptr, mask.bitmask().as_scalar()) }
79    }
80
81    #[inline(always)]
82    unsafe fn store_simd_first(self, ptr: *mut f16, first: usize) {
83        let e = Emulated::<f16, 4>::from_array(Scalar, self.to_array());
84        // SAFETY: Pointer access safety inherited from the caller.
85        unsafe { e.store_simd_first(ptr, first) }
86    }
87}
88
89///////////
90// Tests //
91///////////
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::{arch::aarch64::test_neon, test_utils};
97
98    #[test]
99    fn miri_test_load() {
100        if let Some(arch) = test_neon() {
101            test_utils::test_load_simd::<f16, 4, f16x4>(arch);
102        }
103    }
104
105    #[test]
106    fn miri_test_store() {
107        if let Some(arch) = test_neon() {
108            test_utils::test_store_simd::<f16, 4, f16x4>(arch);
109        }
110    }
111
112    // constructors
113    #[test]
114    fn test_constructors() {
115        if let Some(arch) = test_neon() {
116            test_utils::ops::test_splat::<f16, 4, f16x4>(arch);
117
118            assert_eq!(
119                f16x4::default(arch).to_array(),
120                f16x4::splat(arch, f16::default()).to_array(),
121            );
122        }
123    }
124}