use ethnum::U256;
use crate::AmmMathError;
pub fn mul_div_floor(a: u128, b: u128, c: u128) -> Result<u128, AmmMathError> {
if c == 0 {
return Err(AmmMathError::DivisionByZero);
}
let result = U256::from(a) * U256::from(b) / U256::from(c);
if result > U256::from(u128::MAX) {
return Err(AmmMathError::Overflow);
}
Ok(result.as_u128())
}
pub fn mul_div_ceil(a: u128, b: u128, c: u128) -> Result<u128, AmmMathError> {
if c == 0 {
return Err(AmmMathError::DivisionByZero);
}
let numerator = U256::from(a) * U256::from(b);
let denominator = U256::from(c);
let result = (numerator + denominator - U256::from(1u128)) / denominator;
if result > U256::from(u128::MAX) {
return Err(AmmMathError::Overflow);
}
Ok(result.as_u128())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mul_div_floor_basic() {
assert_eq!(mul_div_floor(10, 20, 3).unwrap(), 66);
}
#[test]
fn test_mul_div_ceil_basic() {
assert_eq!(mul_div_ceil(10, 20, 3).unwrap(), 67);
}
#[test]
fn test_mul_div_floor_exact() {
assert_eq!(mul_div_floor(10, 20, 4).unwrap(), 50);
assert_eq!(mul_div_ceil(10, 20, 4).unwrap(), 50);
}
#[test]
fn test_mul_div_large_numbers() {
let a: u128 = 1u128 << 100;
let b: u128 = 1u128 << 60;
let c: u128 = 1u128 << 80;
assert_eq!(mul_div_floor(a, b, c).unwrap(), 1u128 << 80);
}
#[test]
fn test_mul_div_division_by_zero() {
assert_eq!(mul_div_floor(10, 20, 0), Err(AmmMathError::DivisionByZero));
assert_eq!(mul_div_ceil(10, 20, 0), Err(AmmMathError::DivisionByZero));
}
#[test]
fn test_mul_div_overflow() {
assert_eq!(mul_div_floor(u128::MAX, u128::MAX, 1), Err(AmmMathError::Overflow));
}
#[test]
fn mul_div_floor_basic() {
assert_eq!(mul_div_floor(100, 3, 7).unwrap(), 42);
assert_eq!(mul_div_floor(u128::MAX, 1, 2).unwrap(), u128::MAX / 2);
}
#[test]
fn mul_div_ceil_basic() {
assert_eq!(mul_div_ceil(100, 3, 7).unwrap(), 43);
}
#[test]
fn mul_div_overflow_returns_error() {
assert!(mul_div_floor(u128::MAX, 2, 1).is_err());
}
#[test]
fn test_mul_div_zero_inputs() {
assert_eq!(mul_div_floor(0, 100, 1).unwrap(), 0);
assert_eq!(mul_div_floor(100, 0, 1).unwrap(), 0);
assert_eq!(mul_div_ceil(0, 100, 1).unwrap(), 0);
}
}