use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
pub trait FeeModel {
fn compute_fee(&self, price: Decimal, quantity: Decimal, contract_size: Decimal) -> Decimal;
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Deserialize, Serialize,
)]
pub struct ZeroFeeModel;
impl FeeModel for ZeroFeeModel {
fn compute_fee(&self, _price: Decimal, _quantity: Decimal, _contract_size: Decimal) -> Decimal {
Decimal::ZERO
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub struct PerContractFeeModel {
pub commission_per_contract: Decimal,
}
impl FeeModel for PerContractFeeModel {
fn compute_fee(&self, _price: Decimal, quantity: Decimal, _contract_size: Decimal) -> Decimal {
self.commission_per_contract * quantity.abs()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub struct PercentageFeeModel {
pub rate: Decimal,
}
impl FeeModel for PercentageFeeModel {
fn compute_fee(&self, price: Decimal, quantity: Decimal, _contract_size: Decimal) -> Decimal {
self.rate * price * quantity.abs()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub enum FeeModelConfig {
Zero(ZeroFeeModel),
PerContract(PerContractFeeModel),
Percentage(PercentageFeeModel),
}
impl Default for FeeModelConfig {
fn default() -> Self {
Self::Zero(ZeroFeeModel)
}
}
impl FeeModel for FeeModelConfig {
fn compute_fee(&self, price: Decimal, quantity: Decimal, contract_size: Decimal) -> Decimal {
match self {
FeeModelConfig::Zero(m) => m.compute_fee(price, quantity, contract_size),
FeeModelConfig::PerContract(m) => m.compute_fee(price, quantity, contract_size),
FeeModelConfig::Percentage(m) => m.compute_fee(price, quantity, contract_size),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] mod tests {
use super::*;
fn d(s: &str) -> Decimal {
s.parse().unwrap()
}
#[test]
fn zero_fee_model_always_returns_zero() {
assert_eq!(
ZeroFeeModel.compute_fee(d("100"), d("5"), d("100")),
Decimal::ZERO
);
assert_eq!(
ZeroFeeModel.compute_fee(Decimal::ZERO, Decimal::ZERO, Decimal::ONE),
Decimal::ZERO
);
}
#[test]
fn per_contract_fee_charges_by_quantity() {
let model = PerContractFeeModel {
commission_per_contract: d("0.65"),
};
assert_eq!(model.compute_fee(d("100"), d("10"), d("100")), d("6.5"));
}
#[test]
fn per_contract_fee_uses_abs_quantity() {
let model = PerContractFeeModel {
commission_per_contract: d("0.65"),
};
assert_eq!(
model.compute_fee(d("100"), d("-10"), d("100")),
model.compute_fee(d("100"), d("10"), d("100")),
);
}
#[test]
fn fee_model_config_zero_dispatches() {
let cfg = FeeModelConfig::Zero(ZeroFeeModel);
assert_eq!(cfg.compute_fee(d("100"), d("5"), d("100")), Decimal::ZERO);
}
#[test]
fn fee_model_config_per_contract_dispatches() {
let model = PerContractFeeModel {
commission_per_contract: d("0.65"),
};
let cfg = FeeModelConfig::PerContract(model);
assert_eq!(
cfg.compute_fee(d("100"), d("10"), d("100")),
model.compute_fee(d("100"), d("10"), d("100")),
);
}
#[test]
fn fee_model_config_default_is_zero() {
assert_eq!(
FeeModelConfig::default(),
FeeModelConfig::Zero(ZeroFeeModel)
);
}
#[test]
fn percentage_fee_computes_rate_times_notional() {
let model = PercentageFeeModel { rate: d("0.001") };
assert_eq!(model.compute_fee(d("100"), d("10"), d("1")), d("1"));
}
#[test]
fn percentage_fee_uses_abs_quantity() {
let model = PercentageFeeModel { rate: d("0.001") };
assert_eq!(
model.compute_fee(d("100"), d("-10"), d("1")),
model.compute_fee(d("100"), d("10"), d("1")),
);
}
#[test]
fn fee_model_config_percentage_dispatches() {
let model = PercentageFeeModel { rate: d("0.001") };
let cfg = FeeModelConfig::Percentage(model);
assert_eq!(
cfg.compute_fee(d("100"), d("10"), d("1")),
model.compute_fee(d("100"), d("10"), d("1")),
);
}
#[test]
fn zero_fee_model_serde_roundtrip() {
let cfg = FeeModelConfig::Zero(ZeroFeeModel);
let json = serde_json::to_string(&cfg).unwrap();
assert_eq!(json, r#"{"Zero":null}"#);
let parsed: FeeModelConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, cfg);
}
#[test]
fn fee_model_config_default_when_field_omitted() {
#[derive(Deserialize)]
struct Wrapper {
#[serde(default)]
fee_model: FeeModelConfig,
}
let parsed: Wrapper = serde_json::from_str(r#"{}"#).unwrap();
assert_eq!(parsed.fee_model, FeeModelConfig::Zero(ZeroFeeModel));
}
#[test]
fn percentage_fee_model_serde_roundtrip() {
let cfg = FeeModelConfig::Percentage(PercentageFeeModel { rate: d("0.001") });
let json = serde_json::to_string(&cfg).unwrap();
assert_eq!(json, r#"{"Percentage":{"rate":"0.001"}}"#);
let parsed: FeeModelConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, cfg);
}
#[test]
fn per_contract_fee_model_serde_roundtrip() {
let cfg = FeeModelConfig::PerContract(PerContractFeeModel {
commission_per_contract: d("0.65"),
});
let json = serde_json::to_string(&cfg).unwrap();
assert_eq!(
json,
r#"{"PerContract":{"commission_per_contract":"0.65"}}"#
);
let parsed: FeeModelConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, cfg);
}
}