use alloy_primitives::{I256, U256};
use crate::constants::WAD;
pub fn sqrt_int(x: U256) -> U256 {
if x.is_zero() {
return U256::ZERO;
}
let mut z = (x + U256::from(1u64)) >> 1;
let mut y = x;
while z < y {
y = z;
z = (x / z + z) >> 1;
}
y
}
pub fn wad_exp(x: I256) -> Option<U256> {
if x <= I256::try_from(-41_446_531_673_892_822_313i128).unwrap() {
return Some(U256::ZERO);
}
if x >= I256::from_raw(U256::from_be_slice(
&135_305_999_368_893_231_589u128.to_be_bytes(),
)) {
return None; }
let c = |v: u128| -> I256 { I256::from_raw(U256::from(v)) };
let two_96: I256 = c(1u128 << 96);
let wad: I256 = c(1_000_000_000_000_000_000);
let mut x: I256 = x.wrapping_mul(two_96).wrapping_div(wad);
let ln2_96 = c(54_916_777_467_707_473_351_141_471_128);
let k: I256 = x
.wrapping_mul(two_96)
.wrapping_div(ln2_96)
.wrapping_add(I256::from_raw(U256::from(2u64).pow(U256::from(95))))
.wrapping_div(two_96);
x = x.wrapping_sub(k.wrapping_mul(ln2_96));
let mut y: I256 = x.wrapping_add(c(1_346_386_616_545_796_478_920_950_773_328));
y = y
.wrapping_mul(x)
.wrapping_div(two_96)
.wrapping_add(c(57_155_421_227_552_351_082_224_309_758_442));
let mut p: I256 = y
.wrapping_add(x)
.wrapping_sub(c(94_201_549_194_550_492_254_356_042_504_812));
p = p
.wrapping_mul(y)
.wrapping_div(two_96)
.wrapping_add(c(28_719_021_644_029_726_153_956_944_680_412_240));
p = p
.wrapping_mul(x)
.wrapping_add(c(4_385_272_521_454_847_904_659_076_985_693_276).wrapping_shl(96));
let mut q: I256 = x.wrapping_sub(c(2_855_989_394_907_223_263_936_484_059_900));
q = q
.wrapping_mul(x)
.wrapping_div(two_96)
.wrapping_add(c(50_020_603_652_535_783_019_961_831_881_945));
q = q
.wrapping_mul(x)
.wrapping_div(two_96)
.wrapping_sub(c(533_845_033_583_426_703_283_633_433_725_380));
q = q
.wrapping_mul(x)
.wrapping_div(two_96)
.wrapping_add(c(3_604_857_256_930_695_427_073_651_918_091_429));
q = q
.wrapping_mul(x)
.wrapping_div(two_96)
.wrapping_sub(c(14_423_608_567_350_463_180_887_372_962_807_573));
q = q
.wrapping_mul(x)
.wrapping_div(two_96)
.wrapping_add(c(26_449_188_498_355_588_339_934_803_723_976_023));
let r: I256 = p.wrapping_div(q);
let scale = U256::from_str_radix("29d9dc38563c32e5c2f6dc192ee70ef65f9978af3", 16).unwrap();
let shift: I256 = I256::try_from(195).unwrap().wrapping_sub(k);
let shift_u: usize = shift.as_i64() as usize;
let r_uint: U256 = r.into_raw();
let (product, _) = r_uint.overflowing_mul(scale);
let result_uint: U256 = product >> shift_u;
let result = I256::from_raw(result_uint);
Some(result.into_raw())
}
pub fn p_oracle_up(n: i64, base_price: U256, log_a_ratio: I256) -> Option<U256> {
let power = I256::try_from(-n as i128).ok()? * log_a_ratio;
let exp_result = wad_exp(power)?;
if exp_result <= U256::from(1000u64) {
return None;
}
Some(base_price * exp_result / WAD)
}
pub fn p_oracle_down(n: i64, base_price: U256, log_a_ratio: I256) -> Option<U256> {
p_oracle_up(n + 1, base_price, log_a_ratio)
}
pub fn get_dynamic_fee(p_o: U256, p_o_up: U256, a: U256, a_minus_1: U256) -> U256 {
let p_c_d = (p_o * p_o / p_o_up) * p_o / p_o_up;
let p_c_u = (p_c_d * a / a_minus_1) * a / a_minus_1;
let quarter_wad = WAD / U256::from(4u64);
if p_o < p_c_d {
(p_c_d - p_o) * quarter_wad / p_c_d
} else if p_o > p_c_u {
(p_o - p_c_u) * quarter_wad / p_o
} else {
U256::ZERO
}
}
pub fn get_y0(x: U256, y: U256, p_o: U256, p_o_up: U256, a: U256, a_minus_1: U256) -> Option<U256> {
if p_o.is_zero() {
return None;
}
let mut b = U256::ZERO;
if !x.is_zero() {
b = p_o_up * a_minus_1 * x / p_o;
}
if !y.is_zero() {
b += a * p_o * p_o / p_o_up * y / WAD;
}
if !x.is_zero() && !y.is_zero() {
let d = b * b + (U256::from(4u64) * a * p_o) * y / WAD * x;
Some((b + sqrt_int(d)) * WAD / (U256::from(2u64) * a * p_o))
} else {
Some(b * WAD / (a * p_o))
}
}
pub fn get_p(x: U256, y: U256, p_o: U256, p_o_up: U256, a: U256, a_minus_1: U256) -> Option<U256> {
if p_o_up.is_zero() {
return None;
}
if x.is_zero() {
if y.is_zero() {
return Some(((p_o * p_o / p_o_up) * p_o / p_o_up) * a / a_minus_1);
}
return Some((p_o * p_o / p_o_up) * p_o / p_o_up);
}
if y.is_zero() {
let p_o_down = p_o_up * a_minus_1 / a;
return Some(p_o * p_o / p_o_down * p_o / p_o_down);
}
let y0 = get_y0(x, y, p_o, p_o_up, a, a_minus_1)?;
let f = a * y0 * p_o / p_o_up * p_o;
let g = a_minus_1 * y0 * p_o_up / p_o;
Some((f + x * WAD) / (g + y))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sqrt_int_basic() {
assert_eq!(sqrt_int(U256::ZERO), U256::ZERO);
assert_eq!(sqrt_int(U256::from(1u64)), U256::from(1u64));
assert_eq!(sqrt_int(U256::from(4u64)), U256::from(2u64));
assert_eq!(sqrt_int(U256::from(9u64)), U256::from(3u64));
assert_eq!(sqrt_int(U256::from(10u64)), U256::from(3u64));
}
#[test]
fn sqrt_int_large() {
let x = U256::from(10u64).pow(U256::from(36));
assert_eq!(sqrt_int(x), U256::from(10u64).pow(U256::from(18)));
}
#[test]
fn wad_exp_zero() {
let result = wad_exp(I256::ZERO).unwrap();
assert_eq!(result, WAD);
}
#[test]
fn wad_exp_one() {
let result = wad_exp(I256::try_from(WAD).unwrap()).unwrap();
let expected_low = U256::from(2_718_281_828_000_000_000u128);
let expected_high = U256::from(2_718_281_829_000_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^1 = {result}"
);
}
#[test]
fn wad_exp_negative_large_returns_zero() {
let result = wad_exp(I256::try_from(-50_000_000_000_000_000_000i128).unwrap()).unwrap();
assert_eq!(result, U256::ZERO);
}
#[test]
fn wad_exp_overflow_returns_none() {
let result = wad_exp(I256::try_from(136_000_000_000_000_000_000i128).unwrap());
assert!(result.is_none());
}
#[test]
fn wad_exp_negative_one() {
let neg_wad = -I256::try_from(WAD).unwrap();
let result = wad_exp(neg_wad).unwrap();
let expected_low = U256::from(367_879_441_000_000_000u128);
let expected_high = U256::from(367_879_442_000_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^(-1) = {result}"
);
}
#[test]
fn wad_exp_two() {
let two_wad = I256::try_from(2_000_000_000_000_000_000i128).unwrap();
let result = wad_exp(two_wad).unwrap();
let expected_low = U256::from(7_389_056_098_000_000_000u128);
let expected_high = U256::from(7_389_056_099_000_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^2 = {result}"
);
}
#[test]
fn wad_exp_ten() {
let ten_wad = I256::try_from(10_000_000_000_000_000_000i128).unwrap();
let result = wad_exp(ten_wad).unwrap();
let expected_low = U256::from(22_026_465_794_000_000_000_000u128);
let expected_high = U256::from(22_026_465_795_000_000_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^10 = {result}"
);
}
#[test]
fn wad_exp_negative_two() {
let neg_two_wad = I256::try_from(-2_000_000_000_000_000_000i128).unwrap();
let result = wad_exp(neg_two_wad).unwrap();
let expected_low = U256::from(135_335_283_000_000_000u128);
let expected_high = U256::from(135_335_284_000_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^(-2) = {result}"
);
}
#[test]
fn wad_exp_negative_ten() {
let neg_ten_wad = I256::try_from(-10_000_000_000_000_000_000i128).unwrap();
let result = wad_exp(neg_ten_wad).unwrap();
let expected_low = U256::from(45_399_929_000_000u128);
let expected_high = U256::from(45_399_930_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^(-10) = {result}"
);
}
#[test]
fn wad_exp_half() {
let half_wad = I256::try_from(500_000_000_000_000_000i128).unwrap();
let result = wad_exp(half_wad).unwrap();
let expected_low = U256::from(1_648_721_270_000_000_000u128);
let expected_high = U256::from(1_648_721_271_000_000_000u128);
assert!(
result >= expected_low && result <= expected_high,
"e^0.5 = {result}"
);
}
#[test]
fn wad_exp_monotonic() {
let a = I256::try_from(-5_000_000_000_000_000_000i128).unwrap();
let b = I256::try_from(-1_000_000_000_000_000_000i128).unwrap();
let c_val = I256::ZERO;
let d = I256::try_from(1_000_000_000_000_000_000i128).unwrap();
let e = I256::try_from(5_000_000_000_000_000_000i128).unwrap();
let ra = wad_exp(a).unwrap();
let rb = wad_exp(b).unwrap();
let rc = wad_exp(c_val).unwrap();
let rd = wad_exp(d).unwrap();
let re = wad_exp(e).unwrap();
assert!(ra < rb, "e^(-5) < e^(-1): {ra} vs {rb}");
assert!(rb < rc, "e^(-1) < e^(0): {rb} vs {rc}");
assert!(rc < rd, "e^(0) < e^(1): {rc} vs {rd}");
assert!(rd < re, "e^(1) < e^(5): {rd} vs {re}");
}
#[test]
fn wad_exp_boundary_negative() {
let x = I256::try_from(-41_000_000_000_000_000_000i128).unwrap();
let result = wad_exp(x).unwrap();
assert!(result > U256::ZERO, "e^(-41e18) should be > 0: {result}");
assert!(
result < U256::from(1000u64),
"e^(-41e18) should be tiny: {result}"
);
}
#[test]
fn wad_exp_boundary_zero_threshold() {
let x = I256::try_from(-42_139_678_854_452_767_551i128).unwrap();
let result = wad_exp(x).unwrap();
assert_eq!(result, U256::ZERO);
}
#[test]
fn wad_exp_product_rule() {
let a = I256::try_from(2_000_000_000_000_000_000i128).unwrap();
let b = I256::try_from(3_000_000_000_000_000_000i128).unwrap();
let a_plus_b = I256::try_from(5_000_000_000_000_000_000i128).unwrap();
let ea = wad_exp(a).unwrap();
let eb = wad_exp(b).unwrap();
let eab = wad_exp(a_plus_b).unwrap();
let product = ea * eb / WAD;
let diff = if product > eab {
product - eab
} else {
eab - product
};
let max_err = eab / U256::from(100_000u64); assert!(
diff <= max_err,
"product rule: e^2 * e^3 = {product}, e^5 = {eab}, diff = {diff}"
);
}
#[test]
fn get_dynamic_fee_in_range_returns_zero() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o_up = WAD * U256::from(3000u64);
let p_o = p_o_up * U256::from(99u64) / U256::from(100u64);
let fee = get_dynamic_fee(p_o, p_o_up, a, a_minus_1);
assert!(fee <= WAD / U256::from(4u64), "fee = {fee}");
}
#[test]
fn get_y0_zero_amounts() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let y0 = get_y0(U256::ZERO, U256::ZERO, p_o, p_o_up, a, a_minus_1).unwrap();
assert_eq!(y0, U256::ZERO);
}
#[test]
fn get_y0_only_x() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let x = WAD * U256::from(1000u64);
let y0 = get_y0(x, U256::ZERO, p_o, p_o_up, a, a_minus_1).unwrap();
assert!(y0 > U256::ZERO);
}
#[test]
fn get_y0_only_y() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let y = WAD;
let y0 = get_y0(U256::ZERO, y, p_o, p_o_up, a, a_minus_1).unwrap();
assert!(y0 > U256::ZERO);
}
#[test]
fn get_y0_both_xy() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let x = WAD * U256::from(1000u64);
let y = WAD;
let y0 = get_y0(x, y, p_o, p_o_up, a, a_minus_1).unwrap();
assert!(y0 > U256::ZERO);
}
#[test]
fn get_y0_rejects_zero_p_o() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
assert!(get_y0(WAD, WAD, U256::ZERO, WAD, a, a_minus_1).is_none());
}
#[test]
fn get_p_both_zero_returns_midband() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let price = get_p(U256::ZERO, U256::ZERO, p_o, p_o_up, a, a_minus_1).unwrap();
assert!(price > U256::ZERO);
}
#[test]
fn get_p_only_y_returns_p_current_down() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let price = get_p(U256::ZERO, WAD, p_o, p_o_up, a, a_minus_1).unwrap();
assert!(price > U256::ZERO);
}
#[test]
fn get_p_only_x_returns_p_current_up() {
let a = U256::from(100u64);
let a_minus_1 = U256::from(99u64);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let price = get_p(
WAD * U256::from(1000u64),
U256::ZERO,
p_o,
p_o_up,
a,
a_minus_1,
)
.unwrap();
assert!(price > U256::ZERO);
}
}