#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DividerU64 {
Fast { magic: u64, shift: u8 },
BitShift(u8),
General { magic_low: u64, shift: u8 },
}
#[inline(always)]
fn libdivide_mullhi_u64(x: u64, y: u64) -> u64 {
let xl = x as u128;
let yl = y as u128;
((xl * yl) >> 64) as u64
}
#[inline(always)]
fn floor_log2(n: u64) -> u8 {
assert_ne!(n, 0);
63u8 - (n.leading_zeros() as u8)
}
impl DividerU64 {
fn power_of_2_division(divisor: u64) -> Option<DividerU64> {
if divisor == 0 {
return None;
}
if !divisor.is_power_of_two() {
return None;
}
Some(DividerU64::BitShift(floor_log2(divisor)))
}
fn fast_path(divisor: u64) -> Option<DividerU64> {
if divisor.is_power_of_two() {
return None;
}
let floor_log_2_d: u8 = floor_log2(divisor);
let u = 1u128 << (floor_log_2_d + 64);
let proposed_magic_number: u128 = u / divisor as u128;
let reminder: u64 = (u - proposed_magic_number * (divisor as u128)) as u64;
assert!(reminder > 0 && reminder < divisor);
let e: u64 = divisor - reminder;
if e >= (1u64 << floor_log_2_d) {
return None;
}
Some(DividerU64::Fast {
magic: (proposed_magic_number as u64) + 1u64,
shift: floor_log_2_d,
})
}
fn general_path(divisor: u64) -> DividerU64 {
let p: u8 = 64u8 - (divisor.leading_zeros() as u8);
let e = 1u128 << (63 + p);
let m = 2 + (e + (e - divisor as u128)) / divisor as u128;
DividerU64::General {
magic_low: m as u64,
shift: p - 1,
}
}
pub fn divide_by(divisor: u64) -> DividerU64 {
assert!(divisor > 0u64);
Self::power_of_2_division(divisor)
.or_else(|| DividerU64::fast_path(divisor))
.unwrap_or_else(|| DividerU64::general_path(divisor))
}
#[inline(always)]
pub fn divide(&self, n: u64) -> u64 {
match *self {
DividerU64::Fast { magic, shift } => {
libdivide_mullhi_u64(magic, n) >> shift
}
DividerU64::BitShift(d) => n >> d,
DividerU64::General { magic_low, shift } => {
let q = libdivide_mullhi_u64(magic_low, n);
let t = ((n - q) >> 1).wrapping_add(q);
t >> shift
}
}
}
}
impl core::ops::Div<DividerU64> for u64 {
type Output = u64;
#[inline(always)]
fn div(self, denom: DividerU64) -> Self::Output {
denom.divide(self)
}
}
#[cfg(test)]
mod tests {
use super::DividerU64;
#[test]
fn test_divide_op() {
let divider = DividerU64::divide_by(2);
let res = 4u64 / divider;
assert_eq!(res, 2);
let res = 8u64 / divider;
assert_eq!(res, 4);
}
#[test]
fn test_divide_by_4() {
let divider = DividerU64::divide_by(4);
assert!(matches!(divider, DividerU64::BitShift(2)));
}
#[test]
fn test_divide_by_7() {
let divider = DividerU64::divide_by(7);
assert!(matches!(divider, DividerU64::General { .. }));
}
#[test]
fn test_divide_by_11() {
let divider = DividerU64::divide_by(11);
assert_eq!(
divider,
DividerU64::Fast {
magic: 13415813871788764812,
shift: 3
}
);
}
#[test]
fn test_floor_log2() {
for i in [1, 2, 3, 4, 10, 15, 16, 31, 32, 33, u64::MAX] {
let log_i = super::floor_log2(i);
let lower_bound = 1 << log_i;
let upper_bound = lower_bound - 1 + lower_bound;
assert!(lower_bound <= i);
assert!(upper_bound >= i);
}
}
#[test]
fn test_libdivide() {
for d in (1u64..100u64)
.chain(vec![2048, 234234131223u64].into_iter())
.chain((5..63).map(|i| 1 << i))
{
let divider = DividerU64::divide_by(d);
for i in (0u64..10_000).chain(vec![2048, 234234131223u64, 1 << 43, 1 << 43 + 1]) {
assert_eq!(divider.divide(i), i / d);
}
}
}
}