use crate::bar_indicators::bar_indicator_id::BarIndicatorId;
use crate::bar_indicators::average::MovingAverageType;
use crate::bar_indicators::indicator_value::IndicatorValue;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[repr(u8)]
pub enum OutputSelector {
#[default]
Main = 0,
Upper = 1,
Lower = 2,
Middle = 3,
Second = 4,
Signal = 5,
Histogram = 6,
}
impl OutputSelector {
#[inline]
pub fn extract(&self, value: &IndicatorValue) -> f64 {
match self {
OutputSelector::Main => value.main(),
OutputSelector::Upper => value.upper().unwrap_or(f64::NAN),
OutputSelector::Lower => value.lower().unwrap_or(f64::NAN),
OutputSelector::Middle => value.middle().unwrap_or(value.main()),
OutputSelector::Second => {
match value {
IndicatorValue::Double(_, second) => *second,
_ => f64::NAN,
}
}
OutputSelector::Signal => value.macd_signal().unwrap_or(f64::NAN),
OutputSelector::Histogram => value.macd_histogram().unwrap_or(f64::NAN),
}
}
pub fn short_name(&self) -> &'static str {
match self {
OutputSelector::Main => "main",
OutputSelector::Upper => "upper",
OutputSelector::Lower => "lower",
OutputSelector::Middle => "middle",
OutputSelector::Second => "second",
OutputSelector::Signal => "signal",
OutputSelector::Histogram => "hist",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct IndicatorKey {
pub indicator_id: BarIndicatorId,
pub period: u16,
pub ma_type: Option<MovingAverageType>,
pub output: OutputSelector,
pub param_hash: u64,
}
impl IndicatorKey {
#[inline]
pub fn new(indicator_id: BarIndicatorId, period: u16, ma_type: Option<MovingAverageType>) -> Self {
Self {
indicator_id,
period,
ma_type,
output: OutputSelector::Main,
param_hash: 0,
}
}
#[inline]
pub fn with_output(
indicator_id: BarIndicatorId,
period: u16,
ma_type: Option<MovingAverageType>,
output: OutputSelector,
) -> Self {
Self {
indicator_id,
period,
ma_type,
output,
param_hash: 0,
}
}
#[inline]
pub fn with_param_hash(mut self, hash: u64) -> Self {
self.param_hash = hash;
self
}
#[inline]
pub fn base_key(&self) -> Self {
Self {
indicator_id: self.indicator_id,
period: self.period,
ma_type: self.ma_type,
output: OutputSelector::Main,
param_hash: self.param_hash,
}
}
#[inline]
pub fn is_main_output(&self) -> bool {
self.output == OutputSelector::Main
}
pub fn to_legacy_string(&self) -> String {
let indicator_name = format!("{:?}", self.indicator_id).to_uppercase();
let base = if let Some(ma_type) = self.ma_type {
let ma_name = match ma_type {
MovingAverageType::SMA => "Simple",
MovingAverageType::EMA => "EMA",
MovingAverageType::RMA => "Wilder",
MovingAverageType::WMA => "WMA",
MovingAverageType::HMA => "HMA",
MovingAverageType::DEMA => "DEMA",
MovingAverageType::TEMA => "TEMA",
MovingAverageType::VWMA => "VWMA",
MovingAverageType::TMA => "TMA",
MovingAverageType::VWAP => "VWAP",
MovingAverageType::AMA => "AMA",
};
format!("{}_{}_{}", indicator_name, ma_name, self.period)
} else {
format!("{}_{}", indicator_name, self.period)
};
if self.output != OutputSelector::Main {
format!("{}_{}", base, self.output.short_name())
} else {
base
}
}
pub fn from_legacy_string(s: &str) -> Result<Self, String> {
let parts: Vec<&str> = s.split('_').collect();
if parts.len() < 2 {
return Err(format!("Invalid key format: {}", s));
}
let indicator_str = parts[0];
let indicator_id = Self::parse_indicator_id(indicator_str)?;
let (ma_type, period_str) = if parts.len() >= 3 {
let ma_str = parts[1];
let ma_type = Some(Self::parse_ma_type(ma_str)?);
(ma_type, parts[2])
} else {
(None, parts[1])
};
let period = period_str.parse::<u16>()
.map_err(|e| format!("Invalid period '{}': {}", period_str, e))?;
Ok(Self::new(indicator_id, period, ma_type))
}
fn parse_indicator_id(s: &str) -> Result<BarIndicatorId, String> {
let s_upper = s.to_uppercase();
match s_upper.as_str() {
"SMA" => Ok(BarIndicatorId::Sma),
"EMA" => Ok(BarIndicatorId::Ema),
"RMA" => Ok(BarIndicatorId::Rma),
"WMA" => Ok(BarIndicatorId::Wma),
"HMA" => Ok(BarIndicatorId::Hma),
"DEMA" => Ok(BarIndicatorId::Dema),
"TEMA" => Ok(BarIndicatorId::Tema),
"ATR" => Ok(BarIndicatorId::Atr),
"RSI" => Ok(BarIndicatorId::Rsi),
"MACD" => Ok(BarIndicatorId::Macd),
"BB" => Ok(BarIndicatorId::Bb),
"STOCH" => Ok(BarIndicatorId::Stoch),
"ADX" => Ok(BarIndicatorId::Adx),
"CCI" => Ok(BarIndicatorId::Cci),
"MFI" => Ok(BarIndicatorId::Mfi),
_ => Err(format!("Unknown indicator: {}", s))
}
}
fn parse_ma_type(s: &str) -> Result<MovingAverageType, String> {
match s.to_uppercase().as_str() {
"SIMPLE" | "SMA" => Ok(MovingAverageType::SMA),
"EMA" => Ok(MovingAverageType::EMA),
"WILDER" | "RMA" => Ok(MovingAverageType::RMA),
"WMA" => Ok(MovingAverageType::WMA),
"HMA" => Ok(MovingAverageType::HMA),
"DEMA" => Ok(MovingAverageType::DEMA),
"TEMA" => Ok(MovingAverageType::TEMA),
"VWMA" => Ok(MovingAverageType::VWMA),
"TMA" => Ok(MovingAverageType::TMA),
"VWAP" => Ok(MovingAverageType::VWAP),
"AMA" => Ok(MovingAverageType::AMA),
_ => Err(format!("Unknown MA type: {}", s))
}
}
}
impl Hash for IndicatorKey {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
(self.indicator_id as u16).hash(state);
self.period.hash(state);
if let Some(ma_type) = self.ma_type {
(ma_type as u8).hash(state);
}
(self.output as u8).hash(state);
self.param_hash.hash(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_indicator_key_creation() {
let key = IndicatorKey::new(BarIndicatorId::Sma, 20, None);
assert_eq!(key.indicator_id, BarIndicatorId::Sma);
assert_eq!(key.period, 20);
assert_eq!(key.ma_type, None);
}
#[test]
fn test_indicator_key_with_ma_type() {
let key = IndicatorKey::new(
BarIndicatorId::Atr,
14,
Some(MovingAverageType::RMA)
);
assert_eq!(key.indicator_id, BarIndicatorId::Atr);
assert_eq!(key.period, 14);
assert_eq!(key.ma_type, Some(MovingAverageType::RMA));
}
#[test]
fn test_param_hash_distinguishes_secondary_params() {
let a = IndicatorKey::new(BarIndicatorId::Macd, 12, None).with_param_hash(0x1234);
let b = IndicatorKey::new(BarIndicatorId::Macd, 12, None).with_param_hash(0x5678);
let c = IndicatorKey::new(BarIndicatorId::Macd, 12, None).with_param_hash(0x1234);
assert_ne!(a, b);
assert_eq!(a, c);
}
#[test]
fn test_indicator_key_equality() {
let key1 = IndicatorKey::new(BarIndicatorId::Sma, 20, None);
let key2 = IndicatorKey::new(BarIndicatorId::Sma, 20, None);
let key3 = IndicatorKey::new(BarIndicatorId::Ema, 20, None);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_indicator_key_copy() {
let key1 = IndicatorKey::new(BarIndicatorId::Sma, 20, None);
let key2 = key1;
assert_eq!(key1, key2);
}
#[test]
fn test_to_legacy_string_simple() {
let key = IndicatorKey::new(BarIndicatorId::Sma, 20, None);
assert_eq!(key.to_legacy_string(), "SMA_20");
}
#[test]
fn test_to_legacy_string_with_ma_type() {
let key = IndicatorKey::new(
BarIndicatorId::Atr,
14,
Some(MovingAverageType::RMA)
);
assert_eq!(key.to_legacy_string(), "ATR_Wilder_14");
}
#[test]
fn test_from_legacy_string_simple() {
let key = IndicatorKey::from_legacy_string("SMA_20").unwrap();
assert_eq!(key.indicator_id, BarIndicatorId::Sma);
assert_eq!(key.period, 20);
assert_eq!(key.ma_type, None);
}
#[test]
fn test_from_legacy_string_with_ma_type() {
let key = IndicatorKey::from_legacy_string("ATR_Wilder_14").unwrap();
assert_eq!(key.indicator_id, BarIndicatorId::Atr);
assert_eq!(key.period, 14);
assert_eq!(key.ma_type, Some(MovingAverageType::RMA));
}
#[test]
fn test_roundtrip_conversion() {
let original = IndicatorKey::new(
BarIndicatorId::Atr,
14,
Some(MovingAverageType::RMA)
);
let string = original.to_legacy_string();
let parsed = IndicatorKey::from_legacy_string(&string).unwrap();
assert_eq!(original, parsed);
}
#[test]
fn test_hash_consistency() {
use std::collections::HashMap;
let key = IndicatorKey::new(BarIndicatorId::Sma, 20, None);
let mut map = HashMap::new();
map.insert(key, 42);
assert_eq!(map.get(&key), Some(&42));
}
#[test]
fn test_memory_size() {
use std::mem::size_of;
let key_size = size_of::<IndicatorKey>();
assert!(key_size <= 16, "IndicatorKey is too large: {} bytes", key_size);
let string_size = size_of::<String>();
assert!(key_size < string_size,
"IndicatorKey ({} bytes) should be smaller than String ({} bytes)",
key_size, string_size);
}
}