use std::collections::HashMap;
use crate::error::IndicatorError;
use crate::functions::{self};
use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
use crate::types::Candle;
#[derive(Debug, Clone)]
pub struct MacdParams {
pub fast_period: usize,
pub slow_period: usize,
pub signal_period: usize,
pub column: PriceColumn,
}
impl Default for MacdParams {
fn default() -> Self {
Self {
fast_period: 12,
slow_period: 26,
signal_period: 9,
column: PriceColumn::Close,
}
}
}
#[derive(Debug, Clone)]
pub struct Macd {
pub params: MacdParams,
}
impl Macd {
pub fn new(params: MacdParams) -> Self {
Self { params }
}
}
impl Default for Macd {
fn default() -> Self {
Self::new(MacdParams::default())
}
}
impl Indicator for Macd {
fn name(&self) -> &'static str {
"MACD"
}
fn required_len(&self) -> usize {
self.params.slow_period + self.params.signal_period
}
fn required_columns(&self) -> &[&'static str] {
&["close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let prices = self.params.column.extract(candles);
let (macd_line, signal_line, histogram) = functions::macd(
&prices,
self.params.fast_period,
self.params.slow_period,
self.params.signal_period,
)?;
Ok(IndicatorOutput::from_pairs([
("MACD_line".to_string(), macd_line),
("MACD_signal".to_string(), signal_line),
("MACD_histogram".to_string(), histogram),
]))
}
}
pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
Ok(Box::new(Macd::new(MacdParams {
fast_period: crate::registry::param_usize(params, "fast_period", 12)?,
slow_period: crate::registry::param_usize(params, "slow_period", 26)?,
signal_period: crate::registry::param_usize(params, "signal_period", 9)?,
column: PriceColumn::Close,
})))
}
#[cfg(test)]
mod tests {
use super::*;
fn candles(closes: &[f64]) -> Vec<Candle> {
closes
.iter()
.enumerate()
.map(|(i, &c)| Candle {
time: i64::try_from(i).expect("time index fits i64"),
open: c,
high: c,
low: c,
close: c,
volume: 1.0,
})
.collect()
}
#[test]
fn macd_insufficient_data() {
let macd = Macd::default();
assert!(macd.calculate(&candles(&[1.0; 10])).is_err());
}
#[test]
fn macd_output_has_three_columns() {
let macd = Macd::default();
let closes: Vec<f64> = (1..=50).map(|x| x as f64).collect();
let out = macd.calculate(&candles(&closes)).unwrap();
assert!(out.get("MACD_line").is_some(), "missing MACD_line");
assert!(out.get("MACD_signal").is_some(), "missing MACD_signal");
assert!(
out.get("MACD_histogram").is_some(),
"missing MACD_histogram"
);
}
#[test]
fn macd_histogram_is_line_minus_signal() {
let macd = Macd::default();
let closes: Vec<f64> = (1..=50).map(|x| x as f64).collect();
let out = macd.calculate(&candles(&closes)).unwrap();
let line = out.get("MACD_line").unwrap();
let signal = out.get("MACD_signal").unwrap();
let hist = out.get("MACD_histogram").unwrap();
for i in 0..line.len() {
if !line[i].is_nan() && !signal[i].is_nan() {
assert!((hist[i] - (line[i] - signal[i])).abs() < 1e-9);
}
}
}
#[test]
fn factory_creates_macd() {
let params = HashMap::new();
let ind = factory(¶ms).unwrap();
assert_eq!(ind.name(), "MACD");
}
}