use core::fmt;
use core::str::FromStr;
use alloy_primitives::U256;
use crate::{Error, error::Result, order::OrderKind};
const ONE_HUNDRED_BPS: u64 = 10_000;
const HUNDRED_THOUSANDS: u64 = 100_000;
const SCALE: U256 = U256::from_limbs([(ONE_HUNDRED_BPS * HUNDRED_THOUSANDS), 0, 0, 0]);
pub const DEFAULT_SLIPPAGE_BPS: u32 = 50;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct ProtocolFeeBps(u64);
impl ProtocolFeeBps {
pub const ZERO: Self = Self(0);
pub const fn scaled(self) -> u64 {
self.0
}
}
impl FromStr for ProtocolFeeBps {
type Err = Error;
fn from_str(raw: &str) -> Result<Self> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Ok(Self::ZERO);
}
if trimmed.starts_with('-') {
return Err(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "must be non-negative",
});
}
let (int_part, frac_part) = trimmed
.find('.')
.map_or((trimmed, ""), |i| (&trimmed[..i], &trimmed[i + 1..]));
if int_part.is_empty() && frac_part.is_empty() {
return Err(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "empty value",
});
}
let int_val: u64 = if int_part.is_empty() {
0
} else {
int_part.parse().map_err(|_| Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "integer part is not a decimal number",
})?
};
let frac_scaled = if frac_part.is_empty() {
0u64
} else {
if !frac_part.bytes().all(|b| b.is_ascii_digit()) {
return Err(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "fractional part is not a decimal number",
});
}
round_fraction_to_five_digits(frac_part.as_bytes())
};
int_val
.checked_mul(HUNDRED_THOUSANDS)
.and_then(|v| v.checked_add(frac_scaled))
.map(Self)
.ok_or(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "value too large for internal scale",
})
}
}
impl fmt::Display for ProtocolFeeBps {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let int = self.0 / HUNDRED_THOUSANDS;
let frac = self.0 % HUNDRED_THOUSANDS;
if frac == 0 {
return write!(f, "{int}");
}
let padded = format!("{frac:05}");
write!(f, "{int}.{}", padded.trim_end_matches('0'))
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct OrderCosts {
pub partner_fee_bps: u32,
pub slippage_bps: u32,
pub protocol_fee_bps_override: Option<ProtocolFeeBps>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Amounts {
pub sell_amount: U256,
pub buy_amount: U256,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct QuoteCosts {
pub network_fee_in_sell: U256,
pub network_fee_in_buy: U256,
pub partner_fee_amount: U256,
pub partner_fee_bps: u32,
pub protocol_fee_amount: U256,
pub protocol_fee_bps_scaled: u64,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct QuoteAmountsAndCosts {
pub is_sell: bool,
pub before_all_fees: Amounts,
pub after_protocol_fees: Amounts,
pub after_network_costs: Amounts,
pub after_partner_fees: Amounts,
pub after_slippage: Amounts,
pub amounts_to_sign: Amounts,
pub costs: QuoteCosts,
}
#[derive(Clone, Copy, Debug)]
pub struct QuoteAmountsParams {
pub kind: OrderKind,
pub sell_amount: U256,
pub buy_amount: U256,
pub fee_amount: U256,
pub partner_fee_bps: u32,
pub slippage_bps: u32,
pub protocol_fee_bps: Option<ProtocolFeeBps>,
}
pub fn compute(params: QuoteAmountsParams) -> Result<QuoteAmountsAndCosts> {
let QuoteAmountsParams {
kind,
sell_amount,
buy_amount,
fee_amount,
partner_fee_bps,
slippage_bps,
protocol_fee_bps,
} = params;
let is_sell = matches!(kind, OrderKind::Sell);
let protocol_fee_bps_scaled = protocol_fee_bps.unwrap_or_default().scaled();
if sell_amount.is_zero() {
return Err(Error::QuoteSellAmountZero);
}
let network_fee_in_buy = mul_div(
buy_amount,
fee_amount,
sell_amount,
"network_fee_in_buy.mul_div",
)?;
let protocol_fee_amount = protocol_fee_amount(
kind,
sell_amount,
buy_amount,
fee_amount,
protocol_fee_bps_scaled,
)?;
let dir = Dir::new(is_sell);
let before_all_fees = if is_sell {
let fixed = checked_add(sell_amount, fee_amount, "before_all_fees.sell")?;
let surplus = checked_add(
buy_amount,
network_fee_in_buy,
"before_all_fees.buy_network",
)?;
let surplus = checked_add(surplus, protocol_fee_amount, "before_all_fees.buy_protocol")?;
dir.amounts(fixed, surplus)
} else {
let surplus = checked_sub(
sell_amount,
protocol_fee_amount,
"before_all_fees.sell_protocol",
)?;
dir.amounts(buy_amount, surplus)
};
let after_protocol_fees = dir.amounts(
dir.fixed_of(&before_all_fees),
dir.peel_fee(
dir.surplus_of(&before_all_fees),
protocol_fee_amount,
dir.after_protocol_stage,
)?,
);
let after_network_costs = if is_sell {
dir.amounts(sell_amount, buy_amount)
} else {
let fixed = dir.fixed_of(&after_protocol_fees);
let surplus = checked_add(
dir.surplus_of(&after_protocol_fees),
fee_amount,
"after_network_costs.sell",
)?;
dir.amounts(fixed, surplus)
};
let partner_fee_amount = if partner_fee_bps == 0 {
U256::ZERO
} else {
mul_div(
dir.surplus_of(&before_all_fees),
U256::from(partner_fee_bps),
U256::from(ONE_HUNDRED_BPS),
"partner_fee.mul_div",
)?
};
let after_partner_fees = dir.amounts(
dir.fixed_of(&after_network_costs),
dir.peel_fee(
dir.surplus_of(&after_network_costs),
partner_fee_amount,
dir.after_partner_stage,
)?,
);
let surplus_after_partner = dir.surplus_of(&after_partner_fees);
let slip = mul_div(
surplus_after_partner,
U256::from(slippage_bps),
U256::from(ONE_HUNDRED_BPS),
dir.slippage_mul_div_stage,
)?;
let after_slippage = dir.amounts(
dir.fixed_of(&after_partner_fees),
dir.peel_fee(surplus_after_partner, slip, dir.after_slippage_stage)?,
);
let amounts_to_sign = dir.amounts(
dir.fixed_of(&before_all_fees),
dir.surplus_of(&after_slippage),
);
Ok(QuoteAmountsAndCosts {
is_sell,
before_all_fees,
after_protocol_fees,
after_network_costs,
after_partner_fees,
after_slippage,
amounts_to_sign,
costs: QuoteCosts {
network_fee_in_sell: fee_amount,
network_fee_in_buy,
partner_fee_amount,
partner_fee_bps,
protocol_fee_amount,
protocol_fee_bps_scaled,
},
})
}
#[derive(Clone, Copy)]
struct Dir {
is_sell: bool,
after_protocol_stage: &'static str,
after_partner_stage: &'static str,
slippage_mul_div_stage: &'static str,
after_slippage_stage: &'static str,
}
impl Dir {
const fn new(is_sell: bool) -> Self {
if is_sell {
Self {
is_sell,
after_protocol_stage: "after_protocol_fees.buy",
after_partner_stage: "after_partner_fees.buy",
slippage_mul_div_stage: "slippage_buy.mul_div",
after_slippage_stage: "after_slippage.buy",
}
} else {
Self {
is_sell,
after_protocol_stage: "after_protocol_fees.sell",
after_partner_stage: "after_partner_fees.sell",
slippage_mul_div_stage: "slippage_sell.mul_div",
after_slippage_stage: "after_slippage.sell",
}
}
}
const fn amounts(&self, fixed: U256, surplus: U256) -> Amounts {
if self.is_sell {
Amounts {
sell_amount: fixed,
buy_amount: surplus,
}
} else {
Amounts {
sell_amount: surplus,
buy_amount: fixed,
}
}
}
const fn fixed_of(&self, a: &Amounts) -> U256 {
if self.is_sell {
a.sell_amount
} else {
a.buy_amount
}
}
const fn surplus_of(&self, a: &Amounts) -> U256 {
if self.is_sell {
a.buy_amount
} else {
a.sell_amount
}
}
fn peel_fee(&self, surplus: U256, fee: U256, stage: &'static str) -> Result<U256> {
if self.is_sell {
checked_sub(surplus, fee, stage)
} else {
checked_add(surplus, fee, stage)
}
}
}
#[inline]
fn checked_add(a: U256, b: U256, stage: &'static str) -> Result<U256> {
a.checked_add(b)
.ok_or(Error::QuoteFeeMathOverflow { stage })
}
#[inline]
fn checked_sub(a: U256, b: U256, stage: &'static str) -> Result<U256> {
a.checked_sub(b)
.ok_or(Error::QuoteFeeMathOverflow { stage })
}
pub const fn protocol_fee_bps_scale() -> u64 {
HUNDRED_THOUSANDS
}
fn round_fraction_to_five_digits(frac_digits: &[u8]) -> u64 {
let digit_at = |i: usize| u64::from(frac_digits.get(i).map_or(0, |b| b - b'0'));
let five = (0..5).fold(0u64, |acc, i| acc * 10 + digit_at(i));
if digit_at(5) >= 5 { five + 1 } else { five }
}
fn protocol_fee_amount(
kind: OrderKind,
sell_amount: U256,
buy_amount: U256,
fee_amount: U256,
bps_scaled: u64,
) -> Result<U256> {
if bps_scaled == 0 {
return Ok(U256::ZERO);
}
let bps_big = U256::from(bps_scaled);
match kind {
OrderKind::Sell => {
let denom = SCALE
.checked_sub(bps_big)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom",
})?;
if denom.is_zero() {
return Err(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom",
});
}
mul_div(buy_amount, bps_big, denom, "protocol_fee.sell_mul_div")
}
OrderKind::Buy => {
let denom = SCALE
.checked_add(bps_big)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.buy_denom",
})?;
let base = sell_amount
.checked_add(fee_amount)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.buy_base",
})?;
mul_div(base, bps_big, denom, "protocol_fee.buy_mul_div")
}
}
}
#[inline]
fn mul_div(a: U256, b: U256, c: U256, stage: &'static str) -> Result<U256> {
debug_assert!(
!c.is_zero(),
"mul_div denominator must be non-zero at {stage}"
);
let prod = a
.checked_mul(b)
.ok_or(Error::QuoteFeeMathOverflow { stage })?;
Ok(prod / c)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn matches_ts_pr_867_partner_plus_protocol_fee_divergence() {
let base = QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 100,
slippage_bps: 50,
protocol_fee_bps: None,
};
let without = compute(base).unwrap();
assert_eq!(
without.amounts_to_sign.buy_amount,
U256::from(1_970_100_000_000_000_000u128),
"buyAmount without protocol fee must match TS PR #867 baseline",
);
let with_protocol = compute(QuoteAmountsParams {
protocol_fee_bps: Some("5".parse().unwrap()),
..base
})
.unwrap();
assert_eq!(
with_protocol.amounts_to_sign.buy_amount,
U256::from(1_970_090_045_022_511_257u128),
"buyAmount with protocolFeeBps=5 must match TS PR #867 expected value",
);
assert!(
with_protocol.amounts_to_sign.buy_amount < without.amounts_to_sign.buy_amount,
"protocol fee enlarges the partner-fee base, so signed buy_amount must drop",
);
}
#[test]
fn sell_order_no_fees_passes_amounts_through_to_signing() {
let res = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap();
assert_eq!(
res.amounts_to_sign.sell_amount,
U256::from(1_000_000_000_000_000_000u128)
);
assert_eq!(
res.amounts_to_sign.buy_amount,
U256::from(2_000_000_000_000_000_000u128)
);
assert_eq!(res.costs.partner_fee_amount, U256::ZERO);
assert_eq!(res.costs.protocol_fee_amount, U256::ZERO);
}
#[test]
fn sell_order_folds_fee_amount_into_signed_sell() {
let res = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::from(1_000_000_000_000_000u128),
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap();
assert_eq!(
res.amounts_to_sign.sell_amount,
U256::from(1_001_000_000_000_000_000u128),
);
}
#[test]
fn buy_order_applies_slippage_to_sell_side() {
let res = compute(QuoteAmountsParams {
kind: OrderKind::Buy,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 50,
protocol_fee_bps: None,
})
.unwrap();
assert_eq!(
res.amounts_to_sign.sell_amount,
U256::from(1_005_000_000_000_000_000u128),
);
assert_eq!(
res.amounts_to_sign.buy_amount,
U256::from(2_000_000_000_000_000_000u128)
);
}
fn parse(raw: &str) -> Result<u64> {
raw.parse::<ProtocolFeeBps>().map(ProtocolFeeBps::scaled)
}
#[test]
fn protocol_fee_bps_accepts_decimal_string() {
assert_eq!(parse("").unwrap(), 0);
assert_eq!(parse("0").unwrap(), 0);
assert_eq!(parse("5").unwrap(), 500_000);
assert_eq!(parse("0.3").unwrap(), 30_000);
assert_eq!(parse("0.30000").unwrap(), 30_000);
assert_eq!(parse("0.000005").unwrap(), 1);
assert_eq!(parse("0.000004").unwrap(), 0);
}
#[test]
fn protocol_fee_bps_rejects_negative_or_garbage() {
for garbage in ["-1", "abc", "0.x"] {
assert!(matches!(
parse(garbage),
Err(Error::InvalidProtocolFeeBps { .. }),
));
}
}
#[test]
fn protocol_fee_bps_display_renders_wire_decimal() {
for (raw, rendered) in [("5", "5"), ("0.3", "0.3"), ("0.30000", "0.3"), ("0", "0")] {
assert_eq!(raw.parse::<ProtocolFeeBps>().unwrap().to_string(), rendered);
}
assert_eq!(ProtocolFeeBps::ZERO.to_string(), "0");
}
#[test]
fn rejects_sell_plus_fee_overflow() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::MAX,
buy_amount: U256::from(1u64),
fee_amount: U256::from(1u64),
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap_err();
assert!(
matches!(
err,
Error::QuoteFeeMathOverflow {
stage: "before_all_fees.sell"
},
),
"got: {err:?}",
);
}
#[test]
fn rejects_protocol_fee_mul_div_overflow_on_sell() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1u64),
buy_amount: U256::MAX,
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("5".parse().unwrap()),
})
.unwrap_err();
assert!(
matches!(err, Error::QuoteFeeMathOverflow { .. }),
"got: {err:?}",
);
}
#[test]
fn rejects_slippage_above_one_hundred_percent_on_sell() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(1_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 20_000,
protocol_fee_bps: None,
})
.unwrap_err();
assert!(
matches!(
err,
Error::QuoteFeeMathOverflow {
stage: "after_slippage.buy"
},
),
"got: {err:?}",
);
}
#[test]
fn rejects_buy_protocol_fee_underflowing_sell_amount() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Buy,
sell_amount: U256::from(10u64),
buy_amount: U256::from(1u64),
fee_amount: U256::from(1_000u64),
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("9999.99999".parse().unwrap()),
})
.unwrap_err();
assert!(
matches!(
err,
Error::QuoteFeeMathOverflow {
stage: "before_all_fees.sell_protocol"
},
),
"got: {err:?}",
);
}
#[test]
fn rejects_protocol_fee_bps_at_or_above_one_hundred_percent() {
let err_at = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(1_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("10000".parse().unwrap()),
})
.unwrap_err();
assert!(
matches!(
err_at,
Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom"
},
),
"got: {err_at:?}",
);
let err_above = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(1_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("10001".parse().unwrap()),
})
.unwrap_err();
assert!(
matches!(
err_above,
Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom"
},
),
"got: {err_above:?}",
);
}
#[test]
fn zero_sell_amount_is_refused_instead_of_dividing_by_zero() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::ZERO,
buy_amount: U256::from(1_u64),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap_err();
assert!(matches!(err, Error::QuoteSellAmountZero));
}
}