quantwave_core/indicators/
pivot_points.rs1use crate::indicators::metadata::IndicatorMetadata;
2use crate::traits::Next;
3
4#[derive(Debug, Clone, Default)]
9pub struct PivotPoints {
10 prev_high: Option<f64>,
11 prev_low: Option<f64>,
12 prev_close: Option<f64>,
13}
14
15impl PivotPoints {
16 pub fn new() -> Self {
17 Self::default()
18 }
19}
20
21impl Next<(f64, f64, f64)> for PivotPoints {
22 type Output = (f64, f64, f64, f64, f64);
23
24 fn next(&mut self, (high, low, close): (f64, f64, f64)) -> Self::Output {
25 let (p, r1, s1, r2, s2) = match (self.prev_high, self.prev_low, self.prev_close) {
26 (Some(ph), Some(pl), Some(pc)) => {
27 let p = (ph + pl + pc) / 3.0;
28 let r1 = (p * 2.0) - pl;
29 let s1 = (p * 2.0) - ph;
30 let r2 = p + (ph - pl);
31 let s2 = p - (ph - pl);
32 (p, r1, s1, r2, s2)
33 }
34 _ => (0.0, 0.0, 0.0, 0.0, 0.0), };
36
37 self.prev_high = Some(high);
38 self.prev_low = Some(low);
39 self.prev_close = Some(close);
40
41 (p, r1, s1, r2, s2)
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use proptest::prelude::*;
49 use serde::Deserialize;
50 use std::fs;
51 use std::path::Path;
52
53 #[derive(Debug, Deserialize)]
54 struct PivotCase {
55 high: Vec<f64>,
56 low: Vec<f64>,
57 close: Vec<f64>,
58 expected_p: Vec<f64>,
59 expected_r1: Vec<f64>,
60 expected_s1: Vec<f64>,
61 expected_r2: Vec<f64>,
62 expected_s2: Vec<f64>,
63 }
64
65 #[test]
66 fn test_pivot_points_gold_standard() {
67 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
68 let manifest_path = Path::new(&manifest_dir);
69 let path = manifest_path.join("tests/gold_standard/pivot_points.json");
70 let path = if path.exists() {
71 path
72 } else {
73 manifest_path
74 .parent()
75 .unwrap()
76 .join("tests/gold_standard/pivot_points.json")
77 };
78 let content = fs::read_to_string(path).unwrap();
79 let case: PivotCase = serde_json::from_str(&content).unwrap();
80
81 let mut pivot = PivotPoints::new();
82 for i in 0..case.high.len() {
83 let (p, r1, s1, r2, s2) = pivot.next((case.high[i], case.low[i], case.close[i]));
84 approx::assert_relative_eq!(p, case.expected_p[i], epsilon = 1e-6);
85 approx::assert_relative_eq!(r1, case.expected_r1[i], epsilon = 1e-6);
86 approx::assert_relative_eq!(s1, case.expected_s1[i], epsilon = 1e-6);
87 approx::assert_relative_eq!(r2, case.expected_r2[i], epsilon = 1e-6);
88 approx::assert_relative_eq!(s2, case.expected_s2[i], epsilon = 1e-6);
89 }
90 }
91
92 fn pivot_batch(data: Vec<(f64, f64, f64)>) -> Vec<(f64, f64, f64, f64, f64)> {
93 let mut pivot = PivotPoints::new();
94 data.into_iter().map(|x| pivot.next(x)).collect()
95 }
96
97 proptest! {
98 #[test]
99 fn test_pivot_points_parity(input in prop::collection::vec((0.0..100.0, 0.0..100.0, 0.0..100.0), 1..100)) {
100 let mut adj_input = Vec::with_capacity(input.len());
101 for (h, l, c) in input {
102 let h_f: f64 = h;
103 let l_f: f64 = l;
104 let c_f: f64 = c;
105 let high = h_f.max(l_f).max(c_f);
106 let low = l_f.min(h_f).min(c_f);
107 adj_input.push((high, low, c_f));
108 }
109
110 let mut pivot = PivotPoints::new();
111 let mut streaming_results = Vec::with_capacity(adj_input.len());
112 for &val in &adj_input {
113 streaming_results.push(pivot.next(val));
114 }
115
116 let batch_results = pivot_batch(adj_input);
117
118 for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
119 approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-6);
120 approx::assert_relative_eq!(s.1, b.1, epsilon = 1e-6);
121 approx::assert_relative_eq!(s.2, b.2, epsilon = 1e-6);
122 approx::assert_relative_eq!(s.3, b.3, epsilon = 1e-6);
123 approx::assert_relative_eq!(s.4, b.4, epsilon = 1e-6);
124 }
125 }
126 }
127
128 #[test]
129 fn test_pivot_points_basic() {
130 let mut pivot = PivotPoints::new();
131 let (p0, _, _, _, _) = pivot.next((12.0, 8.0, 10.0));
132 assert_eq!(p0, 0.0);
133 let (p1, r1, s1, r2, s2) = pivot.next((14.0, 9.0, 11.0));
134 assert_eq!(p1, 10.0); assert_eq!(r1, 12.0); assert_eq!(s1, 8.0); assert_eq!(r2, 14.0); assert_eq!(s2, 6.0); }
140}
141
142pub const PIVOT_POINTS_METADATA: IndicatorMetadata = IndicatorMetadata {
143 name: "Pivot Points",
144 description: "Pivot Points are used to determine overall trend over different time frames.",
145 params: &[],
146 formula_source: "https://www.investopedia.com/terms/p/pivotpoint.asp",
147 formula_latex: r#"
148\[
149P = \frac{H + L + C}{3}
150\]
151"#,
152 gold_standard_file: "pivot_points.json",
153 category: "Classic",
154};