use crate::indicators::trend::Ema;
use crate::indicators::validate_period;
use crate::indicators::{Candle, Indicator, IndicatorError};
#[derive(Debug)]
pub struct Macd {
fast_period: usize,
slow_period: usize,
signal_period: usize,
fast_ema: Ema,
slow_ema: Ema,
signal_ema: Ema,
current_macd: Option<f64>,
current_signal: Option<f64>,
current_histogram: Option<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MacdResult {
pub macd: f64,
pub signal: f64,
pub histogram: f64,
}
impl Macd {
pub fn new(
fast_period: usize,
slow_period: usize,
signal_period: usize,
) -> Result<Self, IndicatorError> {
validate_period(fast_period, 1)?;
validate_period(slow_period, 1)?;
validate_period(signal_period, 1)?;
if fast_period >= slow_period {
return Err(IndicatorError::InvalidParameter(
"Slow period must be greater than fast period".to_string(),
));
}
Ok(Self {
fast_period,
slow_period,
signal_period,
fast_ema: Ema::new(fast_period)?,
slow_ema: Ema::new(slow_period)?,
signal_ema: Ema::new(signal_period)?,
current_macd: None,
current_signal: None,
current_histogram: None,
})
}
pub fn reset_state(&mut self) {
<Ema as Indicator<f64, f64>>::reset(&mut self.fast_ema);
<Ema as Indicator<f64, f64>>::reset(&mut self.slow_ema);
<Ema as Indicator<f64, f64>>::reset(&mut self.signal_ema);
self.current_macd = None;
self.current_signal = None;
self.current_histogram = None;
}
}
impl Indicator<f64, MacdResult> for Macd {
fn calculate(&mut self, data: &[f64]) -> Result<Vec<MacdResult>, IndicatorError> {
if data.is_empty() {
return Err(IndicatorError::InsufficientData(format!(
"At least 1 data point required for MACD({},{},{})",
self.fast_period, self.slow_period, self.signal_period,
)));
}
self.reset_state();
let mut result = Vec::with_capacity(data.len());
for &v in data {
if let Some(r) = <Self as Indicator<f64, MacdResult>>::next(self, v)? {
result.push(r);
}
}
Ok(result)
}
fn next(&mut self, value: f64) -> Result<Option<MacdResult>, IndicatorError> {
let fast_ema = self.fast_ema.next(value)?.unwrap_or(value);
let slow_ema = self.slow_ema.next(value)?.unwrap_or(value);
let macd = fast_ema - slow_ema;
self.current_macd = Some(macd);
let signal = if let Some(signal_value) = self.signal_ema.next(macd)? {
signal_value
} else {
macd
};
self.current_signal = Some(signal);
let histogram = macd - signal;
self.current_histogram = Some(histogram);
if self.current_macd.is_some()
&& self.current_signal.is_some()
&& self.current_histogram.is_some()
{
return Ok(Some(MacdResult {
macd,
signal,
histogram,
}));
}
Ok(None)
}
fn reset(&mut self) {
self.reset_state();
}
}
impl Indicator<Candle, MacdResult> for Macd {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<MacdResult>, IndicatorError> {
let close_prices: Vec<f64> = data.iter().map(|candle| candle.close).collect();
self.calculate(&close_prices)
}
fn next(&mut self, candle: Candle) -> Result<Option<MacdResult>, IndicatorError> {
let close_price = candle.close;
self.next(close_price)
}
fn reset(&mut self) {
self.reset_state();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_macd_new() {
assert!(Macd::new(12, 26, 9).is_ok());
assert!(Macd::new(26, 12, 9).is_err());
assert!(Macd::new(0, 26, 9).is_err());
assert!(Macd::new(12, 0, 9).is_err());
assert!(Macd::new(12, 26, 0).is_err());
}
#[test]
fn test_macd_calculation() {
let mut macd = Macd::new(3, 6, 2).unwrap();
let prices: Vec<f64> = (1..=20).map(|i| i as f64 * 2.0).collect();
assert!(prices.len() >= macd.slow_period + macd.signal_period - 1);
let result = macd.calculate(&prices).unwrap();
assert_eq!(result.len(), prices.len());
assert!((result[0].histogram).abs() < 1e-12);
assert!(result.last().unwrap().macd > 0.0);
for r in &result {
assert!((r.histogram - (r.macd - r.signal)).abs() < 1e-12);
}
}
#[test]
fn test_macd_calculate_matches_streaming() {
let prices: Vec<f64> = (1..=40).map(|i| i as f64).collect();
let mut batch = Macd::new(12, 26, 9).unwrap();
let batch_out = batch.calculate(&prices).unwrap();
let mut stream = Macd::new(12, 26, 9).unwrap();
let stream_out: Vec<_> = prices
.iter()
.filter_map(|&p| stream.next(p).unwrap())
.collect();
assert_eq!(batch_out, stream_out);
}
#[test]
fn test_macd_next() {
let mut macd = Macd::new(3, 6, 2).unwrap();
for i in 1..=15 {
let price = i as f64 * 2.0;
macd.next(price).unwrap();
}
let result = macd.next(32.0).unwrap();
assert!(result.is_some());
let output = result.unwrap();
assert!(output.macd > 0.0);
assert!(output.signal > 0.0);
}
#[test]
fn test_macd_reset() {
let mut macd = Macd::new(3, 6, 2).unwrap();
for i in 1..=10 {
macd.next(i as f64 * 2.0).unwrap();
}
macd.reset_state();
assert!(macd.current_macd.is_none());
assert!(macd.current_signal.is_none());
assert!(macd.current_histogram.is_none());
}
#[test]
fn test_macd_with_candles() {
let mut macd = Macd::new(3, 6, 2).unwrap();
let mut candles = Vec::new();
for i in 1..=20 {
let price = i as f64 * 2.0;
candles.push(Candle {
timestamp: i as u64,
open: price - 0.5,
high: price + 1.0,
low: price - 1.0,
close: price,
volume: 1000.0,
});
}
let result = macd.calculate(&candles).unwrap();
assert_eq!(result.len(), candles.len());
assert!(result.last().unwrap().macd > 0.0);
}
#[test]
fn test_macd_implementations_produce_same_results() {
let mut macd_f64 = Macd::new(3, 6, 2).unwrap();
let mut macd_candle = Macd::new(3, 6, 2).unwrap();
let prices: Vec<f64> = (1..=20).map(|i| i as f64 * 2.0).collect();
let candles: Vec<Candle> = prices
.iter()
.enumerate()
.map(|(i, &price)| Candle {
timestamp: i as u64,
open: price - 0.5,
high: price + 1.0,
low: price - 1.0,
close: price,
volume: 1000.0,
})
.collect();
let result_f64 = macd_f64.calculate(&prices).unwrap();
let result_candle = macd_candle.calculate(&candles).unwrap();
assert_eq!(result_f64.len(), result_candle.len());
for (i, (out_f64, out_candle)) in result_f64.iter().zip(result_candle.iter()).enumerate() {
assert!(
(out_f64.macd - out_candle.macd).abs() < 0.000001,
"MACD values differ at index {}: {} vs {}",
i,
out_f64.macd,
out_candle.macd
);
assert!(
(out_f64.signal - out_candle.signal).abs() < 0.000001,
"Signal values differ at index {}: {} vs {}",
i,
out_f64.signal,
out_candle.signal
);
assert!(
(out_f64.histogram - out_candle.histogram).abs() < 0.000001,
"Histogram values differ at index {}: {} vs {}",
i,
out_f64.histogram,
out_candle.histogram
);
}
}
}