use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tycho_simulation::tycho_common::Bytes;
pub const LEGACY_BPS_DENOMINATOR: u64 = 10_000;
const FALLBACK_MAX_FEE_UNITS: u64 = 100_000_000;
const FALLBACK_FEE_ON_OUTPUT: u32 = 1_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FeeRates {
on_output: u32,
on_client_fee: u32,
max_fee_units: u64,
}
impl FeeRates {
pub fn new(on_output: u32, on_client_fee: u32, max_fee_units: u64) -> Self {
Self { on_output, on_client_fee, max_fee_units }
}
pub fn on_output(&self) -> u32 {
self.on_output
}
pub fn on_client_fee(&self) -> u32 {
self.on_client_fee
}
pub fn max_fee_units(&self) -> u64 {
self.max_fee_units
}
pub fn fee_units_per_bps(&self) -> u64 {
self.max_fee_units / LEGACY_BPS_DENOMINATOR
}
pub fn max_fee_units_squared(&self) -> u128 {
(self.max_fee_units as u128) * (self.max_fee_units as u128)
}
}
#[derive(Debug, Clone)]
pub struct RouterFees {
max_fee_units: u64,
default_fee_on_output: u32,
default_fee_on_client_fee: u32,
custom_fees: HashMap<Bytes, (u32, u32)>,
}
impl RouterFees {
pub fn new(
max_fee_units: u64,
default_fee_on_output: u32,
default_fee_on_client_fee: u32,
custom_fees: HashMap<Bytes, (u32, u32)>,
) -> Self {
Self { max_fee_units, default_fee_on_output, default_fee_on_client_fee, custom_fees }
}
pub fn fallback() -> Self {
Self::new(FALLBACK_MAX_FEE_UNITS, FALLBACK_FEE_ON_OUTPUT, 0, HashMap::new())
}
pub fn max_fee_units(&self) -> u64 {
self.max_fee_units
}
pub fn fees_for(&self, client: &Bytes) -> FeeRates {
let (on_output, on_client_fee) = self
.custom_fees
.get(client)
.copied()
.unwrap_or((self.default_fee_on_output, self.default_fee_on_client_fee));
FeeRates::new(on_output, on_client_fee, self.max_fee_units)
}
pub fn custom_client_count(&self) -> usize {
self.custom_fees.len()
}
}
#[derive(Debug, Clone)]
pub struct SharedRouterFees(Arc<RwLock<RouterFees>>);
impl Default for SharedRouterFees {
fn default() -> Self {
Self(Arc::new(RwLock::new(RouterFees::fallback())))
}
}
impl SharedRouterFees {
pub fn snapshot(&self) -> RouterFees {
self.0
.read()
.expect("router fees lock poisoned")
.clone()
}
pub fn set(&self, fees: RouterFees) {
*self
.0
.write()
.expect("router fees lock poisoned") = fees;
}
}
#[cfg(test)]
mod tests {
use super::*;
const SCALE: u64 = 100_000_000;
fn client(byte: u8) -> Bytes {
Bytes::from(vec![byte; 20])
}
#[test]
fn test_fees_for_unknown_client() {
let fees = RouterFees::new(SCALE, 100_000, 20_000_000, HashMap::new());
assert_eq!(fees.fees_for(&client(0xAA)), FeeRates::new(100_000, 20_000_000, SCALE));
}
#[test]
fn test_fees_for_known_client() {
let custom = HashMap::from([(client(0xAA), (50_000u32, 10_000_000u32))]);
let fees = RouterFees::new(SCALE, 100_000, 20_000_000, custom);
assert_eq!(fees.fees_for(&client(0xAA)), FeeRates::new(50_000, 10_000_000, SCALE));
assert_eq!(fees.fees_for(&client(0xBB)), FeeRates::new(100_000, 20_000_000, SCALE));
}
#[test]
fn test_fallback_is_point_one_bps_on_output() {
let rates = RouterFees::fallback().fees_for(&client(0xAA));
assert_eq!(rates.on_output(), 1_000);
assert_eq!(rates.on_client_fee(), 0);
assert_eq!(rates.max_fee_units(), 100_000_000);
}
#[test]
fn test_shared_router_fees_set_overrides() {
let shared = SharedRouterFees::default();
assert_eq!(
shared
.snapshot()
.fees_for(&client(0xAA))
.on_output(),
1_000
);
shared.set(RouterFees::new(SCALE, 1, 2, HashMap::new()));
let snapshot = shared.snapshot();
assert_eq!(snapshot.max_fee_units(), SCALE);
assert_eq!(snapshot.fees_for(&client(0xAA)), FeeRates::new(1, 2, SCALE));
}
}