use crate::indicators::utils::validate_data_length;
use crate::indicators::{Candle, Indicator, IndicatorError};
#[derive(Debug)]
pub struct Sar {
af_start: f64,
af_step: f64,
af_max: f64,
long: bool,
sar: f64,
ep: f64,
af: f64,
prev_high: f64,
prev_low: f64,
seen: usize,
}
impl Sar {
pub fn new(af_start: f64, af_step: f64, af_max: f64) -> Result<Self, IndicatorError> {
if af_start <= 0.0 || af_step <= 0.0 || af_max <= 0.0 {
return Err(IndicatorError::InvalidParameter(
"Parabolic SAR factors must be positive".to_string(),
));
}
if af_start > af_max || af_step > af_max {
return Err(IndicatorError::InvalidParameter(
"af_start and af_step must be <= af_max".to_string(),
));
}
Ok(Self {
af_start,
af_step,
af_max,
long: true,
sar: 0.0,
ep: 0.0,
af: af_start,
prev_high: 0.0,
prev_low: 0.0,
seen: 0,
})
}
pub fn default_params() -> Self {
Self::new(0.02, 0.02, 0.20).expect("canonical params are valid")
}
pub fn reset_state(&mut self) {
self.long = true;
self.sar = 0.0;
self.ep = 0.0;
self.af = self.af_start;
self.prev_high = 0.0;
self.prev_low = 0.0;
self.seen = 0;
}
fn step(&mut self, candle: Candle) -> Option<f64> {
self.seen += 1;
if self.seen == 1 {
self.prev_high = candle.high;
self.prev_low = candle.low;
return None;
}
if self.seen == 2 {
self.long = candle.close >= (self.prev_high + self.prev_low) / 2.0;
if self.long {
self.sar = self.prev_low;
self.ep = candle.high.max(self.prev_high);
} else {
self.sar = self.prev_high;
self.ep = candle.low.min(self.prev_low);
}
self.af = self.af_start;
let out = self.sar;
self.advance(candle);
return Some(out);
}
let out = self.sar;
let reversed = if self.long {
candle.low < self.sar
} else {
candle.high > self.sar
};
if reversed {
self.long = !self.long;
self.sar = self.ep;
self.af = self.af_start;
self.ep = if self.long { candle.high } else { candle.low };
self.prev_high = candle.high;
self.prev_low = candle.low;
return Some(out);
}
self.advance(candle);
Some(out)
}
fn advance(&mut self, candle: Candle) {
if self.long {
if candle.high > self.ep {
self.ep = candle.high;
self.af = (self.af + self.af_step).min(self.af_max);
}
} else if candle.low < self.ep {
self.ep = candle.low;
self.af = (self.af + self.af_step).min(self.af_max);
}
let mut next_sar = self.sar + self.af * (self.ep - self.sar);
if self.long {
next_sar = next_sar.min(self.prev_low).min(candle.low);
} else {
next_sar = next_sar.max(self.prev_high).max(candle.high);
}
self.sar = next_sar;
self.prev_high = candle.high;
self.prev_low = candle.low;
}
}
impl Indicator<Candle, f64> for Sar {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<f64>, IndicatorError> {
validate_data_length(data, 2)?;
self.reset_state();
let mut out = Vec::with_capacity(data.len() - 1);
for c in data {
if let Some(v) = self.step(*c) {
out.push(v);
}
}
Ok(out)
}
fn next(&mut self, value: Candle) -> Result<Option<f64>, IndicatorError> {
Ok(self.step(value))
}
fn reset(&mut self) {
self.reset_state();
}
fn name(&self) -> &'static str {
"Sar"
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ramp(n: usize, slope: f64) -> Vec<Candle> {
(0..n)
.map(|i| {
let mid = i as f64 * slope;
Candle {
timestamp: i as u64,
open: mid,
high: mid + 1.0,
low: mid - 1.0,
close: mid + 0.25,
volume: 1.0,
}
})
.collect()
}
#[test]
fn validates_factors() {
assert!(Sar::new(0.0, 0.02, 0.20).is_err());
assert!(Sar::new(0.02, 0.0, 0.20).is_err());
assert!(Sar::new(0.02, 0.02, 0.0).is_err());
assert!(Sar::new(0.30, 0.02, 0.20).is_err()); assert!(Sar::new(0.02, 0.30, 0.20).is_err()); assert!(Sar::new(0.02, 0.02, 0.20).is_ok());
}
#[test]
fn first_bar_emits_nothing_second_emits() {
let mut sar = Sar::default_params();
let c0 = Candle {
timestamp: 0,
open: 10.0,
high: 11.0,
low: 9.0,
close: 10.0,
volume: 1.0,
};
let c1 = Candle {
timestamp: 1,
open: 10.5,
high: 12.0,
low: 10.0,
close: 11.5,
volume: 1.0,
};
assert!(sar.next(c0).unwrap().is_none());
assert!(sar.next(c1).unwrap().is_some());
}
#[test]
fn uptrend_keeps_sar_below_price() {
let mut sar = Sar::default_params();
let candles = ramp(30, 1.0);
let out = sar.calculate(&candles).unwrap();
for (i, &s) in out.iter().enumerate().skip(1) {
let c = candles[i + 1];
assert!(s <= c.low + 1e-9, "bar {} SAR {} > low {}", i + 1, s, c.low);
}
}
#[test]
fn downtrend_keeps_sar_above_price() {
let mut sar = Sar::default_params();
let candles = ramp(30, -1.0);
let out = sar.calculate(&candles).unwrap();
for (i, &s) in out.iter().enumerate().skip(1) {
let c = candles[i + 1];
assert!(
s >= c.high - 1e-9,
"bar {} SAR {} < high {}",
i + 1,
s,
c.high
);
}
}
#[test]
fn reversal_flips_sar_to_other_side() {
let mut up = ramp(15, 1.0);
for i in 0..10 {
let mid = (15 - i) as f64 * 1.0 - 5.0;
up.push(Candle {
timestamp: (15 + i) as u64,
open: mid,
high: mid + 1.0,
low: mid - 1.0,
close: mid - 0.5,
volume: 1.0,
});
}
let mut sar = Sar::default_params();
let out = sar.calculate(&up).unwrap();
let last = *out.last().unwrap();
let last_close = up.last().unwrap().close;
assert!(
last > last_close,
"expected SAR above price after flip, SAR={last} close={last_close}"
);
}
}