Skip to main content

indicators/trend/
parabolic_sar.rs

1//! Parabolic SAR (Stop and Reverse).
2//!
3//! Python source: `indicators/other/parabolic_sar.py :: class ParabolicSARIndicator`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! sar[i] = prev_sar + af * (ep - prev_sar)
8//! # Uptrend: new high → bump af; close < sar → reverse to downtrend
9//! # Downtrend: new low → bump af; close > sar → reverse to uptrend
10//! ```
11//!
12//! Output column: `"PSAR"`.
13
14use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::indicator::{Indicator, IndicatorOutput};
18use crate::registry::param_f64;
19use crate::types::Candle;
20
21#[derive(Debug, Clone)]
22pub struct PsarParams {
23    /// Acceleration factor step.  Python default: 0.02.
24    pub step: f64,
25    /// Maximum acceleration factor.  Python default: 0.2.
26    pub max_step: f64,
27}
28impl Default for PsarParams {
29    fn default() -> Self {
30        Self {
31            step: 0.02,
32            max_step: 0.2,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct ParabolicSar {
39    pub params: PsarParams,
40}
41
42impl ParabolicSar {
43    pub fn new(params: PsarParams) -> Self {
44        Self { params }
45    }
46}
47
48impl Default for ParabolicSar {
49    fn default() -> Self {
50        Self::new(PsarParams::default())
51    }
52}
53
54impl Indicator for ParabolicSar {
55    fn name(&self) -> &'static str {
56        "ParabolicSAR"
57    }
58    fn required_len(&self) -> usize {
59        2
60    }
61    fn required_columns(&self) -> &[&'static str] {
62        &["high", "low"]
63    }
64
65    /// Ports the iterative SAR state machine.
66    ///
67    /// Initialisation matches the Python source exactly:
68    /// - `sar[0] = 0.0`  (Python uses `np.zeros(len(data))`)
69    /// - `ep  = candles[0].low`
70    /// - `af  = step`
71    /// - trend starts as uptrend (`1`)
72    ///
73    /// The `sar[0] = 0.0` cold-start means the first computed value
74    /// `sar[1] = af * ep` is typically well below market price, but this
75    /// is intentional — it replicates the Python behaviour and the SAR
76    /// converges toward price rapidly.
77    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
78        self.check_len(candles)?;
79
80        let n = candles.len();
81        let step = self.params.step;
82        let max_step = self.params.max_step;
83
84        let mut sar = vec![0.0f64; n];
85        let mut trend: i8 = 1; // 1 = uptrend, -1 = downtrend
86        let mut ep = candles[0].low;
87        let mut af = step;
88
89        for i in 1..n {
90            let prev_sar = sar[i - 1];
91            sar[i] = prev_sar + af * (ep - prev_sar);
92
93            if trend == 1 {
94                if candles[i].high > ep {
95                    ep = candles[i].high;
96                    af = (af + step).min(max_step);
97                }
98                if candles[i].low < sar[i] {
99                    trend = -1;
100                    sar[i] = ep;
101                    ep = candles[i].low;
102                    af = step;
103                }
104            } else {
105                if candles[i].low < ep {
106                    ep = candles[i].low;
107                    af = (af + step).min(max_step);
108                }
109                if candles[i].high > sar[i] {
110                    trend = 1;
111                    sar[i] = ep;
112                    ep = candles[i].high;
113                    af = step;
114                }
115            }
116        }
117
118        Ok(IndicatorOutput::from_pairs([("PSAR".to_string(), sar)]))
119    }
120}
121
122pub fn factory<S: ::std::hash::BuildHasher>(
123    params: &HashMap<String, String, S>,
124) -> Result<Box<dyn Indicator>, IndicatorError> {
125    Ok(Box::new(ParabolicSar::new(PsarParams {
126        step: param_f64(params, "step", 0.02)?,
127        max_step: param_f64(params, "max_step", 0.2)?,
128    })))
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    fn candles(n: usize) -> Vec<Candle> {
136        (0..n)
137            .map(|i| Candle {
138                time: i64::try_from(i).expect("time index fits i64"),
139                open: 10.0,
140                high: 10.0 + i as f64 * 0.1,
141                low: 10.0 - i as f64 * 0.05,
142                close: 10.0,
143                volume: 100.0,
144            })
145            .collect()
146    }
147
148    #[test]
149    fn psar_output_column() {
150        let out = ParabolicSar::default().calculate(&candles(10)).unwrap();
151        assert!(out.get("PSAR").is_some());
152    }
153
154    #[test]
155    fn psar_correct_length() {
156        let bars = candles(20);
157        let out = ParabolicSar::default().calculate(&bars).unwrap();
158        assert_eq!(out.get("PSAR").unwrap().len(), 20);
159    }
160
161    #[test]
162    fn psar_af_bounded() {
163        // Ensure AF never exceeds max_step by checking no divergence in values.
164        let out = ParabolicSar::default().calculate(&candles(50)).unwrap();
165        let vals = out.get("PSAR").unwrap();
166        // Values should be finite (AF bounded means SAR stays near price).
167        for &v in vals {
168            assert!(v.is_finite(), "non-finite SAR: {v}");
169        }
170    }
171
172    #[test]
173    fn factory_creates_psar() {
174        assert_eq!(factory(&HashMap::new()).unwrap().name(), "ParabolicSAR");
175    }
176}