Skip to main content

diskann_wide/arch/x86_64/v4/
conversion.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6// x86 intrinsics
7use std::arch::x86_64::*;
8
9use super::{
10    f16x8_::f16x8, f16x16_::f16x16, f32x8_::f32x8, f32x16_::f32x16, i8x16_::i8x16, i8x32_::i8x32,
11    i8x64_::i8x64, i16x8_::i16x8, i16x16_::i16x16, i16x32_::i16x32, i32x8_::i32x8, u8x16_::u8x16,
12    u8x32_::u8x32, u8x64_::u8x64, u32x8_::u32x8, u32x16_::u32x16,
13};
14use crate::{SIMDCast, SIMDReinterpret, SIMDVector, arch::x86_64::v3, helpers};
15
16/////////////////
17// Conversions //
18/////////////////
19
20impl From<f16x8> for f32x8 {
21    #[inline(always)]
22    fn from(value: f16x8) -> f32x8 {
23        f32x8::from_underlying(value.arch(), v3::f32x8::from(value.retarget()).0)
24    }
25}
26
27impl From<f16x16> for f32x16 {
28    #[inline(always)]
29    fn from(value: f16x16) -> f32x16 {
30        // SAFETY: `_mm512_cvtph_ps` requires AVX512F - implied by `V4`.
31        let cvt = unsafe { _mm512_cvtph_ps(value.to_underlying()) };
32        f32x16::from_underlying(value.arch(), cvt)
33    }
34}
35
36impl SIMDCast<f32> for f16x8 {
37    type Cast = f32x8;
38    fn simd_cast(self) -> f32x8 {
39        self.into()
40    }
41}
42
43impl SIMDCast<f32> for f16x16 {
44    type Cast = f32x16;
45    fn simd_cast(self) -> f32x16 {
46        self.into()
47    }
48}
49
50impl SIMDCast<half::f16> for f32x8 {
51    type Cast = f16x8;
52    fn simd_cast(self) -> f16x8 {
53        f16x8::from_underlying(self.arch(), self.retarget().simd_cast().0)
54    }
55}
56
57impl SIMDCast<half::f16> for f32x16 {
58    type Cast = f16x16;
59    fn simd_cast(self) -> f16x16 {
60        // SAFETY: `_mm512_cvtps_ph` requires AVX512F - implied by `V4`.
61        let cvt = unsafe {
62            _mm512_cvtps_ph(
63                self.to_underlying(),
64                _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC,
65            )
66        };
67        f16x16::from_underlying(self.arch(), cvt)
68    }
69}
70
71// i8 to i16
72helpers::unsafe_map_conversion!(i8x16, i16x16, _mm256_cvtepi8_epi16, "avx2");
73helpers::unsafe_map_conversion!(i8x32, i16x32, _mm512_cvtepi8_epi16, "avx512bw");
74
75// u8 to i16
76helpers::unsafe_map_conversion!(u8x16, i16x16, _mm256_cvtepu8_epi16, "avx2");
77helpers::unsafe_map_conversion!(u8x32, i16x32, _mm512_cvtepu8_epi16, "avx512bw");
78
79// i32 to f32
80helpers::unsafe_map_cast!(i32x8 => (f32, f32x8), _mm256_cvtepi32_ps, "avx");
81
82helpers::unsafe_map_cast!(i16x16 => (u8, u8x16), _mm256_cvtepi16_epi8, "avx512bw,avx512vl");
83helpers::unsafe_map_cast!(i16x16 => (i8, i8x16), _mm256_cvtepi16_epi8, "avx512bw,avx512vl");
84helpers::unsafe_map_cast!(i16x32 => (u8, u8x32), _mm512_cvtepi16_epi8, "avx512bw");
85helpers::unsafe_map_cast!(i16x32 => (i8, i8x32), _mm512_cvtepi16_epi8, "avx512bw");
86
87//////////////////
88// Reinterprets //
89//////////////////
90
91impl SIMDReinterpret<i16x16> for u32x8 {
92    fn reinterpret_simd(self) -> i16x16 {
93        i16x16(self.0)
94    }
95}
96
97impl SIMDReinterpret<i16x32> for u32x16 {
98    fn reinterpret_simd(self) -> i16x32 {
99        i16x32(self.0)
100    }
101}
102
103impl SIMDReinterpret<u8x64> for u32x16 {
104    fn reinterpret_simd(self) -> u8x64 {
105        u8x64(self.0)
106    }
107}
108
109impl SIMDReinterpret<i8x64> for u32x16 {
110    fn reinterpret_simd(self) -> i8x64 {
111        i8x64(self.0)
112    }
113}
114
115impl SIMDReinterpret<u32x16> for u8x64 {
116    fn reinterpret_simd(self) -> u32x16 {
117        u32x16(self.0)
118    }
119}
120
121impl SIMDReinterpret<u32x16> for i8x64 {
122    fn reinterpret_simd(self) -> u32x16 {
123        u32x16(self.0)
124    }
125}
126
127impl SIMDReinterpret<u8x16> for i16x8 {
128    fn reinterpret_simd(self) -> u8x16 {
129        u8x16(self.0)
130    }
131}
132
133impl SIMDReinterpret<i16x8> for u8x16 {
134    fn reinterpret_simd(self) -> i16x8 {
135        i16x8(self.0)
136    }
137}
138
139///////////
140// Tests //
141///////////
142
143#[cfg(test)]
144mod test_x86_conversions {
145    use super::*;
146    use crate::{arch::x86_64::V4, test_utils};
147
148    // Lossless Conversions
149    #[cfg(not(miri))]
150    test_utils::ops::test_lossless_convert!(
151        f16x8 => f32x8, 0xa998182f02ff4d0d, V4::new_checked_uncached()
152    );
153
154    #[cfg(not(miri))]
155    test_utils::ops::test_lossless_convert!(
156        f16x16 => f32x16, 0xe6ab583cbb1b06e0, V4::new_checked_uncached()
157    );
158
159    test_utils::ops::test_lossless_convert!(
160        i8x16 => i16x16, 0x84602159fb122584, V4::new_checked_uncached()
161    );
162    test_utils::ops::test_lossless_convert!(
163        i8x32 => i16x32, 0xa9e19910dabee638, V4::new_checked_uncached()
164    );
165    test_utils::ops::test_lossless_convert!(
166        u8x16 => i16x16, 0x5ba4b69df84ca568, V4::new_checked_uncached()
167    );
168    test_utils::ops::test_lossless_convert!(
169        u8x32 => i16x32, 0xb42af810c6768193, V4::new_checked_uncached()
170    );
171
172    // Numeric Casts
173    test_utils::ops::test_cast!(f16x8 => f32x8, 0x37314659b022466a, V4::new_checked_uncached());
174    test_utils::ops::test_cast!(f16x16 => f32x16, 0x1aa5762d788d7749, V4::new_checked_uncached());
175
176    test_utils::ops::test_cast!(f32x8 => f16x8, 0x8386cb0a7091cc3b, V4::new_checked_uncached());
177    test_utils::ops::test_cast!(f32x16 => f16x16, 0xb3cbae34def475df, V4::new_checked_uncached());
178
179    test_utils::ops::test_cast!(i32x8 => f32x8, 0xde4fbf25c554b29e, V4::new_checked_uncached());
180
181    test_utils::ops::test_cast!(i16x16 => u8x16, 0x0f81df9e640b0269, V4::new_checked_uncached());
182    test_utils::ops::test_cast!(i16x16 => i8x16, 0x4ab1546b9d0e4046, V4::new_checked_uncached());
183
184    test_utils::ops::test_cast!(i16x32 => u8x32, 0xf2c00ea1a1b5c380, V4::new_checked_uncached());
185    test_utils::ops::test_cast!(i16x32 => i8x32, 0x6090af7cb2847dd5, V4::new_checked_uncached());
186}