use crate::AmmMathError;
const MAX_FEE_BPS: u16 = 10000;
pub fn calculate_fee_growth_inside(
fee_growth_global: u128,
fee_growth_outside_lower: u128,
fee_growth_outside_upper: u128,
tick_lower: i32,
tick_upper: i32,
tick_current: i32,
) -> u128 {
let fee_growth_below = if tick_current >= tick_lower {
fee_growth_outside_lower
} else {
fee_growth_global.wrapping_sub(fee_growth_outside_lower)
};
let fee_growth_above = if tick_current < tick_upper {
fee_growth_outside_upper
} else {
fee_growth_global.wrapping_sub(fee_growth_outside_upper)
};
fee_growth_global.wrapping_sub(fee_growth_below).wrapping_sub(fee_growth_above)
}
pub fn tokens_owed(
fee_growth_inside_current: u128,
fee_growth_inside_last: u128,
liquidity: u128,
) -> u128 {
let delta = fee_growth_inside_current.wrapping_sub(fee_growth_inside_last);
match delta.checked_mul(liquidity) {
Some(product) => product >> 64,
None => mul_u128_shr64(delta, liquidity),
}
}
fn mul_u128_shr64(a: u128, b: u128) -> u128 {
let mask: u128 = u64::MAX as u128;
let a_lo = a & mask;
let a_hi = a >> 64;
let b_lo = b & mask;
let b_hi = b >> 64;
let lo_lo = a_lo * b_lo;
let lo_hi = a_lo * b_hi;
let hi_lo = a_hi * b_lo;
let hi_hi = a_hi * b_hi;
let mid_sum = (lo_lo >> 64) + (lo_hi & mask) + (hi_lo & mask);
let carry = mid_sum >> 64;
let upper = hi_hi + (lo_hi >> 64) + (hi_lo >> 64) + carry;
upper.checked_shl(64).and_then(|h| h.checked_add(mid_sum & mask)).unwrap_or(u128::MAX)
}
pub fn fee_amount_from_input(input_amount: u64, fee_rate_bps: u16) -> Result<u64, AmmMathError> {
if fee_rate_bps > MAX_FEE_BPS {
return Err(AmmMathError::InvalidFeeRate(fee_rate_bps));
}
let fee = (input_amount as u128) * (fee_rate_bps as u128) / (MAX_FEE_BPS as u128);
Ok(fee as u64)
}
pub fn fee_amount_from_output(output_amount: u64, fee_rate_bps: u16) -> Result<u64, AmmMathError> {
if fee_rate_bps >= MAX_FEE_BPS {
return Err(AmmMathError::InvalidFeeRate(fee_rate_bps));
}
let numerator = (output_amount as u128) * (fee_rate_bps as u128);
let denominator = (MAX_FEE_BPS - fee_rate_bps) as u128;
let fee = numerator.div_ceil(denominator);
if fee > u64::MAX as u128 {
return Err(AmmMathError::Overflow);
}
Ok(fee as u64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fee_from_input_30bps() {
assert_eq!(fee_amount_from_input(1_000_000, 30).unwrap(), 3000);
}
#[test]
fn test_fee_from_input_zero() {
assert_eq!(fee_amount_from_input(1_000_000, 0).unwrap(), 0);
}
#[test]
fn test_fee_from_input_100_percent() {
assert_eq!(fee_amount_from_input(1_000_000, 10000).unwrap(), 1_000_000);
}
#[test]
fn test_fee_from_input_invalid() {
assert!(fee_amount_from_input(1_000_000, 10001).is_err());
}
#[test]
fn test_fee_from_output_30bps() {
assert_eq!(fee_amount_from_output(1_000_000, 30).unwrap(), 3010);
}
#[test]
fn test_fee_from_output_zero() {
assert_eq!(fee_amount_from_output(1_000_000, 0).unwrap(), 0);
}
#[test]
fn test_fee_from_output_100_percent_invalid() {
assert!(fee_amount_from_output(1_000_000, 10000).is_err());
}
#[test]
fn test_fee_from_output_rounding() {
assert_eq!(fee_amount_from_output(100, 30).unwrap(), 1);
}
#[test]
fn test_fee_from_input_large_amount() {
let amount = u64::MAX;
let fee = fee_amount_from_input(amount, 30).unwrap();
assert!(fee > 0);
let expected = (u64::MAX as u128) * 30 / 10000;
assert_eq!(fee, expected as u64);
}
#[test]
fn test_fee_roundtrip_consistency() {
let input = 1_000_000u64;
let bps = 30u16;
let fee = fee_amount_from_input(input, bps).unwrap();
let net = input - fee;
let reverse_fee = fee_amount_from_output(net, bps).unwrap();
assert!(reverse_fee >= fee);
}
#[test]
fn test_fee_growth_inside_current_in_range() {
let inside = calculate_fee_growth_inside(100, 10, 20, -10, 10, 0);
assert_eq!(inside, 70);
}
#[test]
fn test_fee_growth_inside_current_below() {
let inside = calculate_fee_growth_inside(100, 10, 20, 5, 10, 0);
assert_eq!(inside, u128::MAX - 9);
}
#[test]
fn test_fee_growth_inside_current_above() {
let inside = calculate_fee_growth_inside(100, 10, 20, -10, 5, 10);
assert_eq!(inside, 10);
}
#[test]
fn test_fee_growth_wrapping() {
let global = 50u128;
let lower_outside = 200u128;
let inside = calculate_fee_growth_inside(global, lower_outside, 0, -10, 10, 0);
assert_eq!(inside, global.wrapping_sub(lower_outside));
}
#[test]
fn test_fee_growth_all_zero() {
let inside = calculate_fee_growth_inside(0, 0, 0, -10, 10, 0);
assert_eq!(inside, 0);
}
#[test]
fn test_fee_growth_current_at_lower() {
let inside = calculate_fee_growth_inside(100, 10, 20, 0, 10, 0);
assert_eq!(inside, 70);
}
#[test]
fn test_fee_growth_current_at_upper() {
let inside = calculate_fee_growth_inside(100, 10, 20, -10, 5, 5);
assert_eq!(inside, 10);
}
#[test]
fn test_tokens_owed_basic() {
let owed = tokens_owed(1u128 << 64, 0, 1_000_000);
assert_eq!(owed, 1_000_000);
}
#[test]
fn test_tokens_owed_wrapping() {
let current = 10u128;
let last = u128::MAX - 4; let owed = tokens_owed(current, last, 1u128 << 64);
assert_eq!(owed, 15);
}
#[test]
fn test_tokens_owed_zero_liquidity() {
let owed = tokens_owed(100, 0, 0);
assert_eq!(owed, 0);
}
#[test]
fn test_tokens_owed_no_change() {
let owed = tokens_owed(500, 500, 1_000_000);
assert_eq!(owed, 0);
}
#[test]
fn fuzz_fee_growth_inside_wrapping_identity() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let fee_growth_global: u128 = rng.random();
let fee_growth_outside_lower: u128 = rng.random();
let fee_growth_outside_upper: u128 = rng.random();
let tick_lower: i32 = rng.random_range(-443636..0);
let tick_upper: i32 = rng.random_range(1..=443636);
let tick_current: i32 = rng.random_range(tick_lower..=tick_upper);
let _inside = calculate_fee_growth_inside(
fee_growth_global,
fee_growth_outside_lower,
fee_growth_outside_upper,
tick_lower,
tick_upper,
tick_current,
);
}
}
#[test]
fn fuzz_fee_growth_inside_all_positions() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let fee_growth_global: u128 = rng.random();
let fee_growth_outside_lower: u128 = rng.random();
let fee_growth_outside_upper: u128 = rng.random();
let tick_lower: i32 = rng.random_range(-443636..0);
let tick_upper: i32 = rng.random_range(1..=443636);
for &tick_current in &[
tick_lower - 1,
tick_lower,
(tick_lower + tick_upper) / 2,
tick_upper - 1,
tick_upper,
] {
let _inside = calculate_fee_growth_inside(
fee_growth_global,
fee_growth_outside_lower,
fee_growth_outside_upper,
tick_lower,
tick_upper,
tick_current,
);
}
}
}
#[test]
fn fuzz_tokens_owed_wrapping_no_panic() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let current: u128 = rng.random();
let last: u128 = rng.random();
let liquidity: u128 = rng.random();
let _owed = tokens_owed(current, last, liquidity);
}
}
#[test]
fn fuzz_fee_amount_from_input_valid_range() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let amount: u64 = rng.random();
let bps: u16 = rng.random_range(0..=MAX_FEE_BPS);
let fee = fee_amount_from_input(amount, bps).unwrap();
assert!(fee <= amount, "fee exceeds input: amount={amount}, bps={bps}, fee={fee}");
}
}
#[test]
fn fuzz_fee_amount_from_output_valid_range() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let amount: u64 = rng.random();
let bps: u16 = rng.random_range(0..MAX_FEE_BPS);
let result = fee_amount_from_output(amount, bps);
if let Ok(fee) = result {
if bps == 0 {
assert_eq!(fee, 0, "zero bps should yield zero fee");
}
}
}
}
#[test]
fn fuzz_fee_input_output_consistency() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let input: u64 = rng.random_range(1..=1_000_000_000u64);
let bps: u16 = rng.random_range(1..=9999);
let fee_in = fee_amount_from_input(input, bps).unwrap();
let net = input - fee_in;
if net == 0 {
continue;
}
let fee_out = fee_amount_from_output(net, bps).unwrap();
assert!(
fee_out >= fee_in,
"output fee < input fee: input={input}, bps={bps}, fee_in={fee_in}, net={net}, \
fee_out={fee_out}"
);
}
}
}