use ethnum::U256;
pub const SCALE_OFFSET: u8 = 64;
pub const ONE: u128 = 1u128 << SCALE_OFFSET;
pub const MAX_EXPONENTIAL: u32 = 0x80000;
pub const PRECISION: u128 = 1_000_000_000_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Rounding {
Down,
Up,
}
pub fn mul_div(x: u128, y: u128, denominator: u128, rounding: Rounding) -> Option<u128> {
if denominator == 0 {
return None;
}
let x = U256::from(x);
let y = U256::from(y);
let denominator = U256::from(denominator);
let prod = x.checked_mul(y)?;
match rounding {
Rounding::Down => {
let (quotient, _remainder) = prod.div_rem(denominator);
quotient.try_into().ok()
}
Rounding::Up => {
let (quotient, remainder) = prod.div_rem(denominator);
if remainder > U256::ZERO {
quotient.checked_add(U256::ONE)?.try_into().ok()
} else {
quotient.try_into().ok()
}
}
}
}
pub fn mul_shr(x: u128, y: u128, offset: u8, rounding: Rounding) -> Option<u128> {
let denominator = 1u128.checked_shl(offset.into())?;
mul_div(x, y, denominator, rounding)
}
pub fn shl_div(x: u128, y: u128, offset: u8, rounding: Rounding) -> Option<u128> {
if y == 0 {
return None;
}
let x = U256::from(x);
let y = U256::from(y);
let shifted = x.checked_shl(offset.into())?;
match rounding {
Rounding::Down => {
let (q, _) = shifted.div_rem(y);
q.try_into().ok()
}
Rounding::Up => {
let (q, r) = shifted.div_rem(y);
if r > U256::ZERO {
q.checked_add(U256::from(1u8))?.try_into().ok()
} else {
q.try_into().ok()
}
}
}
}
pub fn pow(base: u128, exp: i32) -> Option<u128> {
let mut invert = exp.is_negative();
if exp == 0 {
return Some(ONE);
}
let exp: u32 = if invert { exp.unsigned_abs() } else { exp as u32 };
if exp >= MAX_EXPONENTIAL {
return None;
}
let mut squared_base = base;
let mut result = ONE;
if squared_base >= result {
squared_base = u128::MAX.checked_div(squared_base)?;
invert = !invert;
}
if exp & 0x1 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x2 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x4 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x8 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x10 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x20 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x40 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x80 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x100 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x200 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x400 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x800 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x1000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x2000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x4000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x8000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x10000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x20000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
squared_base = (squared_base.checked_mul(squared_base)?) >> SCALE_OFFSET;
if exp & 0x40000 > 0 {
result = (result.checked_mul(squared_base)?) >> SCALE_OFFSET;
}
if result == 0 {
return None;
}
if invert {
result = u128::MAX.checked_div(result)?;
}
Some(result)
}
pub fn safe_mul_shr_cast<T: TryFrom<u128>>(
x: u128,
y: u128,
offset: u8,
rounding: Rounding,
) -> Result<T, crate::AmmMathError> {
let value = mul_shr(x, y, offset, rounding).ok_or(crate::AmmMathError::Overflow)?;
T::try_from(value).map_err(|_| crate::AmmMathError::Overflow)
}
pub fn safe_shl_div_cast<T: TryFrom<u128>>(
x: u128,
y: u128,
offset: u8,
rounding: Rounding,
) -> Result<T, crate::AmmMathError> {
let value = shl_div(x, y, offset, rounding).ok_or(crate::AmmMathError::Overflow)?;
T::try_from(value).map_err(|_| crate::AmmMathError::Overflow)
}
pub fn safe_mul_div_cast<T: TryFrom<u128>>(
x: u128,
y: u128,
denominator: u128,
rounding: Rounding,
) -> Result<T, crate::AmmMathError> {
if denominator == 0 {
return Err(crate::AmmMathError::DivisionByZero);
}
let value = mul_div(x, y, denominator, rounding).ok_or(crate::AmmMathError::Overflow)?;
T::try_from(value).map_err(|_| crate::AmmMathError::Overflow)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mul_shr_known_value() {
let x = 1u128 << 64;
let y = 1u128 << 64;
let result = mul_shr(x, y, 64, Rounding::Down).unwrap();
assert_eq!(result, 1u128 << 64);
}
#[test]
fn test_mul_shr_zero_operand() {
assert_eq!(mul_shr(0, 1_000_000, 8, Rounding::Down).unwrap(), 0);
assert_eq!(mul_shr(1_000_000, 0, 8, Rounding::Down).unwrap(), 0);
}
#[test]
fn test_mul_shr_rounding_down() {
let result = mul_shr(100, 200, 8, Rounding::Down).unwrap();
assert_eq!(result, 78);
}
#[test]
fn test_mul_shr_rounding_up() {
let result = mul_shr(100, 200, 8, Rounding::Up).unwrap();
assert_eq!(result, 79);
}
#[test]
fn test_mul_shr_exact_division() {
let down = mul_shr(256, 256, 8, Rounding::Down).unwrap();
let up = mul_shr(256, 256, 8, Rounding::Up).unwrap();
assert_eq!(down, 256);
assert_eq!(up, 256);
}
#[test]
fn test_mul_div_denominator_zero_returns_none() {
assert!(mul_div(100, 200, 0, Rounding::Down).is_none());
}
#[test]
fn test_mul_div_exact() {
assert_eq!(mul_div(10, 20, 5, Rounding::Down).unwrap(), 40);
assert_eq!(mul_div(10, 20, 5, Rounding::Up).unwrap(), 40);
}
#[test]
fn test_mul_div_rounding() {
assert_eq!(mul_div(10, 3, 4, Rounding::Down).unwrap(), 7);
assert_eq!(mul_div(10, 3, 4, Rounding::Up).unwrap(), 8);
}
#[test]
fn test_shl_div_zero_divisor() {
assert!(shl_div(100, 0, 8, Rounding::Down).is_none());
}
#[test]
fn test_shl_div_exact() {
assert_eq!(shl_div(4, 2, 8, Rounding::Down).unwrap(), 512);
assert_eq!(shl_div(4, 2, 8, Rounding::Up).unwrap(), 512);
}
#[test]
fn test_shl_div_rounding() {
assert_eq!(shl_div(3, 4, 8, Rounding::Down).unwrap(), 192);
assert_eq!(shl_div(5, 3, 8, Rounding::Down).unwrap(), 426);
assert_eq!(shl_div(5, 3, 8, Rounding::Up).unwrap(), 427);
}
#[test]
fn test_shl_div_identity() {
let x = 42u128;
let result = shl_div(x, ONE, 64, Rounding::Down).unwrap();
assert_eq!(result, x);
}
#[test]
fn test_pow_zero_exponent_returns_one() {
let base = ONE + 100;
assert_eq!(pow(base, 0), Some(ONE));
}
#[test]
fn test_pow_one_exponent() {
let bps = (ONE / 10_000) * 10; let base = ONE + bps;
let result = pow(base, 1).unwrap();
let decimal = result as f64 / ONE as f64;
let expected = 1.001_f64;
assert!(
(decimal - expected).abs() < 1e-6,
"pow(base, 1) = {} expected {}",
decimal,
expected,
);
}
#[test]
fn test_pow_negative_exponent_below_one() {
let bps = (ONE / 10_000) * 10; let base = ONE + bps;
let result = pow(base, -1).unwrap();
assert!(result < ONE, "base^(-1) should be < 1.0");
}
#[test]
fn test_pow_exceeds_max_exponential() {
let base = ONE + 1;
assert!(pow(base, MAX_EXPONENTIAL as i32).is_none());
}
#[test]
fn test_pow_known_value() {
let bps = ONE / 10_000; let base = ONE + bps;
let result = pow(base, 100).unwrap();
let decimal = result as f64 / ONE as f64;
let expected = 1.01005016708_f64;
let rel_err = (decimal - expected).abs() / expected;
assert!(rel_err < 1e-6, "decimal={decimal} expected={expected}");
}
#[test]
fn test_safe_mul_shr_cast_u64() {
let result: u64 = safe_mul_shr_cast(256, 256, 8, Rounding::Down).unwrap();
assert_eq!(result, 256);
}
#[test]
fn test_safe_mul_shr_cast_overflow_to_u64() {
let result: Result<u64, _> = safe_mul_shr_cast(u128::MAX, 2, 1, Rounding::Down);
assert!(result.is_err());
}
#[test]
fn test_safe_shl_div_cast_u64() {
let result: u64 = safe_shl_div_cast(4, 2, 8, Rounding::Down).unwrap();
assert_eq!(result, 512);
}
#[test]
fn test_safe_shl_div_cast_zero_divisor() {
let result: Result<u64, _> = safe_shl_div_cast(4, 0, 8, Rounding::Down);
assert!(result.is_err());
}
#[test]
fn test_safe_mul_div_cast_u64() {
let result: u64 = safe_mul_div_cast(10, 20, 5, Rounding::Down).unwrap();
assert_eq!(result, 40);
}
#[test]
fn test_safe_mul_div_cast_div_by_zero() {
let result: Result<u64, _> = safe_mul_div_cast(10, 20, 0, Rounding::Down);
assert_eq!(result, Err(crate::AmmMathError::DivisionByZero));
}
}