use alloy_primitives::U256;
use super::PoolProfiler;
#[derive(Debug, Clone)]
pub struct EstimationConfig {
pub enable_adaptive_bounds: bool,
pub max_bound_expansions: u32,
pub tolerance_bps: u32,
pub max_iterations: u32,
}
impl Default for EstimationConfig {
fn default() -> Self {
Self {
enable_adaptive_bounds: true,
max_bound_expansions: 10,
tolerance_bps: 1,
max_iterations: 50,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.model")
)]
pub struct SizeForImpactResult {
pub target_impact_bps: u32,
pub size: U256,
pub actual_impact_bps: u32,
pub zero_for_one: bool,
pub iterations: u32,
pub converged: bool,
pub expansion_count: u32,
pub initial_high: U256,
pub final_low: U256,
pub final_high: U256,
}
impl SizeForImpactResult {
pub fn within_tolerance(&self, tolerance_bps: u32) -> bool {
let diff = self.actual_impact_bps.abs_diff(self.target_impact_bps);
diff <= tolerance_bps
}
pub fn accuracy_percent(&self) -> f64 {
if self.target_impact_bps == 0 {
return 100.0;
}
let diff = self.actual_impact_bps.abs_diff(self.target_impact_bps) as f64;
let target = self.target_impact_bps as f64;
100.0 - (diff / target * 100.0).min(100.0)
}
}
#[derive(Debug, Clone)]
struct BinarySearchState {
low: U256,
high: U256,
initial_high: U256,
iterations: u32,
expansions: u32,
converged: bool,
final_slippage_bps: Option<u32>,
}
pub fn estimate_max_size_for_impact(
profiler: &PoolProfiler,
impact_bps: u32,
zero_for_one: bool,
) -> U256 {
let liquidity = profiler.get_active_liquidity();
if liquidity == 0 {
return U256::from(1_000_000);
}
let sqrt_price = U256::from(profiler.state.price_sqrt_ratio_x96);
let q96 = U256::from(1u128) << 96;
let liquidity_u256 = U256::from(liquidity);
let impact_ratio = U256::from(impact_bps);
let base = if zero_for_one {
(liquidity_u256 * q96 * impact_ratio) / (sqrt_price * U256::from(10000))
} else {
(liquidity_u256 * sqrt_price * impact_ratio) / (q96 * U256::from(10000))
};
let doubled = base * U256::from(2);
let min_val = U256::from(1_000_000);
let max_val = U256::from(1_000_000_000_000_000_000_000_000_000_000u128);
if doubled < min_val {
min_val
} else if doubled > max_val {
max_val
} else {
doubled
}
}
pub fn slippage_for_size_bps(
profiler: &PoolProfiler,
size: U256,
zero_for_one: bool,
) -> anyhow::Result<u32> {
profiler.check_if_initialized();
if size.is_zero() {
return Ok(0);
}
let mut quote = profiler.swap_exact_in(size, zero_for_one, None)?;
quote.calculate_trade_info(&profiler.pool.token0, &profiler.pool.token1)?;
let trade_info = quote
.trade_info
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Trade info not initialized"))?;
trade_info.get_slippage_bps()
}
fn binary_search_for_size(
profiler: &PoolProfiler,
impact_bps: u32,
zero_for_one: bool,
config: &EstimationConfig,
) -> anyhow::Result<BinarySearchState> {
if impact_bps == 0 {
anyhow::bail!("Impact must be greater than zero");
}
if impact_bps > 10000 {
anyhow::bail!("Impact cannot exceed 100% (10000 bps)");
}
profiler.check_if_initialized();
let mut low = U256::ZERO;
let mut high = estimate_max_size_for_impact(profiler, impact_bps, zero_for_one);
let initial_high = high;
let mut iterations = 0;
let mut expansions = 0;
let mut converged = false;
let mut final_slippage_bps = None;
while iterations < config.max_iterations {
iterations += 1;
let mid = (low + high) / U256::from(2);
if mid.is_zero() {
break;
}
let slippage_mid = match slippage_for_size_bps(profiler, mid, zero_for_one) {
Ok(s) => s,
Err(_) => {
high = mid;
continue;
}
};
let diff_bps = slippage_mid.abs_diff(impact_bps);
if diff_bps <= config.tolerance_bps {
low = mid;
final_slippage_bps = Some(slippage_mid);
converged = true;
break;
}
if slippage_mid < impact_bps {
low = mid;
let range = high - low;
let threshold = range / U256::from(5);
if config.enable_adaptive_bounds
&& high - mid <= threshold
&& expansions < config.max_bound_expansions
{
high *= U256::from(2);
expansions += 1;
log::debug!(
"Expanding upper bound (expansion {}/{}): new high={}",
expansions,
config.max_bound_expansions,
high
);
}
} else {
high = mid;
}
}
if iterations >= config.max_iterations {
log::warn!(
"Binary search did not converge after {iterations} iterations, returning conservative estimate"
);
}
Ok(BinarySearchState {
low,
high,
initial_high,
iterations,
expansions,
converged,
final_slippage_bps,
})
}
pub fn size_for_impact_bps(
profiler: &PoolProfiler,
impact_bps: u32,
zero_for_one: bool,
config: &EstimationConfig,
) -> anyhow::Result<U256> {
let state = binary_search_for_size(profiler, impact_bps, zero_for_one, config)?;
Ok(state.low)
}
pub fn size_for_impact_bps_detailed(
profiler: &PoolProfiler,
impact_bps: u32,
zero_for_one: bool,
config: &EstimationConfig,
) -> anyhow::Result<SizeForImpactResult> {
let state = binary_search_for_size(profiler, impact_bps, zero_for_one, config)?;
let actual_impact = if let Some(slippage) = state.final_slippage_bps {
slippage
} else if state.low.is_zero() {
0
} else {
slippage_for_size_bps(profiler, state.low, zero_for_one)?
};
Ok(SizeForImpactResult {
target_impact_bps: impact_bps,
size: state.low,
actual_impact_bps: actual_impact,
zero_for_one,
iterations: state.iterations,
converged: state.converged,
expansion_count: state.expansions,
initial_high: state.initial_high,
final_low: state.low,
final_high: state.high,
})
}