use alloy_primitives::U256;
use std::collections::HashMap;
use crate::constants::{MAX_SKIP_TICKS_UINT, MAX_TICKS_UINT, WAD};
use crate::core::{get_dynamic_fee, get_y0};
#[derive(Debug, Clone, Default)]
pub struct DetailedTrade {
pub in_amount: U256,
pub out_amount: U256,
pub n1: i64,
pub n2: i64,
pub ticks_in: Vec<U256>,
pub last_tick_j: U256,
}
pub struct SwapState<'a> {
pub a: U256,
pub a_minus_1: U256,
pub fee: U256,
pub bands_x: &'a HashMap<i64, U256>,
pub bands_y: &'a HashMap<i64, U256>,
pub active_band: i64,
pub min_band: i64,
pub max_band: i64,
pub p_oracle: U256,
pub oracle_fee: U256,
pub p_oracle_up_active: U256,
pub max_oracle_dn_pow: U256,
pub in_precision: U256,
pub out_precision: U256,
pub static_antifee: bool,
}
pub fn calc_swap_out(pump: bool, in_amount: U256, state: &SwapState) -> Option<DetailedTrade> {
let min_band = state.min_band;
let max_band = state.max_band;
let mut out = DetailedTrade {
n2: state.active_band,
..Default::default()
};
let mut p_o_up = state.p_oracle_up_active;
let mut x = *state.bands_x.get(&out.n2).unwrap_or(&U256::ZERO);
let mut y = *state.bands_y.get(&out.n2).unwrap_or(&U256::ZERO);
let mut in_amount_left = in_amount;
let fee = if state.fee > state.oracle_fee {
state.fee
} else {
state.oracle_fee
};
let mut j: u64 = MAX_TICKS_UINT;
let p_o = state.p_oracle;
let a = state.a;
let a_minus_1 = state.a_minus_1;
let static_antifee_val = if state.static_antifee {
let capped = if fee < WAD - U256::from(1u64) {
fee
} else {
WAD - U256::from(1u64)
};
Some(WAD * WAD / (WAD - capped))
} else {
None
};
let max_iter = MAX_TICKS_UINT + MAX_SKIP_TICKS_UINT;
for i in 0..max_iter {
let y0;
let mut f = U256::ZERO;
let mut g = U256::ZERO;
let mut inv = U256::ZERO;
let mut dynamic_fee = fee;
if x > U256::ZERO || y > U256::ZERO {
if j == MAX_TICKS_UINT {
out.n1 = out.n2;
j = 0;
}
y0 = get_y0(x, y, p_o, p_o_up, a, a_minus_1)?;
f = a * y0 * p_o / p_o_up * p_o / WAD;
g = a_minus_1 * y0 * p_o_up / p_o;
inv = (f + x) * (g + y);
if !state.static_antifee {
let df = get_dynamic_fee(p_o, p_o_up, a, a_minus_1);
dynamic_fee = if df > fee { df } else { fee };
}
}
let antifee = if let Some(val) = static_antifee_val {
val
} else {
let capped_fee = if dynamic_fee < WAD - U256::from(1u64) {
dynamic_fee
} else {
WAD - U256::from(1u64)
};
WAD * WAD / (WAD - capped_fee)
};
if j != MAX_TICKS_UINT {
let tick = if pump { x } else { y };
out.ticks_in.push(tick);
}
let p_ratio = p_o_up * WAD / p_o;
if pump {
if y != U256::ZERO && g != U256::ZERO {
let x_dest = (inv / g - f) - x;
let dx = x_dest * antifee / WAD;
if dx >= in_amount_left {
let x_dest = in_amount_left * WAD / antifee;
let rem = inv / (f + (x + x_dest)) - g + U256::from(1u64);
out.last_tick_j = if rem < y { rem } else { y };
x += in_amount_left;
out.out_amount += y - out.last_tick_j;
out.ticks_in[j as usize] = x;
out.in_amount = in_amount;
break;
} else {
let dx = if dx > U256::from(1u64) {
dx
} else {
U256::from(1u64)
};
in_amount_left -= dx;
out.ticks_in[j as usize] = x + dx;
out.in_amount += dx;
out.out_amount += y;
}
}
if i != max_iter - 1 {
if out.n2 == max_band {
break;
}
if j == MAX_TICKS_UINT - 1 {
break;
}
if p_ratio < WAD * WAD / state.max_oracle_dn_pow {
break;
}
out.n2 += 1;
p_o_up = p_o_up * a_minus_1 / a;
x = U256::ZERO;
y = *state.bands_y.get(&out.n2).unwrap_or(&U256::ZERO);
}
} else {
if x != U256::ZERO && f != U256::ZERO {
let y_dest = (inv / f - g) - y;
let dy = y_dest * antifee / WAD;
if dy >= in_amount_left {
let y_dest = in_amount_left * WAD / antifee;
let rem = inv / (g + (y + y_dest)) - f + U256::from(1u64);
out.last_tick_j = if rem < x { rem } else { x };
y += in_amount_left;
out.out_amount += x - out.last_tick_j;
out.ticks_in[j as usize] = y;
out.in_amount = in_amount;
break;
} else {
let dy = if dy > U256::from(1u64) {
dy
} else {
U256::from(1u64)
};
in_amount_left -= dy;
out.ticks_in[j as usize] = y + dy;
out.in_amount += dy;
out.out_amount += x;
}
}
if i != max_iter - 1 {
if out.n2 == min_band {
break;
}
if j == MAX_TICKS_UINT - 1 {
break;
}
if p_ratio > state.max_oracle_dn_pow {
break;
}
out.n2 -= 1;
p_o_up = p_o_up * a / a_minus_1;
x = *state.bands_x.get(&out.n2).unwrap_or(&U256::ZERO);
y = U256::ZERO;
}
}
if j != MAX_TICKS_UINT {
j += 1;
}
}
let in_prec = state.in_precision;
let out_prec = state.out_precision;
out.in_amount = (out.in_amount + in_prec - U256::from(1u64)) / in_prec * in_prec;
out.out_amount = out.out_amount / out_prec * out_prec;
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_state(
a: u64,
fee_bps: u64,
active_band: i64,
bands_x: HashMap<i64, U256>,
bands_y: HashMap<i64, U256>,
p_oracle: U256,
p_oracle_up_active: U256,
) -> (SwapState<'static>, HashMap<i64, U256>, HashMap<i64, U256>) {
let bx = Box::leak(Box::new(bands_x));
let by = Box::leak(Box::new(bands_y));
let a_val = U256::from(a);
let a_minus_1 = U256::from(a - 1);
let mut pow = WAD;
for _ in 0..50 {
pow = pow * a_val / a_minus_1;
}
(
SwapState {
a: a_val,
a_minus_1,
fee: U256::from(fee_bps) * WAD / U256::from(10000u64),
bands_x: bx,
bands_y: by,
active_band,
min_band: -10,
max_band: 10,
p_oracle,
oracle_fee: U256::ZERO,
p_oracle_up_active,
max_oracle_dn_pow: pow,
in_precision: U256::from(1u64),
out_precision: U256::from(1u64),
static_antifee: false,
},
HashMap::new(), HashMap::new(),
)
}
#[test]
fn calc_swap_out_zero_input_returns_zero_output() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(1000u64));
by.insert(0i64, WAD);
let p_o = WAD * U256::from(3000u64);
let p_o_up = WAD * U256::from(3010u64);
let (state, _, _) = make_state(100, 10, 0, bx, by, p_o, p_o_up);
let result = calc_swap_out(true, U256::ZERO, &state).unwrap();
assert_eq!(result.out_amount, U256::ZERO);
assert_eq!(result.in_amount, U256::ZERO);
}
#[test]
fn calc_swap_out_pump_produces_output() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(100u64));
by.insert(0i64, WAD * U256::from(10u64));
let p_o = WAD * U256::from(2000u64);
let p_o_up = WAD * U256::from(2010u64);
let (state, _, _) = make_state(100, 10, 0, bx, by, p_o, p_o_up);
let dx = WAD * U256::from(100u64);
let result = calc_swap_out(true, dx, &state).unwrap();
assert!(result.out_amount > U256::ZERO, "should get some output");
assert!(result.in_amount > U256::ZERO, "should consume some input");
assert!(result.in_amount <= dx, "should not consume more than input");
}
#[test]
fn calc_swap_out_dump_produces_output() {
let mut bx = HashMap::new();
let mut by = HashMap::new();
bx.insert(0i64, WAD * U256::from(10000u64));
by.insert(0i64, WAD * U256::from(5u64));
let p_o = WAD * U256::from(2000u64);
let p_o_up = WAD * U256::from(2010u64);
let (state, _, _) = make_state(100, 10, 0, bx, by, p_o, p_o_up);
let dy = WAD;
let result = calc_swap_out(false, dy, &state).unwrap();
assert!(result.out_amount > U256::ZERO, "should get some output");
assert!(result.in_amount > U256::ZERO, "should consume some input");
}
#[test]
fn calc_swap_out_no_liquidity_returns_zero() {
let bx = HashMap::new();
let by = HashMap::new();
let p_o = WAD * U256::from(2000u64);
let p_o_up = WAD * U256::from(2010u64);
let (state, _, _) = make_state(100, 10, 0, bx, by, p_o, p_o_up);
let result = calc_swap_out(true, WAD * U256::from(100u64), &state).unwrap();
assert_eq!(result.out_amount, U256::ZERO);
}
}