#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct PartialReciprocal {
multiplier: u64,
shift: u8,
increment: u8,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct Reciprocal {
multiplier: u64,
summand: u64, shift: u8,
}
impl PartialReciprocal {
pub fn new(d: u64) -> Option<PartialReciprocal> {
if d <= 1 || d == u64::MAX {
return None;
}
let ilog2_d = 63 - d.leading_zeros();
assert!(d >= 1u64 << ilog2_d);
if (d & (d - 1)) == 0 {
assert!(d == 1u64 << ilog2_d);
assert!(ilog2_d >= 1);
return Some(PartialReciprocal {
multiplier: 1u64 << (64 - ilog2_d),
shift: 0,
increment: 0,
});
}
let shift = ilog2_d;
let unity = 1u128 << (64 + shift);
let scale = unity / (d as u128);
let rem = unity % (d as u128);
assert!(scale <= u64::MAX as u128);
if rem as u64 <= d / 2 {
Some(PartialReciprocal {
multiplier: scale as u64,
shift: shift as u8,
increment: 1,
})
} else {
assert!(scale < u64::MAX as u128);
Some(PartialReciprocal {
multiplier: (scale + 1) as u64,
shift: shift as u8,
increment: 0,
})
}
}
#[inline]
#[must_use]
#[cfg(target_arch = "x86_64")]
pub fn apply(self, x: u64) -> u64 {
self.apply_overflowing(x)
}
#[inline]
#[must_use]
#[cfg(not(target_arch = "x86_64"))]
pub fn apply(self, x: u64) -> u64 {
self.apply_saturating(x)
}
#[inline]
#[must_use]
pub fn apply_saturating(self, x: u64) -> u64 {
let shifted = x.saturating_add(self.increment as u64);
let hi = ((self.multiplier as u128 * shifted as u128) >> 64) as u64;
hi >> self.shift
}
#[inline]
#[must_use]
pub fn apply_overflowing(self, x: u64) -> u64 {
let (mut shifted, overflow) = x.overflowing_add(self.increment as u64);
if overflow {
shifted = shifted.wrapping_sub(1);
}
let hi = ((self.multiplier as u128 * shifted as u128) >> 64) as u64;
hi >> self.shift
}
}
impl Reciprocal {
pub fn new(d: u64) -> Option<Reciprocal> {
if d == 0 {
return None;
}
if let Some(base) = PartialReciprocal::new(d) {
return Some(Reciprocal {
multiplier: base.multiplier,
summand: base.multiplier * base.increment as u64,
shift: base.shift,
});
}
assert!(d == 1 || d == u64::MAX);
if d == 1 {
return Some(Reciprocal {
multiplier: u64::MAX,
summand: u64::MAX,
shift: 0,
});
}
Some(Reciprocal {
multiplier: 1,
summand: 1,
shift: 0,
})
}
#[inline]
#[must_use]
pub fn apply(&self, x: u64) -> u64 {
let mut product = x as u128 * self.multiplier as u128;
product += self.summand as u128;
(product >> 64) as u64 >> self.shift
}
}
#[cfg(test)]
mod tests {
const PROBE_RANGE: u64 = 1u64 << 12;
fn check(d: u64) {
let partial = crate::PartialReciprocal::new(d);
let recip = crate::Reciprocal::new(d);
let probe = |x: u64| {
let expected = x / d;
if let Some(p) = partial {
assert_eq!(p.apply_saturating(x), expected, "d={}, x={}", d, x);
assert_eq!(p.apply_overflowing(x), expected, "d={}, x={}", d, x);
}
if let Some(r) = recip {
assert_eq!(r.apply(x), expected, "d={}, x={}", d, x);
#[cfg(feature = "nightly")]
assert_eq!(r.apply_branchfree(x), expected, "d={}, x={}", d, x);
}
};
if partial.is_none() && recip.is_none() {
assert!(d == 0);
return;
}
assert!(d > 0);
assert_ne!(recip, None);
if d != 1 && d != u64::MAX {
assert_ne!(partial, None);
}
let center = d * (u64::MAX / d);
for i in 0..=PROBE_RANGE {
probe(i);
probe(u64::MAX - i);
probe(d.wrapping_add(i));
probe(d.wrapping_sub(i));
probe(center.wrapping_sub(d).wrapping_add(i));
probe(center.wrapping_sub(d).wrapping_sub(i));
probe(center.wrapping_add(i));
probe(center.wrapping_sub(i));
probe(u64::MAX.wrapping_sub(d).wrapping_add(i));
probe(u64::MAX.wrapping_sub(d).wrapping_sub(i));
}
}
#[test]
fn check_edge_cases() {
for d in [0, 1, 2, u64::MAX - 1, u64::MAX].iter().copied() {
check(d);
}
}
#[test]
fn check_powers_of_two() {
for p in 0..64 {
check(1u64 << p);
}
}
#[test]
fn test_small_divisors() {
for d in 0..256 {
check(d);
check(u64::MAX - d);
}
}
#[test]
fn test_sparse_divisors() {
for i in 0..64 {
for j in i..64 {
let d = (1u64 << i) | (1u64 << j);
check(d);
check(!d);
}
}
}
#[test]
fn test_near_powers_of_two() {
for p in 0..64 {
let po2 = 1u64 << p;
for i in 1..=8 {
check(po2.wrapping_sub(i));
check(po2.wrapping_add(i));
}
}
}
#[test]
fn test_powers_of_two_and_half() {
for p in 0..64 {
let po2 = 1u64 << p;
let delta = po2 / 2;
let x = po2.wrapping_sub(delta / 4);
let y = po2.wrapping_add(delta);
check(x);
check(y);
for i in 1..=8 {
check(x.wrapping_sub(i));
check(x.wrapping_add(i));
check(y.wrapping_sub(i));
check(y.wrapping_add(i));
}
}
}
#[test]
fn test_factors_of_u64_max() {
let factors = [3, 5, 17, 257, 641, 65537, 6700417];
assert_eq!(factors.iter().product::<u64>(), u64::MAX);
for d in factors.iter().copied() {
check(d);
}
}
}