indicators/trend/
linear_regression.rs1use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
23pub struct LrParams {
24 pub period: usize,
26 pub column: PriceColumn,
28}
29impl Default for LrParams {
30 fn default() -> Self {
31 Self {
32 period: 14,
33 column: PriceColumn::Close,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct LinearRegression {
40 pub params: LrParams,
41}
42
43impl LinearRegression {
44 pub fn new(params: LrParams) -> Self {
45 Self { params }
46 }
47 pub fn with_period(period: usize) -> Self {
48 Self::new(LrParams {
49 period,
50 ..Default::default()
51 })
52 }
53 fn output_key(&self) -> String {
54 format!("LR_slope_{}", self.params.period)
55 }
56
57 fn ols_slope(y: &[f64]) -> f64 {
60 let n = y.len() as f64;
61 let x_mean = (n - 1.0) / 2.0;
62 let y_mean: f64 = y.iter().sum::<f64>() / n;
63 let mut num = 0.0f64;
64 let mut den = 0.0f64;
65 for (i, &yi) in y.iter().enumerate() {
66 let xi = i as f64 - x_mean;
67 num += xi * (yi - y_mean);
68 den += xi * xi;
69 }
70 if den == 0.0 { 0.0 } else { num / den }
71 }
72}
73
74impl Indicator for LinearRegression {
75 fn name(&self) -> &'static str {
76 "LinearRegression"
77 }
78 fn required_len(&self) -> usize {
79 self.params.period
80 }
81 fn required_columns(&self) -> &[&'static str] {
82 &["close"]
83 }
84
85 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
87 self.check_len(candles)?;
88
89 let prices = self.params.column.extract(candles);
90 let n = prices.len();
91 let p = self.params.period;
92 let mut values = vec![f64::NAN; n];
93
94 for i in (p - 1)..n {
96 values[i] = Self::ols_slope(&prices[(i + 1 - p)..=i]);
97 }
98
99 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
100 }
101}
102
103pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
104 Ok(Box::new(LinearRegression::new(LrParams {
105 period: param_usize(params, "period", 14)?,
106 ..Default::default()
107 })))
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 fn candles(closes: &[f64]) -> Vec<Candle> {
115 closes
116 .iter()
117 .enumerate()
118 .map(|(i, &c)| Candle {
119 time: i64::try_from(i).expect("time index fits i64"),
120 open: c,
121 high: c,
122 low: c,
123 close: c,
124 volume: 1.0,
125 })
126 .collect()
127 }
128
129 #[test]
130 fn lr_perfect_line_slope_one() {
131 let closes: Vec<f64> = (0..14).map(|x| x as f64).collect();
133 let out = LinearRegression::with_period(14)
134 .calculate(&candles(&closes))
135 .unwrap();
136 let vals = out.get("LR_slope_14").unwrap();
137 assert!((vals[13] - 1.0).abs() < 1e-9, "got {}", vals[13]);
138 }
139
140 #[test]
141 fn lr_constant_slope_zero() {
142 let closes = vec![5.0f64; 14];
143 let out = LinearRegression::with_period(14)
144 .calculate(&candles(&closes))
145 .unwrap();
146 let vals = out.get("LR_slope_14").unwrap();
147 assert!(vals[13].abs() < 1e-9);
148 }
149
150 #[test]
151 fn lr_leading_nans() {
152 let closes: Vec<f64> = (0..20).map(|x| x as f64).collect();
153 let out = LinearRegression::with_period(14)
154 .calculate(&candles(&closes))
155 .unwrap();
156 let vals = out.get("LR_slope_14").unwrap();
157 assert!(vals[0].is_nan());
158 assert!(!vals[13].is_nan());
159 }
160
161 #[test]
162 fn factory_creates_lr() {
163 assert_eq!(factory(&HashMap::new()).unwrap().name(), "LinearRegression");
164 }
165}