use alloy::primitives::U256;
use num_bigint::BigUint;
use num_traits::Zero;
use tycho_common::{simulation::errors::SimulationError, Bytes};
use crate::evm::protocol::{
safe_math::{safe_add_u256, safe_div_u256, safe_mul_u256, safe_sub_u256},
u256_num::u256_to_biguint,
};
pub fn get_amount_out(
amount_in: U256,
zero2one: bool,
reserve0: U256,
reserve1: U256,
fee_bps: u32,
decimals0: u8,
decimals1: u8,
) -> Result<U256, SimulationError> {
if amount_in.is_zero() {
return Err(SimulationError::InvalidInput("Amount in cannot be zero".to_string(), None));
}
if reserve0.is_zero() || reserve1.is_zero() {
return Err(SimulationError::RecoverableError("No liquidity".to_string()));
}
let xy = _k(reserve0, reserve1, decimals0, decimals1)?;
let e18 = U256::from(10u128.pow(18));
let decimals0_scale = U256::from(10u128.pow(decimals0 as u32));
let decimals1_scale = U256::from(10u128.pow(decimals1 as u32));
let reserve0_normalized = safe_div_u256(safe_mul_u256(reserve0, e18)?, decimals0_scale)?;
let reserve1_normalized = safe_div_u256(safe_mul_u256(reserve1, e18)?, decimals1_scale)?;
let (reserve_in, reserve_out, decimals_in, decimals_out) = if zero2one {
(reserve0_normalized, reserve1_normalized, decimals0, decimals1)
} else {
(reserve1_normalized, reserve0_normalized, decimals1, decimals0)
};
let fee_amount =
safe_div_u256(safe_mul_u256(amount_in, U256::from(fee_bps))?, U256::from(10000))?;
let amount_in_with_fee = safe_sub_u256(amount_in, fee_amount)?;
let decimals_in_scale = U256::from(10u128.pow(decimals_in as u32));
let amount_in_normalized =
safe_div_u256(safe_mul_u256(amount_in_with_fee, e18)?, decimals_in_scale)?;
let x0 = safe_add_u256(amount_in_normalized, reserve_in)?;
let y_new = _get_y(x0, xy, reserve_out)?;
let y_diff = safe_sub_u256(reserve_out, y_new)?;
let decimals_out_scale = U256::from(10u128.pow(decimals_out as u32));
let amount_out = safe_div_u256(safe_mul_u256(y_diff, decimals_out_scale)?, e18)?;
Ok(amount_out)
}
fn _f(x0: U256, y: U256) -> Result<U256, SimulationError> {
let e18 = U256::from(10u128.pow(18));
let a = safe_div_u256(safe_mul_u256(x0, y)?, e18)?;
let x0_squared = safe_div_u256(safe_mul_u256(x0, x0)?, e18)?;
let y_squared = safe_div_u256(safe_mul_u256(y, y)?, e18)?;
let b = safe_add_u256(x0_squared, y_squared)?;
safe_div_u256(safe_mul_u256(a, b)?, e18)
}
fn _d(x0: U256, y: U256) -> Result<U256, SimulationError> {
let e18 = U256::from(10u128.pow(18));
let y_squared = safe_div_u256(safe_mul_u256(y, y)?, e18)?;
let term1 = safe_div_u256(safe_mul_u256(safe_mul_u256(U256::from(3), x0)?, y_squared)?, e18)?;
let x0_squared = safe_div_u256(safe_mul_u256(x0, x0)?, e18)?;
let term2 = safe_div_u256(safe_mul_u256(x0_squared, x0)?, e18)?;
safe_add_u256(term1, term2)
}
fn _k(x: U256, y: U256, decimals0: u8, decimals1: u8) -> Result<U256, SimulationError> {
let e18 = U256::from(10u128.pow(18));
let decimals0_scale = U256::from(10u128.pow(decimals0 as u32));
let decimals1_scale = U256::from(10u128.pow(decimals1 as u32));
let x = safe_div_u256(safe_mul_u256(x, e18)?, decimals0_scale)?;
let y = safe_div_u256(safe_mul_u256(y, e18)?, decimals1_scale)?;
let a = safe_div_u256(safe_mul_u256(x, y)?, e18)?;
let b = safe_add_u256(
safe_div_u256(safe_mul_u256(x, x)?, e18)?,
safe_div_u256(safe_mul_u256(y, y)?, e18)?,
)?;
safe_div_u256(safe_mul_u256(a, b)?, e18)
}
fn _get_y(x0: U256, xy: U256, mut y: U256) -> Result<U256, SimulationError> {
let e18 = U256::from(10u128.pow(18));
for _ in 0..255 {
let k = _f(x0, y)?;
if k < xy {
let d = _d(x0, y)?;
if d.is_zero() {
return Err(SimulationError::FatalError("Division by zero in _get_y".to_string()));
}
let diff = safe_sub_u256(xy, k)?;
let mut dy = safe_div_u256(safe_mul_u256(diff, e18)?, d)?;
if dy.is_zero() {
if k == xy {
return Ok(y);
}
let y_plus_1 = safe_add_u256(y, U256::from(1))?;
if _f(x0, y_plus_1)? > xy {
return Ok(y_plus_1);
}
dy = U256::from(1);
}
y = safe_add_u256(y, dy)?;
} else {
let d = _d(x0, y)?;
if d.is_zero() {
return Err(SimulationError::FatalError("Division by zero in _get_y".to_string()));
}
let diff = safe_sub_u256(k, xy)?;
let mut dy = safe_div_u256(safe_mul_u256(diff, e18)?, d)?;
if dy.is_zero() {
if k == xy {
return Ok(y);
}
let y_minus_1 = safe_sub_u256(y, U256::from(1))?;
if _f(x0, y_minus_1)? < xy {
return Ok(y);
}
dy = U256::from(1);
}
y = safe_sub_u256(y, dy)?;
}
}
Err(SimulationError::FatalError(
"Failed to converge in _get_y after 255 iterations".to_string(),
))
}
pub fn get_limits(
sell_token: Bytes,
buy_token: Bytes,
reserve0: U256,
reserve1: U256,
decimals0: u8,
decimals1: u8,
) -> Result<(BigUint, BigUint), SimulationError> {
if reserve0.is_zero() || reserve1.is_zero() {
return Ok((BigUint::zero(), BigUint::zero()));
}
let zero_for_one = sell_token < buy_token;
let (reserve_in, reserve_out, decimals_in, decimals_out) = if zero_for_one {
(reserve0, reserve1, decimals0, decimals1)
} else {
(reserve1, reserve0, decimals1, decimals0)
};
let xy = _k(reserve0, reserve1, decimals0, decimals1)?;
let e18 = U256::from(10u128.pow(18));
let decimals_in_scale = U256::from(10u128.pow(decimals_in as u32));
let decimals_out_scale = U256::from(10u128.pow(decimals_out as u32));
let reserve_in_normalized = safe_div_u256(safe_mul_u256(reserve_in, e18)?, decimals_in_scale)?;
let reserve_out_normalized =
safe_div_u256(safe_mul_u256(reserve_out, e18)?, decimals_out_scale)?;
let amount_in_estimate =
safe_div_u256(safe_mul_u256(reserve_in, U256::from(300))?, U256::from(100))?;
let amount_in_normalized =
safe_div_u256(safe_mul_u256(amount_in_estimate, e18)?, decimals_in_scale)?;
let x0 = safe_add_u256(reserve_in_normalized, amount_in_normalized)?;
let y_new = _get_y(x0, xy, reserve_out_normalized)?;
let amount_out_normalized = safe_sub_u256(reserve_out_normalized, y_new)?;
let amount_out = safe_div_u256(safe_mul_u256(amount_out_normalized, decimals_out_scale)?, e18)?;
Ok((u256_to_biguint(amount_in_estimate), u256_to_biguint(amount_out)))
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use num_bigint::BigUint;
use tycho_common::simulation::errors::SimulationError;
use super::*;
#[test]
fn test_get_amount_out() {
assert_eq!(
get_amount_out(
U256::from_str("2000000000000000000").unwrap(),
true,
U256::from_str("2642455102346776307825").unwrap(),
U256::from_str("3320301880379841502303").unwrap(),
5,
18,
18,
)
.unwrap(),
U256::from_str("2004830151166915124").unwrap()
)
}
#[test]
fn test_get_amount_out_zero_input_rejected() {
let err = get_amount_out(
U256::ZERO,
true,
U256::from(1_000_000u64),
U256::from(1_000_000u64),
5,
18,
18,
)
.expect_err("zero input should fail");
assert!(matches!(err, SimulationError::InvalidInput(_, _)));
}
#[test]
fn test_get_amount_out_no_liquidity_rejected() {
let err =
get_amount_out(U256::from(1u64), true, U256::ZERO, U256::from(1_000_000u64), 5, 18, 18)
.expect_err("zero reserve should fail");
assert!(matches!(err, SimulationError::RecoverableError(_)));
}
#[test]
fn test_get_amount_out_higher_fee_reduces_output() {
let reserve0 = U256::from_str("2642455102346776307825").unwrap();
let reserve1 = U256::from_str("3320301880379841502303").unwrap();
let amount_in = U256::from_str("2000000000000000000").unwrap();
let low_fee_out = get_amount_out(amount_in, true, reserve0, reserve1, 5, 18, 18).unwrap();
let high_fee_out =
get_amount_out(amount_in, true, reserve0, reserve1, 100, 18, 18).unwrap();
assert!(high_fee_out < low_fee_out);
}
#[test]
fn test_get_amount_out_reverse_direction() {
let reserve0 = U256::from_str("2642455102346776307825").unwrap();
let reserve1 = U256::from_str("3320301880379841502303").unwrap();
let amount_in = U256::from_str("2000000000000000000").unwrap();
let out = get_amount_out(amount_in, false, reserve0, reserve1, 5, 18, 18).unwrap();
assert!(out > U256::ZERO);
assert!(out < reserve0);
}
#[test]
fn test_get_amount_out_with_different_decimals() {
let reserve0 = U256::from_str("1000000000000000000000000").unwrap();
let reserve1 = U256::from(1_000_000_000_000u64);
let amount_in = U256::from_str("1000000000000000000").unwrap();
let out = get_amount_out(amount_in, true, reserve0, reserve1, 5, 18, 6).unwrap();
assert!(out > U256::ZERO);
assert!(out < reserve1);
}
#[test]
fn test_get_limits_zero_liquidity_returns_zeroes() {
let sell = Bytes::from([0_u8; 20]);
let mut buy_addr = [0_u8; 20];
buy_addr[19] = 1;
let buy = Bytes::from(buy_addr);
let (amount_in, amount_out) =
get_limits(sell, buy, U256::ZERO, U256::from(1_000_000u64), 18, 18)
.expect("zero-liquidity limits should succeed");
assert_eq!(amount_in, BigUint::ZERO);
assert_eq!(amount_out, BigUint::ZERO);
}
#[test]
fn test_get_limits_returns_non_zero_values() {
let sell = Bytes::from([0_u8; 20]);
let mut buy_addr = [0_u8; 20];
buy_addr[19] = 1;
let buy = Bytes::from(buy_addr);
let reserve0 = U256::from_str("2642455102346776307825").unwrap();
let reserve1 = U256::from_str("3320301880379841502303").unwrap();
let (amount_in, amount_out) =
get_limits(sell, buy, reserve0, reserve1, 18, 18).expect("limits should succeed");
assert!(amount_in > BigUint::ZERO);
assert!(amount_out > BigUint::ZERO);
}
#[test]
fn test_get_limits_changes_with_direction() {
let sell0 = Bytes::from([0_u8; 20]);
let mut sell1_addr = [0_u8; 20];
sell1_addr[19] = 1;
let sell1 = Bytes::from(sell1_addr);
let reserve0 = U256::from_str("2642455102346776307825").unwrap();
let reserve1 = U256::from_str("3320301880379841502303").unwrap();
let (zero_to_one_in, zero_to_one_out) =
get_limits(sell0.clone(), sell1.clone(), reserve0, reserve1, 18, 18).unwrap();
let (one_to_zero_in, one_to_zero_out) =
get_limits(sell1, sell0, reserve0, reserve1, 18, 18).unwrap();
assert!(zero_to_one_in > BigUint::ZERO);
assert!(zero_to_one_out > BigUint::ZERO);
assert!(one_to_zero_in > BigUint::ZERO);
assert!(one_to_zero_out > BigUint::ZERO);
assert!(zero_to_one_in != one_to_zero_in || zero_to_one_out != one_to_zero_out);
}
}