Documentation
use num_enum::{IntoPrimitive, TryFromPrimitive};
use std::convert::TryInto;

#[cfg(test)]
use proptest_derive::Arbitrary;

#[derive(Copy, Clone, IntoPrimitive, TryFromPrimitive, Debug)]
#[cfg_attr(test, derive(Arbitrary))]
#[repr(u8)]
pub enum FeeTier {
    Base,
    SRM2,
    SRM3,
    SRM4,
    SRM5,
    SRM6,
    MSRM,
}

#[repr(transparent)]
#[derive(Copy, Clone)]
struct U64F64(u128);

impl U64F64 {
    const ONE: Self = U64F64(1 << 64);

    #[inline(always)]
    const fn add(self, other: U64F64) -> U64F64 {
        U64F64(self.0 + other.0)
    }

    #[inline(always)]
    const fn div(self, other: U64F64) -> u128 {
        self.0 / other.0
    }

    #[inline(always)]
    const fn mul_u64(self, other: u64) -> U64F64 {
        U64F64(self.0 * other as u128)
    }

    #[inline(always)]
    const fn floor(self) -> u64 {
        (self.0 >> 64) as u64
    }

    #[inline(always)]
    const fn frac_part(self) -> u64 {
        self.0 as u64
    }

    #[inline(always)]
    const fn from_int(n: u64) -> Self {
        U64F64((n as u128) << 64)
    }
}

#[inline(always)]
const fn fee_bps(bps: u64) -> U64F64 {
    U64F64(((bps as u128) << 64) / 10_000)
}

#[inline(always)]
const fn rebate_bps(bps: u64) -> U64F64 {
    U64F64(fee_bps(bps).0 + 1)
}

impl FeeTier {
    #[inline]
    pub fn from_srm_and_msrm_balances(srm_held: u64, msrm_held: u64) -> FeeTier {
        let one_srm = 1_000_000;
        match () {
            () if msrm_held >= 1 => FeeTier::MSRM,
            () if srm_held >= one_srm * 1_000_000 => FeeTier::SRM6,
            () if srm_held >= one_srm * 100_000 => FeeTier::SRM5,
            () if srm_held >= one_srm * 10_000 => FeeTier::SRM4,
            () if srm_held >= one_srm * 1_000 => FeeTier::SRM3,
            () if srm_held >= one_srm * 100 => FeeTier::SRM2,
            () => FeeTier::Base,
        }
    }

    #[inline]
    pub fn maker_rebate(self, pc_qty: u64) -> u64 {
        use FeeTier::*;
        let rate: U64F64 = match self {
            MSRM => rebate_bps(5),
            Base | SRM2 | SRM3 | SRM4 | SRM5 | SRM6 => rebate_bps(3),
        };
        rate.mul_u64(pc_qty).floor()
    }

    fn taker_rate(self) -> U64F64 {
        use FeeTier::*;
        match self {
            Base => fee_bps(22),
            SRM2 => fee_bps(20),
            SRM3 => fee_bps(18),
            SRM4 => fee_bps(16),
            SRM5 => fee_bps(14),
            SRM6 => fee_bps(12),
            MSRM => fee_bps(10),
        }
    }

    #[inline]
    pub fn taker_fee(self, pc_qty: u64) -> u64 {
        let rate = self.taker_rate();
        let exact_fee: U64F64 = rate.mul_u64(pc_qty);
        exact_fee.floor() + ((exact_fee.frac_part() != 0) as u64)
    }

    #[inline]
    pub fn remove_taker_fee(self, pc_qty_incl_fee: u64) -> u64 {
        let rate = self.taker_rate();
        U64F64::from_int(pc_qty_incl_fee)
            .div(U64F64::ONE.add(rate))
            .try_into()
            .unwrap()
    }
}

#[inline]
pub fn referrer_rebate(amount: u64) -> u64 {
    amount / 5
}

#[cfg(test)]
mod tests {
    use super::*;
    use proptest::prelude::*;

    proptest! {
        #[test]
        fn positive_net_fees(tt: FeeTier, mt: FeeTier, qty in 1..=std::u64::MAX) {
            let fee = tt.taker_fee(qty);
            let rebate = mt.maker_rebate(qty) + referrer_rebate(fee);
            assert!(fee > rebate);
            let net_bps_u64f64 = (fee - rebate) as u128 * 10_000;
            let three_bps = (qty as u128) * 3;
            let dust_qty_u64f64 = 1 << 32;
            assert!(net_bps_u64f64 + dust_qty_u64f64 > three_bps, "{:x}, {:x}, {:x}", qty, net_bps_u64f64, three_bps);
        }

        #[test]
        fn fee_bps_approx(bps in 1..100u64) {
            let rate = fee_bps(bps);
            let rate_bps: U64F64 = rate.mul_u64(10_000);
            let rate_bps_int: u64 = rate_bps.floor();
            let rate_bps_frac: u64 = rate_bps.frac_part();
            let inexact = rate_bps_frac != 0;
            assert!(rate_bps_int == bps - (inexact as u64));
        }

        #[test]
        fn market_order_cannot_cheat(tier: FeeTier, qty: u64) {
            let qty_without_fees = tier.remove_taker_fee(qty);
            let required_fee = tier.taker_fee(qty_without_fees) as i128;
            let actual_fee = qty as i128 - qty_without_fees as i128;
            assert!([required_fee + 1, required_fee].contains(&actual_fee),
                    "actual_fee = {}, required_fee = {}",
                    actual_fee, required_fee);
        }

        #[test]
        fn test_add_remove_fees(tier: FeeTier, qty in 1..=(std::u64::MAX >> 1)) {
            let qty_with_fees = qty + tier.taker_fee(qty);
            let qty2 = tier.remove_taker_fee(qty_with_fees);
            assert!([-1, 0, 1].contains(&(qty as i128 - qty2 as i128)))
        }
    }
}