Skip to main content

locus_core/simd/
math.rs

1//! SIMD optimized mathematical kernels for Fast-Math.
2
3use multiversion::multiversion;
4
5/// Compute 1.0 / w using SIMD reciprocal estimation with Newton-Raphson refinement.
6///
7/// This is significantly faster than standard floating-point division.
8/// $w_{inv} = w_{inv} \cdot (2.0 - w \cdot w_{inv})$
9#[multiversion(targets(
10    "x86_64+avx2+bmi1+bmi2+popcnt+lzcnt",
11    "x86_64+avx512f+avx512bw+avx512dq+avx512vl",
12    "aarch64+neon"
13))]
14#[must_use]
15pub(crate) fn rcp_nr(w: f32) -> f32 {
16    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
17    // SAFETY: SSE/AVX intrinsics are safe on x86_64 with avx2.
18    unsafe {
19        use std::arch::x86_64::*;
20        let w_vec = _mm_set_ss(w);
21        let rcp = _mm_rcp_ss(w_vec);
22
23        // Newton-Raphson: r1 = r0 * (2.0 - w * r0)
24        let two = _mm_set_ss(2.0);
25        let prod = _mm_mul_ss(w_vec, rcp);
26        let diff = _mm_sub_ss(two, prod);
27        let res = _mm_mul_ss(rcp, diff);
28
29        return _mm_cvtss_f32(res);
30    }
31
32    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
33    #[allow(unsafe_code)]
34    // SAFETY: NEON intrinsics are safe on aarch64 with neon feature.
35    unsafe {
36        use std::arch::aarch64::*;
37        // Load f32 into a D-register (float32x2_t)
38        let v_w = vdupq_n_f32(w);
39        let res_vec = vrecpeq_f32(v_w);
40        let res_vec = vmulq_f32(res_vec, vrecpsq_f32(v_w, res_vec));
41        return vgetq_lane_f32(res_vec, 0);
42    }
43
44    #[cfg(not(any(
45        all(target_arch = "x86_64", target_feature = "avx2"),
46        all(target_arch = "aarch64", target_feature = "neon")
47    )))]
48    {
49        1.0 / w
50    }
51}
52
53/// Perform bilinear interpolation using 16.16 fixed-point arithmetic.
54///
55/// This is faster than floating-point bilinear interpolation on most CPUs.
56/// Coordinates (x, y) should be sub-pixel floats.
57/// Pixels (p00, p10, p01, p11) are the 4 surrounding pixels.
58#[must_use]
59#[allow(clippy::cast_sign_loss, dead_code)]
60pub(crate) fn bilinear_interpolate_fixed(x: f32, y: f32, p00: u8, p10: u8, p01: u8, p11: u8) -> u8 {
61    // Convert to 16.16 fixed point (using only fractional part for weights)
62    let fx = ((x.fract() * 65536.0) as u32) & 0xFFFF;
63    let fy = ((y.fract() * 65536.0) as u32) & 0xFFFF;
64
65    let inv_x = 0x10000 - fx;
66    let inv_y = 0x10000 - fy;
67
68    // Weights: w00 = (1-fx)(1-fy), w10 = fx(1-fy), w01 = (1-fx)fy, w11 = fxfy
69    // Use u64 for intermediate product to avoid overflow (u32 * u32 can be up to 2^32)
70    let w00 = (u64::from(inv_x) * u64::from(inv_y)) >> 16;
71    let w10 = (u64::from(fx) * u64::from(inv_y)) >> 16;
72    let w01 = (u64::from(inv_x) * u64::from(fy)) >> 16;
73    let w11 = (u64::from(fx) * u64::from(fy)) >> 16;
74
75    let res =
76        (u64::from(p00) * w00 + u64::from(p10) * w10 + u64::from(p01) * w01 + u64::from(p11) * w11)
77            >> 16;
78    res as u8
79}
80
81/// Approximate error function (erf) using the Abramowitz and Stegun approximation.
82///
83/// Maximum error: 1.5e-7 over the entire domain.
84/// This is a pure, stateless mathematical function extracted from the quad module
85/// to serve as a foundational leaf dependency for both quad refinement and decoder stages.
86#[must_use]
87pub(crate) fn erf_approx(x: f64) -> f64 {
88    if x == 0.0 {
89        return 0.0;
90    }
91    let sign = if x < 0.0 { -1.0 } else { 1.0 };
92    let x = x.abs();
93
94    // Abramowitz and Stegun constants (formula 7.1.26)
95    let a1 = 0.254_829_592;
96    let a2 = -0.284_496_736;
97    let a3 = 1.421_413_741;
98    let a4 = -1.453_152_027;
99    let a5 = 1.061_405_429;
100    let p = 0.327_591_1;
101
102    let t = 1.0 / (1.0 + p * x);
103    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
104
105    sign * y
106}
107
108/// Vectorized error function approximation over 4 lanes (AVX2 `__m256d`).
109///
110/// Computes `erf_approx` for 4 `f64` values simultaneously using FMA instructions,
111/// eliminating the register-spill penalty of unpacking to scalar in the Gauss-Newton loop.
112///
113/// On non-AVX2 targets, falls back to 4 scalar evaluations.
114#[cfg(all(
115    target_arch = "x86_64",
116    target_feature = "avx2",
117    target_feature = "fma"
118))]
119#[must_use]
120pub(crate) unsafe fn erf_approx_v4(x: std::arch::x86_64::__m256d) -> std::arch::x86_64::__m256d {
121    use std::arch::x86_64::*;
122
123    // sign = copysign(1.0, x)
124    let sign_mask = _mm256_set1_pd(-0.0);
125    let sign_bits = _mm256_and_pd(x, sign_mask);
126    let abs_x = _mm256_andnot_pd(sign_mask, x);
127
128    // Abramowitz and Stegun constants
129    let a1 = _mm256_set1_pd(0.254_829_592);
130    let a2 = _mm256_set1_pd(-0.284_496_736);
131    let a3 = _mm256_set1_pd(1.421_413_741);
132    let a4 = _mm256_set1_pd(-1.453_152_027);
133    let a5 = _mm256_set1_pd(1.061_405_429);
134    let p = _mm256_set1_pd(0.327_591_1);
135    let one = _mm256_set1_pd(1.0);
136
137    // t = 1.0 / (1.0 + p * |x|)
138    let t = _mm256_div_pd(one, _mm256_fmadd_pd(p, abs_x, one));
139
140    // Horner's method: poly = ((((a5*t + a4)*t + a3)*t + a2)*t + a1)
141    let poly = _mm256_fmadd_pd(a5, t, a4);
142    let poly = _mm256_fmadd_pd(poly, t, a3);
143    let poly = _mm256_fmadd_pd(poly, t, a2);
144    let poly = _mm256_fmadd_pd(poly, t, a1);
145
146    // exp(-x^2): compute using scalar fallback since there's no fast _mm256_exp_pd.
147    // We extract, compute exp, and re-pack. This is still faster than full scalar erf
148    // because the polynomial chain above is fully vectorized.
149    let neg_x2 = _mm256_mul_pd(abs_x, abs_x);
150    let neg_x2 = _mm256_xor_pd(neg_x2, sign_mask); // negate
151
152    // SAFETY: transmute is safe for same-size SIMD ↔ array conversions.
153    let neg_x2_arr: [f64; 4] = std::mem::transmute(neg_x2);
154    let exp_vals = _mm256_set_pd(
155        neg_x2_arr[3].exp(),
156        neg_x2_arr[2].exp(),
157        neg_x2_arr[1].exp(),
158        neg_x2_arr[0].exp(),
159    );
160
161    // y = 1.0 - poly * t * exp(-x^2)
162    let y = _mm256_fnmadd_pd(_mm256_mul_pd(poly, t), exp_vals, one);
163
164    // Apply sign: result = y XOR sign_bits
165    _mm256_or_pd(y, sign_bits)
166}
167
168/// Scalar fallback for `erf_approx_v4` on non-AVX2 targets.
169///
170/// Evaluates 4 `f64` values independently using the scalar `erf_approx`.
171#[cfg(not(all(
172    target_arch = "x86_64",
173    target_feature = "avx2",
174    target_feature = "fma"
175)))]
176#[must_use]
177#[allow(dead_code)]
178pub(crate) fn erf_approx_v4(x: [f64; 4]) -> [f64; 4] {
179    [
180        erf_approx(x[0]),
181        erf_approx(x[1]),
182        erf_approx(x[2]),
183        erf_approx(x[3]),
184    ]
185}
186
187#[cfg(test)]
188#[allow(clippy::float_cmp)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_rcp_nr_precision() {
194        let values = [1.0, 2.0, 10.0, 0.5, 123.456];
195        for &w in &values {
196            let expected = 1.0 / w;
197            let actual = rcp_nr(w);
198            let diff = (expected - actual).abs();
199            // Newton-Raphson iteration should get us very close to 1.0/w
200            assert!(
201                diff < 1e-4,
202                "rcp_nr({w}) failed: expected {expected}, got {actual}, diff {diff}"
203            );
204        }
205    }
206
207    #[test]
208    fn test_erf_approx_properties() {
209        // Zero crossing
210        assert_eq!(erf_approx(0.0), 0.0);
211
212        // Symmetry: erf(-x) == -erf(x)
213        for x in [0.1, 0.5, 1.0, 2.0, 5.0] {
214            assert!((erf_approx(-x) + erf_approx(x)).abs() < 1e-15);
215        }
216
217        // Asymptotic bounds
218        assert!((erf_approx(10.0) - 1.0).abs() < 1e-7);
219        assert!((erf_approx(-10.0) + 1.0).abs() < 1e-7);
220        assert!((erf_approx(100.0) - 1.0).abs() < 1e-15);
221    }
222
223    #[test]
224    fn test_erf_approx_accuracy() {
225        let cases = [
226            (0.5, 0.520_499_877_813_046_5),
227            (1.0, 0.842_700_792_949_714_8),
228            (2.0, 0.995_322_265_018_952_7),
229        ];
230
231        for (x, expected) in cases {
232            let actual = erf_approx(x);
233            let diff = (actual - expected).abs();
234            assert!(
235                diff < 1.5e-7,
236                "erf_approx({x}) error {diff} exceeds tolerance 1.5e-7"
237            );
238        }
239    }
240
241    #[test]
242    fn test_erf_approx_v4_matches_scalar() {
243        let inputs = [0.5, -1.0, 2.0, -0.3];
244
245        #[cfg(all(
246            target_arch = "x86_64",
247            target_feature = "avx2",
248            target_feature = "fma"
249        ))]
250        {
251            use std::arch::x86_64::*;
252            // SAFETY: AVX2+FMA checked by cfg.
253            unsafe {
254                let v = _mm256_set_pd(inputs[3], inputs[2], inputs[1], inputs[0]);
255                let result = erf_approx_v4(v);
256                let result_arr: [f64; 4] = std::mem::transmute(result);
257                for i in 0..4 {
258                    let scalar = erf_approx(inputs[i]);
259                    let diff = (result_arr[i] - scalar).abs();
260                    assert!(
261                        diff < 1e-15,
262                        "erf_approx_v4 lane {i}: expected {scalar}, got {}, diff {diff}",
263                        result_arr[i]
264                    );
265                }
266            }
267        }
268
269        #[cfg(not(all(
270            target_arch = "x86_64",
271            target_feature = "avx2",
272            target_feature = "fma"
273        )))]
274        {
275            let result = erf_approx_v4(inputs);
276            for i in 0..4 {
277                let scalar = erf_approx(inputs[i]);
278                let diff = (result[i] - scalar).abs();
279                assert!(
280                    diff < 1e-15,
281                    "erf_approx_v4 lane {i}: expected {scalar}, got {}, diff {diff}",
282                    result[i]
283                );
284            }
285        }
286    }
287
288    #[test]
289    fn test_bilinear_fixed() {
290        // Center of 4 pixels: average
291        assert_eq!(
292            bilinear_interpolate_fixed(0.5, 0.5, 100, 200, 100, 200),
293            150
294        );
295        // Top-left: p00
296        assert_eq!(bilinear_interpolate_fixed(0.0, 0.0, 100, 200, 50, 250), 100);
297        // Bottom-right: p11
298        assert_eq!(
299            bilinear_interpolate_fixed(0.999, 0.999, 100, 200, 50, 250),
300            249
301        ); // Rounding
302    }
303}