rasterize/simd/
x86.rs

1#![allow(non_camel_case_types)]
2
3use bytemuck::{Pod, Zeroable};
4#[cfg(target_arch = "x86")]
5use std::arch::x86::*;
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9use std::{
10    fmt,
11    mem::transmute,
12    ops::{Add, Div, Mul, Sub},
13};
14
15#[repr(transparent)]
16#[derive(Copy, Clone, Pod, Zeroable)]
17pub struct f32x4(__m128);
18
19impl f32x4 {
20    #[inline(always)]
21    pub fn new(x0: f32, x1: f32, x2: f32, x3: f32) -> Self {
22        Self(unsafe { _mm_set_ps(x3, x2, x1, x0) })
23    }
24
25    pub fn fallback(self) -> super::fallback::f32x4 {
26        let this: [f32; 4] = self.into();
27        this.into()
28    }
29
30    #[inline(always)]
31    pub fn x0(self) -> f32 {
32        f32::from_bits(unsafe { _mm_extract_ps::<0>(self.0) } as u32)
33    }
34
35    #[inline(always)]
36    pub fn x1(self) -> f32 {
37        f32::from_bits(unsafe { _mm_extract_ps::<1>(self.0) } as u32)
38    }
39
40    #[inline(always)]
41    pub fn x2(self) -> f32 {
42        f32::from_bits(unsafe { _mm_extract_ps::<2>(self.0) } as u32)
43    }
44
45    #[inline(always)]
46    pub fn x3(self) -> f32 {
47        f32::from_bits(unsafe { _mm_extract_ps::<3>(self.0) } as u32)
48    }
49
50    #[inline(always)]
51    pub fn splat(val: f32) -> Self {
52        Self(unsafe { _mm_set1_ps(val) })
53    }
54
55    #[inline(always)]
56    pub fn zero() -> Self {
57        Self(unsafe { _mm_setzero_ps() })
58    }
59
60    #[inline(always)]
61    pub fn to_array(self) -> [f32; 4] {
62        self.into()
63    }
64
65    #[inline(always)]
66    pub fn sqrt(self) -> Self {
67        Self(unsafe { _mm_sqrt_ps(self.0) })
68    }
69
70    #[inline(always)]
71    pub fn mul_add(self, mul: f32x4, add: f32x4) -> Self {
72        Self(unsafe { _mm_fmadd_ps(self.0, mul.0, add.0) })
73    }
74
75    #[inline(always)]
76    pub fn add_mul(self, a: f32x4, b: f32x4) -> Self {
77        a.mul_add(b, self)
78    }
79
80    #[inline(always)]
81    pub fn dot(self, other: Self) -> f32 {
82        let result = unsafe { _mm_extract_ps::<0>(_mm_dp_ps::<0b1111_1111>(self.0, other.0)) };
83        f32::from_bits(result as u32)
84    }
85}
86
87impl fmt::Debug for f32x4 {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        let [x0, x1, x2, x3] = self.to_array();
90        f.debug_tuple("f32x4")
91            .field(&x0)
92            .field(&x1)
93            .field(&x2)
94            .field(&x3)
95            .finish()
96    }
97}
98
99impl Default for f32x4 {
100    #[inline(always)]
101    fn default() -> Self {
102        Self::zero()
103    }
104}
105
106impl PartialEq for f32x4 {
107    #[inline(always)]
108    fn eq(&self, other: &Self) -> bool {
109        let mask = unsafe { _mm_movemask_ps(_mm_cmpeq_ps(self.0, other.0)) };
110        mask == 0b1111
111    }
112}
113
114impl Add<Self> for f32x4 {
115    type Output = Self;
116
117    #[inline(always)]
118    fn add(self, other: Self) -> Self::Output {
119        Self(unsafe { _mm_add_ps(self.0, other.0) })
120    }
121}
122
123impl Sub<Self> for f32x4 {
124    type Output = Self;
125
126    #[inline(always)]
127    fn sub(self, rhs: Self) -> Self::Output {
128        Self(unsafe { _mm_sub_ps(self.0, rhs.0) })
129    }
130}
131
132impl Mul for f32x4 {
133    type Output = Self;
134
135    #[inline(always)]
136    fn mul(self, other: Self) -> Self::Output {
137        Self(unsafe { _mm_mul_ps(self.0, other.0) })
138    }
139}
140
141impl Mul<f32> for f32x4 {
142    type Output = Self;
143
144    #[inline(always)]
145    fn mul(self, rhs: f32) -> Self::Output {
146        self * Self::splat(rhs)
147    }
148}
149
150impl Mul<f32x4> for f32 {
151    type Output = f32x4;
152
153    #[inline(always)]
154    fn mul(self, rhs: f32x4) -> Self::Output {
155        rhs * f32x4::splat(self)
156    }
157}
158
159impl Div<f32x4> for f32x4 {
160    type Output = Self;
161
162    #[inline(always)]
163    fn div(self, rhs: f32x4) -> Self::Output {
164        Self(unsafe { _mm_div_ps(self.0, rhs.0) })
165    }
166}
167
168impl From<[f32; 4]> for f32x4 {
169    #[inline(always)]
170    fn from(arr: [f32; 4]) -> Self {
171        // Safety: because this semantically moves the value from the input position
172        // (align4) to the output position (align16) it is fine to increase our
173        // required alignment without worry.
174        unsafe { transmute(arr) }
175    }
176}
177
178impl From<f32x4> for [f32; 4] {
179    #[inline(always)]
180    fn from(m: f32x4) -> Self {
181        // We can of course transmute to a lower alignment
182        unsafe { transmute(m) }
183    }
184}
185
186impl IntoIterator for f32x4 {
187    type Item = f32;
188    type IntoIter = <[f32; 4] as IntoIterator>::IntoIter;
189
190    fn into_iter(self) -> Self::IntoIter {
191        let vals: [f32; 4] = self.into();
192        vals.into_iter()
193    }
194}
195
196#[inline(always)]
197pub fn l2s(x0: f32x4) -> f32x4 {
198    let x1 = x0.sqrt();
199    let x2 = x1.sqrt();
200    let x3 = x2.sqrt();
201    let high = -0.01848558 * x0 + 0.6445592 * x1 + 0.70994765 * x2 - 0.33605254 * x3;
202    // much slower without `-C target-cpu=native`
203    // let high = (-0.01848558 * x0)
204    //     .add_mul(f32x4::splat(0.6445592), x1)
205    //     .add_mul(f32x4::splat(0.70994765), x2)
206    //     .add_mul(f32x4::splat(-0.33605254), x3);
207    unsafe {
208        f32x4(_mm_blendv_ps(
209            high.0,
210            (x0 * 12.92).0,
211            _mm_cmple_ps(x0.0, _mm_set1_ps(0.0031308)),
212        ))
213    }
214}
215
216#[inline(always)]
217pub fn s2l(vs: f32x4) -> f32x4 {
218    // def s2l(value):
219    //   if value <= 0.04045:
220    //     return value / 12.92
221    //   else:
222    //     return ((value + 0.055) / 1.055) ** 2.4
223    // x = np.linspace(0.04045, 1, 16000)
224    // y = np.array([s2l(v) for v in x])
225    // np.polynomial.Polynomial.fit(x, y, 3)
226    // 𝑥 ↦ 0.23361048543711943 +
227    //      0.4665843122387033 * (-1.0843103538116827 + 2.0843103538116825 * 𝑥) +
228    //      0.26901741378006355 * (-1.0843103538116827 + 2.0843103538116825 * 𝑥) ^2 +
229    //      0.031661580753065945 * (-1.0843103538116827 + 2.0843103538116825 * 𝑥) ^ 3
230    let x1 = 2.0843103538116825 * vs - f32x4::splat(1.0843103538116827);
231    let x2 = x1 * x1;
232    let x3 = x2 * x1;
233    let vs_high = f32x4::splat(0.23361048543711943)
234        + 0.4665843122387033 * x1
235        + 0.26901741378006355 * x2
236        + 0.031661580753065945 * x3;
237    unsafe {
238        f32x4(_mm_blendv_ps(
239            vs_high.0,
240            (vs * 0.07739938080495357).0,
241            _mm_cmple_ps(vs.0, _mm_set1_ps(0.04045)),
242        ))
243    }
244}
245
246/// Create shuffle/permute mask
247pub const fn shuffle_mask(x0: u32, x1: u32, x2: u32, x3: u32) -> i32 {
248    ((x3 << 6) | (x2 << 4) | (x1 << 2) | x0) as i32
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_simd() {
257        let a = f32x4::new(1.0, 2.0, 3.0, 4.0);
258        assert_eq!(a.x0(), 1.0);
259        assert_eq!(a.x1(), 2.0);
260        assert_eq!(a.x2(), 3.0);
261        assert_eq!(a.x3(), 4.0);
262
263        let b = f32x4::new(5.0, 6.0, 7.0, 8.0);
264        unsafe {
265            assert_eq!(
266                f32x4(_mm_shuffle_ps::<{ shuffle_mask(0, 3, 0, 3) }>(a.0, b.0)),
267                f32x4::new(1.0, 4.0, 5.0, 8.0),
268            );
269            assert_eq!(
270                f32x4(_mm_permute_ps::<{ shuffle_mask(0, 3, 3, 2) }>(b.0)),
271                f32x4::new(5.0, 8.0, 8.0, 7.0)
272            );
273        }
274
275        let c = f32x4::new(0.001, 0.1, 0.2, 0.7);
276        println!("{c:?}");
277        println!("{:?}", l2s(c));
278        println!("{:?}", s2l(l2s(c)));
279        dbg!(s2l(f32x4::splat(1.0)));
280    }
281
282    #[test]
283    fn test_dot() {
284        let a = f32x4::new(1.0, 2.0, 3.0, 4.0);
285        let b = f32x4::new(5.0, 6.0, 7.0, 8.0);
286        assert_eq!(70.0, a.dot(b));
287        assert_eq!(70.0, a.fallback().dot(b.fallback()));
288    }
289}