use rust_decimal::Decimal;
use crate::types::OHLC;
use super::{MarketRegime, RegimeState, RegimeConfig};
pub trait RegimeDetector: Send + Sync {
fn detect(&mut self, data: &[OHLC], config: &RegimeConfig) -> Option<RegimeState>;
fn update(&mut self, candle: &OHLC, config: &RegimeConfig) -> Option<RegimeState>;
fn calculate_confidence(&self, regime: MarketRegime, data: &[OHLC]) -> Decimal;
fn reset(&mut self);
fn name(&self) -> &str;
fn has_sufficient_data(&self, data_points: usize, config: &RegimeConfig) -> bool {
data_points >= config.lookback_period
}
}
pub mod helpers {
use rust_decimal::Decimal;
use crate::types::OHLC;
use super::MarketRegime;
pub fn calculate_return(price_old: Decimal, price_new: Decimal) -> Decimal {
if price_old == Decimal::ZERO {
Decimal::ZERO
} else {
(price_new - price_old) / price_old
}
}
pub fn calculate_returns(data: &[OHLC]) -> Vec<Decimal> {
if data.len() < 2 {
return vec![];
}
data.windows(2)
.map(|window| calculate_return(window[0].close, window[1].close))
.collect()
}
pub fn calculate_cumulative_return(data: &[OHLC]) -> Decimal {
if data.len() < 2 {
return Decimal::ZERO;
}
let first = data.first().unwrap();
let last = data.last().unwrap();
calculate_return(first.close, last.close)
}
pub fn calculate_average_return(returns: &[Decimal]) -> Decimal {
if returns.is_empty() {
return Decimal::ZERO;
}
let sum: Decimal = returns.iter().sum();
sum / Decimal::from(returns.len())
}
pub fn calculate_volatility(returns: &[Decimal]) -> Decimal {
if returns.len() < 2 {
return Decimal::ZERO;
}
let mean = calculate_average_return(returns);
let variance: Decimal = returns
.iter()
.map(|r| {
let diff = r - mean;
diff * diff
})
.sum::<Decimal>() / Decimal::from(returns.len() - 1);
sqrt_approximation(variance)
}
pub fn sqrt_approximation(value: Decimal) -> Decimal {
if value <= Decimal::ZERO {
return Decimal::ZERO;
}
let mut x = value;
let mut last_x = Decimal::ZERO;
let epsilon = Decimal::new(1, 6);
while (x - last_x).abs() > epsilon {
last_x = x;
x = (x + value / x) / Decimal::TWO;
}
x
}
pub fn identify_trend(data: &[OHLC], short_period: usize, long_period: usize) -> MarketRegime {
if data.len() < long_period {
return MarketRegime::Sideways;
}
let short_ma = calculate_sma(&data[data.len() - short_period..]);
let long_ma = calculate_sma(&data[data.len() - long_period..]);
let difference = (short_ma - long_ma) / long_ma;
let threshold = Decimal::new(1, 2);
if difference > threshold {
MarketRegime::Bull
} else if difference < -threshold {
MarketRegime::Bear
} else {
MarketRegime::Sideways
}
}
pub fn calculate_sma(data: &[OHLC]) -> Decimal {
if data.is_empty() {
return Decimal::ZERO;
}
let sum: Decimal = data.iter().map(|candle| candle.close).sum();
sum / Decimal::from(data.len())
}
pub fn calculate_ema(data: &[OHLC], period: usize) -> Decimal {
if data.is_empty() {
return Decimal::ZERO;
}
let alpha = Decimal::TWO / (Decimal::from(period + 1));
let mut ema = data[0].close;
for candle in &data[1..] {
ema = alpha * candle.close + (Decimal::ONE - alpha) * ema;
}
ema
}
pub fn count_higher_highs_lows(data: &[OHLC]) -> (usize, usize) {
if data.len() < 2 {
return (0, 0);
}
let mut higher_highs = 0;
let mut higher_lows = 0;
for window in data.windows(2) {
if window[1].high > window[0].high {
higher_highs += 1;
}
if window[1].low > window[0].low {
higher_lows += 1;
}
}
(higher_highs, higher_lows)
}
pub fn count_lower_highs_lows(data: &[OHLC]) -> (usize, usize) {
if data.len() < 2 {
return (0, 0);
}
let mut lower_highs = 0;
let mut lower_lows = 0;
for window in data.windows(2) {
if window[1].high < window[0].high {
lower_highs += 1;
}
if window[1].low < window[0].low {
lower_lows += 1;
}
}
(lower_highs, lower_lows)
}
}
#[cfg(test)]
mod tests {
use super::helpers::*;
use super::MarketRegime;
use crate::types::OHLC;
use rust_decimal::Decimal;
fn create_test_candle(close: i64) -> OHLC {
OHLC::new(
Decimal::from(close),
Decimal::from(close + 1),
Decimal::from(close - 1),
Decimal::from(close),
1000,
1000000,
)
}
#[test]
fn test_calculate_return() {
let return_val = calculate_return(Decimal::from(100), Decimal::from(110));
assert_eq!(return_val, Decimal::new(1, 1));
let return_val = calculate_return(Decimal::from(100), Decimal::from(90));
assert_eq!(return_val, Decimal::new(-1, 1)); }
#[test]
fn test_calculate_returns() {
let data = vec![
create_test_candle(100),
create_test_candle(110),
create_test_candle(105),
];
let returns = calculate_returns(&data);
assert_eq!(returns.len(), 2);
assert_eq!(returns[0], Decimal::new(1, 1)); }
#[test]
fn test_sqrt_approximation() {
let result = sqrt_approximation(Decimal::from(4));
assert!((result - Decimal::TWO).abs() < Decimal::new(1, 3));
let result = sqrt_approximation(Decimal::from(9));
assert!((result - Decimal::from(3)).abs() < Decimal::new(1, 3)); }
#[test]
fn test_calculate_sma() {
let data = vec![
create_test_candle(100),
create_test_candle(110),
create_test_candle(120),
];
let sma = calculate_sma(&data);
assert_eq!(sma, Decimal::from(110)); }
#[test]
fn test_identify_trend() {
let mut data = Vec::new();
for i in 0..20 {
data.push(create_test_candle(100 + i * 2));
}
let trend = identify_trend(&data, 5, 10);
assert_eq!(trend, MarketRegime::Bull);
}
}