use crate::Decimal;
use crate::types::error::{MMError, MMResult};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DrawdownRecord {
pub drawdown: Decimal,
pub timestamp: u64,
pub peak_equity: Decimal,
pub trough_equity: Decimal,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DrawdownTracker {
peak_equity: Decimal,
current_equity: Decimal,
max_allowed_drawdown: Decimal,
peak_timestamp: u64,
max_historical_drawdown: Decimal,
drawdown_history: Vec<DrawdownRecord>,
max_history_size: usize,
}
impl DrawdownTracker {
pub fn new(initial_equity: Decimal, max_allowed_drawdown: Decimal) -> MMResult<Self> {
if initial_equity <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"initial_equity must be positive".to_string(),
));
}
if max_allowed_drawdown <= Decimal::ZERO || max_allowed_drawdown > Decimal::ONE {
return Err(MMError::InvalidConfiguration(
"max_allowed_drawdown must be between 0 (exclusive) and 1 (inclusive)".to_string(),
));
}
Ok(Self {
peak_equity: initial_equity,
current_equity: initial_equity,
max_allowed_drawdown,
peak_timestamp: 0,
max_historical_drawdown: Decimal::ZERO,
drawdown_history: Vec::new(),
max_history_size: 1000,
})
}
pub fn with_timestamp(
initial_equity: Decimal,
max_allowed_drawdown: Decimal,
timestamp: u64,
) -> MMResult<Self> {
let mut tracker = Self::new(initial_equity, max_allowed_drawdown)?;
tracker.peak_timestamp = timestamp;
Ok(tracker)
}
#[must_use]
pub fn with_max_history_size(mut self, size: usize) -> Self {
self.max_history_size = size;
self
}
pub fn update(&mut self, equity: Decimal, timestamp: u64) {
self.current_equity = equity;
if equity > self.peak_equity {
self.peak_equity = equity;
self.peak_timestamp = timestamp;
} else if self.peak_equity > Decimal::ZERO {
let drawdown = (self.peak_equity - equity) / self.peak_equity;
if drawdown > self.max_historical_drawdown {
self.max_historical_drawdown = drawdown;
}
if drawdown > Decimal::new(1, 2) {
self.record_drawdown(drawdown, timestamp, equity);
}
}
}
fn record_drawdown(&mut self, drawdown: Decimal, timestamp: u64, trough_equity: Decimal) {
let record = DrawdownRecord {
drawdown,
timestamp,
peak_equity: self.peak_equity,
trough_equity,
};
self.drawdown_history.push(record);
if self.drawdown_history.len() > self.max_history_size {
self.drawdown_history.remove(0);
}
}
#[must_use]
pub fn current_drawdown(&self) -> Decimal {
if self.peak_equity <= Decimal::ZERO {
return Decimal::ZERO;
}
let drawdown = (self.peak_equity - self.current_equity) / self.peak_equity;
drawdown.max(Decimal::ZERO)
}
#[must_use]
pub fn current_drawdown_pct(&self) -> Decimal {
self.current_drawdown() * Decimal::ONE_HUNDRED
}
#[must_use]
pub fn is_max_drawdown_reached(&self) -> bool {
self.current_drawdown() >= self.max_allowed_drawdown
}
#[must_use]
pub fn peak_equity(&self) -> Decimal {
self.peak_equity
}
#[must_use]
pub fn current_equity(&self) -> Decimal {
self.current_equity
}
#[must_use]
pub fn peak_timestamp(&self) -> u64 {
self.peak_timestamp
}
#[must_use]
pub fn max_allowed_drawdown(&self) -> Decimal {
self.max_allowed_drawdown
}
#[must_use]
pub fn max_historical_drawdown(&self) -> Decimal {
self.max_historical_drawdown
}
#[must_use]
pub fn drawdown_history(&self) -> &[DrawdownRecord] {
&self.drawdown_history
}
#[must_use]
pub fn distance_to_max_drawdown(&self) -> Decimal {
let distance = self.max_allowed_drawdown - self.current_drawdown();
distance.max(Decimal::ZERO)
}
#[must_use]
pub fn equity_at_max_drawdown(&self) -> Decimal {
self.peak_equity * (Decimal::ONE - self.max_allowed_drawdown)
}
pub fn reset(&mut self, new_equity: Decimal, timestamp: u64) {
self.peak_equity = new_equity;
self.current_equity = new_equity;
self.peak_timestamp = timestamp;
self.max_historical_drawdown = Decimal::ZERO;
self.drawdown_history.clear();
}
pub fn reset_peak(&mut self, timestamp: u64) {
self.peak_equity = self.current_equity;
self.peak_timestamp = timestamp;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dec;
#[test]
fn test_new_valid() {
let tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20));
assert!(tracker.is_ok());
let tracker = tracker.unwrap();
assert_eq!(tracker.peak_equity(), dec!(10000.0));
assert_eq!(tracker.current_equity(), dec!(10000.0));
assert_eq!(tracker.max_allowed_drawdown(), dec!(0.20));
assert_eq!(tracker.current_drawdown(), dec!(0.0));
}
#[test]
fn test_new_invalid_equity() {
let result = DrawdownTracker::new(dec!(0.0), dec!(0.20));
assert!(result.is_err());
let result = DrawdownTracker::new(dec!(-1000.0), dec!(0.20));
assert!(result.is_err());
}
#[test]
fn test_new_invalid_max_drawdown() {
let result = DrawdownTracker::new(dec!(10000.0), dec!(0.0));
assert!(result.is_err());
let result = DrawdownTracker::new(dec!(10000.0), dec!(-0.1));
assert!(result.is_err());
let result = DrawdownTracker::new(dec!(10000.0), dec!(1.1));
assert!(result.is_err());
let result = DrawdownTracker::new(dec!(10000.0), dec!(1.0));
assert!(result.is_ok());
}
#[test]
fn test_with_timestamp() {
let tracker = DrawdownTracker::with_timestamp(dec!(10000.0), dec!(0.20), 12345).unwrap();
assert_eq!(tracker.peak_timestamp(), 12345);
}
#[test]
fn test_update_new_peak() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
tracker.update(dec!(11000.0), 1000);
assert_eq!(tracker.peak_equity(), dec!(11000.0));
assert_eq!(tracker.current_equity(), dec!(11000.0));
assert_eq!(tracker.peak_timestamp(), 1000);
assert_eq!(tracker.current_drawdown(), dec!(0.0));
}
#[test]
fn test_update_drawdown() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
tracker.update(dec!(9000.0), 1000);
assert_eq!(tracker.peak_equity(), dec!(10000.0)); assert_eq!(tracker.current_equity(), dec!(9000.0));
assert_eq!(tracker.current_drawdown(), dec!(0.1)); assert_eq!(tracker.current_drawdown_pct(), dec!(10.0));
}
#[test]
fn test_max_drawdown_reached() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.10)).unwrap();
tracker.update(dec!(9500.0), 1000);
assert!(!tracker.is_max_drawdown_reached());
tracker.update(dec!(9000.0), 2000);
assert!(tracker.is_max_drawdown_reached());
tracker.update(dec!(8000.0), 3000);
assert!(tracker.is_max_drawdown_reached()); }
#[test]
fn test_max_historical_drawdown() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.50)).unwrap();
tracker.update(dec!(8000.0), 1000); assert_eq!(tracker.max_historical_drawdown(), dec!(0.2));
tracker.update(dec!(9000.0), 2000); assert_eq!(tracker.max_historical_drawdown(), dec!(0.2));
tracker.update(dec!(10000.0), 3000); tracker.update(dec!(9500.0), 4000); assert_eq!(tracker.max_historical_drawdown(), dec!(0.2));
tracker.update(dec!(7500.0), 5000); assert_eq!(tracker.max_historical_drawdown(), dec!(0.25)); }
#[test]
fn test_distance_to_max_drawdown() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
assert_eq!(tracker.distance_to_max_drawdown(), dec!(0.20));
tracker.update(dec!(9000.0), 1000); assert_eq!(tracker.distance_to_max_drawdown(), dec!(0.10));
tracker.update(dec!(8000.0), 2000); assert_eq!(tracker.distance_to_max_drawdown(), dec!(0.0));
tracker.update(dec!(7000.0), 3000); assert_eq!(tracker.distance_to_max_drawdown(), dec!(0.0)); }
#[test]
fn test_equity_at_max_drawdown() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
assert_eq!(tracker.equity_at_max_drawdown(), dec!(8000.0));
tracker.update(dec!(12000.0), 1000); assert_eq!(tracker.equity_at_max_drawdown(), dec!(9600.0));
}
#[test]
fn test_reset() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
tracker.update(dec!(8000.0), 1000);
assert_eq!(tracker.max_historical_drawdown(), dec!(0.2));
tracker.reset(dec!(15000.0), 100000);
assert_eq!(tracker.peak_equity(), dec!(15000.0));
assert_eq!(tracker.current_equity(), dec!(15000.0));
assert_eq!(tracker.peak_timestamp(), 100000);
assert_eq!(tracker.current_drawdown(), dec!(0.0));
assert_eq!(tracker.max_historical_drawdown(), dec!(0.0));
}
#[test]
fn test_reset_peak() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
tracker.update(dec!(8000.0), 1000); assert_eq!(tracker.current_drawdown(), dec!(0.2));
tracker.reset_peak(2000);
assert_eq!(tracker.peak_equity(), dec!(8000.0));
assert_eq!(tracker.current_drawdown(), dec!(0.0));
assert_eq!(tracker.max_historical_drawdown(), dec!(0.2));
}
#[test]
fn test_drawdown_history() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.50)).unwrap();
tracker.update(dec!(9950.0), 1000);
assert!(tracker.drawdown_history().is_empty());
tracker.update(dec!(9000.0), 2000);
assert_eq!(tracker.drawdown_history().len(), 1);
let record = &tracker.drawdown_history()[0];
assert_eq!(record.drawdown, dec!(0.1));
assert_eq!(record.peak_equity, dec!(10000.0));
assert_eq!(record.trough_equity, dec!(9000.0));
assert_eq!(record.timestamp, 2000);
}
#[test]
fn test_history_pruning() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.50))
.unwrap()
.with_max_history_size(3);
for i in 1..=5 {
let equity = dec!(10000.0) - Decimal::from(i) * dec!(500.0);
tracker.update(equity, i as u64 * 1000);
}
assert_eq!(tracker.drawdown_history().len(), 3);
}
#[test]
fn test_recovery_and_new_peak() {
let mut tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
tracker.update(dec!(9000.0), 1000);
assert_eq!(tracker.current_drawdown(), dec!(0.1));
tracker.update(dec!(9500.0), 2000);
assert_eq!(tracker.current_drawdown(), dec!(0.05));
assert_eq!(tracker.peak_equity(), dec!(10000.0));
tracker.update(dec!(10000.0), 3000);
assert_eq!(tracker.current_drawdown(), dec!(0.0));
tracker.update(dec!(11000.0), 4000);
assert_eq!(tracker.peak_equity(), dec!(11000.0));
assert_eq!(tracker.peak_timestamp(), 4000);
}
#[cfg(feature = "serde")]
#[test]
fn test_serialization() {
let tracker = DrawdownTracker::new(dec!(10000.0), dec!(0.20)).unwrap();
let json = serde_json::to_string(&tracker).unwrap();
let deserialized: DrawdownTracker = serde_json::from_str(&json).unwrap();
assert_eq!(tracker.peak_equity(), deserialized.peak_equity());
assert_eq!(
tracker.max_allowed_drawdown(),
deserialized.max_allowed_drawdown()
);
}
#[cfg(feature = "serde")]
#[test]
fn test_record_serialization() {
let record = DrawdownRecord {
drawdown: dec!(0.15),
timestamp: 12345,
peak_equity: dec!(10000.0),
trough_equity: dec!(8500.0),
};
let json = serde_json::to_string(&record).unwrap();
let deserialized: DrawdownRecord = serde_json::from_str(&json).unwrap();
assert_eq!(record, deserialized);
}
}