Skip to main content

quantwave_core/indicators/
pivot_points.rs

1use crate::indicators::metadata::IndicatorMetadata;
2use crate::traits::Next;
3
4/// Standard Pivot Points
5/// Uses the High, Low, and Close of the previous period to calculate
6/// the current period's Support and Resistance levels.
7/// Output: (P, R1, S1, R2, S2)
8#[derive(Debug, Clone, Default)]
9pub struct PivotPoints {
10    prev_high: Option<f64>,
11    prev_low: Option<f64>,
12    prev_close: Option<f64>,
13}
14
15impl PivotPoints {
16    pub fn new() -> Self {
17        Self::default()
18    }
19}
20
21impl Next<(f64, f64, f64)> for PivotPoints {
22    type Output = (f64, f64, f64, f64, f64);
23
24    fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
25        let (p, r1, s1, r2, s2) = match (self.prev_high, self.prev_low, self.prev_close) {
26            (Some(ph), Some(pl), Some(pc)) => {
27                let p = (ph + pl + pc) / 3.0;
28                let r1 = (p * 2.0) - pl;
29                let s1 = (p * 2.0) - ph;
30                let r2 = p + (ph - pl);
31                let s2 = p - (ph - pl);
32                (p, r1, s1, r2, s2)
33            }
34            _ => (0.0, 0.0, 0.0, 0.0, 0.0), // Warmup
35        };
36
37        self.prev_high = Some(high);
38        self.prev_low = Some(low);
39        self.prev_close = Some(close);
40
41        (p, r1, s1, r2, s2)
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use proptest::prelude::*;
49    use serde::Deserialize;
50    use std::fs;
51    use std::path::Path;
52
53    #[derive(Debug, Deserialize)]
54    struct PivotCase {
55        high: Vec<f64>,
56        low: Vec<f64>,
57        close: Vec<f64>,
58        expected_p: Vec<f64>,
59        expected_r1: Vec<f64>,
60        expected_s1: Vec<f64>,
61        expected_r2: Vec<f64>,
62        expected_s2: Vec<f64>,
63    }
64
65    #[test]
66    fn test_pivot_points_gold_standard() {
67        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
68        let manifest_path = Path::new(&manifest_dir);
69        let path = manifest_path.join("tests/gold_standard/pivot_points.json");
70        let path = if path.exists() {
71            path
72        } else {
73            manifest_path
74                .parent()
75                .unwrap()
76                .join("tests/gold_standard/pivot_points.json")
77        };
78        let content = fs::read_to_string(path).unwrap();
79        let case: PivotCase = serde_json::from_str(&content).unwrap();
80
81        let mut pivot = PivotPoints::new();
82        for i in 0..case.high.len() {
83            let (p, r1, s1, r2, s2) = pivot.next((case.high[i], case.low[i], case.close[i]));
84            approx::assert_relative_eq!(p, case.expected_p[i], epsilon = 1e-6);
85            approx::assert_relative_eq!(r1, case.expected_r1[i], epsilon = 1e-6);
86            approx::assert_relative_eq!(s1, case.expected_s1[i], epsilon = 1e-6);
87            approx::assert_relative_eq!(r2, case.expected_r2[i], epsilon = 1e-6);
88            approx::assert_relative_eq!(s2, case.expected_s2[i], epsilon = 1e-6);
89        }
90    }
91
92    fn pivot_batch(data: Vec<(f64, f64, f64)>) -> Vec<(f64, f64, f64, f64, f64)> {
93        let mut pivot = PivotPoints::new();
94        data.into_iter().map(|x| pivot.next(x)).collect()
95    }
96
97    proptest! {
98        #[test]
99        fn test_pivot_points_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
100            let mut adj_input = Vec::with_capacity(input.len());
101            for (h, l, c) in input {
102                let h_f: f64 = h;
103                let l_f: f64 = l;
104                let c_f: f64 = c;
105                let high = h_f.max(l_f).max(c_f);
106                let low = l_f.min(h_f).min(c_f);
107                adj_input.push((high, low, c_f));
108            }
109
110            let mut pivot = PivotPoints::new();
111            let mut streaming_results = Vec::with_capacity(adj_input.len());
112            for &val in &adj_input {
113                streaming_results.push(pivot.next(val));
114            }
115
116            let batch_results = pivot_batch(adj_input);
117
118            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
119                approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
120                approx::assert_relative_eq!(s.1, b.1, epsilon = 1e-6);
121                approx::assert_relative_eq!(s.2, b.2, epsilon = 1e-6);
122                approx::assert_relative_eq!(s.3, b.3, epsilon = 1e-6);
123                approx::assert_relative_eq!(s.4, b.4, epsilon = 1e-6);
124            }
125        }
126    }
127
128    #[test]
129    fn test_pivot_points_basic() {
130        let mut pivot = PivotPoints::new();
131        let (p0, _, _, _, _) = pivot.next((12.0, 8.0, 10.0));
132        assert_eq!(p0, 0.0);
133        let (p1, r1, s1, r2, s2) = pivot.next((14.0, 9.0, 11.0));
134        assert_eq!(p1, 10.0); // (12+8+10)/3
135        assert_eq!(r1, 12.0); // 20 - 8
136        assert_eq!(s1, 8.0); // 20 - 12
137        assert_eq!(r2, 14.0); // 10 + 4
138        assert_eq!(s2, 6.0); // 10 - 4
139    }
140}
141
142pub const PIVOT_POINTS_METADATA: IndicatorMetadata = IndicatorMetadata {
143    name: "Pivot Points",
144    description: "Pivot Points are used to determine overall trend over different time frames.",
145    params: &[],
146    formula_source: "https://www.investopedia.com/terms/p/pivotpoint.asp",
147    formula_latex: r#"
148\[
149P = \frac{H + L + C}{3}
150\]
151"#,
152    gold_standard_file: "pivot_points.json",
153    category: "Classic",
154};