use crate::error::{Error, Result};
use crate::ohlcv::Candle;
use crate::traits::Indicator;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Trend {
Up,
Down,
}
#[derive(Debug, Clone, Copy)]
struct Accel {
init: f64,
step: f64,
max: f64,
}
impl Accel {
fn validate(self) -> Result<Self> {
if !(self.init.is_finite() && self.step.is_finite() && self.max.is_finite()) {
return Err(Error::NonPositiveMultiplier);
}
if self.init <= 0.0 || self.step <= 0.0 || self.max <= 0.0 {
return Err(Error::NonPositiveMultiplier);
}
if self.init > self.max {
return Err(Error::InvalidPeriod {
message: "acceleration init must be <= max",
});
}
Ok(self)
}
}
#[derive(Debug, Clone)]
pub struct SarExt {
start_value: f64,
offset_on_reverse: f64,
long: Accel,
short: Accel,
initialised: bool,
has_emitted: bool,
prev_high: f64,
prev_low: f64,
trend: Trend,
sar: f64,
ep: f64,
af: f64,
}
impl SarExt {
#[allow(clippy::too_many_arguments)]
pub fn new(
start_value: f64,
offset_on_reverse: f64,
accel_init_long: f64,
accel_long: f64,
accel_max_long: f64,
accel_init_short: f64,
accel_short: f64,
accel_max_short: f64,
) -> Result<Self> {
if !start_value.is_finite() || !offset_on_reverse.is_finite() || offset_on_reverse < 0.0 {
return Err(Error::NonPositiveMultiplier);
}
let long = Accel {
init: accel_init_long,
step: accel_long,
max: accel_max_long,
}
.validate()?;
let short = Accel {
init: accel_init_short,
step: accel_short,
max: accel_max_short,
}
.validate()?;
Ok(Self {
start_value,
offset_on_reverse,
long,
short,
initialised: false,
has_emitted: false,
prev_high: f64::NAN,
prev_low: f64::NAN,
trend: Trend::Up,
sar: f64::NAN,
ep: f64::NAN,
af: long.init,
})
}
pub fn classic() -> Self {
Self::new(0.0, 0.0, 0.02, 0.02, 0.20, 0.02, 0.02, 0.20)
.expect("classic SAREXT params are valid")
}
fn signed(&self, sar: f64) -> f64 {
match self.trend {
Trend::Up => sar,
Trend::Down => -sar,
}
}
}
impl Indicator for SarExt {
type Input = Candle;
type Output = f64;
fn update(&mut self, candle: Candle) -> Option<f64> {
if !self.initialised {
self.prev_high = candle.high;
self.prev_low = candle.low;
if self.start_value > 0.0 {
self.trend = Trend::Up;
self.sar = self.start_value;
self.ep = candle.high;
self.af = self.long.init;
} else if self.start_value < 0.0 {
self.trend = Trend::Down;
self.sar = -self.start_value;
self.ep = candle.low;
self.af = self.short.init;
} else {
self.trend = Trend::Up;
self.sar = candle.low;
self.ep = candle.high;
self.af = self.long.init;
}
self.initialised = true;
return None;
}
let mut new_sar = self.sar + self.af * (self.ep - self.sar);
let prev_h = self.prev_high;
let prev_l = self.prev_low;
new_sar = match self.trend {
Trend::Up => new_sar.min(prev_l).min(candle.low),
Trend::Down => new_sar.max(prev_h).max(candle.high),
};
let mut output_sar = new_sar;
let reversed = match self.trend {
Trend::Up => candle.low <= new_sar,
Trend::Down => candle.high >= new_sar,
};
if reversed {
output_sar = self.ep;
self.trend = match self.trend {
Trend::Up => Trend::Down,
Trend::Down => Trend::Up,
};
match self.trend {
Trend::Up => {
output_sar -= output_sar.abs() * self.offset_on_reverse;
self.ep = candle.high;
self.af = self.long.init;
}
Trend::Down => {
output_sar += output_sar.abs() * self.offset_on_reverse;
self.ep = candle.low;
self.af = self.short.init;
}
}
} else {
match self.trend {
Trend::Up => {
if candle.high > self.ep {
self.ep = candle.high;
self.af = (self.af + self.long.step).min(self.long.max);
}
}
Trend::Down => {
if candle.low < self.ep {
self.ep = candle.low;
self.af = (self.af + self.short.step).min(self.short.max);
}
}
}
}
self.sar = output_sar;
self.prev_high = candle.high;
self.prev_low = candle.low;
self.has_emitted = true;
Some(self.signed(output_sar))
}
fn reset(&mut self) {
self.initialised = false;
self.has_emitted = false;
self.prev_high = f64::NAN;
self.prev_low = f64::NAN;
self.trend = Trend::Up;
self.sar = f64::NAN;
self.ep = f64::NAN;
self.af = self.long.init;
}
fn warmup_period(&self) -> usize {
2
}
fn is_ready(&self) -> bool {
self.has_emitted
}
fn name(&self) -> &'static str {
"SAREXT"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
fn c(h: f64, l: f64, cl: f64) -> Candle {
Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
}
fn classic() -> SarExt {
SarExt::classic()
}
#[test]
fn rejects_invalid_params() {
assert!(SarExt::new(0.0, 0.0, 0.0, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
assert!(SarExt::new(0.0, 0.0, 0.02, 0.02, 0.2, 0.0, 0.02, 0.2).is_err());
assert!(SarExt::new(0.0, 0.0, 0.30, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
assert!(SarExt::new(0.0, 0.0, f64::NAN, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
assert!(SarExt::new(0.0, 0.0, 0.02, 0.02, 0.2, 0.02, f64::INFINITY, 0.2).is_err());
assert!(SarExt::new(f64::NAN, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
assert!(SarExt::new(0.0, -1.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
}
#[test]
fn accessors_and_metadata() {
let s = classic();
assert_eq!(s.warmup_period(), 2);
assert_eq!(s.name(), "SAREXT");
assert!(!s.is_ready());
}
#[test]
fn seed_returns_none_then_emits() {
let mut s = classic();
assert_eq!(s.update(c(11.0, 9.0, 10.0)), None);
assert!(!s.is_ready());
assert!(s.update(c(12.0, 10.0, 11.0)).is_some());
assert!(s.is_ready());
}
#[test]
fn uptrend_is_positive_and_below_lows() {
let candles: Vec<Candle> = (0..40)
.map(|i| {
let base = 100.0 + f64::from(i);
c(base + 0.5, base - 0.5, base)
})
.collect();
let mut s = classic();
let ok = s
.batch(&candles)
.iter()
.enumerate()
.all(|(i, v)| v.is_none_or(|x| x > 0.0 && x <= candles[i].low + 1e-9));
assert!(ok, "long-phase SAREXT must be positive and below the low");
}
#[test]
fn downtrend_is_negative_and_above_highs() {
let candles: Vec<Candle> = (0..40)
.rev()
.map(|i| {
let base = 100.0 + f64::from(i);
c(base + 0.5, base - 0.5, base)
})
.collect();
let mut s = classic();
let ok = s
.batch(&candles)
.iter()
.enumerate()
.skip(5)
.all(|(i, v)| v.is_none_or(|x| x < 0.0 && -x >= candles[i].high - 1e-9));
assert!(ok, "short-phase SAREXT must be negative and above the high");
}
#[test]
fn positive_start_value_begins_long() {
let mut s = SarExt::new(95.0, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).unwrap();
assert_eq!(s.update(c(101.0, 99.0, 100.0)), None);
let v = s.update(c(102.0, 100.0, 101.0)).unwrap();
assert!(v > 0.0);
}
#[test]
fn negative_start_value_begins_short() {
let mut s = SarExt::new(-105.0, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).unwrap();
assert_eq!(s.update(c(101.0, 99.0, 100.0)), None);
let v = s.update(c(100.0, 98.0, 99.0)).unwrap();
assert!(v < 0.0);
}
#[test]
fn offset_on_reverse_pushes_sar_further() {
let candles: Vec<Candle> = (0..12)
.map(|i| {
let base = if i < 6 {
100.0 - f64::from(i) * 2.0
} else {
88.0 + f64::from(i - 6) * 2.0
};
c(base + 1.0, base - 1.0, base)
})
.collect();
let plain = SarExt::new(0.0, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2)
.unwrap()
.batch(&candles);
let offset = SarExt::new(0.0, 0.1, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2)
.unwrap()
.batch(&candles);
assert_ne!(plain, offset);
}
#[test]
fn batch_equals_streaming() {
let candles: Vec<Candle> = (0..60)
.map(|i| {
let m = 100.0 + (f64::from(i) * 0.3).sin() * 8.0;
c(m + 1.0, m - 1.0, m)
})
.collect();
let mut a = classic();
let mut b = classic();
assert_eq!(
a.batch(&candles),
candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
);
}
#[test]
fn reset_allows_clean_reuse() {
let candles: Vec<Candle> = (0..40)
.map(|i| {
let base = 100.0 + f64::from(i);
c(base + 0.5, base - 0.5, base)
})
.collect();
let mut s = classic();
let first = s.batch(&candles);
assert!(s.is_ready());
s.reset();
assert!(!s.is_ready());
assert_eq!(first, s.batch(&candles));
}
}