indicators/trend/
parabolic_sar.rs1use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::indicator::{Indicator, IndicatorOutput};
18use crate::registry::param_f64;
19use crate::types::Candle;
20
21#[derive(Debug, Clone)]
22pub struct PsarParams {
23 pub step: f64,
25 pub max_step: f64,
27}
28impl Default for PsarParams {
29 fn default() -> Self {
30 Self {
31 step: 0.02,
32 max_step: 0.2,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct ParabolicSar {
39 pub params: PsarParams,
40}
41
42impl ParabolicSar {
43 pub fn new(params: PsarParams) -> Self {
44 Self { params }
45 }
46 pub fn default() -> Self {
47 Self::new(PsarParams::default())
48 }
49}
50
51impl Indicator for ParabolicSar {
52 fn name(&self) -> &str {
53 "ParabolicSAR"
54 }
55 fn required_len(&self) -> usize {
56 2
57 }
58 fn required_columns(&self) -> &[&'static str] {
59 &["high", "low"]
60 }
61
62 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
64 self.check_len(candles)?;
65
66 let n = candles.len();
67 let step = self.params.step;
68 let max_step = self.params.max_step;
69
70 let mut sar = vec![0.0f64; n];
71 let mut trend: i8 = 1; let mut ep = candles[0].low;
73 let mut af = step;
74
75 for i in 1..n {
77 let prev_sar = sar[i - 1];
78 sar[i] = prev_sar + af * (ep - prev_sar);
79
80 if trend == 1 {
81 if candles[i].high > ep {
82 ep = candles[i].high;
83 af = (af + step).min(max_step);
84 }
85 if candles[i].low < sar[i] {
86 trend = -1;
87 sar[i] = ep;
88 ep = candles[i].low;
89 af = step;
90 }
91 } else {
92 if candles[i].low < ep {
93 ep = candles[i].low;
94 af = (af + step).min(max_step);
95 }
96 if candles[i].high > sar[i] {
97 trend = 1;
98 sar[i] = ep;
99 ep = candles[i].high;
100 af = step;
101 }
102 }
103 }
104
105 Ok(IndicatorOutput::from_pairs([("PSAR".to_string(), sar)]))
106 }
107}
108
109pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
110 Ok(Box::new(ParabolicSar::new(PsarParams {
111 step: param_f64(params, "step", 0.02)?,
112 max_step: param_f64(params, "max_step", 0.2)?,
113 })))
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 fn candles(n: usize) -> Vec<Candle> {
121 (0..n).map(|i| Candle {
122 time: i as i64, open: 10.0, high: 10.0 + i as f64 * 0.1,
123 low: 10.0 - i as f64 * 0.05, close: 10.0, volume: 100.0,
124 }).collect()
125 }
126
127 #[test]
128 fn psar_output_column() {
129 let out = ParabolicSar::default().calculate(&candles(10)).unwrap();
130 assert!(out.get("PSAR").is_some());
131 }
132
133 #[test]
134 fn psar_correct_length() {
135 let bars = candles(20);
136 let out = ParabolicSar::default().calculate(&bars).unwrap();
137 assert_eq!(out.get("PSAR").unwrap().len(), 20);
138 }
139
140 #[test]
141 fn psar_af_bounded() {
142 let out = ParabolicSar::default().calculate(&candles(50)).unwrap();
144 let vals = out.get("PSAR").unwrap();
145 for &v in vals { assert!(v.is_finite(), "non-finite SAR: {v}"); }
147 }
148
149 #[test]
150 fn factory_creates_psar() {
151 assert_eq!(factory(&HashMap::new()).unwrap().name(), "ParabolicSAR");
152 }
153}