use alloy_primitives::{I256, U256};
use std::collections::HashMap;
use crate::core::{get_p, p_oracle_up};
use crate::swap::{self, SwapState};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PoolError {
InvalidIndex,
MathError,
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidIndex => f.write_str("invalid token index: must be 0 or 1"),
Self::MathError => f.write_str("math error during computation"),
}
}
}
impl std::error::Error for PoolError {}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct LlammaPool {
pub a: U256,
pub a_minus_1: U256,
pub base_price: U256,
pub log_a_ratio: I256,
pub max_oracle_dn_pow: U256,
pub sqrt_band_ratio: U256,
pub borrowed_precision: U256,
pub collateral_precision: U256,
pub fee: U256,
pub active_band: i64,
pub min_band: i64,
pub max_band: i64,
pub bands_x: HashMap<i64, U256>,
pub bands_y: HashMap<i64, U256>,
pub p_oracle: U256,
pub oracle_fee: U256,
pub static_antifee: bool,
}
impl LlammaPool {
#[allow(clippy::too_many_arguments)]
pub fn new(
a: U256,
a_minus_1: U256,
base_price: U256,
log_a_ratio: I256,
max_oracle_dn_pow: U256,
sqrt_band_ratio: U256,
borrowed_precision: U256,
collateral_precision: U256,
fee: U256,
active_band: i64,
min_band: i64,
max_band: i64,
bands_x: HashMap<i64, U256>,
bands_y: HashMap<i64, U256>,
p_oracle: U256,
oracle_fee: U256,
static_antifee: bool,
) -> Self {
Self {
a,
a_minus_1,
base_price,
log_a_ratio,
max_oracle_dn_pow,
sqrt_band_ratio,
borrowed_precision,
collateral_precision,
fee,
active_band,
min_band,
max_band,
bands_x,
bands_y,
p_oracle,
oracle_fee,
static_antifee,
}
}
pub fn get_amount_out(&self, i: usize, j: usize, dx: U256) -> Result<U256, PoolError> {
if !((i == 0 && j == 1) || (i == 1 && j == 0)) {
return Err(PoolError::InvalidIndex);
}
if dx.is_zero() {
return Ok(U256::ZERO);
}
let (in_precision, out_precision) = if i == 0 {
(self.borrowed_precision, self.collateral_precision)
} else {
(self.collateral_precision, self.borrowed_precision)
};
let pump = i == 0;
let scaled_in = dx * in_precision;
let p_o_up = self.compute_p_oracle_up(self.active_band)?;
let state = SwapState {
a: self.a,
a_minus_1: self.a_minus_1,
fee: self.fee,
bands_x: &self.bands_x,
bands_y: &self.bands_y,
active_band: self.active_band,
min_band: self.min_band,
max_band: self.max_band,
p_oracle: self.p_oracle,
oracle_fee: self.oracle_fee,
p_oracle_up_active: p_o_up,
max_oracle_dn_pow: self.max_oracle_dn_pow,
in_precision,
out_precision,
static_antifee: self.static_antifee,
};
let result = swap::calc_swap_out(pump, scaled_in, &state).ok_or(PoolError::MathError)?;
Ok(result.out_amount / out_precision)
}
pub fn spot_price(&self) -> Result<U256, PoolError> {
let p_o_up = self.compute_p_oracle_up(self.active_band)?;
let x = *self.bands_x.get(&self.active_band).unwrap_or(&U256::ZERO);
let y = *self.bands_y.get(&self.active_band).unwrap_or(&U256::ZERO);
get_p(x, y, self.p_oracle, p_o_up, self.a, self.a_minus_1).ok_or(PoolError::MathError)
}
fn compute_p_oracle_up(&self, n: i64) -> Result<U256, PoolError> {
p_oracle_up(n, self.base_price, self.log_a_ratio).ok_or(PoolError::MathError)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::WAD;
fn test_pool(
a: u64,
base_price: U256,
bands_x: HashMap<i64, U256>,
bands_y: HashMap<i64, U256>,
p_oracle: U256,
) -> LlammaPool {
let a_val = U256::from(a);
let a_minus_1 = U256::from(a - 1);
let log_a_ratio = I256::try_from(10_050_335_853_501i128).unwrap();
let mut pow = WAD;
for _ in 0..50 {
pow = pow * a_val / a_minus_1;
}
let sqrt_band_ratio = WAD * U256::from(10050u64) / U256::from(10000u64);
LlammaPool::new(
a_val,
a_minus_1,
base_price,
log_a_ratio,
pow,
sqrt_band_ratio,
U256::from(1u64), U256::from(1u64), WAD / U256::from(1000u64), 0,
-10,
10,
bands_x,
bands_y,
p_oracle,
U256::ZERO,
false,
)
}
#[test]
fn get_amount_out_invalid_index() {
let pool = test_pool(
100,
WAD * U256::from(2000u64),
HashMap::new(),
HashMap::new(),
WAD * U256::from(2000u64),
);
assert_eq!(
pool.get_amount_out(0, 0, WAD).unwrap_err(),
PoolError::InvalidIndex
);
assert_eq!(
pool.get_amount_out(2, 0, WAD).unwrap_err(),
PoolError::InvalidIndex
);
}
#[test]
fn get_amount_out_zero_returns_zero() {
let pool = test_pool(
100,
WAD * U256::from(2000u64),
HashMap::new(),
HashMap::new(),
WAD * U256::from(2000u64),
);
assert_eq!(pool.get_amount_out(0, 1, U256::ZERO).unwrap(), U256::ZERO);
assert_eq!(pool.get_amount_out(1, 0, U256::ZERO).unwrap(), U256::ZERO);
}
#[test]
fn get_amount_out_no_liquidity_returns_zero() {
let pool = test_pool(
100,
WAD * U256::from(2000u64),
HashMap::new(),
HashMap::new(),
WAD * U256::from(2000u64),
);
let result = pool.get_amount_out(0, 1, WAD * U256::from(100u64)).unwrap();
assert_eq!(result, U256::ZERO);
}
#[test]
fn get_amount_out_pump_produces_output() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(100u64));
by.insert(0i64, WAD * U256::from(10u64));
let base_price = WAD * U256::from(2000u64);
let p_oracle = WAD * U256::from(2000u64);
let pool = test_pool(100, base_price, bx, by, p_oracle);
let dx = WAD * U256::from(100u64);
let dy = pool.get_amount_out(0, 1, dx).unwrap();
assert!(dy > U256::ZERO, "should get collateral output, got {dy}");
}
#[test]
fn get_amount_out_dump_produces_output() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(10000u64));
by.insert(0i64, WAD * U256::from(5u64));
let base_price = WAD * U256::from(2000u64);
let p_oracle = WAD * U256::from(2000u64);
let pool = test_pool(100, base_price, bx, by, p_oracle);
let dy = WAD; let dx = pool.get_amount_out(1, 0, dy).unwrap();
assert!(dx > U256::ZERO, "should get borrowed output, got {dx}");
}
#[test]
fn spot_price_returns_nonzero() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(1000u64));
by.insert(0i64, WAD * U256::from(5u64));
let base_price = WAD * U256::from(2000u64);
let p_oracle = WAD * U256::from(2000u64);
let pool = test_pool(100, base_price, bx, by, p_oracle);
let price = pool.spot_price().unwrap();
assert!(price > U256::ZERO, "spot price should be > 0, got {price}");
}
#[test]
fn get_amount_out_larger_input_gives_more_output() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(1000u64));
by.insert(0i64, WAD * U256::from(50u64));
let base_price = WAD * U256::from(2000u64);
let p_oracle = WAD * U256::from(2000u64);
let pool = test_pool(100, base_price, bx, by, p_oracle);
let small_dx = WAD * U256::from(10u64);
let large_dx = WAD * U256::from(100u64);
let small_dy = pool.get_amount_out(0, 1, small_dx).unwrap();
let large_dy = pool.get_amount_out(0, 1, large_dx).unwrap();
assert!(
large_dy > small_dy,
"larger input should give more output: {small_dy} vs {large_dy}"
);
}
}