Skip to main content

diskann_wide/arch/aarch64/
f16x8_.rs

1/*
2 * Copyright (c) Microsoft Corporation. All rights reserved.
3 * Licensed under the MIT license.
4 */
5
6use crate::{
7    Emulated,
8    arch::Scalar,
9    constant::Const,
10    traits::{SIMDMask, SIMDVector},
11};
12
13use half::f16;
14
15// AArch64 masks
16use super::{
17    Neon, f16x4, f32x8,
18    macros::{self, AArchLoadStore, AArchSplat},
19    masks::mask16x8,
20    u16x8,
21};
22
23// AArch64 intrinsics
24use std::arch::aarch64::*;
25
26/////////////////////
27// 16-bit floating //
28/////////////////////
29
30macros::aarch64_define_register!(f16x8, uint16x8_t, mask16x8, f16, 8, Neon);
31macros::aarch64_splitjoin!(f16x8, f16x4, vget_low_u16, vget_high_u16, vcombine_u16);
32
33impl AArchSplat for f16x8 {
34    #[inline(always)]
35    fn aarch_splat(_: Neon, value: f16) -> Self {
36        // SAFETY: Allowed by the `Neon` architecture.
37        Self(unsafe { vmovq_n_u16(value.to_bits()) })
38    }
39
40    #[inline(always)]
41    fn aarch_default(arch: Neon) -> Self {
42        Self::aarch_splat(arch, f16::default())
43    }
44}
45
46impl AArchLoadStore for f16x8 {
47    #[inline(always)]
48    unsafe fn load_simd(_: Neon, ptr: *const f16) -> Self {
49        // SAFETY: Pointer access safety inherited from the caller. Allowed by the `Neon`
50        // architecture.
51        Self(unsafe { vld1q_u16(ptr.cast::<u16>()) })
52    }
53
54    #[inline(always)]
55    unsafe fn load_simd_masked_logical(arch: Neon, ptr: *const f16, mask: Self::Mask) -> Self {
56        // SAFETY: Pointer access safety inherited from the caller.
57        let e = unsafe {
58            Emulated::<f16, 8>::load_simd_masked_logical(Scalar, ptr, mask.bitmask().as_scalar())
59        };
60        Self::from_array(arch, e.to_array())
61    }
62
63    #[inline(always)]
64    unsafe fn load_simd_first(arch: Neon, ptr: *const f16, first: usize) -> Self {
65        // SAFETY: f16 and u16 share the same 2-byte representation. Pointer access
66        // inherited from caller.
67        Self(unsafe {
68            <u16x8 as AArchLoadStore>::load_simd_first(arch, ptr.cast::<u16>(), first).0
69        })
70    }
71
72    #[inline(always)]
73    unsafe fn store_simd(self, ptr: *mut <Self as SIMDVector>::Scalar) {
74        // SAFETY: Pointer access safety inherited from the caller. Use of the instruction
75        // is allowed by the `Neon` architecture.
76        unsafe { vst1q_u16(ptr.cast::<u16>(), self.0) }
77    }
78
79    #[inline(always)]
80    unsafe fn store_simd_masked_logical(self, ptr: *mut f16, mask: Self::Mask) {
81        let e = Emulated::<f16, 8>::from_array(Scalar, self.to_array());
82        // SAFETY: Pointer access safety inherited from the caller.
83        unsafe { e.store_simd_masked_logical(ptr, mask.bitmask().as_scalar()) }
84    }
85
86    #[inline(always)]
87    unsafe fn store_simd_first(self, ptr: *mut f16, first: usize) {
88        let e = Emulated::<f16, 8>::from_array(Scalar, self.to_array());
89        // SAFETY: Pointer access safety inherited from the caller.
90        unsafe { e.store_simd_first(ptr, first) }
91    }
92}
93
94//------------//
95// Conversion //
96//------------//
97
98impl crate::SIMDCast<f32> for f16x8 {
99    type Cast = f32x8;
100
101    #[inline(always)]
102    fn simd_cast(self) -> f32x8 {
103        self.into()
104    }
105}
106
107///////////
108// Tests //
109///////////
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::{arch::aarch64::test_neon, test_utils};
115
116    #[test]
117    fn miri_test_load() {
118        if let Some(arch) = test_neon() {
119            test_utils::test_load_simd::<f16, 8, f16x8>(arch);
120        }
121    }
122
123    #[test]
124    fn miri_test_store() {
125        if let Some(arch) = test_neon() {
126            test_utils::test_store_simd::<f16, 8, f16x8>(arch);
127        }
128    }
129
130    // constructors
131    #[test]
132    fn test_constructors() {
133        if let Some(arch) = test_neon() {
134            test_utils::ops::test_splat::<f16, 8, f16x8>(arch);
135
136            assert_eq!(
137                f16x8::default(arch).to_array(),
138                f16x8::splat(arch, f16::default()).to_array(),
139            );
140        }
141    }
142
143    test_utils::ops::test_splitjoin!(f16x8 => f16x4, 0xa4d00a4d04293967, test_neon());
144
145    // Conversions
146    test_utils::ops::test_cast!(f16x8 => f32x8, 0x37314659b022466a, test_neon());
147}