use alloy_primitives::U256;
use crate::{Error, error::Result, order::OrderKind};
const ONE_HUNDRED_BPS: u64 = 10_000;
const HUNDRED_THOUSANDS: u64 = 100_000;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Amounts {
pub sell_amount: U256,
pub buy_amount: U256,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct QuoteCosts {
pub network_fee_in_sell: U256,
pub network_fee_in_buy: U256,
pub partner_fee_amount: U256,
pub partner_fee_bps: u32,
pub protocol_fee_amount: U256,
pub protocol_fee_bps_scaled: u64,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct QuoteAmountsAndCosts {
pub is_sell: bool,
pub before_all_fees: Amounts,
pub after_protocol_fees: Amounts,
pub after_network_costs: Amounts,
pub after_partner_fees: Amounts,
pub after_slippage: Amounts,
pub amounts_to_sign: Amounts,
pub costs: QuoteCosts,
}
#[derive(Clone, Copy, Debug)]
pub struct QuoteAmountsParams<'a> {
pub kind: OrderKind,
pub sell_amount: U256,
pub buy_amount: U256,
pub fee_amount: U256,
pub partner_fee_bps: u32,
pub slippage_bps: u32,
pub protocol_fee_bps: Option<&'a str>,
}
pub fn compute(params: QuoteAmountsParams<'_>) -> Result<QuoteAmountsAndCosts> {
let QuoteAmountsParams {
kind,
sell_amount,
buy_amount,
fee_amount,
partner_fee_bps,
slippage_bps,
protocol_fee_bps,
} = params;
let is_sell = matches!(kind, OrderKind::Sell);
let protocol_fee_bps_scaled = parse_protocol_fee_bps(protocol_fee_bps)?;
if sell_amount.is_zero() {
return Err(Error::QuoteSellAmountZero);
}
let network_fee_in_buy = mul_div(
buy_amount,
fee_amount,
sell_amount,
"network_fee_in_buy.mul_div",
)?;
let protocol_fee_amount = protocol_fee_amount(
kind,
sell_amount,
buy_amount,
fee_amount,
protocol_fee_bps_scaled,
)?;
let before_all_fees = if is_sell {
Amounts {
sell_amount: sell_amount.checked_add(fee_amount).ok_or(
Error::QuoteFeeMathOverflow {
stage: "before_all_fees.sell",
},
)?,
buy_amount: buy_amount
.checked_add(network_fee_in_buy)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "before_all_fees.buy_network",
})?
.checked_add(protocol_fee_amount)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "before_all_fees.buy_protocol",
})?,
}
} else {
Amounts {
sell_amount: sell_amount.checked_sub(protocol_fee_amount).ok_or(
Error::QuoteFeeMathOverflow {
stage: "before_all_fees.sell_protocol",
},
)?,
buy_amount,
}
};
let after_protocol_fees = if is_sell {
Amounts {
sell_amount: before_all_fees.sell_amount,
buy_amount: before_all_fees
.buy_amount
.checked_sub(protocol_fee_amount)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "after_protocol_fees.buy",
})?,
}
} else {
Amounts {
sell_amount,
buy_amount: before_all_fees.buy_amount,
}
};
let after_network_costs = if is_sell {
Amounts {
sell_amount,
buy_amount,
}
} else {
Amounts {
sell_amount: sell_amount.checked_add(fee_amount).ok_or(
Error::QuoteFeeMathOverflow {
stage: "after_network_costs.sell",
},
)?,
buy_amount: after_protocol_fees.buy_amount,
}
};
let surplus_base = if is_sell {
before_all_fees.buy_amount
} else {
before_all_fees.sell_amount
};
let partner_fee_amount = if partner_fee_bps == 0 {
U256::ZERO
} else {
mul_div(
surplus_base,
U256::from(partner_fee_bps),
U256::from(ONE_HUNDRED_BPS),
"partner_fee.mul_div",
)?
};
let after_partner_fees = if is_sell {
Amounts {
sell_amount: after_network_costs.sell_amount,
buy_amount: after_network_costs
.buy_amount
.checked_sub(partner_fee_amount)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "after_partner_fees.buy",
})?,
}
} else {
Amounts {
sell_amount: after_network_costs
.sell_amount
.checked_add(partner_fee_amount)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "after_partner_fees.sell",
})?,
buy_amount: after_network_costs.buy_amount,
}
};
let after_slippage = if is_sell {
let slip = mul_div(
after_partner_fees.buy_amount,
U256::from(slippage_bps),
U256::from(ONE_HUNDRED_BPS),
"slippage_buy.mul_div",
)?;
Amounts {
sell_amount: after_partner_fees.sell_amount,
buy_amount: after_partner_fees.buy_amount.checked_sub(slip).ok_or(
Error::QuoteFeeMathOverflow {
stage: "after_slippage.buy",
},
)?,
}
} else {
let slip = mul_div(
after_partner_fees.sell_amount,
U256::from(slippage_bps),
U256::from(ONE_HUNDRED_BPS),
"slippage_sell.mul_div",
)?;
Amounts {
sell_amount: after_partner_fees.sell_amount.checked_add(slip).ok_or(
Error::QuoteFeeMathOverflow {
stage: "after_slippage.sell",
},
)?,
buy_amount: after_partner_fees.buy_amount,
}
};
let amounts_to_sign = if is_sell {
Amounts {
sell_amount: before_all_fees.sell_amount,
buy_amount: after_slippage.buy_amount,
}
} else {
Amounts {
sell_amount: after_slippage.sell_amount,
buy_amount: before_all_fees.buy_amount,
}
};
Ok(QuoteAmountsAndCosts {
is_sell,
before_all_fees,
after_protocol_fees,
after_network_costs,
after_partner_fees,
after_slippage,
amounts_to_sign,
costs: QuoteCosts {
network_fee_in_sell: fee_amount,
network_fee_in_buy,
partner_fee_amount,
partner_fee_bps,
protocol_fee_amount,
protocol_fee_bps_scaled,
},
})
}
pub const fn protocol_fee_bps_scale() -> u64 {
HUNDRED_THOUSANDS
}
fn parse_protocol_fee_bps(input: Option<&str>) -> Result<u64> {
let Some(raw) = input else {
return Ok(0);
};
let trimmed = raw.trim();
if trimmed.is_empty() {
return Ok(0);
}
if trimmed.starts_with('-') {
return Err(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "must be non-negative",
});
}
let (int_part, frac_part) = trimmed
.find('.')
.map_or((trimmed, ""), |i| (&trimmed[..i], &trimmed[i + 1..]));
if int_part.is_empty() && frac_part.is_empty() {
return Err(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "empty value",
});
}
let int_val: u64 = if int_part.is_empty() {
0
} else {
int_part.parse().map_err(|_| Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "integer part is not a decimal number",
})?
};
let frac_scaled = if frac_part.is_empty() {
0u64
} else {
if !frac_part.bytes().all(|b| b.is_ascii_digit()) {
return Err(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "fractional part is not a decimal number",
});
}
let mut padded = String::with_capacity(6);
padded.push_str(frac_part);
while padded.len() < 6 {
padded.push('0');
}
let five: u64 = padded[..5]
.parse()
.map_err(|_| Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "fractional part overflows internal scale",
})?;
let round_digit = padded.as_bytes()[5] - b'0';
if round_digit >= 5 { five + 1 } else { five }
};
int_val
.checked_mul(HUNDRED_THOUSANDS)
.and_then(|v| v.checked_add(frac_scaled))
.ok_or(Error::InvalidProtocolFeeBps {
value: raw.to_owned(),
reason: "value too large for internal scale",
})
}
fn protocol_fee_amount(
kind: OrderKind,
sell_amount: U256,
buy_amount: U256,
fee_amount: U256,
bps_scaled: u64,
) -> Result<U256> {
if bps_scaled == 0 {
return Ok(U256::ZERO);
}
let bps_big = U256::from(bps_scaled);
let scale = U256::from(ONE_HUNDRED_BPS)
.checked_mul(U256::from(HUNDRED_THOUSANDS))
.expect("ONE_HUNDRED_BPS * HUNDRED_THOUSANDS fits in U256");
match kind {
OrderKind::Sell => {
let denom = scale
.checked_sub(bps_big)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom",
})?;
if denom.is_zero() {
return Err(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom",
});
}
mul_div(buy_amount, bps_big, denom, "protocol_fee.sell_mul_div")
}
OrderKind::Buy => {
let denom = scale
.checked_add(bps_big)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.buy_denom",
})?;
let base = sell_amount
.checked_add(fee_amount)
.ok_or(Error::QuoteFeeMathOverflow {
stage: "protocol_fee.buy_base",
})?;
mul_div(base, bps_big, denom, "protocol_fee.buy_mul_div")
}
}
}
#[inline]
fn mul_div(a: U256, b: U256, c: U256, stage: &'static str) -> Result<U256> {
if c.is_zero() {
return Ok(U256::ZERO);
}
let prod = a
.checked_mul(b)
.ok_or(Error::QuoteFeeMathOverflow { stage })?;
Ok(prod / c)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn matches_ts_pr_867_partner_plus_protocol_fee_divergence() {
let base = QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 100,
slippage_bps: 50,
protocol_fee_bps: None,
};
let without = compute(base).unwrap();
assert_eq!(
without.amounts_to_sign.buy_amount,
U256::from(1_970_100_000_000_000_000u128),
"buyAmount without protocol fee must match TS PR #867 baseline",
);
let with_protocol = compute(QuoteAmountsParams {
protocol_fee_bps: Some("5"),
..base
})
.unwrap();
assert_eq!(
with_protocol.amounts_to_sign.buy_amount,
U256::from(1_970_090_045_022_511_257u128),
"buyAmount with protocolFeeBps=5 must match TS PR #867 expected value",
);
assert!(
with_protocol.amounts_to_sign.buy_amount < without.amounts_to_sign.buy_amount,
"protocol fee enlarges the partner-fee base, so signed buy_amount must drop",
);
}
#[test]
fn sell_order_no_fees_passes_amounts_through_to_signing() {
let res = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap();
assert_eq!(
res.amounts_to_sign.sell_amount,
U256::from(1_000_000_000_000_000_000u128)
);
assert_eq!(
res.amounts_to_sign.buy_amount,
U256::from(2_000_000_000_000_000_000u128)
);
assert_eq!(res.costs.partner_fee_amount, U256::ZERO);
assert_eq!(res.costs.protocol_fee_amount, U256::ZERO);
}
#[test]
fn sell_order_folds_fee_amount_into_signed_sell() {
let res = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::from(1_000_000_000_000_000u128),
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap();
assert_eq!(
res.amounts_to_sign.sell_amount,
U256::from(1_001_000_000_000_000_000u128),
);
}
#[test]
fn buy_order_applies_slippage_to_sell_side() {
let res = compute(QuoteAmountsParams {
kind: OrderKind::Buy,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(2_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 50,
protocol_fee_bps: None,
})
.unwrap();
assert_eq!(
res.amounts_to_sign.sell_amount,
U256::from(1_005_000_000_000_000_000u128),
);
assert_eq!(
res.amounts_to_sign.buy_amount,
U256::from(2_000_000_000_000_000_000u128)
);
}
#[test]
fn protocol_fee_bps_accepts_decimal_string() {
assert_eq!(parse_protocol_fee_bps(None).unwrap(), 0);
assert_eq!(parse_protocol_fee_bps(Some("")).unwrap(), 0);
assert_eq!(parse_protocol_fee_bps(Some("0")).unwrap(), 0);
assert_eq!(parse_protocol_fee_bps(Some("5")).unwrap(), 500_000);
assert_eq!(parse_protocol_fee_bps(Some("0.3")).unwrap(), 30_000);
assert_eq!(parse_protocol_fee_bps(Some("0.30000")).unwrap(), 30_000);
assert_eq!(parse_protocol_fee_bps(Some("0.000005")).unwrap(), 1);
assert_eq!(parse_protocol_fee_bps(Some("0.000004")).unwrap(), 0);
}
#[test]
fn protocol_fee_bps_rejects_negative_or_garbage() {
assert!(matches!(
parse_protocol_fee_bps(Some("-1")),
Err(Error::InvalidProtocolFeeBps { .. }),
));
assert!(matches!(
parse_protocol_fee_bps(Some("abc")),
Err(Error::InvalidProtocolFeeBps { .. }),
));
assert!(matches!(
parse_protocol_fee_bps(Some("0.x")),
Err(Error::InvalidProtocolFeeBps { .. }),
));
}
#[test]
fn rejects_sell_plus_fee_overflow() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::MAX,
buy_amount: U256::from(1u64),
fee_amount: U256::from(1u64),
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap_err();
assert!(
matches!(
err,
Error::QuoteFeeMathOverflow {
stage: "before_all_fees.sell"
},
),
"got: {err:?}",
);
}
#[test]
fn rejects_protocol_fee_mul_div_overflow_on_sell() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1u64),
buy_amount: U256::MAX,
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("5"),
})
.unwrap_err();
assert!(
matches!(err, Error::QuoteFeeMathOverflow { .. }),
"got: {err:?}",
);
}
#[test]
fn rejects_slippage_above_one_hundred_percent_on_sell() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(1_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 20_000,
protocol_fee_bps: None,
})
.unwrap_err();
assert!(
matches!(
err,
Error::QuoteFeeMathOverflow {
stage: "after_slippage.buy"
},
),
"got: {err:?}",
);
}
#[test]
fn rejects_buy_protocol_fee_underflowing_sell_amount() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Buy,
sell_amount: U256::from(10u64),
buy_amount: U256::from(1u64),
fee_amount: U256::from(1_000u64),
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("9999.99999"),
})
.unwrap_err();
assert!(
matches!(
err,
Error::QuoteFeeMathOverflow {
stage: "before_all_fees.sell_protocol"
},
),
"got: {err:?}",
);
}
#[test]
fn rejects_protocol_fee_bps_at_or_above_one_hundred_percent() {
let err_at = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(1_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("10000"),
})
.unwrap_err();
assert!(
matches!(
err_at,
Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom"
},
),
"got: {err_at:?}",
);
let err_above = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::from(1_000_000_000_000_000_000u128),
buy_amount: U256::from(1_000_000_000_000_000_000u128),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: Some("10001"),
})
.unwrap_err();
assert!(
matches!(
err_above,
Error::QuoteFeeMathOverflow {
stage: "protocol_fee.sell_denom"
},
),
"got: {err_above:?}",
);
}
#[test]
fn zero_sell_amount_is_refused_instead_of_dividing_by_zero() {
let err = compute(QuoteAmountsParams {
kind: OrderKind::Sell,
sell_amount: U256::ZERO,
buy_amount: U256::from(1_u64),
fee_amount: U256::ZERO,
partner_fee_bps: 0,
slippage_bps: 0,
protocol_fee_bps: None,
})
.unwrap_err();
assert!(matches!(err, Error::QuoteSellAmountZero));
}
}