use crate::indicators::trend::Ema;
use crate::indicators::validate_period;
use crate::indicators::{Candle, Indicator, IndicatorError};
#[derive(Debug)]
pub struct Macd {
fast_period: usize,
slow_period: usize,
signal_period: usize,
fast_ema: Ema,
slow_ema: Ema,
signal_ema: Ema,
current_macd: Option<f64>,
current_signal: Option<f64>,
current_histogram: Option<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MacdResult {
pub macd: f64,
pub signal: f64,
pub histogram: f64,
}
impl Macd {
pub fn new(
fast_period: usize,
slow_period: usize,
signal_period: usize,
) -> Result<Self, IndicatorError> {
validate_period(fast_period, 1)?;
validate_period(slow_period, 1)?;
validate_period(signal_period, 1)?;
if fast_period >= slow_period {
return Err(IndicatorError::InvalidParameter(
"Slow period must be greater than fast period".to_string(),
));
}
Ok(Self {
fast_period,
slow_period,
signal_period,
fast_ema: Ema::new(fast_period)?,
slow_ema: Ema::new(slow_period)?,
signal_ema: Ema::new(signal_period)?,
current_macd: None,
current_signal: None,
current_histogram: None,
})
}
pub fn reset_state(&mut self) {
<Ema as Indicator<f64, f64>>::reset(&mut self.fast_ema);
<Ema as Indicator<f64, f64>>::reset(&mut self.slow_ema);
<Ema as Indicator<f64, f64>>::reset(&mut self.signal_ema);
self.current_macd = None;
self.current_signal = None;
self.current_histogram = None;
}
}
impl Indicator<f64, MacdResult> for Macd {
fn calculate(&mut self, data: &[f64]) -> Result<Vec<MacdResult>, IndicatorError> {
if data.len() < self.slow_period + self.signal_period - 1 {
return Err(IndicatorError::InsufficientData(format!(
"At least {} data points required for MACD({},{},{})",
self.slow_period + self.signal_period - 1,
self.fast_period,
self.slow_period,
self.signal_period
)));
}
let fast_ema_values = self.fast_ema.calculate(data)?;
let slow_ema_values = self.slow_ema.calculate(data)?;
let mut macd_line = Vec::new();
let start_idx = self.slow_period - 1;
for i in start_idx..data.len() {
let fast_idx = i - (self.slow_period - self.fast_period);
macd_line.push(fast_ema_values[fast_idx] - slow_ema_values[i - start_idx]);
}
let signal_values = self.signal_ema.calculate(&macd_line)?;
let mut result = Vec::new();
let signal_start_idx = self.signal_period - 1;
for i in signal_start_idx..macd_line.len() {
let macd = macd_line[i];
let signal = signal_values[i - signal_start_idx];
let histogram = macd - signal;
result.push(MacdResult {
macd,
signal,
histogram,
});
}
if let Some(last) = result.last() {
self.current_macd = Some(last.macd);
self.current_signal = Some(last.signal);
self.current_histogram = Some(last.histogram);
}
Ok(result)
}
fn next(&mut self, value: f64) -> Result<Option<MacdResult>, IndicatorError> {
let fast_ema = self.fast_ema.next(value)?.unwrap_or(value);
let slow_ema = self.slow_ema.next(value)?.unwrap_or(value);
let macd = fast_ema - slow_ema;
self.current_macd = Some(macd);
let signal = if let Some(signal_value) = self.signal_ema.next(macd)? {
signal_value
} else {
macd
};
self.current_signal = Some(signal);
let histogram = macd - signal;
self.current_histogram = Some(histogram);
if self.current_macd.is_some()
&& self.current_signal.is_some()
&& self.current_histogram.is_some()
{
return Ok(Some(MacdResult {
macd,
signal,
histogram,
}));
}
Ok(None)
}
fn reset(&mut self) {
self.reset_state();
}
}
impl Indicator<Candle, MacdResult> for Macd {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<MacdResult>, IndicatorError> {
let close_prices: Vec<f64> = data.iter().map(|candle| candle.close).collect();
self.calculate(&close_prices)
}
fn next(&mut self, candle: Candle) -> Result<Option<MacdResult>, IndicatorError> {
let close_price = candle.close;
self.next(close_price)
}
fn reset(&mut self) {
self.reset_state();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_macd_new() {
assert!(Macd::new(12, 26, 9).is_ok());
assert!(Macd::new(26, 12, 9).is_err());
assert!(Macd::new(0, 26, 9).is_err());
assert!(Macd::new(12, 0, 9).is_err());
assert!(Macd::new(12, 26, 0).is_err());
}
#[test]
fn test_macd_calculation() {
let mut macd = Macd::new(3, 6, 2).unwrap();
let prices: Vec<f64> = (1..=20).map(|i| i as f64 * 2.0).collect();
assert!(prices.len() >= macd.slow_period + macd.signal_period - 1);
let result = macd.calculate(&prices).unwrap();
let expected_len = prices.len() - macd.slow_period - macd.signal_period + 2;
assert_eq!(result.len(), expected_len);
for output in &result {
assert!(output.macd != 0.0);
assert!(output.signal != 0.0);
}
assert!(result.last().unwrap().macd > 0.0);
}
#[test]
fn test_macd_next() {
let mut macd = Macd::new(3, 6, 2).unwrap();
for i in 1..=15 {
let price = i as f64 * 2.0;
macd.next(price).unwrap();
}
let result = macd.next(32.0).unwrap();
assert!(result.is_some());
let output = result.unwrap();
assert!(output.macd > 0.0);
assert!(output.signal > 0.0);
}
#[test]
fn test_macd_reset() {
let mut macd = Macd::new(3, 6, 2).unwrap();
for i in 1..=10 {
macd.next(i as f64 * 2.0).unwrap();
}
macd.reset_state();
assert!(macd.current_macd.is_none());
assert!(macd.current_signal.is_none());
assert!(macd.current_histogram.is_none());
}
#[test]
fn test_macd_with_candles() {
let mut macd = Macd::new(3, 6, 2).unwrap();
let mut candles = Vec::new();
for i in 1..=20 {
let price = i as f64 * 2.0;
candles.push(Candle {
timestamp: i as u64,
open: price - 0.5,
high: price + 1.0,
low: price - 1.0,
close: price,
volume: 1000.0,
});
}
let result = macd.calculate(&candles).unwrap();
let expected_len = candles.len() - macd.slow_period - macd.signal_period + 2;
assert_eq!(result.len(), expected_len);
assert!(result.last().unwrap().macd > 0.0);
}
#[test]
fn test_macd_implementations_produce_same_results() {
let mut macd_f64 = Macd::new(3, 6, 2).unwrap();
let mut macd_candle = Macd::new(3, 6, 2).unwrap();
let prices: Vec<f64> = (1..=20).map(|i| i as f64 * 2.0).collect();
let candles: Vec<Candle> = prices
.iter()
.enumerate()
.map(|(i, &price)| Candle {
timestamp: i as u64,
open: price - 0.5,
high: price + 1.0,
low: price - 1.0,
close: price,
volume: 1000.0,
})
.collect();
let result_f64 = macd_f64.calculate(&prices).unwrap();
let result_candle = macd_candle.calculate(&candles).unwrap();
assert_eq!(result_f64.len(), result_candle.len());
for (i, (out_f64, out_candle)) in result_f64.iter().zip(result_candle.iter()).enumerate() {
assert!(
(out_f64.macd - out_candle.macd).abs() < 0.000001,
"MACD values differ at index {}: {} vs {}",
i,
out_f64.macd,
out_candle.macd
);
assert!(
(out_f64.signal - out_candle.signal).abs() < 0.000001,
"Signal values differ at index {}: {} vs {}",
i,
out_f64.signal,
out_candle.signal
);
assert!(
(out_f64.histogram - out_candle.histogram).abs() < 0.000001,
"Histogram values differ at index {}: {} vs {}",
i,
out_f64.histogram,
out_candle.histogram
);
}
}
}