use std::fmt;
use rust_decimal::Decimal;
use crate::MetricsError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KellyMode {
Full,
Half,
Quarter,
}
impl KellyMode {
pub fn scale(&self) -> Decimal {
match self {
Self::Full => Decimal::ONE,
Self::Half => Decimal::new(5, 1), Self::Quarter => Decimal::new(25, 2), }
}
}
impl fmt::Display for KellyMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Full => write!(f, "Full"),
Self::Half => write!(f, "Half"),
Self::Quarter => write!(f, "Quarter"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct KellyFraction(Decimal);
impl KellyFraction {
pub fn zero() -> Self {
Self(Decimal::ZERO)
}
pub fn as_decimal(&self) -> Decimal {
self.0
}
pub fn is_zero(&self) -> bool {
self.0 == Decimal::ZERO
}
}
impl fmt::Display for KellyFraction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:.4}", self.0)
}
}
pub fn compute_kelly_fraction(
win_rate: Decimal,
avg_win_loss_ratio: Decimal,
mode: KellyMode,
) -> KellyFraction {
if avg_win_loss_ratio <= Decimal::ZERO {
return KellyFraction::zero();
}
let loss_rate = Decimal::ONE - win_rate;
let raw_fraction = win_rate - loss_rate / avg_win_loss_ratio;
if raw_fraction <= Decimal::ZERO {
return KellyFraction::zero();
}
let scaled = raw_fraction * mode.scale();
let clamped = if scaled > Decimal::ONE {
Decimal::ONE
} else {
scaled
};
KellyFraction(clamped)
}
pub fn compute_kelly_inputs(
trade_pnls: &[Decimal],
) -> Result<(Decimal, Decimal, usize), MetricsError> {
if trade_pnls.is_empty() {
return Err(MetricsError::InsufficientData {
required: 1,
actual: 0,
});
}
let wr_pct = crate::win_rate(trade_pnls)?;
let wr_fraction = wr_pct / Decimal::from(100);
let avg_w = crate::avg_win(trade_pnls)?;
let avg_l = match crate::avg_loss(trade_pnls) {
Ok(v) => v.abs(),
Err(MetricsError::InsufficientData { .. }) => {
return Err(MetricsError::DivisionByZero {
context: "no losing trades — cannot compute win/loss ratio",
});
}
Err(e) => return Err(e),
};
let ratio = avg_w / avg_l;
Ok((wr_fraction, ratio, trade_pnls.len()))
}
#[cfg(test)]
#[path = "kelly_tests.rs"]
mod tests;