Skip to main content

deep_time/math/
sqrt.rs

1#![allow(clippy::indexing_slicing)]
2#![allow(clippy::excessive_precision)]
3#![allow(clippy::approx_constant)]
4#![allow(clippy::eq_op)]
5
6use crate::Real;
7
8const RSQRT_TAB: [u16; 128] = [
9    0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43, 0xaa14, 0xa8eb, 0xa7c8, 0xa6aa,
10    0xa592, 0xa480, 0xa373, 0xa26b, 0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
11    0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430, 0x936b, 0x92a9, 0x91ea, 0x912e,
12    0x9075, 0x8fbe, 0x8f0a, 0x8e59, 0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
13    0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479, 0x83ec, 0x8361, 0x82d8, 0x8250,
14    0x81c9, 0x8145, 0x80c2, 0x8040, 0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
15    0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2, 0xe443, 0xe2dc, 0xe17a, 0xe020,
16    0xdecb, 0xdd7d, 0xdc34, 0xdaf1, 0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
17    0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f, 0xc858, 0xc764, 0xc674, 0xc587,
18    0xc49d, 0xc3b7, 0xc2d4, 0xc1f4, 0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
19    0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
20];
21
22#[inline]
23const fn mul32(a: u32, b: u32) -> u32 {
24    ((a as u64).wrapping_mul(b as u64) >> 32) as u32
25}
26
27#[inline]
28const fn mul64(a: u64, b: u64) -> u64 {
29    let ahi = a >> 32;
30    let alo = a & 0xffffffff;
31    let bhi = b >> 32;
32    let blo = b & 0xffffffff;
33    ahi.wrapping_mul(bhi)
34        .wrapping_add(ahi.wrapping_mul(blo) >> 32)
35        .wrapping_add(alo.wrapping_mul(bhi) >> 32)
36}
37
38/// Computes sqrt(x) using the table-driven Goldschmidt iteration
39/// from musl libc. Correctly rounded to nearest-even for all Real inputs.
40/// const, no std, no alloc friendly.
41pub const fn sqrt(x: Real) -> Real {
42    let mut ix = x.to_bits();
43    let mut top = ix >> 52;
44
45    // Special cases: subnormal, inf, nan, negative, zero
46    if top.wrapping_sub(0x001) >= 0x7fe {
47        if ix << 1 == 0 {
48            return x; // ±0.0
49        }
50        if ix == 0x7ff0_0000_0000_0000 {
51            return x; // +inf
52        }
53        if ix > 0x7ff0_0000_0000_0000 {
54            // negative or NaN → quiet NaN, preserve sign bit for -inf/-num
55            let nan_bits = 0x7ff8_0000_0000_0000 | (ix & 0x8000_0000_0000_0000);
56            return Real::from_bits(nan_bits);
57        }
58        // Subnormal: normalize by multiplying by 2^52
59        let scale = Real::from_bits(0x4330_0000_0000_0000); // 2^52
60        ix = (x * scale).to_bits();
61        top = (ix >> 52).wrapping_sub(52);
62    }
63
64    let even = top & 1;
65    let mut m = (ix << 11) | 0x8000_0000_0000_0000u64;
66    if even != 0 {
67        m >>= 1;
68    }
69    let top = (top.wrapping_add(0x3ff)) >> 1; // result exponent (biased)
70
71    // Table-driven initial reciprocal sqrt estimate + Goldschmidt iterations
72    // All vars u64 to match C closely; mul32/mul64 return u64 for simplicity
73    let three: u64 = 0xc000_0000;
74    let i = ((ix >> 46) % 128) as usize;
75    let mut r: u64 = (RSQRT_TAB[i] as u64) << 16;
76
77    let mut s: u64 = mul32((m >> 32) as u32, r as u32) as u64;
78    let mut d: u64 = mul32(s as u32, r as u32) as u64;
79    let mut u: u64 = three - d;
80    r = (mul32(r as u32, u as u32) << 1) as u64;
81    s = (mul32(s as u32, u as u32) << 1) as u64;
82
83    d = mul32(s as u32, r as u32) as u64;
84    u = three - d;
85    r = (mul32(r as u32, u as u32) << 1) as u64;
86
87    r <<= 32;
88    s = mul64(m, r);
89    d = mul64(s, r);
90    u = (three << 32) - d;
91    s = mul64(s, u);
92
93    // Final adjustment and rounding decision
94    s = (s - 2) >> 9;
95
96    let d0 = (m << 42).wrapping_sub(s.wrapping_mul(s));
97    let d1 = s.wrapping_sub(d0);
98    let _d2 = d1.wrapping_add(s).wrapping_add(1);
99
100    if (d1 >> 63) != 0 {
101        s = s.wrapping_add(1);
102    }
103    s &= 0x000f_ffff_ffff_ffff;
104    s |= top << 52;
105
106    Real::from_bits(s)
107}
108
109const SPLIT: Real = 134217728. + 1.; // 0x1p27 + 1 === (2 ^ 27) + 1
110
111const fn sq(x: Real) -> (Real, Real) {
112    let xc: Real = x * SPLIT;
113    let xh: Real = x - xc + xc;
114    let xl: Real = x - xh;
115    let hi = x * x;
116    let lo = xh * xh - hi + 2. * xh * xl + xl * xl;
117    (hi, lo)
118}
119
120pub const fn hypot(mut x: Real, mut y: Real) -> Real {
121    let x1p700 = Real::from_bits(0x6bb0000000000000); // 0x1p700 === 2 ^ 700
122    let x1p_700 = Real::from_bits(0x1430000000000000); // 0x1p-700 === 2 ^ -700
123
124    let mut uxi = x.to_bits();
125    let mut uyi = y.to_bits();
126    let uti;
127    let mut z: Real;
128
129    /* arrange |x| >= |y| */
130    uxi &= -1i64 as u64 >> 1;
131    uyi &= -1i64 as u64 >> 1;
132    if uxi < uyi {
133        uti = uxi;
134        uxi = uyi;
135        uyi = uti;
136    }
137
138    /* special cases */
139    let ex: i64 = (uxi >> 52) as i64;
140    let ey: i64 = (uyi >> 52) as i64;
141    x = Real::from_bits(uxi);
142    y = Real::from_bits(uyi);
143    /* note: hypot(inf,nan) == inf */
144    if ey == 0x7ff {
145        return y;
146    }
147    if ex == 0x7ff || uyi == 0 {
148        return x;
149    }
150    /* note: hypot(x,y) ~= x + y*y/x/2 with inexact for small y/x */
151    /* 64 difference is enough for ld80 double_t */
152    if ex - ey > 64 {
153        return x + y;
154    }
155
156    /* precise sqrt argument in nearest rounding mode without overflow */
157    /* xh*xh must not overflow and xl*xl must not underflow in sq */
158    z = 1.;
159    if ex > 0x3ff + 510 {
160        z = x1p700;
161        x *= x1p_700;
162        y *= x1p_700;
163    } else if ey < 0x3ff - 450 {
164        z = x1p_700;
165        x *= x1p700;
166        y *= x1p700;
167    }
168    let (hx, lx) = sq(x);
169    let (hy, ly) = sq(y);
170    z * sqrt(ly + lx + hy + hx)
171}
172
173#[cfg(all(test, feature = "std"))]
174mod sqrt_tests {
175    use super::sqrt;
176    use std::{f64, vec, vec::Vec};
177
178    #[test]
179    fn test_special_cases() {
180        assert_eq!(sqrt(0.0), 0.0);
181        assert_eq!(sqrt(-0.0), -0.0);
182        assert!(sqrt(f64::INFINITY).is_infinite() && sqrt(f64::INFINITY) > 0.0);
183        assert!(sqrt(f64::NEG_INFINITY).is_nan());
184        assert!(sqrt(-1.0).is_nan());
185        assert!(sqrt(f64::NAN).is_nan());
186        // signaling nan? but in practice quiet
187    }
188
189    #[test]
190    fn test_perfect_squares() {
191        for i in 0..100u32 {
192            let x = (i * i) as f64;
193            let r = sqrt(x);
194            assert!((r - i as f64).abs() < 1e-10 || r.is_nan());
195        }
196    }
197
198    #[test]
199    fn test_random_vs_std() {
200        // 5M deterministic LCG random normals in [1,2) — exercises table + Goldschmidt fully
201        let mut failures = 0u32;
202        let mut state: u64 = 0x123456789abcdef0;
203        for _ in 0..5_000_000 {
204            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
205            let bits = (state & 0x000f_ffff_ffff_ffff) | 0x3ff0_0000_0000_0000; // positive normal [1,2)
206            let val = f64::from_bits(bits);
207            let r1 = sqrt(val);
208            let r2 = val.sqrt();
209            if r1.to_bits() != r2.to_bits() {
210                failures += 1;
211                // if failures < 3 {
212                //     eprintln!(
213                //         "Mismatch at {:016x}: ours={:016x} std={:016x}",
214                //         bits,
215                //         r1.to_bits(),
216                //         r2.to_bits()
217                //     );
218                // }
219            }
220        }
221        assert_eq!(
222            failures, 0,
223            "Found {} mismatches in 5M random normals [1,2)",
224            failures
225        );
226    }
227
228    #[test]
229    fn test_subnormals_random() {
230        // 100k random subnormals (exp=0) — critical for normalize path
231        let mut failures = 0u32;
232        let mut state: u64 = 0xdeadbeefcafebabe;
233        for _ in 0..100_000 {
234            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
235            // subnormal: exp=0, random mantissa (low 52 bits)
236            let bits = state & 0x000f_ffff_ffff_ffff; // clears sign + exp
237            let val = f64::from_bits(bits);
238            if val == 0.0 {
239                continue;
240            } // skip zero
241            let r1 = sqrt(val);
242            let r2 = val.sqrt();
243            if r1.to_bits() != r2.to_bits() {
244                failures += 1;
245                // if failures < 3 {
246                //     eprintln!(
247                //         "Subnormal mismatch at {:016x}: ours={:016x} std={:016x}",
248                //         bits,
249                //         r1.to_bits(),
250                //         r2.to_bits()
251                //     );
252                // }
253            }
254        }
255        assert_eq!(
256            failures, 0,
257            "Found {} mismatches in 100k random subnormals",
258            failures
259        );
260    }
261
262    #[test]
263    fn test_boundaries() {
264        // Critical boundaries: min/max normal, subnormal boundary, overflow edge, powers of 2
265        let boundaries: [f64; 8] = [
266            f64::MIN_POSITIVE,                         // 2^-1022 (smallest normal)
267            f64::from_bits(0x0010_0000_0000_0000),     // 2^-1021
268            f64::from_bits(0x000f_ffff_ffff_ffff),     // largest subnormal
269            2.0f64.powi(-1074),                        // smallest positive subnormal (2^-1074)
270            f64::MAX,                                  // ~1.8e308
271            f64::from_bits(0x7fe0_0000_0000_0000),     // largest finite < inf
272            2.0f64.powi(1023),                         // 2^1023 (largest power of 2)
273            2.0f64.powi(-1022) * (1.0 + f64::EPSILON), // just above min normal
274        ];
275        for &x in &boundaries {
276            let r1 = sqrt(x);
277            let r2 = x.sqrt();
278            assert_eq!(r1.to_bits(), r2.to_bits(), "Boundary mismatch for {:e}", x);
279            // Also check sqrt(x*x) ~ |x| for positive x (within rounding), but skip underflow cases
280            if x > 0.0 && x.is_finite() && x > 1e-200 {
281                let xx = x * x;
282                if xx.is_finite() && xx.is_normal() {
283                    let r = sqrt(xx);
284                    let rel = ((r - x).abs() / x).max(0.0);
285                    assert!(
286                        rel < 1e-14 || r.is_nan(),
287                        "sqrt(x*x) not close to x for {}",
288                        x
289                    );
290                }
291            }
292        }
293    }
294
295    #[test]
296    fn test_known_hard_cases() {
297        // Known hard-to-round / exact / boundary cases — all must match std bit-exactly
298        let cases: &[f64] = &[
299            2.0,
300            0.5,
301            4.0,
302            9.0,
303            0.0,
304            f64::INFINITY,
305            1.0e-300,                              // very small normal
306            f64::from_bits(0x0010_0000_0000_0001), // just above min normal
307            1.0 + f64::EPSILON,                    // next after 1.0
308            f64::from_bits(0x7fefffffffffffff),    // largest finite
309        ];
310        for &x in cases {
311            let r = sqrt(x);
312            // bit-exact check vs Rust std (the gold standard for this platform)
313            assert_eq!(r.to_bits(), x.sqrt().to_bits(), "Bit mismatch for {:e}", x);
314        }
315    }
316
317    // Manual nextUp / nextDown
318    fn next_up(x: f64) -> f64 {
319        if x.is_nan() || x == f64::INFINITY {
320            return x;
321        }
322        if x == 0.0 {
323            return f64::from_bits(1);
324        }
325        let bits = x.to_bits();
326        if x > 0.0 {
327            f64::from_bits(bits + 1)
328        } else {
329            f64::from_bits(bits - 1)
330        }
331    }
332    fn next_down(x: f64) -> f64 {
333        if x.is_nan() || x == f64::NEG_INFINITY {
334            return x;
335        }
336        if x == -0.0 || x == 0.0 {
337            return f64::from_bits(0x8000_0000_0000_0001);
338        }
339        let bits = x.to_bits();
340        if x > 0.0 {
341            f64::from_bits(bits - 1)
342        } else {
343            f64::from_bits(bits + 1)
344        }
345    }
346
347    #[test]
348    fn test_powers_of_two() {
349        // All representable powers of 2 (even exponents must be exact, odd use std)
350        for exp in -1074i32..=1023 {
351            let x = if exp >= -1022 {
352                2.0f64.powi(exp)
353            } else {
354                // subnormal 2^exp = 2^(exp + 1074) * 2^-1074
355                f64::from_bits(1u64 << (exp + 1074))
356            };
357            if !x.is_finite() || x == 0.0 {
358                continue;
359            }
360            let r1 = sqrt(x);
361            let r2 = x.sqrt();
362            assert_eq!(
363                r1.to_bits(),
364                r2.to_bits(),
365                "Power-of-2 mismatch for 2^{}",
366                exp
367            );
368            // For even exponents, result should be exactly 2^(exp/2) when representable
369            if exp % 2 == 0 {
370                let expected_exp = exp / 2;
371                if expected_exp >= -1022 {
372                    let expected = 2.0f64.powi(expected_exp);
373                    assert_eq!(
374                        r1.to_bits(),
375                        expected.to_bits(),
376                        "Even power-of-2 not exact for 2^{}",
377                        exp
378                    );
379                }
380            }
381        }
382    }
383
384    #[test]
385    fn test_nextafter_edges() {
386        // nextUp / nextDown around critical points (0, 1, min_normal, max)
387        let mut edges: Vec<f64> = vec![
388            f64::from_bits(1),                     // smallest positive subnormal
389            f64::from_bits(0x0000_0000_0000_0002), // next subnormal
390            next_down(f64::MIN_POSITIVE),          // largest subnormal
391            f64::MIN_POSITIVE,                     // smallest normal
392            next_up(f64::MIN_POSITIVE),
393            next_down(1.0),
394            1.0,
395            next_up(1.0),
396            next_down(f64::MAX),
397            f64::MAX,
398        ];
399        // Also a few negative edges (should all produce NaN)
400        edges.push(next_up(-f64::MIN_POSITIVE)); // negative smallest normal-ish
401        for &x in &edges {
402            let r1 = sqrt(x);
403            let r2 = x.sqrt();
404            assert_eq!(
405                r1.to_bits(),
406                r2.to_bits(),
407                "nextafter edge mismatch for {:e} (bits {:016x})",
408                x,
409                x.to_bits()
410            );
411        }
412    }
413
414    #[test]
415    fn test_negative_subnormals() {
416        // All negative subnormals must produce NaN (sign bit set in result)
417        let mut state: u64 = 0xfeedface_deadbeef;
418        for _ in 0..10_000 {
419            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
420            let bits = (state & 0x000f_ffff_ffff_ffff) | 0x8000_0000_0000_0000; // negative subnormal
421            let val = f64::from_bits(bits);
422            if val == 0.0 {
423                continue;
424            }
425            let r = sqrt(val);
426            assert!(
427                r.is_nan(),
428                "Negative subnormal did not produce NaN: {:e}",
429                val
430            );
431            // sign bit should be set (negative NaN)
432            assert!(
433                r.to_bits() & 0x8000_0000_0000_0000 != 0,
434                "NaN sign bit not set for negative subnormal"
435            );
436        }
437    }
438}