use crate::libraries::big_num::{U128, U256};
pub trait MulDiv<RHS = Self> {
type Output;
fn mul_div_floor(self, num: RHS, denom: RHS) -> Option<Self::Output>;
fn mul_div_ceil(self, num: RHS, denom: RHS) -> Option<Self::Output>;
}
pub trait Upcast {
fn as_u256(self) -> U256;
}
impl Upcast for U128 {
fn as_u256(self) -> U256 {
U256([self.0[0], self.0[1], 0, 0])
}
}
pub trait Downcast {
fn as_u128(self) -> U128;
}
impl Downcast for U256 {
fn as_u128(self) -> U128 {
U128([self.0[0], self.0[1]])
}
}
impl MulDiv for u64 {
type Output = u64;
fn mul_div_floor(self, num: Self, denom: Self) -> Option<Self::Output> {
assert_ne!(denom, 0);
let r = (U128::from(self) * U128::from(num)) / U128::from(denom);
if r > U128::from(u64::MAX) {
None
} else {
Some(r.as_u64())
}
}
fn mul_div_ceil(self, num: Self, denom: Self) -> Option<Self::Output> {
assert_ne!(denom, 0);
let r = (U128::from(self) * U128::from(num) + U128::from(denom - 1)) / U128::from(denom);
if r > U128::from(u64::MAX) {
None
} else {
Some(r.as_u64())
}
}
}
impl MulDiv for U128 {
type Output = U128;
fn mul_div_floor(self, num: Self, denom: Self) -> Option<Self::Output> {
assert_ne!(denom, U128::default());
let r = ((self.as_u256()) * (num.as_u256())) / (denom.as_u256());
if r > U128::MAX.as_u256() {
None
} else {
Some(r.as_u128())
}
}
fn mul_div_ceil(self, num: Self, denom: Self) -> Option<Self::Output> {
assert_ne!(denom, U128::default());
let r = (self.as_u256() * num.as_u256() + (denom - 1).as_u256()) / denom.as_u256();
if r > U128::MAX.as_u256() {
None
} else {
Some(r.as_u128())
}
}
}
#[cfg(test)]
mod muldiv_u64_tests {
use super::*;
use quickcheck::{quickcheck, Arbitrary, Gen};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct NonZero(u64);
impl Arbitrary for NonZero {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
loop {
let v = u64::arbitrary(g);
if v != 0 {
return NonZero(v);
}
}
}
}
quickcheck! {
fn scale_floor(val: u64, num: u64, den: NonZero) -> bool {
let res = val.mul_div_floor(num, den.0);
let expected = (U128::from(val) * U128::from(num)) / U128::from(den.0);
if expected > U128::from(u64::MAX) {
res.is_none()
} else {
res == Some(expected.as_u64())
}
}
}
quickcheck! {
fn scale_ceil(val: u64, num: u64, den: NonZero) -> bool {
let res = val.mul_div_ceil(num, den.0);
let mut expected = (U128::from(val) * U128::from(num)) / U128::from(den.0);
let expected_rem = (U128::from(val) * U128::from(num)) % U128::from(den.0);
if expected_rem != U128::default() {
expected += U128::from(1)
}
if expected > U128::from(u64::MAX) {
res.is_none()
} else {
res == Some(expected.as_u64())
}
}
}
}
#[cfg(test)]
mod muldiv_u128_tests {
use super::*;
use quickcheck::{quickcheck, Arbitrary, Gen};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct NonZero(U128);
impl Arbitrary for NonZero {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
loop {
let v = U128::from(u128::arbitrary(g));
if v != U128::default() {
return NonZero(v);
}
}
}
}
impl Arbitrary for U128 {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
loop {
let v = U128::from(u128::arbitrary(g));
if v != U128::default() {
return v;
}
}
}
}
quickcheck! {
fn scale_floor(val: U128, num: U128, den: NonZero) -> bool {
let res = val.mul_div_floor(num, den.0);
let expected = ((val.as_u256()) * (num.as_u256())) / (den.0.as_u256());
if expected > U128::MAX.as_u256() {
res.is_none()
} else {
res == Some(expected.as_u128())
}
}
}
quickcheck! {
fn scale_ceil(val: U128, num: U128, den: NonZero) -> bool {
let res = val.mul_div_ceil(num, den.0);
let mut expected = ((val.as_u256()) * (num.as_u256())) / (den.0.as_u256());
let expected_rem = ((val.as_u256()) * (num.as_u256())) % (den.0.as_u256());
if expected_rem != U256::default() {
expected += U256::from(1)
}
if expected > U128::MAX.as_u256() {
res.is_none()
} else {
res == Some(expected.as_u128())
}
}
}
}
#[cfg(test)]
mod uniswap_tests {
use super::*;
use crate::libraries::fixed_point_32;
mod mul_div {
use super::*;
#[test]
#[should_panic]
fn reverts_if_denominator_is_zero() {
fixed_point_32::Q32.mul_div_floor(5, 0);
}
#[test]
#[should_panic]
fn reverts_if_denominator_is_zero_and_numerator_overflows() {
fixed_point_32::Q32.mul_div_floor(fixed_point_32::Q32, 0);
}
#[test]
#[should_panic]
fn reverts_if_output_overflows_u64() {
fixed_point_32::Q32
.mul_div_floor(fixed_point_32::Q32, 1)
.unwrap();
}
#[test]
#[should_panic]
fn reverts_on_overflow_with_all_max_inputs() {
u64::MAX.mul_div_floor(u64::MAX, u64::MAX - 1).unwrap();
}
#[test]
fn all_max_inputs() {
assert_eq!(
u64::MAX.mul_div_floor(u64::MAX, u64::MAX).unwrap(),
u64::MAX
);
}
#[test]
fn accurate_without_phantom_overflow() {
let result = fixed_point_32::Q32 / 3;
assert_eq!(
fixed_point_32::Q32
.mul_div_floor(
50 * fixed_point_32::Q32 / 100,
150 * fixed_point_32::Q32 / 100
)
.unwrap(),
result
);
}
#[test]
fn accurate_with_phantom_overflow() {
let result = 4375 * fixed_point_32::Q32 / 1000;
assert_eq!(
fixed_point_32::Q32
.mul_div_floor(35 * fixed_point_32::Q32, 8 * fixed_point_32::Q32)
.unwrap(),
result
);
}
#[test]
fn accurate_with_phantom_overflow_and_repeating_decimal() {
let result = fixed_point_32::Q32 / 3;
assert_eq!(
fixed_point_32::Q32
.mul_div_floor(1000 * fixed_point_32::Q32, 3000 * fixed_point_32::Q32)
.unwrap(),
result
);
}
}
mod mul_div_rounding_up {
use super::*;
#[test]
#[should_panic]
fn reverts_if_denominator_is_zero() {
fixed_point_32::Q32.mul_div_ceil(5, 0);
}
#[test]
#[should_panic]
fn reverts_if_denominator_is_zero_and_numerator_overflows() {
fixed_point_32::Q32.mul_div_ceil(fixed_point_32::Q32, 0);
}
#[test]
#[should_panic]
fn reverts_if_output_overflows_u64() {
fixed_point_32::Q32
.mul_div_ceil(fixed_point_32::Q32, 1)
.unwrap();
}
#[test]
#[should_panic]
fn reverts_on_overflow_with_all_max_inputs() {
u64::MAX.mul_div_ceil(u64::MAX, u64::MAX - 1).unwrap();
}
#[test]
#[should_panic]
fn reverts_if_muldiv_overflows_64_bits_after_rounding_up() {
let a = 145295143558111;
let b = 31 * 8191;
let floor_ans = a.mul_div_floor(b, 2).unwrap(); println!("floor {}", floor_ans);
let ceil_ans = a.mul_div_ceil(b, 2).unwrap(); println!("ceil {}", ceil_ans);
}
#[test]
fn all_max_inputs() {
assert_eq!(u64::MAX.mul_div_ceil(u64::MAX, u64::MAX).unwrap(), u64::MAX);
}
#[test]
fn accurate_without_phantom_overflow() {
let result = fixed_point_32::Q32 / 3 + 1;
assert_eq!(
fixed_point_32::Q32
.mul_div_ceil(
50 * fixed_point_32::Q32 / 100,
150 * fixed_point_32::Q32 / 100
)
.unwrap(),
result
);
}
#[test]
fn accurate_with_phantom_overflow() {
let result = 4375 * fixed_point_32::Q32 / 1000;
assert_eq!(
fixed_point_32::Q32
.mul_div_ceil(35 * fixed_point_32::Q32, 8 * fixed_point_32::Q32)
.unwrap(),
result
);
}
#[test]
fn accurate_with_phantom_overflow_and_repeating_decimal() {
let result = fixed_point_32::Q32 / 3 + 1;
assert_eq!(
fixed_point_32::Q32
.mul_div_ceil(1000 * fixed_point_32::Q32, 3000 * fixed_point_32::Q32)
.unwrap(),
result
);
}
}
}