use alloy::primitives::{FixedBytes, Uint, U256};
type U512 = Uint<512, 8>;
pub const SCALE_OFFSET: u32 = 128;
pub const SCALE: U256 = U256::from_limbs([0, 0, 1, 0]);
pub const PRECISION: u128 = 1_000_000_000_000_000_000;
pub const SQUARED_PRECISION: u128 = 0;
pub const MAX_FEE: u128 = 100_000_000_000_000_000;
pub const BASIS_POINT_MAX: u128 = 10_000;
pub const REAL_ID_SHIFT: i32 = 1 << 23;
pub fn get_base(bin_step: u16) -> U256 {
SCALE + (U256::from(bin_step) << SCALE_OFFSET) / U256::from(BASIS_POINT_MAX)
}
pub fn get_exponent(id: u32) -> i32 {
id as i32 - REAL_ID_SHIFT
}
pub fn get_price_from_id(id: u32, bin_step: u16) -> U256 {
let base = get_base(bin_step);
let exponent = get_exponent(id);
pow_128x128(base, exponent)
}
pub fn pow_128x128(x: U256, y: i32) -> U256 {
if y == 0 {
return SCALE;
}
let abs_y = y.unsigned_abs();
let mut invert = y < 0;
let mut squared = x;
if squared > SCALE {
squared = U256::MAX / squared;
invert = !invert;
}
let mut result = SCALE;
for bit in 0..20u32 {
if abs_y & (1 << bit) != 0 {
result = mul_128x128(result, squared);
}
squared = mul_128x128(squared, squared);
}
if result.is_zero() {
return U256::ZERO;
}
if invert {
U256::MAX / result
} else {
result
}
}
fn mul_128x128(a: U256, b: U256) -> U256 {
let product: U512 = a.widening_mul(b);
let shifted = product >> SCALE_OFFSET;
let limbs = shifted.as_limbs();
U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]])
}
pub fn get_base_fee(base_factor: u16, bin_step: u16) -> u128 {
base_factor as u128 * bin_step as u128 * 10_000_000_000u128
}
pub fn get_variable_fee(
volatility_accumulator: u32,
bin_step: u16,
variable_fee_control: u32,
) -> u128 {
if variable_fee_control == 0 {
return 0;
}
let prod = volatility_accumulator as u128 * bin_step as u128;
(prod * prod * variable_fee_control as u128 + 99) / 100
}
pub fn get_total_fee(
base_factor: u16,
bin_step: u16,
volatility_accumulator: u32,
variable_fee_control: u32,
) -> u128 {
let total = get_base_fee(base_factor, bin_step)
+ get_variable_fee(volatility_accumulator, bin_step, variable_fee_control);
total.min(MAX_FEE)
}
pub fn get_fee_amount_from(amount_with_fees: u128, total_fee: u128) -> u128 {
debug_assert!(total_fee <= MAX_FEE, "Fee too large");
let numerator =
U256::from(amount_with_fees) * U256::from(total_fee) + U256::from(PRECISION - 1);
let result = numerator / U256::from(PRECISION);
result.to::<u128>()
}
pub fn get_fee_amount(amount: u128, total_fee: u128) -> u128 {
debug_assert!(total_fee <= MAX_FEE, "Fee too large");
let denominator = PRECISION - total_fee;
let numerator = U256::from(amount) * U256::from(total_fee) + U256::from(denominator - 1);
let result = numerator / U256::from(denominator);
result.to::<u128>()
}
pub fn decode_amounts(packed: FixedBytes<32>) -> (u128, u128) {
let bytes = packed.as_slice();
let y = u128::from_be_bytes(bytes[0..16].try_into().unwrap());
let x = u128::from_be_bytes(bytes[16..32].try_into().unwrap());
(x, y)
}
pub fn decode_amount(packed: FixedBytes<32>, decode_x: bool) -> u128 {
let (x, y) = decode_amounts(packed);
if decode_x {
x
} else {
y
}
}
pub fn mul_shift_round_down(x: U256, y: U256, offset: u32) -> U256 {
let product: U512 = x.widening_mul(y);
let shifted = product >> offset;
let limbs = shifted.as_limbs();
U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]])
}
pub fn mul_shift_round_up(x: U256, y: U256, offset: u32) -> U256 {
let result = mul_shift_round_down(x, y, offset);
let product: U512 = x.widening_mul(y);
let mask = if offset >= 512 {
U512::ZERO
} else {
(U512::from(1u64) << offset) - U512::from(1u64)
};
if (product & mask) > U512::ZERO {
result + U256::from(1u64)
} else {
result
}
}
pub fn shift_div_round_down(x: U256, offset: u32, denominator: U256) -> U256 {
if denominator.is_zero() {
return U256::ZERO;
}
let x_wide = U512::from(x);
let shifted = x_wide << offset;
let denom_wide = U512::from(denominator);
let result = shifted / denom_wide;
let limbs = result.as_limbs();
U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]])
}
pub fn shift_div_round_up(x: U256, offset: u32, denominator: U256) -> U256 {
if denominator.is_zero() {
return U256::ZERO;
}
let result = shift_div_round_down(x, offset, denominator);
let x_wide = U512::from(x);
let shifted = x_wide << offset;
let denom_wide = U512::from(denominator);
if shifted % denom_wide > U512::ZERO {
result + U256::from(1u64)
} else {
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scale_constant() {
assert_eq!(SCALE, U256::from(1u64) << 128);
}
#[test]
fn test_get_base() {
let base = get_base(1);
assert!(base > SCALE);
let diff = base - SCALE;
let expected_diff = SCALE / U256::from(10000u64);
assert_eq!(diff, expected_diff);
}
#[test]
fn test_get_base_25() {
let base = get_base(25);
let diff = base - SCALE;
let expected_diff = U256::from(25u64) * SCALE / U256::from(10000u64);
assert_eq!(diff, expected_diff);
}
#[test]
fn test_pow_identity() {
assert_eq!(pow_128x128(get_base(25), 0), SCALE);
}
#[test]
fn test_pow_one() {
let base = get_base(25);
let result = pow_128x128(base, 1);
assert_eq!(result, base);
}
#[test]
fn test_pow_negative_one() {
let base = get_base(25);
let result = pow_128x128(base, -1);
let expected = U256::MAX / base;
let diff = if result > expected {
result - expected
} else {
expected - result
};
assert!(diff <= U256::from(1u64));
}
#[test]
fn test_price_at_center_bin() {
let price = get_price_from_id(REAL_ID_SHIFT as u32, 25);
assert_eq!(price, SCALE);
}
#[test]
fn test_price_monotonic() {
let center = REAL_ID_SHIFT as u32;
let p1 = get_price_from_id(center, 25);
let p2 = get_price_from_id(center + 1, 25);
let p3 = get_price_from_id(center + 10, 25);
assert!(p2 > p1);
assert!(p3 > p2);
let p0 = get_price_from_id(center - 1, 25);
assert!(p0 < p1);
}
#[test]
fn test_base_fee() {
let fee = get_base_fee(15, 25);
assert_eq!(fee, 3_750_000_000_000u128);
}
#[test]
fn test_variable_fee_zero_control() {
assert_eq!(get_variable_fee(1000, 25, 0), 0);
}
#[test]
fn test_total_fee_capped() {
let fee = get_total_fee(u16::MAX, u16::MAX, u32::MAX, u32::MAX);
assert!(fee <= MAX_FEE);
}
#[test]
fn test_fee_amount_from() {
let total_fee = 3_000_000_000_000_000u128; let amount_with_fees = 1_000_000_000_000_000_000_000u128; let fee = get_fee_amount_from(amount_with_fees, total_fee);
assert_eq!(fee, 3_000_000_000_000_000_000u128);
}
#[test]
fn test_decode_amounts() {
let mut bytes = [0u8; 32];
bytes[0..16].copy_from_slice(&200u128.to_be_bytes());
bytes[16..32].copy_from_slice(&100u128.to_be_bytes());
let packed = FixedBytes::<32>::from(bytes);
let (x, y) = decode_amounts(packed);
assert_eq!(x, 100);
assert_eq!(y, 200);
}
#[test]
fn test_mul_shift_round_down() {
let result = mul_shift_round_down(SCALE, SCALE, SCALE_OFFSET);
assert_eq!(result, SCALE);
}
#[test]
fn test_shift_div_round_down() {
let result = shift_div_round_down(U256::from(100u64), SCALE_OFFSET, SCALE);
assert_eq!(result, U256::from(100u64));
}
}