use crate::error::{Error, Result};
use crate::indicators::ema::Ema;
use crate::traits::Indicator;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MacdOutput {
pub macd: f64,
pub signal: f64,
pub histogram: f64,
}
#[derive(Debug, Clone)]
pub struct MacdIndicator {
fast: Ema,
slow: Ema,
signal_ema: Ema,
fast_period: usize,
slow_period: usize,
signal_period: usize,
last: Option<MacdOutput>,
}
impl MacdIndicator {
pub fn new(fast: usize, slow: usize, signal: usize) -> Result<Self> {
if fast == 0 || slow == 0 || signal == 0 {
return Err(Error::PeriodZero);
}
if fast >= slow {
return Err(Error::InvalidPeriod {
message: "fast period must be strictly less than slow period",
});
}
Ok(Self {
fast: Ema::new(fast)?,
slow: Ema::new(slow)?,
signal_ema: Ema::new(signal)?,
fast_period: fast,
slow_period: slow,
signal_period: signal,
last: None,
})
}
pub fn classic() -> Self {
Self::new(12, 26, 9).expect("classic MACD periods are valid")
}
pub const fn periods(&self) -> (usize, usize, usize) {
(self.fast_period, self.slow_period, self.signal_period)
}
pub const fn value(&self) -> Option<MacdOutput> {
self.last
}
}
impl Indicator for MacdIndicator {
type Input = f64;
type Output = MacdOutput;
fn update(&mut self, input: f64) -> Option<MacdOutput> {
if !input.is_finite() {
return self.last;
}
let fast = self.fast.update(input);
let slow = self.slow.update(input);
match (fast, slow) {
(Some(f), Some(s)) => {
let macd = f - s;
let signal = self.signal_ema.update(macd)?;
let out = MacdOutput {
macd,
signal,
histogram: macd - signal,
};
self.last = Some(out);
Some(out)
}
_ => None,
}
}
fn reset(&mut self) {
self.fast.reset();
self.slow.reset();
self.signal_ema.reset();
self.last = None;
}
fn warmup_period(&self) -> usize {
self.slow_period + self.signal_period - 1
}
fn is_ready(&self) -> bool {
self.last.is_some()
}
fn name(&self) -> &'static str {
"MACD"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
#[test]
fn rejects_fast_geq_slow() {
assert!(matches!(
MacdIndicator::new(26, 12, 9),
Err(Error::InvalidPeriod { .. })
));
assert!(matches!(
MacdIndicator::new(12, 12, 9),
Err(Error::InvalidPeriod { .. })
));
}
#[test]
fn rejects_zero_periods() {
assert!(matches!(
MacdIndicator::new(0, 26, 9),
Err(Error::PeriodZero)
));
assert!(matches!(
MacdIndicator::new(12, 0, 9),
Err(Error::PeriodZero)
));
assert!(matches!(
MacdIndicator::new(12, 26, 0),
Err(Error::PeriodZero)
));
}
#[test]
fn first_emission_matches_warmup_period() {
let prices: Vec<f64> = (1..=60).map(f64::from).collect();
let mut macd = MacdIndicator::classic();
let out = macd.batch(&prices);
let warmup = macd.warmup_period();
for x in out.iter().take(warmup - 1) {
assert!(x.is_none(), "expected None within warmup");
}
assert!(
out[warmup - 1].is_some(),
"expected first emission at warmup_period - 1 ({warmup} idx)"
);
}
#[test]
fn histogram_equals_macd_minus_signal() {
let prices: Vec<f64> = (1..=80).map(|i| f64::from(i) * 0.5).collect();
let mut macd = MacdIndicator::classic();
for v in macd.batch(&prices).into_iter().flatten() {
assert_relative_eq!(v.histogram, v.macd - v.signal, epsilon = 1e-12);
}
}
#[test]
fn constant_series_yields_zero_macd_eventually() {
let mut macd = MacdIndicator::classic();
let out = macd.batch(&[100.0_f64; 200]);
let last = out.iter().rev().flatten().next().expect("emits a value");
assert_relative_eq!(last.macd, 0.0, epsilon = 1e-9);
assert_relative_eq!(last.signal, 0.0, epsilon = 1e-9);
assert_relative_eq!(last.histogram, 0.0, epsilon = 1e-9);
}
#[test]
fn rising_series_macd_positive_then_signal_catches_up() {
let prices: Vec<f64> = (1..=200).map(f64::from).collect();
let mut macd = MacdIndicator::classic();
let out = macd.batch(&prices);
let last = out.iter().rev().flatten().next().unwrap();
assert!(last.macd > 0.0, "rising series must yield positive MACD");
}
#[test]
fn batch_equals_streaming() {
let prices: Vec<f64> = (1..=100)
.map(|i| (f64::from(i) * 0.4).cos() * 10.0)
.collect();
let mut a = MacdIndicator::classic();
let mut b = MacdIndicator::classic();
assert_eq!(
a.batch(&prices),
prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
);
}
#[test]
fn reset_clears_state() {
let mut macd = MacdIndicator::classic();
macd.batch(&(1..=80).map(f64::from).collect::<Vec<_>>());
assert!(macd.is_ready());
macd.reset();
assert!(!macd.is_ready());
assert_eq!(macd.update(1.0), None);
}
}