1use multiversion::multiversion;
4
5#[multiversion(targets(
10 "x86_64+avx2+bmi1+bmi2+popcnt+lzcnt",
11 "x86_64+avx512f+avx512bw+avx512dq+avx512vl",
12 "aarch64+neon"
13))]
14#[must_use]
15pub(crate) fn rcp_nr(w: f32) -> f32 {
16 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
17 unsafe {
19 use std::arch::x86_64::*;
20 let w_vec = _mm_set_ss(w);
21 let rcp = _mm_rcp_ss(w_vec);
22
23 let two = _mm_set_ss(2.0);
25 let prod = _mm_mul_ss(w_vec, rcp);
26 let diff = _mm_sub_ss(two, prod);
27 let res = _mm_mul_ss(rcp, diff);
28
29 return _mm_cvtss_f32(res);
30 }
31
32 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
33 #[allow(unsafe_code)]
34 unsafe {
36 use std::arch::aarch64::*;
37 let v_w = vdupq_n_f32(w);
39 let res_vec = vrecpeq_f32(v_w);
40 let res_vec = vmulq_f32(res_vec, vrecpsq_f32(v_w, res_vec));
41 return vgetq_lane_f32(res_vec, 0);
42 }
43
44 #[cfg(not(any(
45 all(target_arch = "x86_64", target_feature = "avx2"),
46 all(target_arch = "aarch64", target_feature = "neon")
47 )))]
48 {
49 1.0 / w
50 }
51}
52
53#[must_use]
59#[allow(clippy::cast_sign_loss, dead_code)]
60pub(crate) fn bilinear_interpolate_fixed(x: f32, y: f32, p00: u8, p10: u8, p01: u8, p11: u8) -> u8 {
61 let fx = ((x.fract() * 65536.0) as u32) & 0xFFFF;
63 let fy = ((y.fract() * 65536.0) as u32) & 0xFFFF;
64
65 let inv_x = 0x10000 - fx;
66 let inv_y = 0x10000 - fy;
67
68 let w00 = (u64::from(inv_x) * u64::from(inv_y)) >> 16;
71 let w10 = (u64::from(fx) * u64::from(inv_y)) >> 16;
72 let w01 = (u64::from(inv_x) * u64::from(fy)) >> 16;
73 let w11 = (u64::from(fx) * u64::from(fy)) >> 16;
74
75 let res =
76 (u64::from(p00) * w00 + u64::from(p10) * w10 + u64::from(p01) * w01 + u64::from(p11) * w11)
77 >> 16;
78 res as u8
79}
80
81#[must_use]
87pub(crate) fn erf_approx(x: f64) -> f64 {
88 if x == 0.0 {
89 return 0.0;
90 }
91 let sign = if x < 0.0 { -1.0 } else { 1.0 };
92 let x = x.abs();
93
94 let a1 = 0.254_829_592;
96 let a2 = -0.284_496_736;
97 let a3 = 1.421_413_741;
98 let a4 = -1.453_152_027;
99 let a5 = 1.061_405_429;
100 let p = 0.327_591_1;
101
102 let t = 1.0 / (1.0 + p * x);
103 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
104
105 sign * y
106}
107
108#[cfg(all(
115 target_arch = "x86_64",
116 target_feature = "avx2",
117 target_feature = "fma"
118))]
119#[must_use]
120pub(crate) unsafe fn erf_approx_v4(x: std::arch::x86_64::__m256d) -> std::arch::x86_64::__m256d {
121 use std::arch::x86_64::*;
122
123 let sign_mask = _mm256_set1_pd(-0.0);
125 let sign_bits = _mm256_and_pd(x, sign_mask);
126 let abs_x = _mm256_andnot_pd(sign_mask, x);
127
128 let a1 = _mm256_set1_pd(0.254_829_592);
130 let a2 = _mm256_set1_pd(-0.284_496_736);
131 let a3 = _mm256_set1_pd(1.421_413_741);
132 let a4 = _mm256_set1_pd(-1.453_152_027);
133 let a5 = _mm256_set1_pd(1.061_405_429);
134 let p = _mm256_set1_pd(0.327_591_1);
135 let one = _mm256_set1_pd(1.0);
136
137 let t = _mm256_div_pd(one, _mm256_fmadd_pd(p, abs_x, one));
139
140 let poly = _mm256_fmadd_pd(a5, t, a4);
142 let poly = _mm256_fmadd_pd(poly, t, a3);
143 let poly = _mm256_fmadd_pd(poly, t, a2);
144 let poly = _mm256_fmadd_pd(poly, t, a1);
145
146 let neg_x2 = _mm256_mul_pd(abs_x, abs_x);
150 let neg_x2 = _mm256_xor_pd(neg_x2, sign_mask); let neg_x2_arr: [f64; 4] = std::mem::transmute(neg_x2);
154 let exp_vals = _mm256_set_pd(
155 neg_x2_arr[3].exp(),
156 neg_x2_arr[2].exp(),
157 neg_x2_arr[1].exp(),
158 neg_x2_arr[0].exp(),
159 );
160
161 let y = _mm256_fnmadd_pd(_mm256_mul_pd(poly, t), exp_vals, one);
163
164 _mm256_or_pd(y, sign_bits)
166}
167
168#[cfg(not(all(
172 target_arch = "x86_64",
173 target_feature = "avx2",
174 target_feature = "fma"
175)))]
176#[must_use]
177#[allow(dead_code)]
178pub(crate) fn erf_approx_v4(x: [f64; 4]) -> [f64; 4] {
179 [
180 erf_approx(x[0]),
181 erf_approx(x[1]),
182 erf_approx(x[2]),
183 erf_approx(x[3]),
184 ]
185}
186
187#[cfg(test)]
188#[allow(clippy::float_cmp)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_rcp_nr_precision() {
194 let values = [1.0, 2.0, 10.0, 0.5, 123.456];
195 for &w in &values {
196 let expected = 1.0 / w;
197 let actual = rcp_nr(w);
198 let diff = (expected - actual).abs();
199 assert!(
201 diff < 1e-4,
202 "rcp_nr({w}) failed: expected {expected}, got {actual}, diff {diff}"
203 );
204 }
205 }
206
207 #[test]
208 fn test_erf_approx_properties() {
209 assert_eq!(erf_approx(0.0), 0.0);
211
212 for x in [0.1, 0.5, 1.0, 2.0, 5.0] {
214 assert!((erf_approx(-x) + erf_approx(x)).abs() < 1e-15);
215 }
216
217 assert!((erf_approx(10.0) - 1.0).abs() < 1e-7);
219 assert!((erf_approx(-10.0) + 1.0).abs() < 1e-7);
220 assert!((erf_approx(100.0) - 1.0).abs() < 1e-15);
221 }
222
223 #[test]
224 fn test_erf_approx_accuracy() {
225 let cases = [
226 (0.5, 0.520_499_877_813_046_5),
227 (1.0, 0.842_700_792_949_714_8),
228 (2.0, 0.995_322_265_018_952_7),
229 ];
230
231 for (x, expected) in cases {
232 let actual = erf_approx(x);
233 let diff = (actual - expected).abs();
234 assert!(
235 diff < 1.5e-7,
236 "erf_approx({x}) error {diff} exceeds tolerance 1.5e-7"
237 );
238 }
239 }
240
241 #[test]
242 fn test_erf_approx_v4_matches_scalar() {
243 let inputs = [0.5, -1.0, 2.0, -0.3];
244
245 #[cfg(all(
246 target_arch = "x86_64",
247 target_feature = "avx2",
248 target_feature = "fma"
249 ))]
250 {
251 use std::arch::x86_64::*;
252 unsafe {
254 let v = _mm256_set_pd(inputs[3], inputs[2], inputs[1], inputs[0]);
255 let result = erf_approx_v4(v);
256 let result_arr: [f64; 4] = std::mem::transmute(result);
257 for i in 0..4 {
258 let scalar = erf_approx(inputs[i]);
259 let diff = (result_arr[i] - scalar).abs();
260 assert!(
261 diff < 1e-15,
262 "erf_approx_v4 lane {i}: expected {scalar}, got {}, diff {diff}",
263 result_arr[i]
264 );
265 }
266 }
267 }
268
269 #[cfg(not(all(
270 target_arch = "x86_64",
271 target_feature = "avx2",
272 target_feature = "fma"
273 )))]
274 {
275 let result = erf_approx_v4(inputs);
276 for i in 0..4 {
277 let scalar = erf_approx(inputs[i]);
278 let diff = (result[i] - scalar).abs();
279 assert!(
280 diff < 1e-15,
281 "erf_approx_v4 lane {i}: expected {scalar}, got {}, diff {diff}",
282 result[i]
283 );
284 }
285 }
286 }
287
288 #[test]
289 fn test_bilinear_fixed() {
290 assert_eq!(
292 bilinear_interpolate_fixed(0.5, 0.5, 100, 200, 100, 200),
293 150
294 );
295 assert_eq!(bilinear_interpolate_fixed(0.0, 0.0, 100, 200, 50, 250), 100);
297 assert_eq!(
299 bilinear_interpolate_fixed(0.999, 0.999, 100, 200, 50, 250),
300 249
301 ); }
303}