use crate::indicators::IndicatorError;
use crate::signals::{Signal, SignalEvent};
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy)]
struct Pivot {
price: f64,
osc: f64,
bar: usize,
}
#[derive(Debug)]
pub struct Divergence {
lookback: usize,
min_distance: usize,
window: VecDeque<(f64, f64)>,
last_high: Option<Pivot>,
last_low: Option<Pivot>,
seen: usize,
}
impl Divergence {
pub fn new(lookback: usize) -> Result<Self, IndicatorError> {
if lookback == 0 {
return Err(IndicatorError::InvalidParameter(
"Divergence lookback must be at least 1".to_string(),
));
}
Ok(Self {
lookback,
min_distance: lookback + 1,
window: VecDeque::with_capacity(2 * lookback + 1),
last_high: None,
last_low: None,
seen: 0,
})
}
pub fn with_min_distance(mut self, min_distance: usize) -> Self {
self.min_distance = min_distance;
self
}
pub fn last_swing_high_bar(&self) -> Option<usize> {
self.last_high.map(|p| p.bar)
}
pub fn last_swing_low_bar(&self) -> Option<usize> {
self.last_low.map(|p| p.bar)
}
}
impl Signal for Divergence {
type Input = (f64, f64);
fn next(&mut self, (price, osc): (f64, f64)) -> Option<SignalEvent> {
self.seen += 1;
self.window.push_back((price, osc));
let cap = 2 * self.lookback + 1;
if self.window.len() < cap {
return None; }
if self.window.len() > cap {
self.window.pop_front();
}
let center_bar = self.seen - 1 - self.lookback;
let (cp, co) = self.window[self.lookback];
let is_high = self
.window
.iter()
.enumerate()
.all(|(i, &(p, _))| i == self.lookback || p < cp);
let is_low = self
.window
.iter()
.enumerate()
.all(|(i, &(p, _))| i == self.lookback || p > cp);
if is_high {
let new = Pivot {
price: cp,
osc: co,
bar: center_bar,
};
let event = self.last_high.and_then(|prev| {
if center_bar.saturating_sub(prev.bar) < self.min_distance {
return None;
}
if new.price > prev.price && new.osc < prev.osc {
Some(SignalEvent::Short)
} else {
None
}
});
self.last_high = Some(new);
return Some(event.unwrap_or(SignalEvent::Hold));
}
if is_low {
let new = Pivot {
price: cp,
osc: co,
bar: center_bar,
};
let event = self.last_low.and_then(|prev| {
if center_bar.saturating_sub(prev.bar) < self.min_distance {
return None;
}
if new.price < prev.price && new.osc > prev.osc {
Some(SignalEvent::Long)
} else {
None
}
});
self.last_low = Some(new);
return Some(event.unwrap_or(SignalEvent::Hold));
}
Some(SignalEvent::Hold)
}
fn reset(&mut self) {
self.window.clear();
self.last_high = None;
self.last_low = None;
self.seen = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run(div: &mut Divergence, series: &[(f64, f64)]) -> Vec<SignalEvent> {
let mut out = Vec::new();
for &x in series {
if let Some(e) = div.next(x) {
out.push(e);
}
}
out
}
#[test]
fn validates_lookback() {
assert!(Divergence::new(0).is_err());
assert!(Divergence::new(1).is_ok());
}
#[test]
fn warmup_emits_none() {
let mut d = Divergence::new(2).unwrap();
for i in 0..4 {
let v = i as f64;
assert!(d.next((v, v)).is_none(), "premature emission at bar {i}");
}
assert!(d.next((4.0, 4.0)).is_some());
}
#[test]
fn bullish_divergence_emits_long() {
let mut d = Divergence::new(2).unwrap();
let series = [
(10.0, 50.0),
(9.0, 35.0),
(8.0, 25.0), (9.0, 35.0),
(10.0, 45.0),
(12.0, 60.0),
(10.0, 50.0),
(8.0, 40.0),
(6.0, 30.0), (7.0, 38.0),
(9.0, 50.0),
];
let events = run(&mut d, &series);
assert!(
events.iter().any(|e| matches!(e, SignalEvent::Long)),
"no Long event in {events:?}"
);
assert!(events.iter().all(|e| !matches!(e, SignalEvent::Short)));
}
#[test]
fn bearish_divergence_emits_short() {
let mut d = Divergence::new(2).unwrap();
let series = [
(10.0, 50.0),
(11.0, 55.0),
(12.0, 60.0), (11.0, 55.0),
(10.0, 50.0),
(8.0, 40.0),
(10.0, 45.0),
(12.0, 48.0),
(14.0, 50.0), (13.0, 47.0),
(11.0, 40.0),
];
let events = run(&mut d, &series);
assert!(
events.iter().any(|e| matches!(e, SignalEvent::Short)),
"no Short event in {events:?}"
);
assert!(events.iter().all(|e| !matches!(e, SignalEvent::Long)));
}
#[test]
fn matching_trends_dont_fire() {
let mut d = Divergence::new(2).unwrap();
let series = [
(10.0, 50.0),
(11.0, 55.0),
(12.0, 60.0), (11.0, 55.0),
(10.0, 50.0),
(8.0, 40.0),
(10.0, 50.0),
(12.0, 60.0),
(14.0, 70.0), (13.0, 65.0),
(11.0, 55.0),
];
let events = run(&mut d, &series);
assert!(
events.iter().all(|e| matches!(e, SignalEvent::Hold)),
"expected only Hold, got {events:?}",
);
}
#[test]
fn reset_clears_state() {
let mut d = Divergence::new(1).unwrap();
for i in 0..5 {
d.next((i as f64, i as f64));
}
d.reset();
assert!(d.next((10.0, 10.0)).is_none());
assert!(d.next((11.0, 11.0)).is_none());
assert!(d.next((12.0, 12.0)).is_some());
}
}