use super::{Error, Notional, ParamKind};
use rust_decimal::Decimal;
use std::fmt::{Display, Formatter};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Leverage(u16);
impl Leverage {
pub const SCALE: u16 = 10;
pub const MIN: u16 = 1;
pub const MAX: u16 = 3000;
pub const STEP: f32 = 0.1;
const MIN_RAW: u16 = Self::MIN * Self::SCALE;
const MAX_RAW: u16 = Self::MAX * Self::SCALE;
pub fn from_raw(raw: u16) -> Result<Self, Error> {
if !(Self::MIN_RAW..=Self::MAX_RAW).contains(&raw) {
return Err(Error::InvalidLeverage);
}
Ok(Self(raw))
}
pub const fn raw(&self) -> u16 {
self.0
}
pub fn from_u16(multiplier: u16) -> Result<Self, Error> {
let raw = multiplier.checked_mul(Self::SCALE).ok_or(Error::Overflow {
param: ParamKind::Leverage,
})?;
Self::from_raw(raw)
}
pub fn from_f64(multiplier: f64) -> Result<Self, Error> {
if !multiplier.is_finite() {
return Err(Error::InvalidLeverage);
}
let scaled = multiplier * f64::from(Self::SCALE);
let rounded = scaled.round();
if (scaled - rounded).abs() > 1e-9 {
return Err(Error::InvalidLeverage);
}
let raw_i64 = rounded as i64;
if raw_i64 < i64::from(Self::MIN_RAW) || raw_i64 > i64::from(Self::MAX_RAW) {
return Err(Error::InvalidLeverage);
}
Self::from_raw(raw_i64 as u16)
}
pub fn value(&self) -> f32 {
f32::from(self.0) / f32::from(Self::SCALE)
}
pub fn calculate_margin_required(&self, notional: Notional) -> Result<Notional, Error> {
let raw = Decimal::from(self.raw());
let scale = Decimal::from(Self::SCALE);
let numerator = notional
.to_decimal()
.checked_mul(scale)
.ok_or(Error::Overflow {
param: ParamKind::Notional,
})?;
let margin = numerator.checked_div(raw).ok_or(Error::Overflow {
param: ParamKind::Notional,
})?;
Ok(Notional::new_unchecked(margin))
}
}
impl Display for Leverage {
fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
let integer = self.0 / Self::SCALE;
let fractional = self.0 % Self::SCALE;
if fractional == 0 {
write!(formatter, "{integer}")
} else {
write!(formatter, "{integer}.{fractional}")
}
}
}
#[cfg(test)]
mod tests {
use super::Leverage;
use crate::param::{Error, Notional, ParamKind};
#[test]
fn from_u16_creates_valid_leverage() {
let lev = Leverage::from_u16(100).expect("leverage must be valid");
assert_eq!(lev.value(), 100.0);
}
#[test]
fn from_u16_scales_value() {
let lev = Leverage::from_u16(100).expect("leverage must be valid");
assert_eq!(lev.value(), 100.0);
}
#[test]
fn from_raw_and_raw_roundtrip() {
let lev = Leverage::from_raw(1005).expect("leverage must be valid");
assert_eq!(lev.raw(), 1005);
assert_eq!(lev.value(), 100.5);
}
#[test]
fn from_raw_supports_fractional_table() {
let cases = [
(11_u16, 1.1_f32),
(1005_u16, 100.5_f32),
(29999_u16, 2999.9_f32),
];
for (raw, expected) in cases {
let lev = Leverage::from_raw(raw).expect("leverage must be valid");
assert_eq!(lev.value(), expected);
assert_eq!(lev.raw(), raw);
}
}
#[test]
fn from_raw_boundaries_table() {
let cases = [(10_u16, 1.0_f32), (30000_u16, 3000.0_f32)];
for (raw, expected) in cases {
let lev = Leverage::from_raw(raw).expect("boundary leverage must be valid");
assert_eq!(lev.value(), expected);
assert_eq!(lev.raw(), raw);
}
}
#[test]
fn from_raw_rejects_invalid_range_table() {
let cases = [9_u16, 30001_u16];
for raw in cases {
assert_eq!(Leverage::from_raw(raw), Err(Error::InvalidLeverage));
}
}
#[test]
fn from_u16_rejects_zero() {
assert_eq!(Leverage::from_u16(0), Err(Error::InvalidLeverage));
}
#[test]
fn from_u16_rejects_values_above_business_limit() {
assert_eq!(Leverage::from_u16(3001), Err(Error::InvalidLeverage));
}
#[test]
fn from_u16_reports_overflow() {
assert_eq!(
Leverage::from_u16(7000),
Err(Error::Overflow {
param: ParamKind::Leverage
})
);
}
#[test]
fn from_float_creates_fractional_values_table() {
let cases = [
(1.1_f64, 1.1_f32),
(100.5_f64, 100.5_f32),
(2999.9_f64, 2999.9_f32),
];
for (input, expected) in cases {
let leverage = Leverage::from_f64(input).expect("fractional leverage must be valid");
assert_eq!(leverage.value(), expected);
}
}
#[test]
fn from_float_rejects_invalid_step_or_range_table() {
let cases = [
0.0_f64,
0.9_f64,
1.11_f64,
3000.1_f64,
f64::NAN,
f64::INFINITY,
];
for input in cases {
assert_eq!(Leverage::from_f64(input), Err(Error::InvalidLeverage));
}
}
#[test]
fn boundaries_are_valid_in_table() {
let cases = [(Leverage::MIN, 1.0_f32), (Leverage::MAX, 3000.0_f32)];
for (input, expected) in cases {
let leverage = Leverage::from_u16(input).expect("boundary leverage must be valid");
assert_eq!(leverage.value(), expected);
}
}
#[test]
fn calculate_margin_required_calculates_expected_value() {
let lev = Leverage::from_u16(100).expect("leverage must be valid");
let notional = Notional::from_str("1000").expect("notional must be valid");
assert_eq!(
lev.calculate_margin_required(notional)
.expect("margin must be valid")
.to_string(),
"10"
);
}
#[test]
fn display_omits_trailing_fractional_zeroes() {
assert_eq!(
Leverage::from_u16(100)
.expect("leverage must be valid")
.to_string(),
"100"
);
assert_eq!(
Leverage::from_f64(100.5)
.expect("leverage must be valid")
.to_string(),
"100.5"
);
assert_eq!(
Leverage::from_u16(2500)
.expect("leverage must be valid")
.to_string(),
"2500"
);
}
#[test]
fn supports_max_business_leverage() {
let leverage = Leverage::from_u16(3000).expect("leverage must be valid");
assert_eq!(leverage.value(), 3000.0);
}
#[test]
fn from_f64_raw_from_raw_value_is_symmetric() {
let cases = [1.1_f64, 100.5_f64, 2999.9_f64, 1.0_f64, 3000.0_f64];
for input in cases {
let from_float = Leverage::from_f64(input).expect("leverage must be valid");
let from_raw =
Leverage::from_raw(from_float.raw()).expect("raw leverage must be valid");
assert_eq!(from_raw.raw(), from_float.raw());
assert_eq!(from_raw.value(), from_float.value());
}
}
}