use std::collections::BTreeMap;
use nautilus_core::UnixNanos;
use crate::statistic::PortfolioStatistic;
#[repr(C)]
#[derive(Debug, Clone, Default)]
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.analysis", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.analysis")
)]
pub struct MaxDrawdown {}
impl MaxDrawdown {
#[must_use]
pub fn new() -> Self {
Self {}
}
}
impl PortfolioStatistic for MaxDrawdown {
type Item = f64;
fn name(&self) -> String {
"Max Drawdown".to_string()
}
fn calculate_from_returns(&self, returns: &BTreeMap<UnixNanos, f64>) -> Option<Self::Item> {
if returns.is_empty() {
return Some(0.0);
}
let mut cumulative = 1.0;
let mut running_max = 1.0;
let mut max_drawdown = 0.0;
for &ret in returns.values() {
cumulative *= 1.0 + ret;
if cumulative > running_max {
running_max = cumulative;
}
let drawdown = (running_max - cumulative) / running_max;
if drawdown > max_drawdown {
max_drawdown = drawdown;
}
}
Some(-max_drawdown)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
fn create_returns(values: &[f64]) -> BTreeMap<UnixNanos, f64> {
values
.iter()
.copied()
.enumerate()
.map(|(i, v)| (UnixNanos::from(i as u64), v))
.collect()
}
#[rstest]
fn test_name() {
let stat = MaxDrawdown::new();
assert_eq!(stat.name(), "Max Drawdown");
}
#[rstest]
fn test_empty_returns() {
let stat = MaxDrawdown::new();
let returns = BTreeMap::new();
let result = stat.calculate_from_returns(&returns);
assert_eq!(result, Some(0.0));
}
#[rstest]
fn test_no_drawdown() {
let stat = MaxDrawdown::new();
let returns = create_returns(&[0.01, 0.02, 0.01, 0.015]);
let result = stat.calculate_from_returns(&returns).unwrap();
assert_eq!(result, 0.0);
}
#[rstest]
fn test_simple_drawdown() {
let stat = MaxDrawdown::new();
let returns = create_returns(&[0.10, -0.10]);
let result = stat.calculate_from_returns(&returns).unwrap();
assert!((result + 0.10).abs() < 0.01);
}
#[rstest]
fn test_multiple_drawdowns() {
let stat = MaxDrawdown::new();
let returns = create_returns(&[0.10, -0.10, 0.50, -0.20, 0.10]);
let result = stat.calculate_from_returns(&returns).unwrap();
assert!((result + 0.20).abs() < 0.01);
}
#[rstest]
fn test_initial_loss() {
let stat = MaxDrawdown::new();
let returns = create_returns(&[-0.40, -0.10]);
let result = stat.calculate_from_returns(&returns).unwrap();
assert!((result + 0.46).abs() < 0.01);
}
}