indicators/trend/
linear_regression.rs1use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
23pub struct LrParams {
24 pub period: usize,
26 pub column: PriceColumn,
28}
29impl Default for LrParams {
30 fn default() -> Self {
31 Self {
32 period: 14,
33 column: PriceColumn::Close,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct LinearRegression {
40 pub params: LrParams,
41}
42
43impl LinearRegression {
44 pub fn new(params: LrParams) -> Self {
45 Self { params }
46 }
47 pub fn with_period(period: usize) -> Self {
48 Self::new(LrParams {
49 period,
50 ..Default::default()
51 })
52 }
53 fn output_key(&self) -> String {
54 format!("LR_slope_{}", self.params.period)
55 }
56
57 fn ols_slope(y: &[f64]) -> f64 {
60 let n = y.len() as f64;
61 let x_mean = (n - 1.0) / 2.0;
62 let y_mean: f64 = y.iter().sum::<f64>() / n;
63 let mut num = 0.0f64;
64 let mut den = 0.0f64;
65 for (i, &yi) in y.iter().enumerate() {
66 let xi = i as f64 - x_mean;
67 num += xi * (yi - y_mean);
68 den += xi * xi;
69 }
70 if den == 0.0 { 0.0 } else { num / den }
71 }
72}
73
74impl Indicator for LinearRegression {
75 fn name(&self) -> &'static str {
76 "LinearRegression"
77 }
78 fn required_len(&self) -> usize {
79 self.params.period
80 }
81 fn required_columns(&self) -> &[&'static str] {
82 &["close"]
83 }
84
85 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
90 self.check_len(candles)?;
91
92 let prices = self.params.column.extract(candles);
93 let n = prices.len();
94 let p = self.params.period;
95 let mut values = vec![f64::NAN; n];
96
97 for i in (p - 1)..n {
98 values[i] = Self::ols_slope(&prices[(i + 1 - p)..=i]);
99 }
100
101 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
102 }
103}
104
105pub fn factory<S: ::std::hash::BuildHasher>(
106 params: &HashMap<String, String, S>,
107) -> Result<Box<dyn Indicator>, IndicatorError> {
108 Ok(Box::new(LinearRegression::new(LrParams {
109 period: param_usize(params, "period", 14)?,
110 ..Default::default()
111 })))
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 fn candles(closes: &[f64]) -> Vec<Candle> {
119 closes
120 .iter()
121 .enumerate()
122 .map(|(i, &c)| Candle {
123 time: i64::try_from(i).expect("time index fits i64"),
124 open: c,
125 high: c,
126 low: c,
127 close: c,
128 volume: 1.0,
129 })
130 .collect()
131 }
132
133 #[test]
134 fn lr_perfect_line_slope_one() {
135 let closes: Vec<f64> = (0..14).map(|x| x as f64).collect();
137 let out = LinearRegression::with_period(14)
138 .calculate(&candles(&closes))
139 .unwrap();
140 let vals = out.get("LR_slope_14").unwrap();
141 assert!((vals[13] - 1.0).abs() < 1e-9, "got {}", vals[13]);
142 }
143
144 #[test]
145 fn lr_constant_slope_zero() {
146 let closes = vec![5.0f64; 14];
147 let out = LinearRegression::with_period(14)
148 .calculate(&candles(&closes))
149 .unwrap();
150 let vals = out.get("LR_slope_14").unwrap();
151 assert!(vals[13].abs() < 1e-9);
152 }
153
154 #[test]
155 fn lr_leading_nans() {
156 let closes: Vec<f64> = (0..20).map(|x| x as f64).collect();
157 let out = LinearRegression::with_period(14)
158 .calculate(&candles(&closes))
159 .unwrap();
160 let vals = out.get("LR_slope_14").unwrap();
161 assert!(vals[0].is_nan());
162 assert!(!vals[13].is_nan());
163 }
164
165 #[test]
166 fn factory_creates_lr() {
167 assert_eq!(factory(&HashMap::new()).unwrap().name(), "LinearRegression");
168 }
169}