use crate::utils::{adc, msb};
use core::{convert::TryFrom, u64};
const fn val_2(lo: u64, hi: u64) -> u128 {
((hi as u128) << 64) | (lo as u128)
}
const fn mul_2(a: u64, b: u64) -> u128 {
(a as u128) * (b as u128)
}
fn divrem_2by1(lo: u64, hi: u64, d: u64) -> (u64, u64) {
debug_assert!(d > 0);
debug_assert!(d > hi);
let d = u128::from(d);
let n = val_2(lo, hi);
let q = n / d;
let r = n % d;
debug_assert!(q < val_2(0, 1));
debug_assert!(
mul_2(u64::try_from(q).unwrap(), u64::try_from(d).unwrap())
+ val_2(u64::try_from(r).unwrap(), 0)
== val_2(lo, hi)
);
debug_assert!(r < d);
#[allow(clippy::cast_possible_truncation)]
(q as u64, r as u64)
}
pub(crate) fn divrem_nby1(numerator: &mut [u64], divisor: u64) -> u64 {
debug_assert!(divisor > 0);
let mut remainder = 0;
for i in (0..numerator.len()).rev() {
let (ni, ri) = divrem_2by1(numerator[i], remainder, divisor);
numerator[i] = ni;
remainder = ri;
}
remainder
}
fn div_3by2(n: &[u64; 3], d: &[u64; 2]) -> u64 {
debug_assert!(d[1] >> 63 == 1);
debug_assert!(val_2(n[1], n[2]) < val_2(d[0], d[1]));
if n[2] == d[1] {
debug_assert!(n[1] < d[0]);
let neg_remainder = val_2(0, d[0]) - val_2(n[0], n[1]);
if neg_remainder > val_2(d[0], d[1]) {
0xffff_ffff_ffff_fffe_u64
} else {
0xffff_ffff_ffff_ffff_u64
}
} else {
let (mut q, mut r) = divrem_2by1(n[1], n[2], d[1]);
if mul_2(q, d[0]) > val_2(n[0], r) {
q -= 1;
r = r.wrapping_add(d[1]);
let overflow = r < d[1];
if !overflow && mul_2(q, d[0]) > val_2(n[0], r) {
q -= 1;
}
}
q
}
}
pub(crate) fn divrem_nbym(numerator: &mut [u64], divisor: &mut [u64]) {
debug_assert!(divisor.len() >= 2);
debug_assert!(numerator.len() > divisor.len());
debug_assert!(*divisor.last().unwrap() > 0);
debug_assert!(*numerator.last().unwrap() == 0);
let n = divisor.len();
let m = numerator.len() - n - 1;
let shift = divisor[n - 1].leading_zeros();
if shift > 0 {
numerator[n + m] = numerator[n + m - 1] >> (64 - shift);
for i in (1..n + m).rev() {
numerator[i] <<= shift;
numerator[i] |= numerator[i - 1] >> (64 - shift);
}
numerator[0] <<= shift;
for i in (1..n).rev() {
divisor[i] <<= shift;
divisor[i] |= divisor[i - 1] >> (64 - shift);
}
divisor[0] <<= shift;
}
for j in (0..=m).rev() {
let mut qhat = div_3by2(
&[numerator[j + n - 2], numerator[j + n - 1], numerator[j + n]],
&[divisor[n - 2], divisor[n - 1]],
);
let mut borrow = 0;
for i in 0..n {
let (a, b) = msb(numerator[j + i], qhat, divisor[i], borrow);
numerator[j + i] = a;
borrow = b;
}
if numerator[j + n] < borrow {
let mut carry = 0;
for i in 0..n {
let (a, b) = adc(numerator[j + i], divisor[i], carry);
numerator[j + i] = a;
carry = b;
}
qhat -= 1;
debug_assert_eq!(numerator[j + n].wrapping_sub(borrow).wrapping_add(carry), 0);
} else {
debug_assert_eq!(numerator[j + n].wrapping_sub(borrow), 0);
}
numerator[j + n] = qhat;
}
if shift > 0 {
for i in 0..(n - 1) {
numerator[i] >>= shift;
numerator[i] |= numerator[i + 1] << (64 - shift);
}
numerator[n - 1] >>= shift;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::u256::U256;
use quickcheck_macros::quickcheck;
use zkp_macros_decl::u256h;
const HALF: u64 = 1_u64 << 63;
const FULL: u64 = u64::max_value();
#[test]
fn div_3by2_tests() {
assert_eq!(div_3by2(&[FULL, FULL - 1, HALF], &[FULL, HALF]), FULL);
assert_eq!(div_3by2(&[0, 0, HALF], &[FULL, HALF]), FULL - 1);
}
#[test]
fn test_divrem_4by3() {
let mut numerator = [40, 31, 79, 84, 0];
let mut divisor = [53, 12, 12];
let expected_quotient = [u64::max_value(), 6];
let expected_remainder = [93, 0xffff_ffff_ffff_feb8, 6];
divrem_nbym(&mut numerator, &mut divisor);
let remainder = &numerator[0..3];
let quotient = &numerator[3..5];
assert_eq!(remainder, expected_remainder);
assert_eq!(quotient, expected_quotient);
}
#[allow(clippy::unreadable_literal)]
#[test]
fn test_divrem_8by4() {
let mut numerator = [
0x9c2bcebfa9cca2c6_u64,
0x274e154bb5e24f7a_u64,
0xe1442d5d3842be2b_u64,
0xf18f5adfd420853f_u64,
0x04ed6127eba3b594_u64,
0xc5c179973cdb1663_u64,
0x7d7f67780bb268ff_u64,
0x0000000000000003_u64,
0x0000000000000000_u64,
];
let mut divisor = [
0x0181880b078ab6a1_u64,
0x62d67f6b7b0bda6b_u64,
0x92b1840f9c792ded_u64,
0x0000000000000019_u64,
];
let expected_quotient = [
0x9128464e61d6b5b3_u64,
0xd9eea4fc30c5ac6c_u64,
0x944a2d832d5a6a08_u64,
0x22f06722e8d883b1_u64,
0x0000000000000000_u64,
];
let expected_remainder = [
0x1dfa5a7ea5191b33_u64,
0xb5aeb3f9ad5e294e_u64,
0xfc710038c13e4eed_u64,
0x000000000000000b_u64,
];
divrem_nbym(&mut numerator, &mut divisor);
let remainder = &numerator[0..4];
let quotient = &numerator[4..9];
assert_eq!(remainder, expected_remainder);
assert_eq!(quotient, expected_quotient);
}
#[test]
fn test_divrem_4by4() {
let a = u256h!("6f1480e63854afa41868b9a7d418e9c64edef514135f5899e72530a3d4e91ea3");
let b = u256h!("3ba5ddaec5090ef0b87126f34ee28533ffb08af4108f9aeaa62b65900d2a62bb");
let r = a.clone() - &b;
let mut numerator = [a.c0, a.c1, a.c2, a.c3, 0];
let mut divisor = [b.c0, b.c1, b.c2, b.c3];
divrem_nbym(&mut numerator, &mut divisor);
let remainder = &numerator[0..4];
let quotient = numerator[4];
assert_eq!(remainder, [r.c0, r.c1, r.c2, r.c3]);
assert_eq!(quotient, 1);
}
#[quickcheck]
fn div_3by2_correct(q: u64, d0: u64, d1: u64) -> bool {
let d1 = d1 | (1 << 63);
let n = U256::from_limbs(d0, d1, 0, 0) * &U256::from(q);
debug_assert!(n.c3 == 0);
let qhat = div_3by2(&[n.c0, n.c1, n.c2], &[d0, d1]);
qhat == q
}
}