div-pow10 0.1.1

Fast division by powers of 10.
Documentation
/// Calculate division: `n / 10.pow(i)`.
///
/// The divident is 1-word, 128-bit, same with divisor.
///
/// Return `None` if `i > 38`.
pub fn div_single(n: u128, i: u32) -> Option<u128> {
    if i > 38 {
        None
    } else if i == 0 {
        Some(n)
    } else {
        Some(unsafe { unchecked_div_single(n, i) })
    }
}

/// Calculate division: `n / 10.pow(i)`.
///
/// The divident is 1-word, 128-bit, same with divisor.
///
/// # Saftey:
///
/// It's UB if `i > 38`.
pub unsafe fn unchecked_div_single(n: u128, i: u32) -> u128 {
    unsafe { do_unchecked_div_single(n, i, false) }
}

/// Calculate division: `n / 10.pow(i)`.
///
/// Compared to [`unchecked_div_single`], this divident is 127-bit,
/// which is 1 bit less. This is slightly faster.
/// Besides, this works with `i = 0`.
///
/// # Saftey:
///
/// It's UB if `i == 0` or `i > 38` or `n > 2.pow(127)`.
pub unsafe fn unchecked_div_single_r1b(n: u128, i: u32) -> u128 {
    debug_assert!(i > 0);
    debug_assert!(n <= 2_u128.pow(127));
    unsafe { do_unchecked_div_single(n, i, true) }
}

unsafe fn do_unchecked_div_single(n: u128, i: u32, is_r1b: bool) -> u128 {
    // The magics are generated by python code:
    //
    // def gen(d):
    //    l = math.ceil( math.log2(d) )
    //    m = pow(2, 128+l) // d + 1
    //    m = m - pow(2, 128) # make m fit in 128-bit
    //    return (m, l)
    const GM_EXP_MAGICS: [(u128, u32, u32); 39] = [
        (0, u32::MAX, 0),
        (0x9999999999999999999999999999999a, 4, 3),
        (0x47ae147ae147ae147ae147ae147ae148, 7, 6),
        (0x0624dd2f1a9fbe76c8b4395810624dd3, 10, 9),
        (0xa36e2eb1c432ca57a786c226809d4952, 14, 13),
        (0x4f8b588e368f08461f9f01b866e43aa8, 17, 16),
        (0x0c6f7a0b5ed8d36b4c7f349385836220, 20, 19),
        (0xad7f29abcaf485787a6520ec08d2369a, 24, 23),
        (0x5798ee2308c39df9fb841a566d74f87b, 27, 26),
        (0x12e0be826d694b2e62d01511f12a6062, 30, 29),
        (0xb7cdfd9d7bdbab7d6ae6881cb5109a37, 34, 33),
        (0x5fd7fe17964955fdef1ed34a2a73ae92, 37, 36),
        (0x19799812dea11197f27f0f6e885c8ba8, 40, 39),
        (0xc25c268497681c2650cb4be40d60df74, 44, 43),
        (0x6849b86a12b9b01ea70909833de71929, 47, 46),
        (0x203af9ee756159b21f3a6e0297ec1421, 50, 49),
        (0xcd2b297d889bc2b6985d7cd0f3135368, 54, 53),
        (0x70ef54646d496892137dfd73f5a90f86, 57, 56),
        (0x2725dd1d243aba0e75fe645cc4873f9f, 60, 59),
        (0xd83c94fb6d2ac34a5663d3c7a0d865cb, 64, 63),
        (0x79ca10c9242235d511e976394d79eb09, 67, 66),
        (0x2e3b40a0e9b4f7dda7edf82dd794bc07, 70, 69),
        (0xe392010175ee5962a6498d1625bac671, 74, 73),
        (0x82db34012b25144eeb6e0a781e2f0528, 77, 76),
        (0x357c299a88ea76a58924d52ce4f26a86, 80, 79),
        (0xef2d0f5da7dd8aa27507bb7b07ea440a, 84, 83),
        (0x8c240c4aecb13bb52a6c95fc0655033b, 87, 86),
        (0x3ce9a36f23c0fc90eebd44c99eaa68fc, 90, 89),
        (0xfb0f6be50601941b17953adc3110a7f9, 94, 93),
        (0x95a5efea6b34767c12ddc8b027408661, 97, 96),
        (0x4484bfeebc29f863424b06f3529a051b, 100, 99),
        (0x039d66589687f9e901d59f290ee19daf, 103, 102),
        (0x9f623d5a8a732974cfbc31db4b0295e5, 107, 106),
        (0x4c4e977ba1f5bac3d9635b15d59bab1d, 110, 109),
        (0x09d8792fb4c495697ab5e277de16227e, 113, 112),
        (0xa95a5b7f87a0ef0f2abc9d8c9689d0c9, 117, 116),
        (0x54484932d2e725a5bbca17a3aba173d4, 120, 119),
        (0x1039d428a8b8eaeafca1ac82efb45caa, 123, 122),
        (0xb38fb9daa78e44ab2dcf7a6b19209443, 127, 126),
    ];

    debug_assert!(i < 39);
    let magic = unsafe { GM_EXP_MAGICS.get_unchecked(i as usize) };

    // (n + ((n * m) >> 128)) >> l
    let (high, _) = mul2(n, magic.0);

    if is_r1b {
        // no overflow
        (high + n) >> magic.1
    } else {
        // n + high may overflow, so
        // note: magic.1 = l - 1
        (high + ((n - high) >> 1)) >> magic.2
    }
}

/// Calculate: `(a * b + c) / 10.pow(i)`, return the quotient and remainder.
///
/// Return `None` if: `i > 38` or overflow.
pub fn mul_div(a: u128, b: u128, c: u128, i: u32) -> Option<(u128, u128)> {
    let (high, low) = mul2(a, b);
    let (low, carry) = low.overflowing_add(c);
    div_double(high + carry as u128, low, i)
}

/// Calculate: `(a * b + c) / 10.pow(i)`, return the quotient and remainder.
///
/// # Safety:
///
/// It's UB if: `i > 38` or overflow.
pub unsafe fn unchecked_mul_div(a: u128, b: u128, c: u128, i: u32) -> (u128, u128) {
    let (high, low) = mul2(a, b);
    let (low, carry) = low.overflowing_add(c);
    unsafe { unchecked_div_double(high + carry as u128, low, i) }
}

/// Calculate division: `[n_high, n_low] / 10.pow(i)`, return the
/// quotient and remainder.
///
/// The divident is 2-word, 256-bit, double of divisor.
///
/// Return `None` if: `i > 38` or `n_high >= 10.pow(i)`.
pub fn div_double(n_high: u128, n_low: u128, i: u32) -> Option<(u128, u128)> {
    let Some(exp) = POWERS.get(i as usize) else {
        return None;
    };
    if &n_high >= exp {
        return None;
    }
    unsafe { Some(unchecked_div_double(n_high, n_low, i)) }
}

/// Calculate division: `[n_high, n_low] / 10.pow(i)` , return the
/// quotient and remainder.
///
/// The divident is 2-word, 256-bit, double of divisor.
///
/// # Safety:
///
/// It's UB if: `i > 38` or `n_high >= 10.pow(i)`.
pub unsafe fn unchecked_div_double(n_high: u128, n_low: u128, i: u32) -> (u128, u128) {
    if i <= 19 {
        unsafe { unchecked_div_double_small(n_high, n_low, i) }
    } else {
        unsafe { unchecked_div_double_big(n_high, n_low, i) }
    }
}

unsafe fn unchecked_div_double_big(n_high: u128, n_low: u128, i: u32) -> (u128, u128) {
    debug_assert!((i as usize) < POWERS.len());
    let exp = unsafe { *POWERS.get_unchecked(i as usize) };

    // check overflow
    debug_assert!(n_high < exp);

    // The magics are generated by python code:
    //
    // def gen(d):
    //     zeros = 128 - d.bit_length()
    //     magic = pow(2, 256) // (d << zeros)
    //     magic = magic - pow(2, 128) # make magic fit in 128-bit
    //     return (magic, zeros)
    const MG_EXP_MAGICS: [(u128, u32); 39] = [
        (0, 128),
        (0x99999999999999999999999999999999, 124),
        (0x47ae147ae147ae147ae147ae147ae147, 121),
        (0x0624dd2f1a9fbe76c8b4395810624dd2, 118),
        (0xa36e2eb1c432ca57a786c226809d4951, 114),
        (0x4f8b588e368f08461f9f01b866e43aa7, 111),
        (0x0c6f7a0b5ed8d36b4c7f34938583621f, 108),
        (0xad7f29abcaf485787a6520ec08d23699, 104),
        (0x5798ee2308c39df9fb841a566d74f87a, 101),
        (0x12e0be826d694b2e62d01511f12a6061, 98),
        (0xb7cdfd9d7bdbab7d6ae6881cb5109a36, 94),
        (0x5fd7fe17964955fdef1ed34a2a73ae91, 91),
        (0x19799812dea11197f27f0f6e885c8ba7, 88),
        (0xc25c268497681c2650cb4be40d60df73, 84),
        (0x6849b86a12b9b01ea70909833de71928, 81),
        (0x203af9ee756159b21f3a6e0297ec1420, 78),
        (0xcd2b297d889bc2b6985d7cd0f3135367, 74),
        (0x70ef54646d496892137dfd73f5a90f85, 71),
        (0x2725dd1d243aba0e75fe645cc4873f9e, 68),
        (0xd83c94fb6d2ac34a5663d3c7a0d865ca, 64),
        (0x79ca10c9242235d511e976394d79eb08, 61),
        (0x2e3b40a0e9b4f7dda7edf82dd794bc06, 58),
        (0xe392010175ee5962a6498d1625bac670, 54),
        (0x82db34012b25144eeb6e0a781e2f0527, 51),
        (0x357c299a88ea76a58924d52ce4f26a85, 48),
        (0xef2d0f5da7dd8aa27507bb7b07ea4409, 44),
        (0x8c240c4aecb13bb52a6c95fc0655033a, 41),
        (0x3ce9a36f23c0fc90eebd44c99eaa68fb, 38),
        (0xfb0f6be50601941b17953adc3110a7f8, 34),
        (0x95a5efea6b34767c12ddc8b027408660, 31),
        (0x4484bfeebc29f863424b06f3529a051a, 28),
        (0x039d66589687f9e901d59f290ee19dae, 25),
        (0x9f623d5a8a732974cfbc31db4b0295e4, 21),
        (0x4c4e977ba1f5bac3d9635b15d59bab1c, 18),
        (0x09d8792fb4c495697ab5e277de16227d, 15),
        (0xa95a5b7f87a0ef0f2abc9d8c9689d0c8, 11),
        (0x54484932d2e725a5bbca17a3aba173d3, 8),
        (0x1039d428a8b8eaeafca1ac82efb45ca9, 5),
        (0xb38fb9daa78e44ab2dcf7a6b19209442, 1),
    ];

    // algorithm:
    //   zn = n << zeros
    //   q = (((magic * zn) >> 128) + zn) >> 128

    // SAFETY: exp has been read by i already above
    let &(magic, zeros) = unsafe { MG_EXP_MAGICS.get_unchecked(i as usize) };

    // calc: (z_high, z_low) := n << zeros
    let z_high = (n_high << zeros) | (n_low >> (128 - zeros));
    let z_low = n_low << zeros;

    // calc: (m_high, m_low) := (magic * zn) >> 128
    //
    // We should have calculated m1_high as:
    //   let (m1_high, _) = mul2(z_low, magic);
    // But we only do the highest multiplication, which is faster,
    // and right if:
    //   (zn * Rm / exp) + Rs + 2 * 2^96 < 2 * 2^256
    // which is satisfied for the above magics,
    // where Rm is the remainder when calculating magic, and Rs is
    // the remainder of the bit-right-shift in this algorithm.
    let m1_high = (z_low >> 64) * (magic >> 64);
    let (m2_high, m2_low) = mul2(z_high, magic);

    let (m_low, carry) = m2_low.overflowing_add(m1_high);
    let m_high = m2_high + carry as u128;

    // calc: final q
    let (_, carry) = m_low.overflowing_add(z_low);
    let q = m_high + z_high + carry as u128;

    // correction by remainder
    // check: n - q * exp < exp
    let (pp_high, pp_low) = mul2(q, exp);
    let (r_low, borrow) = n_low.overflowing_sub(pp_low);
    debug_assert_eq!(n_high, pp_high + borrow as u128); // 10.pow(38)*2 < MAX

    if r_low < exp {
        (q, r_low)
    } else {
        (q + 1, r_low - exp)
    }
}

// if i.pow(10) fits in 64-bits
unsafe fn unchecked_div_double_small(n_high: u128, n_low: u128, i: u32) -> (u128, u128) {
    debug_assert!(i <= 19);
    debug_assert!(n_high < POWERS[i as usize]);

    let d = (n_high << 64) | (n_low >> 64);
    let (q1, r) = unsafe { crate::bit64::unchecked_div_double(d, i) };

    let d = ((r as u128) << 64) | (n_low as u64 as u128);
    let (q2, r) = unsafe { crate::bit64::unchecked_div_double(d, i) };
    ((q1 as u128) << 64 | q2 as u128, r as u128)
}

// calculate: a * b => (mhigh,mlow)
const fn mul2(a: u128, b: u128) -> (u128, u128) {
    let (ahigh, alow) = (a >> 64, a & u64::MAX as u128);
    let (bhigh, blow) = (b >> 64, b & u64::MAX as u128);

    let (mid, carry1) = (alow * bhigh).overflowing_add(ahigh * blow);
    let (mlow, carry2) = (alow * blow).overflowing_add(mid << 64);
    let mhigh = ahigh * bhigh + (mid >> 64) + ((carry1 as u128) << 64) + carry2 as u128;
    (mhigh, mlow)
}

const POWERS: [u128; 39] = [
    1,
    10_u128.pow(1),
    10_u128.pow(2),
    10_u128.pow(3),
    10_u128.pow(4),
    10_u128.pow(5),
    10_u128.pow(6),
    10_u128.pow(7),
    10_u128.pow(8),
    10_u128.pow(9),
    10_u128.pow(10),
    10_u128.pow(11),
    10_u128.pow(12),
    10_u128.pow(13),
    10_u128.pow(14),
    10_u128.pow(15),
    10_u128.pow(16),
    10_u128.pow(17),
    10_u128.pow(18),
    10_u128.pow(19),
    10_u128.pow(20),
    10_u128.pow(21),
    10_u128.pow(22),
    10_u128.pow(23),
    10_u128.pow(24),
    10_u128.pow(25),
    10_u128.pow(26),
    10_u128.pow(27),
    10_u128.pow(28),
    10_u128.pow(29),
    10_u128.pow(30),
    10_u128.pow(31),
    10_u128.pow(32),
    10_u128.pow(33),
    10_u128.pow(34),
    10_u128.pow(35),
    10_u128.pow(36),
    10_u128.pow(37),
    10_u128.pow(38),
];

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_single() {
        let n = 123_u128;
        assert_eq!(div_single(n, 39), None);
        assert_eq!(div_single(n, 0), Some(n));

        const COUNT: u128 = 100000;
        const STEP: u128 = u128::MAX / COUNT;
        for i in 1..39 {
            let pow = POWERS[i as usize];
            for j in 0..COUNT {
                let n = j * STEP;
                assert_eq!(div_single(n, i), Some(n / pow));

                if n <= 2_u128.pow(127) {
                    assert_eq!(unsafe { unchecked_div_single_r1b(n, i) }, n / pow);
                }
            }
        }
    }

    #[test]
    fn test_double() {
        let n = 123_u128;
        assert_eq!(div_double(n, n, 39), None);
        assert_eq!(div_double(0, n, 0), Some((n, 0)));

        const COUNT: u128 = 1000; // enlarge this for more test
        const K_STEP: u128 = u128::MAX / COUNT;
        for i in 1..39 {
            let pow = POWERS[i as usize];
            let count = COUNT.min(pow);
            let step = pow / count;
            for j in 0..count {
                let high = j * step;

                for k in 0..COUNT {
                    let low = k * K_STEP;

                    let (q, r) = div_double(high, low, i).unwrap();

                    let (p_high, p_low) = mul2(q, pow);
                    let (p_low, carry) = p_low.overflowing_add(r);
                    let p_high = p_high + carry as u128;

                    assert_eq!(p_low, low);
                    assert_eq!(p_high, high);
                }
            }
        }
    }
}