use alloy_primitives::U256;
use rust_decimal::{Decimal, prelude::ToPrimitive as _};
use super::types::QuoteAmountsAndCosts;
pub const DEFAULT_FEE_SLIPPAGE_FACTOR_PCT: u32 = 50;
pub const DEFAULT_VOLUME_SLIPPAGE_BPS: u32 = 50;
pub const MAX_SLIPPAGE_BPS: u32 = 10_000;
#[must_use]
pub fn suggest_slippage_from_fee(fee_amount: U256, multiply_factor_pct: u32) -> U256 {
fee_amount * U256::from(multiply_factor_pct) / U256::from(100u32)
}
#[must_use]
pub fn suggest_slippage_from_volume(
sell_before: U256,
sell_after: U256,
is_sell: bool,
volume_bps: u32,
) -> U256 {
let base = if is_sell { sell_after } else { sell_before };
if base.is_zero() {
return U256::ZERO;
}
base * U256::from(volume_bps) / U256::from(10_000u32)
}
#[must_use]
pub fn suggest_slippage_bps(
costs: &QuoteAmountsAndCosts,
fee_factor_pct: u32,
volume_bps: u32,
min_bps: u32,
) -> u32 {
let sell_before = costs.before_network_costs.sell_amount;
if sell_before.is_zero() {
return min_bps;
}
let fee = costs.network_fee.amount_in_sell_currency;
let sell_after = costs.after_network_costs.sell_amount;
let fee_component = suggest_slippage_from_fee(fee, fee_factor_pct);
let vol_component =
suggest_slippage_from_volume(sell_before, sell_after, costs.is_sell, volume_bps);
let total = fee_component + vol_component;
let suggested_u64: u64 = (total * U256::from(10_000u32) / sell_before)
.try_into()
.unwrap_or_else(|_| u64::from(MAX_SLIPPAGE_BPS));
let suggested = suggested_u64.min(u64::from(MAX_SLIPPAGE_BPS)) as u32;
suggested.max(min_bps)
}
#[must_use]
pub fn percentage_to_bps(percentage: Decimal) -> u32 {
let bps = (percentage * Decimal::from(100)).round();
bps.to_u32().map_or(0, |v| v)
}
#[must_use]
pub fn bps_to_percentage(bps: u32) -> Decimal {
Decimal::from(bps) / Decimal::from(100)
}
#[must_use]
pub fn apply_percentage(value: U256, percentage: Decimal) -> U256 {
let bps = percentage_to_bps(percentage);
value * U256::from(bps) / U256::from(100u32)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trading::types::{Amounts, NetworkFee, PartnerFeeCost, ProtocolFeeCost};
#[test]
fn slippage_from_fee_50_percent() {
assert_eq!(suggest_slippage_from_fee(U256::from(1_000u32), 50), U256::from(500u32));
}
#[test]
fn slippage_from_fee_zero_factor() {
assert_eq!(suggest_slippage_from_fee(U256::from(1_000u32), 0), U256::ZERO);
}
#[test]
fn slippage_from_fee_100_percent() {
assert_eq!(suggest_slippage_from_fee(U256::from(1_000u32), 100), U256::from(1_000u32));
}
#[test]
fn slippage_from_volume_sell_order() {
let result =
suggest_slippage_from_volume(U256::from(10_000u32), U256::from(9_000u32), true, 50);
assert_eq!(result, U256::from(45u32)); }
#[test]
fn slippage_from_volume_buy_order() {
let result =
suggest_slippage_from_volume(U256::from(10_000u32), U256::from(9_000u32), false, 50);
assert_eq!(result, U256::from(50u32));
}
#[test]
fn slippage_from_volume_zero_base() {
let result = suggest_slippage_from_volume(U256::ZERO, U256::ZERO, true, 50);
assert_eq!(result, U256::ZERO);
}
fn make_costs(
sell_before: u64,
sell_after: u64,
fee: u64,
is_sell: bool,
) -> QuoteAmountsAndCosts {
QuoteAmountsAndCosts {
is_sell,
before_all_fees: Amounts {
sell_amount: U256::from(sell_before),
buy_amount: U256::from(100u64),
},
before_network_costs: Amounts {
sell_amount: U256::from(sell_before),
buy_amount: U256::from(100u64),
},
after_network_costs: Amounts {
sell_amount: U256::from(sell_after),
buy_amount: U256::from(100u64),
},
after_partner_fees: Amounts {
sell_amount: U256::from(sell_after),
buy_amount: U256::from(100u64),
},
after_slippage: Amounts {
sell_amount: U256::from(sell_after),
buy_amount: U256::from(100u64),
},
network_fee: NetworkFee {
amount_in_sell_currency: U256::from(fee),
amount_in_buy_currency: U256::ZERO,
},
partner_fee: PartnerFeeCost { amount: U256::ZERO, bps: 0 },
protocol_fee: ProtocolFeeCost { amount: U256::ZERO, bps: 0 },
}
}
#[test]
fn suggest_slippage_bps_basic() {
let costs = make_costs(10_000, 9_000, 1_000, true);
let bps = suggest_slippage_bps(&costs, 50, 50, 0);
assert!(bps > 0);
}
#[test]
fn suggest_slippage_bps_zero_sell_returns_min() {
let costs = make_costs(0, 0, 0, true);
let bps = suggest_slippage_bps(&costs, 50, 50, 42);
assert_eq!(bps, 42);
}
#[test]
fn suggest_slippage_bps_respects_min() {
let costs = make_costs(1_000_000, 999_000, 1_000, true);
let bps = suggest_slippage_bps(&costs, 50, 50, 200);
assert!(bps >= 200);
}
#[test]
fn percentage_to_bps_half_percent() {
assert_eq!(percentage_to_bps(Decimal::new(5, 1)), 50);
}
#[test]
fn percentage_to_bps_one_percent() {
assert_eq!(percentage_to_bps(Decimal::new(1, 0)), 100);
}
#[test]
fn percentage_to_bps_zero() {
assert_eq!(percentage_to_bps(Decimal::ZERO), 0);
}
#[test]
fn bps_to_percentage_50() {
assert_eq!(bps_to_percentage(50), Decimal::new(5, 1));
}
#[test]
fn bps_to_percentage_100() {
assert_eq!(bps_to_percentage(100), Decimal::new(1, 0));
}
#[test]
fn bps_to_percentage_zero() {
assert_eq!(bps_to_percentage(0), Decimal::ZERO);
}
#[test]
fn apply_percentage_half_percent() {
assert_eq!(apply_percentage(U256::from(200u32), Decimal::new(5, 1)), U256::from(100u32));
}
#[test]
fn apply_percentage_zero() {
assert_eq!(apply_percentage(U256::from(1_000u32), Decimal::ZERO), U256::ZERO);
}
#[test]
fn bps_percentage_roundtrip() {
for bps in [0, 1, 25, 50, 100, 500, 10_000] {
let pct = bps_to_percentage(bps);
assert_eq!(percentage_to_bps(pct), bps);
}
}
}