1#![allow(non_camel_case_types)]
2
3use bytemuck::{Pod, Zeroable};
4#[cfg(target_arch = "x86")]
5use std::arch::x86::*;
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9use std::{
10 fmt,
11 mem::transmute,
12 ops::{Add, Div, Mul, Sub},
13};
14
15#[repr(transparent)]
16#[derive(Copy, Clone, Pod, Zeroable)]
17pub struct f32x4(__m128);
18
19impl f32x4 {
20 #[inline(always)]
21 pub fn new(x0: f32, x1: f32, x2: f32, x3: f32) -> Self {
22 Self(unsafe { _mm_set_ps(x3, x2, x1, x0) })
23 }
24
25 pub fn fallback(self) -> super::fallback::f32x4 {
26 let this: [f32; 4] = self.into();
27 this.into()
28 }
29
30 #[inline(always)]
31 pub fn x0(self) -> f32 {
32 f32::from_bits(unsafe { _mm_extract_ps::<0>(self.0) } as u32)
33 }
34
35 #[inline(always)]
36 pub fn x1(self) -> f32 {
37 f32::from_bits(unsafe { _mm_extract_ps::<1>(self.0) } as u32)
38 }
39
40 #[inline(always)]
41 pub fn x2(self) -> f32 {
42 f32::from_bits(unsafe { _mm_extract_ps::<2>(self.0) } as u32)
43 }
44
45 #[inline(always)]
46 pub fn x3(self) -> f32 {
47 f32::from_bits(unsafe { _mm_extract_ps::<3>(self.0) } as u32)
48 }
49
50 #[inline(always)]
51 pub fn splat(val: f32) -> Self {
52 Self(unsafe { _mm_set1_ps(val) })
53 }
54
55 #[inline(always)]
56 pub fn zero() -> Self {
57 Self(unsafe { _mm_setzero_ps() })
58 }
59
60 #[inline(always)]
61 pub fn to_array(self) -> [f32; 4] {
62 self.into()
63 }
64
65 #[inline(always)]
66 pub fn sqrt(self) -> Self {
67 Self(unsafe { _mm_sqrt_ps(self.0) })
68 }
69
70 #[inline(always)]
71 pub fn mul_add(self, mul: f32x4, add: f32x4) -> Self {
72 Self(unsafe { _mm_fmadd_ps(self.0, mul.0, add.0) })
73 }
74
75 #[inline(always)]
76 pub fn add_mul(self, a: f32x4, b: f32x4) -> Self {
77 a.mul_add(b, self)
78 }
79
80 #[inline(always)]
81 pub fn dot(self, other: Self) -> f32 {
82 let result = unsafe { _mm_extract_ps::<0>(_mm_dp_ps::<0b1111_1111>(self.0, other.0)) };
83 f32::from_bits(result as u32)
84 }
85}
86
87impl fmt::Debug for f32x4 {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 let [x0, x1, x2, x3] = self.to_array();
90 f.debug_tuple("f32x4")
91 .field(&x0)
92 .field(&x1)
93 .field(&x2)
94 .field(&x3)
95 .finish()
96 }
97}
98
99impl Default for f32x4 {
100 #[inline(always)]
101 fn default() -> Self {
102 Self::zero()
103 }
104}
105
106impl PartialEq for f32x4 {
107 #[inline(always)]
108 fn eq(&self, other: &Self) -> bool {
109 let mask = unsafe { _mm_movemask_ps(_mm_cmpeq_ps(self.0, other.0)) };
110 mask == 0b1111
111 }
112}
113
114impl Add<Self> for f32x4 {
115 type Output = Self;
116
117 #[inline(always)]
118 fn add(self, other: Self) -> Self::Output {
119 Self(unsafe { _mm_add_ps(self.0, other.0) })
120 }
121}
122
123impl Sub<Self> for f32x4 {
124 type Output = Self;
125
126 #[inline(always)]
127 fn sub(self, rhs: Self) -> Self::Output {
128 Self(unsafe { _mm_sub_ps(self.0, rhs.0) })
129 }
130}
131
132impl Mul for f32x4 {
133 type Output = Self;
134
135 #[inline(always)]
136 fn mul(self, other: Self) -> Self::Output {
137 Self(unsafe { _mm_mul_ps(self.0, other.0) })
138 }
139}
140
141impl Mul<f32> for f32x4 {
142 type Output = Self;
143
144 #[inline(always)]
145 fn mul(self, rhs: f32) -> Self::Output {
146 self * Self::splat(rhs)
147 }
148}
149
150impl Mul<f32x4> for f32 {
151 type Output = f32x4;
152
153 #[inline(always)]
154 fn mul(self, rhs: f32x4) -> Self::Output {
155 rhs * f32x4::splat(self)
156 }
157}
158
159impl Div<f32x4> for f32x4 {
160 type Output = Self;
161
162 #[inline(always)]
163 fn div(self, rhs: f32x4) -> Self::Output {
164 Self(unsafe { _mm_div_ps(self.0, rhs.0) })
165 }
166}
167
168impl From<[f32; 4]> for f32x4 {
169 #[inline(always)]
170 fn from(arr: [f32; 4]) -> Self {
171 unsafe { transmute(arr) }
175 }
176}
177
178impl From<f32x4> for [f32; 4] {
179 #[inline(always)]
180 fn from(m: f32x4) -> Self {
181 unsafe { transmute(m) }
183 }
184}
185
186impl IntoIterator for f32x4 {
187 type Item = f32;
188 type IntoIter = <[f32; 4] as IntoIterator>::IntoIter;
189
190 fn into_iter(self) -> Self::IntoIter {
191 let vals: [f32; 4] = self.into();
192 vals.into_iter()
193 }
194}
195
196#[inline(always)]
197pub fn l2s(x0: f32x4) -> f32x4 {
198 let x1 = x0.sqrt();
199 let x2 = x1.sqrt();
200 let x3 = x2.sqrt();
201 let high = -0.01848558 * x0 + 0.6445592 * x1 + 0.70994765 * x2 - 0.33605254 * x3;
202 unsafe {
208 f32x4(_mm_blendv_ps(
209 high.0,
210 (x0 * 12.92).0,
211 _mm_cmple_ps(x0.0, _mm_set1_ps(0.0031308)),
212 ))
213 }
214}
215
216#[inline(always)]
217pub fn s2l(vs: f32x4) -> f32x4 {
218 let x1 = 2.0843103538116825 * vs - f32x4::splat(1.0843103538116827);
231 let x2 = x1 * x1;
232 let x3 = x2 * x1;
233 let vs_high = f32x4::splat(0.23361048543711943)
234 + 0.4665843122387033 * x1
235 + 0.26901741378006355 * x2
236 + 0.031661580753065945 * x3;
237 unsafe {
238 f32x4(_mm_blendv_ps(
239 vs_high.0,
240 (vs * 0.07739938080495357).0,
241 _mm_cmple_ps(vs.0, _mm_set1_ps(0.04045)),
242 ))
243 }
244}
245
246pub const fn shuffle_mask(x0: u32, x1: u32, x2: u32, x3: u32) -> i32 {
248 ((x3 << 6) | (x2 << 4) | (x1 << 2) | x0) as i32
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_simd() {
257 let a = f32x4::new(1.0, 2.0, 3.0, 4.0);
258 assert_eq!(a.x0(), 1.0);
259 assert_eq!(a.x1(), 2.0);
260 assert_eq!(a.x2(), 3.0);
261 assert_eq!(a.x3(), 4.0);
262
263 let b = f32x4::new(5.0, 6.0, 7.0, 8.0);
264 unsafe {
265 assert_eq!(
266 f32x4(_mm_shuffle_ps::<{ shuffle_mask(0, 3, 0, 3) }>(a.0, b.0)),
267 f32x4::new(1.0, 4.0, 5.0, 8.0),
268 );
269 assert_eq!(
270 f32x4(_mm_permute_ps::<{ shuffle_mask(0, 3, 3, 2) }>(b.0)),
271 f32x4::new(5.0, 8.0, 8.0, 7.0)
272 );
273 }
274
275 let c = f32x4::new(0.001, 0.1, 0.2, 0.7);
276 println!("{c:?}");
277 println!("{:?}", l2s(c));
278 println!("{:?}", s2l(l2s(c)));
279 dbg!(s2l(f32x4::splat(1.0)));
280 }
281
282 #[test]
283 fn test_dot() {
284 let a = f32x4::new(1.0, 2.0, 3.0, 4.0);
285 let b = f32x4::new(5.0, 6.0, 7.0, 8.0);
286 assert_eq!(70.0, a.dot(b));
287 assert_eq!(70.0, a.fallback().dot(b.fallback()));
288 }
289}