use std::{convert::TryFrom, fmt, num::NonZeroU64, ops::Add};
use serde::{Deserialize, Serialize, de};
use super::{
SATS_PER_BTC,
error::{MarginValidationError, TradeValidationError},
leverage::Leverage,
price::Price,
quantity::Quantity,
trade::TradeSide,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Margin(u64);
impl Margin {
pub const MIN: Self = Self(1);
pub fn bounded<T>(value: T) -> Self
where
T: Into<f64>,
{
let as_f64: f64 = value.into();
let rounded = as_f64.round().max(0.0) as u64;
let margin = rounded.max(Self::MIN.0);
Self(margin)
}
pub fn as_u64(&self) -> u64 {
self.0
}
pub fn as_i64(&self) -> i64 {
self.0 as i64
}
pub fn as_f64(&self) -> f64 {
self.0 as f64
}
pub fn calculate(quantity: Quantity, price: Price, leverage: Leverage) -> Self {
let margin = quantity.as_f64() * (SATS_PER_BTC / (price.as_f64() * leverage.as_f64()));
Self::try_from(margin.ceil() as u64).expect("must result in valid `Margin`")
}
pub fn est_from_liquidation_price(
side: TradeSide,
quantity: Quantity,
price: Price,
liquidation: Price,
) -> Result<Self, TradeValidationError> {
match side {
TradeSide::Buy if liquidation >= price => {
return Err(TradeValidationError::LiquidationNotBelowPriceForLong {
liquidation,
price,
});
}
TradeSide::Sell if liquidation <= price => {
return Err(TradeValidationError::LiquidationNotAbovePriceForShort {
liquidation,
price,
});
}
_ => {}
}
let a = 1.0 / price.as_f64();
let b = match side {
TradeSide::Buy => {
1.0 / liquidation.as_f64() - a
}
TradeSide::Sell => {
a - 1.0 / liquidation.as_f64()
}
};
assert!(b > 0.0, "'b' must be positive from validations above");
let floored_margin = b * SATS_PER_BTC * quantity.as_f64();
let margin =
Margin::try_from(floored_margin.ceil() as u64).expect("must be valid `Margin`");
Ok(margin)
}
}
impl Add for Margin {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
Margin(self.0 + other.0)
}
}
impl From<Margin> for u64 {
fn from(value: Margin) -> Self {
value.0
}
}
impl From<Margin> for i64 {
fn from(value: Margin) -> Self {
value.0 as i64
}
}
impl From<Margin> for f64 {
fn from(value: Margin) -> Self {
value.0 as f64
}
}
impl From<NonZeroU64> for Margin {
fn from(value: NonZeroU64) -> Self {
Margin(value.get())
}
}
impl TryFrom<u8> for Margin {
type Error = MarginValidationError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Self::try_from(value as u64)
}
}
impl TryFrom<u16> for Margin {
type Error = MarginValidationError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
Self::try_from(value as u64)
}
}
impl TryFrom<u32> for Margin {
type Error = MarginValidationError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
Self::try_from(value as u64)
}
}
impl TryFrom<u64> for Margin {
type Error = MarginValidationError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
if value < Self::MIN.0 {
return Err(MarginValidationError::TooLow { value });
}
Ok(Self(value))
}
}
impl TryFrom<i8> for Margin {
type Error = MarginValidationError;
fn try_from(value: i8) -> Result<Self, Self::Error> {
Self::try_from(value.max(0) as u64)
}
}
impl TryFrom<i16> for Margin {
type Error = MarginValidationError;
fn try_from(value: i16) -> Result<Self, Self::Error> {
Self::try_from(value.max(0) as u64)
}
}
impl TryFrom<i32> for Margin {
type Error = MarginValidationError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
Self::try_from(value.max(0) as u64)
}
}
impl TryFrom<i64> for Margin {
type Error = MarginValidationError;
fn try_from(value: i64) -> Result<Self, Self::Error> {
Self::try_from(value.max(0) as u64)
}
}
impl TryFrom<usize> for Margin {
type Error = MarginValidationError;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Self::try_from(value as u64)
}
}
impl TryFrom<isize> for Margin {
type Error = MarginValidationError;
fn try_from(value: isize) -> Result<Self, Self::Error> {
Self::try_from(value.max(0) as u64)
}
}
impl TryFrom<f32> for Margin {
type Error = MarginValidationError;
fn try_from(value: f32) -> Result<Self, Self::Error> {
Self::try_from(value as f64)
}
}
impl TryFrom<f64> for Margin {
type Error = MarginValidationError;
fn try_from(value: f64) -> Result<Self, Self::Error> {
if value.fract() != 0.0 {
return Err(MarginValidationError::NotAnInteger { value });
}
if !value.is_finite() {
return Err(MarginValidationError::NotFinite);
}
Self::try_from(value.max(0.) as u64)
}
}
impl fmt::Display for Margin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl Serialize for Margin {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.0)
}
}
impl<'de> Deserialize<'de> for Margin {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let margin_u64 = u64::deserialize(deserializer)?;
Margin::try_from(margin_u64).map_err(|e| de::Error::custom(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::super::trade::util as trade_util;
use super::*;
#[test]
fn test_calculate_margin() {
let quantity = Quantity::try_from(5).unwrap();
let price = Price::try_from(95000).unwrap();
let leverage = Leverage::try_from(1.0).unwrap();
let margin = Margin::calculate(quantity, price, leverage);
assert_eq!(margin.as_u64(), 5264);
let leverage = Leverage::try_from(2.0).unwrap();
let margin = Margin::calculate(quantity, price, leverage);
assert_eq!(margin.as_u64(), 2632);
let leverage = Leverage::try_from(50.0).unwrap();
let margin = Margin::calculate(quantity, price, leverage);
assert_eq!(margin.as_u64(), 106);
let leverage = Leverage::try_from(100.0).unwrap();
let margin = Margin::calculate(quantity, price, leverage);
assert_eq!(margin.as_u64(), 53);
let margin = Margin::calculate(Quantity::MIN, Price::MAX, Leverage::MAX);
assert_eq!(margin, Margin::MIN);
let margin = Margin::calculate(Quantity::MAX, Price::MIN, Leverage::MIN);
assert_eq!(margin.as_u64(), 50_000_000_000_000);
}
#[test]
fn test_margin_from_liquidation_price_calculation() {
let side = TradeSide::Buy;
let quantity = Quantity::try_from(1_000).unwrap();
let entry_price = Price::try_from(100_000).unwrap();
let leverage = Leverage::MIN;
let liquidation_price =
trade_util::estimate_liquidation_price(side, quantity, entry_price, leverage);
let margin =
Margin::est_from_liquidation_price(side, quantity, entry_price, liquidation_price)
.expect("should calculate valid margin");
let expected_margin = Margin::calculate(quantity, entry_price, leverage);
assert!(
(margin.as_i64() - expected_margin.as_i64()).abs() <= 999,
"Margin difference too large: calculated {} vs expected {}",
margin.as_u64(),
expected_margin.as_u64()
);
let leverage = Leverage::MAX;
let liquidation_price =
trade_util::estimate_liquidation_price(side, quantity, entry_price, leverage);
let margin =
Margin::est_from_liquidation_price(side, quantity, entry_price, liquidation_price)
.expect("should calculate valid margin");
let expected_margin = Margin::calculate(quantity, entry_price, leverage);
assert!(
(margin.as_i64() - expected_margin.as_i64()).abs() <= 999,
"Margin difference too large: calculated {} vs expected {}",
margin.as_u64(),
expected_margin.as_u64()
);
let side = TradeSide::Sell;
let leverage = Leverage::MIN;
let liquidation_price =
trade_util::estimate_liquidation_price(side, quantity, entry_price, leverage);
let margin =
Margin::est_from_liquidation_price(side, quantity, entry_price, liquidation_price)
.expect("should calculate valid margin");
let expected_margin = Margin::calculate(quantity, entry_price, leverage);
assert!(
(margin.as_i64() - expected_margin.as_i64()).abs() <= 999,
"Margin difference too large: calculated {} vs expected {}",
margin.as_u64(),
expected_margin.as_u64()
);
}
}