use crate::error::{Error, Result};
use crate::indicators::dema::Dema;
use crate::indicators::ema::Ema;
use crate::indicators::macd::MacdOutput;
use crate::indicators::sma::Sma;
use crate::indicators::tema::Tema;
use crate::indicators::trima::Trima;
use crate::indicators::wma::Wma;
use crate::traits::Indicator;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MaType {
Sma,
Ema,
Wma,
Dema,
Tema,
Trima,
}
impl MaType {
pub fn from_code(code: u32) -> Result<Self> {
match code {
0 => Ok(Self::Sma),
1 => Ok(Self::Ema),
2 => Ok(Self::Wma),
3 => Ok(Self::Dema),
4 => Ok(Self::Tema),
5 => Ok(Self::Trima),
_ => Err(Error::InvalidPeriod {
message: "unsupported moving-average type code (expected 0..=5)",
}),
}
}
}
#[derive(Debug, Clone)]
enum Ma {
Sma(Sma),
Ema(Ema),
Wma(Wma),
Dema(Dema),
Tema(Tema),
Trima(Trima),
}
impl Ma {
fn new(kind: MaType, period: usize) -> Result<Self> {
Ok(match kind {
MaType::Sma => Self::Sma(Sma::new(period)?),
MaType::Ema => Self::Ema(Ema::new(period)?),
MaType::Wma => Self::Wma(Wma::new(period)?),
MaType::Dema => Self::Dema(Dema::new(period)?),
MaType::Tema => Self::Tema(Tema::new(period)?),
MaType::Trima => Self::Trima(Trima::new(period)?),
})
}
fn update(&mut self, value: f64) -> Option<f64> {
match self {
Self::Sma(m) => m.update(value),
Self::Ema(m) => m.update(value),
Self::Wma(m) => m.update(value),
Self::Dema(m) => m.update(value),
Self::Tema(m) => m.update(value),
Self::Trima(m) => m.update(value),
}
}
fn reset(&mut self) {
match self {
Self::Sma(m) => m.reset(),
Self::Ema(m) => m.reset(),
Self::Wma(m) => m.reset(),
Self::Dema(m) => m.reset(),
Self::Tema(m) => m.reset(),
Self::Trima(m) => m.reset(),
}
}
fn warmup_period(&self) -> usize {
match self {
Self::Sma(m) => m.warmup_period(),
Self::Ema(m) => m.warmup_period(),
Self::Wma(m) => m.warmup_period(),
Self::Dema(m) => m.warmup_period(),
Self::Tema(m) => m.warmup_period(),
Self::Trima(m) => m.warmup_period(),
}
}
}
#[derive(Debug, Clone)]
pub struct MacdExt {
fast: Ma,
slow: Ma,
signal: Ma,
has_emitted: bool,
}
impl MacdExt {
pub fn new(
fast: usize,
fast_type: MaType,
slow: usize,
slow_type: MaType,
signal: usize,
signal_type: MaType,
) -> Result<Self> {
if fast == 0 || slow == 0 || signal == 0 {
return Err(Error::PeriodZero);
}
if fast >= slow {
return Err(Error::InvalidPeriod {
message: "fast period must be < slow period",
});
}
Ok(Self {
fast: Ma::new(fast_type, fast)?,
slow: Ma::new(slow_type, slow)?,
signal: Ma::new(signal_type, signal)?,
has_emitted: false,
})
}
}
impl Indicator for MacdExt {
type Input = f64;
type Output = MacdOutput;
fn update(&mut self, value: f64) -> Option<MacdOutput> {
let fast_v = self.fast.update(value);
let slow_v = self.slow.update(value);
let (Some(fast_v), Some(slow_v)) = (fast_v, slow_v) else {
return None;
};
let macd = fast_v - slow_v;
let signal = self.signal.update(macd)?;
self.has_emitted = true;
Some(MacdOutput {
macd,
signal,
histogram: macd - signal,
})
}
fn reset(&mut self) {
self.fast.reset();
self.slow.reset();
self.signal.reset();
self.has_emitted = false;
}
fn warmup_period(&self) -> usize {
self.slow.warmup_period() + self.signal.warmup_period()
}
fn is_ready(&self) -> bool {
self.has_emitted
}
fn name(&self) -> &'static str {
"MACDEXT"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
const TYPES: [MaType; 6] = [
MaType::Sma,
MaType::Ema,
MaType::Wma,
MaType::Dema,
MaType::Tema,
MaType::Trima,
];
#[test]
fn from_code_maps_all_supported_types() {
assert_eq!(MaType::from_code(0).unwrap(), MaType::Sma);
assert_eq!(MaType::from_code(1).unwrap(), MaType::Ema);
assert_eq!(MaType::from_code(2).unwrap(), MaType::Wma);
assert_eq!(MaType::from_code(3).unwrap(), MaType::Dema);
assert_eq!(MaType::from_code(4).unwrap(), MaType::Tema);
assert_eq!(MaType::from_code(5).unwrap(), MaType::Trima);
assert!(MaType::from_code(6).is_err());
}
#[test]
fn rejects_invalid_periods() {
assert!(matches!(
MacdExt::new(0, MaType::Ema, 26, MaType::Ema, 9, MaType::Ema),
Err(Error::PeriodZero)
));
assert!(matches!(
MacdExt::new(26, MaType::Ema, 12, MaType::Ema, 9, MaType::Ema),
Err(Error::InvalidPeriod { .. })
));
}
#[test]
fn accessors_and_metadata() {
let m = MacdExt::new(12, MaType::Ema, 26, MaType::Sma, 9, MaType::Sma).unwrap();
assert_eq!(m.name(), "MACDEXT");
assert!(!m.is_ready());
assert!(m.warmup_period() >= 26);
}
#[test]
fn every_ma_type_produces_a_consistent_histogram() {
let prices: Vec<f64> = (0..120)
.map(|i| 100.0 + (f64::from(i) * 0.2).sin() * 6.0)
.collect();
for &t in &TYPES {
let mut m = MacdExt::new(5, t, 10, t, 4, t).unwrap();
let out: Vec<Option<MacdOutput>> = m.batch(&prices);
assert!(out.iter().any(Option::is_some), "{t:?} never emitted");
for o in out.into_iter().flatten() {
assert!((o.histogram - (o.macd - o.signal)).abs() < 1e-9);
}
assert!(m.warmup_period() >= 10);
assert!(m.is_ready());
m.reset();
assert!(!m.is_ready());
}
}
#[test]
fn mixed_ma_types_per_line() {
let prices: Vec<f64> = (0..120).map(|i| 100.0 + f64::from(i)).collect();
let mut m = MacdExt::new(12, MaType::Wma, 26, MaType::Dema, 9, MaType::Trima).unwrap();
let last = m.batch(&prices).into_iter().flatten().last();
assert!(last.is_some());
}
}