use crate::compat::Instant;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use crate::error::CorpFinanceError;
use crate::types::*;
use crate::CorpFinanceResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KellyInput {
pub win_probability: Rate,
pub win_loss_ratio: Decimal,
pub kelly_fraction: Rate,
#[serde(skip_serializing_if = "Option::is_none")]
pub portfolio_value: Option<Money>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_position_pct: Option<Rate>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KellyOutput {
pub full_kelly_pct: Rate,
pub fractional_kelly_pct: Rate,
pub recommended_position: Option<Money>,
pub edge: Rate,
pub growth_rate: Rate,
}
pub fn calculate_kelly(input: &KellyInput) -> CorpFinanceResult<ComputationOutput<KellyOutput>> {
let start = Instant::now();
let mut warnings: Vec<String> = Vec::new();
if input.win_probability <= Decimal::ZERO || input.win_probability >= Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "win_probability".into(),
reason: "Must be between 0 and 1 (exclusive)".into(),
});
}
if input.win_loss_ratio <= Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: "win_loss_ratio".into(),
reason: "Must be positive".into(),
});
}
if input.kelly_fraction <= Decimal::ZERO || input.kelly_fraction > Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "kelly_fraction".into(),
reason: "Must be between 0 (exclusive) and 1 (inclusive)".into(),
});
}
let p = input.win_probability;
let q = Decimal::ONE - p;
let b = input.win_loss_ratio;
let full_kelly = p - q / b;
let edge = p * b - q;
if full_kelly <= Decimal::ZERO {
warnings.push("Negative edge: Kelly recommends no position".into());
}
let clamped_kelly = full_kelly.max(Decimal::ZERO).min(Decimal::ONE);
let mut fractional_kelly = clamped_kelly * input.kelly_fraction;
if let Some(max_pct) = input.max_position_pct {
if fractional_kelly > max_pct {
warnings.push(format!(
"Fractional Kelly {fractional_kelly} capped at max_position_pct {max_pct}"
));
fractional_kelly = max_pct;
}
}
let recommended_position = input.portfolio_value.map(|pv| pv * fractional_kelly);
let growth_rate = if fractional_kelly > Decimal::ZERO && fractional_kelly < Decimal::ONE {
let win_part = Decimal::ONE + fractional_kelly * b;
let lose_part = Decimal::ONE - fractional_kelly;
if win_part > Decimal::ZERO && lose_part > Decimal::ZERO {
p * ln_decimal(win_part) + q * ln_decimal(lose_part)
} else {
Decimal::ZERO
}
} else {
Decimal::ZERO
};
let output = KellyOutput {
full_kelly_pct: clamped_kelly,
fractional_kelly_pct: fractional_kelly,
recommended_position,
edge,
growth_rate,
};
let elapsed = start.elapsed().as_micros() as u64;
Ok(with_metadata(
"Kelly Criterion Position Sizing",
&serde_json::json!({
"win_probability": input.win_probability.to_string(),
"win_loss_ratio": input.win_loss_ratio.to_string(),
"kelly_fraction": input.kelly_fraction.to_string(),
}),
warnings,
elapsed,
output,
))
}
fn ln_decimal(x: Decimal) -> Decimal {
if x <= Decimal::ZERO {
return Decimal::ZERO;
}
if x == Decimal::ONE {
return Decimal::ZERO;
}
let ln2 = dec!(0.6931471805599453);
let mut val = x;
let mut k: i64 = 0;
while val > dec!(2) {
val /= dec!(2);
k += 1;
}
while val < dec!(0.5) {
val *= dec!(2);
k -= 1;
}
let u = (val - Decimal::ONE) / (val + Decimal::ONE);
let u2 = u * u;
let mut term = u;
let mut sum = u;
for n in 1..=20 {
term *= u2;
let denom = Decimal::from(2 * n + 1);
sum += term / denom;
}
dec!(2) * sum + Decimal::from(k) * ln2
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_basic_kelly() {
let input = KellyInput {
win_probability: dec!(0.5),
win_loss_ratio: dec!(2),
kelly_fraction: dec!(1.0),
portfolio_value: None,
max_position_pct: None,
};
let result = calculate_kelly(&input).unwrap();
assert_eq!(result.result.full_kelly_pct, dec!(0.25));
}
#[test]
fn test_fractional_kelly() {
let input = KellyInput {
win_probability: dec!(0.5),
win_loss_ratio: dec!(2),
kelly_fraction: dec!(0.5),
portfolio_value: Some(dec!(100000)),
max_position_pct: None,
};
let result = calculate_kelly(&input).unwrap();
assert_eq!(result.result.fractional_kelly_pct, dec!(0.125));
assert_eq!(result.result.recommended_position, Some(dec!(12500)));
}
#[test]
fn test_edge_calculation() {
let input = KellyInput {
win_probability: dec!(0.6),
win_loss_ratio: dec!(1.5),
kelly_fraction: dec!(1.0),
portfolio_value: None,
max_position_pct: None,
};
let result = calculate_kelly(&input).unwrap();
assert_eq!(result.result.edge, dec!(0.5));
}
#[test]
fn test_negative_edge() {
let input = KellyInput {
win_probability: dec!(0.3),
win_loss_ratio: dec!(1),
kelly_fraction: dec!(1.0),
portfolio_value: None,
max_position_pct: None,
};
let result = calculate_kelly(&input).unwrap();
assert_eq!(result.result.full_kelly_pct, Decimal::ZERO);
assert!(result.result.edge < Decimal::ZERO);
}
#[test]
fn test_max_position_cap() {
let input = KellyInput {
win_probability: dec!(0.6),
win_loss_ratio: dec!(3.0),
kelly_fraction: dec!(1.0),
portfolio_value: Some(dec!(100000)),
max_position_pct: Some(dec!(0.10)),
};
let result = calculate_kelly(&input).unwrap();
assert!(result.result.fractional_kelly_pct <= dec!(0.10));
}
#[test]
fn test_growth_rate_positive() {
let input = KellyInput {
win_probability: dec!(0.6),
win_loss_ratio: dec!(2),
kelly_fraction: dec!(0.5),
portfolio_value: None,
max_position_pct: None,
};
let result = calculate_kelly(&input).unwrap();
assert!(result.result.growth_rate > Decimal::ZERO);
}
#[test]
fn test_invalid_probability() {
let input = KellyInput {
win_probability: dec!(1.5),
win_loss_ratio: dec!(2),
kelly_fraction: dec!(1.0),
portfolio_value: None,
max_position_pct: None,
};
assert!(calculate_kelly(&input).is_err());
}
#[test]
fn test_invalid_ratio() {
let input = KellyInput {
win_probability: dec!(0.5),
win_loss_ratio: dec!(-1),
kelly_fraction: dec!(1.0),
portfolio_value: None,
max_position_pct: None,
};
assert!(calculate_kelly(&input).is_err());
}
}