fin_primitives/signals/indicators/
linear_deviation.rs1use crate::error::FinError;
4use crate::signals::{BarInput, Signal, SignalValue};
5use rust_decimal::Decimal;
6use std::collections::VecDeque;
7
8pub struct LinearDeviation {
29 name: String,
30 period: usize,
31 history: VecDeque<Decimal>,
32}
33
34impl LinearDeviation {
35 pub fn new(name: impl Into<String>, period: usize) -> Result<Self, FinError> {
40 if period < 2 {
41 return Err(FinError::InvalidPeriod(period));
42 }
43 Ok(Self {
44 name: name.into(),
45 period,
46 history: VecDeque::with_capacity(period),
47 })
48 }
49}
50
51impl Signal for LinearDeviation {
52 fn name(&self) -> &str { &self.name }
53
54 fn update(&mut self, bar: &BarInput) -> Result<SignalValue, FinError> {
55 self.history.push_back(bar.close);
56 if self.history.len() > self.period {
57 self.history.pop_front();
58 }
59 if self.history.len() < self.period {
60 return Ok(SignalValue::Unavailable);
61 }
62
63 let n = self.period as i64;
64 let sum_x = Decimal::from(n * (n - 1) / 2);
67 let sum_x2 = Decimal::from(n * (n - 1) * (2 * n - 1) / 6);
68 let sum_y: Decimal = self.history.iter().sum();
69 let sum_xy: Decimal = self.history.iter().enumerate()
70 .map(|(i, &y)| Decimal::from(i as i64) * y)
71 .sum();
72
73 let n_dec = Decimal::from(n);
74 let denom = n_dec * sum_x2 - sum_x * sum_x;
75 if denom.is_zero() {
76 return Ok(SignalValue::Unavailable);
77 }
78
79 let slope = (n_dec * sum_xy - sum_x * sum_y)
80 .checked_div(denom)
81 .ok_or(FinError::ArithmeticOverflow)?;
82 let intercept = (sum_y - slope * sum_x)
83 .checked_div(n_dec)
84 .ok_or(FinError::ArithmeticOverflow)?;
85
86 let x_last = Decimal::from(n - 1);
88 let linreg_val = slope * x_last + intercept;
89
90 let close = bar.close;
91 if close.is_zero() {
92 return Ok(SignalValue::Unavailable);
93 }
94
95 let dev = (close - linreg_val)
96 .checked_div(close)
97 .ok_or(FinError::ArithmeticOverflow)?
98 * Decimal::from(100u32);
99
100 Ok(SignalValue::Scalar(dev))
101 }
102
103 fn is_ready(&self) -> bool {
104 self.history.len() >= self.period
105 }
106
107 fn period(&self) -> usize { self.period }
108
109 fn reset(&mut self) {
110 self.history.clear();
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::ohlcv::OhlcvBar;
118 use crate::types::{NanoTimestamp, Price, Quantity, Symbol};
119 use rust_decimal_macros::dec;
120
121 fn bar(c: &str) -> OhlcvBar {
122 let p = Price::new(c.parse().unwrap()).unwrap();
123 OhlcvBar {
124 symbol: Symbol::new("X").unwrap(),
125 open: p, high: p, low: p, close: p,
126 volume: Quantity::zero(),
127 ts_open: NanoTimestamp::new(0),
128 ts_close: NanoTimestamp::new(1),
129 tick_count: 1,
130 }
131 }
132
133 #[test]
134 fn test_ld_invalid_period() {
135 assert!(LinearDeviation::new("l", 0).is_err());
136 assert!(LinearDeviation::new("l", 1).is_err());
137 }
138
139 #[test]
140 fn test_ld_unavailable_early() {
141 let mut ld = LinearDeviation::new("l", 3).unwrap();
142 assert_eq!(ld.update_bar(&bar("100")).unwrap(), SignalValue::Unavailable);
143 assert_eq!(ld.update_bar(&bar("101")).unwrap(), SignalValue::Unavailable);
144 }
145
146 #[test]
147 fn test_ld_perfectly_on_line_is_zero() {
148 let mut ld = LinearDeviation::new("l", 3).unwrap();
150 ld.update_bar(&bar("100")).unwrap();
151 ld.update_bar(&bar("101")).unwrap();
152 if let SignalValue::Scalar(v) = ld.update_bar(&bar("102")).unwrap() {
153 assert!(v.abs() < dec!(0.001), "on-line deviation should be ~0: {v}");
154 } else { panic!("expected Scalar"); }
155 }
156
157 #[test]
158 fn test_ld_above_line_positive() {
159 let mut ld = LinearDeviation::new("l", 3).unwrap();
161 ld.update_bar(&bar("100")).unwrap();
162 ld.update_bar(&bar("100")).unwrap();
163 if let SignalValue::Scalar(v) = ld.update_bar(&bar("110")).unwrap() {
164 assert!(v > dec!(0), "above line → positive deviation: {v}");
165 } else { panic!("expected Scalar"); }
166 }
167
168 #[test]
169 fn test_ld_below_line_negative() {
170 let mut ld = LinearDeviation::new("l", 3).unwrap();
172 ld.update_bar(&bar("100")).unwrap();
173 ld.update_bar(&bar("100")).unwrap();
174 if let SignalValue::Scalar(v) = ld.update_bar(&bar("90")).unwrap() {
175 assert!(v < dec!(0), "below line → negative deviation: {v}");
176 } else { panic!("expected Scalar"); }
177 }
178
179 #[test]
180 fn test_ld_reset() {
181 let mut ld = LinearDeviation::new("l", 3).unwrap();
182 for p in &["100", "101", "102"] { ld.update_bar(&bar(p)).unwrap(); }
183 assert!(ld.is_ready());
184 ld.reset();
185 assert!(!ld.is_ready());
186 }
187}