1use crate::float_trait::Float;
2use conv::*;
3use std::fmt::Debug;
4
5pub trait ErrorFunction<T>: Clone + Debug
11where
12 T: ErfFloat,
13{
14 fn erf(x: T) -> T;
15
16 fn normal_cdf(x: T, mean: T, sigma: T) -> T {
17 T::half() * (T::one() + Self::erf((x - mean) / sigma * T::FRAC_1_SQRT_2()))
18 }
19
20 fn max_dx_nonunity_normal_cdf(sigma: T) -> T;
21
22 fn min_dx_nonzero_normal_cdf(sigma: T) -> T {
23 -Self::max_dx_nonunity_normal_cdf(sigma)
24 }
25}
26
27#[derive(Copy, Clone, Debug)]
29pub struct ExactErf;
30
31impl<T> ErrorFunction<T> for ExactErf
32where
33 T: ErfFloat,
34{
35 fn erf(x: T) -> T {
36 x.libm_erf()
37 }
38
39 fn max_dx_nonunity_normal_cdf(sigma: T) -> T {
40 T::SQRT_2_ERFINV_UNITY_MINUS_EPS * sigma
41 }
42}
43
44#[derive(Copy, Clone, Debug)]
48pub struct Eps1Over1e3Erf;
49
50impl<T> ErrorFunction<T> for Eps1Over1e3Erf
51where
52 T: ErfFloat,
53{
54 fn erf(x: T) -> T {
55 x.erf_eps_1over1e3()
56 }
57
58 fn max_dx_nonunity_normal_cdf(sigma: T) -> T {
59 T::SQRT_2_MAX_X_FOR_ERF_EPS_1OVER1E3 * sigma
60 }
61}
62
63pub trait ErfFloat: Float + ApproxInto<usize, RoundToZero> + num_traits::Float {
65 const SQRT_2_ERFINV_UNITY_MINUS_EPS: Self;
66
67 fn libm_erf(self) -> Self;
68
69 const SQRT_2_MAX_X_FOR_ERF_EPS_1OVER1E3: Self;
70 const X_FOR_ERF_EPS_1OVER1E3: [Self; 64];
71 const INVERSED_DX_FOR_ERF_EPS_1OVER1E3: Self;
72 const Y_FOR_ERF_EPS_1OVER1E3: [Self; 64];
73 fn erf_eps_1over1e3(self) -> Self {
74 match self {
75 _ if self < Self::X_FOR_ERF_EPS_1OVER1E3[0] => -Self::one(),
76 _ if self >= Self::X_FOR_ERF_EPS_1OVER1E3[63] => Self::one(),
77 x => {
78 let idx =
79 (x - Self::X_FOR_ERF_EPS_1OVER1E3[0]) * Self::INVERSED_DX_FOR_ERF_EPS_1OVER1E3;
80 let alpha = idx.fract();
81 let idx: usize = idx.approx_by::<RoundToZero>().unwrap();
82 Self::Y_FOR_ERF_EPS_1OVER1E3[idx] * (Self::one() - alpha)
83 + Self::Y_FOR_ERF_EPS_1OVER1E3[idx + 1] * alpha
84 }
85 }
86 }
87}
88
89#[allow(clippy::excessive_precision)]
90impl ErfFloat for f32 {
91 const SQRT_2_ERFINV_UNITY_MINUS_EPS: Self = 5.294704084854598;
92
93 fn libm_erf(self) -> Self {
94 libm::erff(self)
95 }
96
97 const SQRT_2_MAX_X_FOR_ERF_EPS_1OVER1E3: Self = 3.389783571270326;
98 const X_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
99 -2.39693895,
100 -2.32084565,
101 -2.24475235,
102 -2.16865905,
103 -2.09256575,
104 -2.01647245,
105 -1.94037915,
106 -1.86428585,
107 -1.78819255,
108 -1.71209925,
109 -1.63600595,
110 -1.55991265,
111 -1.48381935,
112 -1.40772605,
113 -1.33163275,
114 -1.25553945,
115 -1.17944615,
116 -1.10335285,
117 -1.02725955,
118 -0.95116625,
119 -0.87507295,
120 -0.79897965,
121 -0.72288635,
122 -0.64679305,
123 -0.57069975,
124 -0.49460645,
125 -0.41851315,
126 -0.34241985,
127 -0.26632655,
128 -0.19023325,
129 -0.11413995,
130 -0.03804665,
131 0.03804665,
132 0.11413995,
133 0.19023325,
134 0.26632655,
135 0.34241985,
136 0.41851315,
137 0.49460645,
138 0.57069975,
139 0.64679305,
140 0.72288635,
141 0.79897965,
142 0.87507295,
143 0.95116625,
144 1.02725955,
145 1.10335285,
146 1.17944615,
147 1.25553945,
148 1.33163275,
149 1.40772605,
150 1.48381935,
151 1.55991265,
152 1.63600595,
153 1.71209925,
154 1.78819255,
155 1.86428585,
156 1.94037915,
157 2.01647245,
158 2.09256575,
159 2.16865905,
160 2.24475235,
161 2.32084565,
162 2.39693895,
163 ];
164 const INVERSED_DX_FOR_ERF_EPS_1OVER1E3: Self = 13.141761468984605;
165 const Y_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
166 -0.99930052,
167 -0.99896989,
168 -0.99849936,
169 -0.99783743,
170 -0.99691696,
171 -0.9956517,
172 -0.99393249,
173 -0.99162334,
174 -0.98855749,
175 -0.98453378,
176 -0.97931372,
177 -0.97261948,
178 -0.96413348,
179 -0.9534999,
180 -0.94032851,
181 -0.92420128,
182 -0.90468204,
183 -0.88132908,
184 -0.85371082,
185 -0.82142392,
186 -0.78411334,
187 -0.74149338,
188 -0.69336849,
189 -0.6396527,
190 -0.58038613,
191 -0.51574736,
192 -0.44606033,
193 -0.37179495,
194 -0.29356079,
195 -0.21209374,
196 -0.12823602,
197 -0.04291034,
198 0.04291034,
199 0.12823602,
200 0.21209374,
201 0.29356079,
202 0.37179495,
203 0.44606033,
204 0.51574736,
205 0.58038613,
206 0.6396527,
207 0.69336849,
208 0.74149338,
209 0.78411334,
210 0.82142392,
211 0.85371082,
212 0.88132908,
213 0.90468204,
214 0.92420128,
215 0.94032851,
216 0.9534999,
217 0.96413348,
218 0.97261948,
219 0.97931372,
220 0.98453378,
221 0.98855749,
222 0.99162334,
223 0.99393249,
224 0.9956517,
225 0.99691696,
226 0.99783743,
227 0.99849936,
228 0.99896989,
229 0.99930052,
230 ];
231}
232
233impl ErfFloat for f64 {
234 const SQRT_2_ERFINV_UNITY_MINUS_EPS: Self = 8.20953615160139;
235
236 fn libm_erf(self) -> Self {
237 libm::erf(self)
238 }
239
240 const SQRT_2_MAX_X_FOR_ERF_EPS_1OVER1E3: Self = 3.389783571270326;
241 const X_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
242 -2.39693895,
243 -2.32084565,
244 -2.24475235,
245 -2.16865905,
246 -2.09256575,
247 -2.01647245,
248 -1.94037915,
249 -1.86428585,
250 -1.78819255,
251 -1.71209925,
252 -1.63600595,
253 -1.55991265,
254 -1.48381935,
255 -1.40772605,
256 -1.33163275,
257 -1.25553945,
258 -1.17944615,
259 -1.10335285,
260 -1.02725955,
261 -0.95116625,
262 -0.87507295,
263 -0.79897965,
264 -0.72288635,
265 -0.64679305,
266 -0.57069975,
267 -0.49460645,
268 -0.41851315,
269 -0.34241985,
270 -0.26632655,
271 -0.19023325,
272 -0.11413995,
273 -0.03804665,
274 0.03804665,
275 0.11413995,
276 0.19023325,
277 0.26632655,
278 0.34241985,
279 0.41851315,
280 0.49460645,
281 0.57069975,
282 0.64679305,
283 0.72288635,
284 0.79897965,
285 0.87507295,
286 0.95116625,
287 1.02725955,
288 1.10335285,
289 1.17944615,
290 1.25553945,
291 1.33163275,
292 1.40772605,
293 1.48381935,
294 1.55991265,
295 1.63600595,
296 1.71209925,
297 1.78819255,
298 1.86428585,
299 1.94037915,
300 2.01647245,
301 2.09256575,
302 2.16865905,
303 2.24475235,
304 2.32084565,
305 2.39693895,
306 ];
307 const INVERSED_DX_FOR_ERF_EPS_1OVER1E3: Self = 13.141761468984605;
308 const Y_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
309 -0.99930052,
310 -0.99896989,
311 -0.99849936,
312 -0.99783743,
313 -0.99691696,
314 -0.9956517,
315 -0.99393249,
316 -0.99162334,
317 -0.98855749,
318 -0.98453378,
319 -0.97931372,
320 -0.97261948,
321 -0.96413348,
322 -0.9534999,
323 -0.94032851,
324 -0.92420128,
325 -0.90468204,
326 -0.88132908,
327 -0.85371082,
328 -0.82142392,
329 -0.78411334,
330 -0.74149338,
331 -0.69336849,
332 -0.6396527,
333 -0.58038613,
334 -0.51574736,
335 -0.44606033,
336 -0.37179495,
337 -0.29356079,
338 -0.21209374,
339 -0.12823602,
340 -0.04291034,
341 0.04291034,
342 0.12823602,
343 0.21209374,
344 0.29356079,
345 0.37179495,
346 0.44606033,
347 0.51574736,
348 0.58038613,
349 0.6396527,
350 0.69336849,
351 0.74149338,
352 0.78411334,
353 0.82142392,
354 0.85371082,
355 0.88132908,
356 0.90468204,
357 0.92420128,
358 0.94032851,
359 0.9534999,
360 0.96413348,
361 0.97261948,
362 0.97931372,
363 0.98453378,
364 0.98855749,
365 0.99162334,
366 0.99393249,
367 0.9956517,
368 0.99691696,
369 0.99783743,
370 0.99849936,
371 0.99896989,
372 0.99930052,
373 ];
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use approx::assert_abs_diff_eq;
380 use ndarray::Array1;
381
382 #[test]
383 fn erf_eps_1over1e3() {
384 let x = Array1::linspace(-5.0, 5.0, 1 << 20);
385 let desired = x.mapv(f32::libm_erf);
386 let actual = x.mapv(f32::erf_eps_1over1e3);
387 assert_abs_diff_eq!(
388 actual.as_slice().unwrap(),
389 desired.as_slice().unwrap(),
390 epsilon = 7e-4,
391 );
392 }
393}