Skip to main content

indicators/trend/
parabolic_sar.rs

1//! Parabolic SAR (Stop and Reverse).
2//!
3//! Python source: `indicators/other/parabolic_sar.py :: class ParabolicSARIndicator`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! sar[i] = prev_sar + af * (ep - prev_sar)
8//! # Uptrend: new high → bump af; close < sar → reverse to downtrend
9//! # Downtrend: new low → bump af; close > sar → reverse to uptrend
10//! ```
11//!
12//! Output column: `"PSAR"`.
13
14use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::indicator::{Indicator, IndicatorOutput};
18use crate::registry::param_f64;
19use crate::types::Candle;
20
21#[derive(Debug, Clone)]
22pub struct PsarParams {
23    /// Acceleration factor step.  Python default: 0.02.
24    pub step: f64,
25    /// Maximum acceleration factor.  Python default: 0.2.
26    pub max_step: f64,
27}
28impl Default for PsarParams {
29    fn default() -> Self {
30        Self {
31            step: 0.02,
32            max_step: 0.2,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct ParabolicSar {
39    pub params: PsarParams,
40}
41
42impl ParabolicSar {
43    pub fn new(params: PsarParams) -> Self {
44        Self { params }
45    }
46    pub fn default() -> Self {
47        Self::new(PsarParams::default())
48    }
49}
50
51impl Indicator for ParabolicSar {
52    fn name(&self) -> &str {
53        "ParabolicSAR"
54    }
55    fn required_len(&self) -> usize {
56        2
57    }
58    fn required_columns(&self) -> &[&'static str] {
59        &["high", "low"]
60    }
61
62    /// TODO: port Python iterative SAR state machine.
63    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
64        self.check_len(candles)?;
65
66        let n = candles.len();
67        let step = self.params.step;
68        let max_step = self.params.max_step;
69
70        let mut sar = vec![0.0f64; n];
71        let mut trend: i8 = 1; // 1 = uptrend, -1 = downtrend
72        let mut ep = candles[0].low;
73        let mut af = step;
74
75        // TODO: port Python loop exactly.
76        for i in 1..n {
77            let prev_sar = sar[i - 1];
78            sar[i] = prev_sar + af * (ep - prev_sar);
79
80            if trend == 1 {
81                if candles[i].high > ep {
82                    ep = candles[i].high;
83                    af = (af + step).min(max_step);
84                }
85                if candles[i].low < sar[i] {
86                    trend = -1;
87                    sar[i] = ep;
88                    ep = candles[i].low;
89                    af = step;
90                }
91            } else {
92                if candles[i].low < ep {
93                    ep = candles[i].low;
94                    af = (af + step).min(max_step);
95                }
96                if candles[i].high > sar[i] {
97                    trend = 1;
98                    sar[i] = ep;
99                    ep = candles[i].high;
100                    af = step;
101                }
102            }
103        }
104
105        Ok(IndicatorOutput::from_pairs([("PSAR".to_string(), sar)]))
106    }
107}
108
109pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
110    Ok(Box::new(ParabolicSar::new(PsarParams {
111        step: param_f64(params, "step", 0.02)?,
112        max_step: param_f64(params, "max_step", 0.2)?,
113    })))
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    fn candles(n: usize) -> Vec<Candle> {
121        (0..n).map(|i| Candle {
122            time: i as i64, open: 10.0, high: 10.0 + i as f64 * 0.1,
123            low: 10.0 - i as f64 * 0.05, close: 10.0, volume: 100.0,
124        }).collect()
125    }
126
127    #[test]
128    fn psar_output_column() {
129        let out = ParabolicSar::default().calculate(&candles(10)).unwrap();
130        assert!(out.get("PSAR").is_some());
131    }
132
133    #[test]
134    fn psar_correct_length() {
135        let bars = candles(20);
136        let out = ParabolicSar::default().calculate(&bars).unwrap();
137        assert_eq!(out.get("PSAR").unwrap().len(), 20);
138    }
139
140    #[test]
141    fn psar_af_bounded() {
142        // Ensure AF never exceeds max_step by checking no divergence in values.
143        let out = ParabolicSar::default().calculate(&candles(50)).unwrap();
144        let vals = out.get("PSAR").unwrap();
145        // Values should be finite (AF bounded means SAR stays near price).
146        for &v in vals { assert!(v.is_finite(), "non-finite SAR: {v}"); }
147    }
148
149    #[test]
150    fn factory_creates_psar() {
151        assert_eq!(factory(&HashMap::new()).unwrap().name(), "ParabolicSAR");
152    }
153}