use crate::math::quantities::{BaseLots, BasisPoints, Constant, ScalarBounds, WrapperNum};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct LeverageTier {
pub upper_bound_size: BaseLots,
pub max_leverage: Constant,
pub limit_order_risk_factor: BasisPoints,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct LeverageTiers {
tiers: [LeverageTier; 4],
}
impl LeverageTiers {
pub fn new(tiers: [LeverageTier; 4]) -> Result<Self, &'static str> {
Self::validate(&tiers)?;
Ok(Self { tiers })
}
pub const fn new_unchecked(tiers: [LeverageTier; 4]) -> Self {
Self { tiers }
}
pub fn validate(tiers: &[LeverageTier; 4]) -> Result<(), &'static str> {
for i in 1..tiers.len() {
let prev_tier = &tiers[i - 1];
let curr_tier = &tiers[i];
if curr_tier.upper_bound_size == BaseLots::ZERO
|| prev_tier.upper_bound_size == BaseLots::ZERO
{
return Err("Leverage tier upper_bound_size cannot be zero");
}
if curr_tier.upper_bound_size <= prev_tier.upper_bound_size {
return Err("Leverage tiers must have increasing upper_bound_size");
}
if curr_tier.max_leverage > prev_tier.max_leverage {
return Err("Leverage tiers must have non-increasing max_leverage");
}
if curr_tier.limit_order_risk_factor < prev_tier.limit_order_risk_factor {
return Err("Leverage tiers must have non-decreasing limit_order_risk_factor");
}
}
Ok(())
}
pub fn get_leverage_constant(&self, position_size: BaseLots) -> Constant {
for (i, tier) in self.tiers.iter().enumerate() {
if position_size <= tier.upper_bound_size {
if i == 0 {
return tier.max_leverage;
}
let prev_tier = &self.tiers[i - 1];
return interpolate_leverage(
prev_tier.upper_bound_size,
prev_tier.max_leverage,
tier.upper_bound_size,
tier.max_leverage,
position_size,
);
}
}
Constant::new(1)
}
pub fn get_limit_order_risk_factor(&self, position_size: BaseLots) -> BasisPoints {
for (i, tier) in self.tiers.iter().enumerate() {
if position_size <= tier.upper_bound_size {
if i == 0 {
return tier.limit_order_risk_factor;
}
let prev_tier = &self.tiers[i - 1];
return interpolate_limit_order_risk_factor(
prev_tier.upper_bound_size,
prev_tier.limit_order_risk_factor,
tier.upper_bound_size,
tier.limit_order_risk_factor,
position_size,
);
}
}
BasisPoints::UPPER_BOUND.into()
}
pub fn get(&self, index: usize) -> Option<&LeverageTier> {
self.tiers.get(index)
}
pub fn iter(&self) -> impl Iterator<Item = &LeverageTier> {
self.tiers.iter()
}
pub const fn len(&self) -> usize {
4
}
pub const fn is_empty(&self) -> bool {
false
}
}
impl Default for LeverageTiers {
fn default() -> Self {
Self::new_unchecked([
LeverageTier {
upper_bound_size: BaseLots::new(1_000_000),
max_leverage: Constant::new(20),
limit_order_risk_factor: BasisPoints::new(10_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(10_000_000),
max_leverage: Constant::new(10),
limit_order_risk_factor: BasisPoints::new(10_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(100_000_000),
max_leverage: Constant::new(5),
limit_order_risk_factor: BasisPoints::new(10_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(u32::MAX as u64),
max_leverage: Constant::new(1),
limit_order_risk_factor: BasisPoints::new(10_000),
},
])
}
}
fn interpolate_u64(x1: u64, y1: u64, x2: u64, y2: u64, x: u64) -> u64 {
if x1 == x2 || y1 == y2 {
return y1;
}
let x_range = x2 as f64 - x1 as f64;
let y_range = y2 as f64 - y1 as f64;
let x_offset = x as f64 - x1 as f64;
let percent_of_x_range = (x_offset / x_range).clamp(0.0, 1.0);
let interpolated_value = (y1 as f64) + percent_of_x_range * y_range;
interpolated_value as u64
}
fn interpolate_leverage(
x1: BaseLots,
y1: Constant,
x2: BaseLots,
y2: Constant,
x: BaseLots,
) -> Constant {
let result = interpolate_u64(
x1.as_inner(),
y1.as_inner(),
x2.as_inner(),
y2.as_inner(),
x.as_inner(),
);
Constant::new(result)
}
fn interpolate_limit_order_risk_factor(
x1: BaseLots,
y1: BasisPoints,
x2: BaseLots,
y2: BasisPoints,
x: BaseLots,
) -> BasisPoints {
let result = interpolate_u64(
x1.as_inner(),
y1.as_inner(),
x2.as_inner(),
y2.as_inner(),
x.as_inner(),
);
BasisPoints::new(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interpolation_at_boundaries() {
let tiers = LeverageTiers::default();
let leverage = tiers.get_leverage_constant(BaseLots::new(1_000_000));
assert_eq!(leverage, Constant::new(20));
let leverage = tiers.get_leverage_constant(BaseLots::new(10_000_000));
assert_eq!(leverage, Constant::new(10));
}
#[test]
fn test_interpolation_between_tiers() {
let tiers = LeverageTiers::default();
let mid_point = (1_000_000 + 10_000_000) / 2;
let leverage = tiers.get_leverage_constant(BaseLots::new(mid_point));
assert!(leverage.as_inner() >= 10 && leverage.as_inner() <= 20);
}
#[test]
fn test_exceeds_all_tiers() {
let tiers = LeverageTiers::default();
let leverage = tiers.get_leverage_constant(BaseLots::new(u32::MAX as u64 + 1));
assert_eq!(leverage, Constant::new(1));
}
#[test]
fn test_validation_increasing_sizes() {
let invalid_tiers = [
LeverageTier {
upper_bound_size: BaseLots::new(10_000),
max_leverage: Constant::new(20),
limit_order_risk_factor: BasisPoints::new(5_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(5_000), max_leverage: Constant::new(10),
limit_order_risk_factor: BasisPoints::new(5_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(20_000),
max_leverage: Constant::new(5),
limit_order_risk_factor: BasisPoints::new(7_500),
},
LeverageTier {
upper_bound_size: BaseLots::new(100_000),
max_leverage: Constant::new(1),
limit_order_risk_factor: BasisPoints::new(10_000),
},
];
assert!(LeverageTiers::new(invalid_tiers).is_err());
}
#[test]
fn test_validation_non_increasing_leverage() {
let invalid_tiers = [
LeverageTier {
upper_bound_size: BaseLots::new(1_000),
max_leverage: Constant::new(10),
limit_order_risk_factor: BasisPoints::new(5_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(10_000),
max_leverage: Constant::new(20), limit_order_risk_factor: BasisPoints::new(5_000),
},
LeverageTier {
upper_bound_size: BaseLots::new(20_000),
max_leverage: Constant::new(5),
limit_order_risk_factor: BasisPoints::new(7_500),
},
LeverageTier {
upper_bound_size: BaseLots::new(100_000),
max_leverage: Constant::new(1),
limit_order_risk_factor: BasisPoints::new(10_000),
},
];
assert!(LeverageTiers::new(invalid_tiers).is_err());
}
}