Skip to main content

indicators/trend/
linear_regression.rs

1//! Linear Regression Slope.
2//!
3//! Python source: `indicators/other/linear_regression.py :: class LinearRegressionIndicator`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! X = np.arange(self.period)
8//! slopes = data["Close"].rolling(window=self.period).apply(
9//!     lambda y: np.polyfit(X, y, 1)[0], raw=True
10//! )
11//! ```
12//!
13//! Output column: `"LR_slope_{period}"`.
14
15use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
23pub struct LrParams {
24    /// Rolling window.  Python default: 14.
25    pub period: usize,
26    /// Price field.  Python default: close.
27    pub column: PriceColumn,
28}
29impl Default for LrParams {
30    fn default() -> Self {
31        Self {
32            period: 14,
33            column: PriceColumn::Close,
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct LinearRegression {
40    pub params: LrParams,
41}
42
43impl LinearRegression {
44    pub fn new(params: LrParams) -> Self {
45        Self { params }
46    }
47    pub fn with_period(period: usize) -> Self {
48        Self::new(LrParams {
49            period,
50            ..Default::default()
51        })
52    }
53    fn output_key(&self) -> String {
54        format!("LR_slope_{}", self.params.period)
55    }
56
57    /// OLS slope: `sum((x - x_mean)(y - y_mean)) / sum((x - x_mean)^2)`
58    /// where `x = 0..period`.
59    fn ols_slope(y: &[f64]) -> f64 {
60        let n = y.len() as f64;
61        let x_mean = (n - 1.0) / 2.0;
62        let y_mean: f64 = y.iter().sum::<f64>() / n;
63        let mut num = 0.0f64;
64        let mut den = 0.0f64;
65        for (i, &yi) in y.iter().enumerate() {
66            let xi = i as f64 - x_mean;
67            num += xi * (yi - y_mean);
68            den += xi * xi;
69        }
70        if den == 0.0 { 0.0 } else { num / den }
71    }
72}
73
74impl Indicator for LinearRegression {
75    fn name(&self) -> &str {
76        "LinearRegression"
77    }
78    fn required_len(&self) -> usize {
79        self.params.period
80    }
81    fn required_columns(&self) -> &[&'static str] {
82        &["close"]
83    }
84
85    /// TODO: port Python rolling `np.polyfit` slope.
86    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
87        self.check_len(candles)?;
88
89        let prices = self.params.column.extract(candles);
90        let n = prices.len();
91        let p = self.params.period;
92        let mut values = vec![f64::NAN; n];
93
94        // TODO: implement rolling OLS slope (matches np.polyfit(X, y, 1)[0]).
95        for i in (p - 1)..n {
96            values[i] = Self::ols_slope(&prices[(i + 1 - p)..=i]);
97        }
98
99        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
100    }
101}
102
103pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
104    Ok(Box::new(LinearRegression::new(LrParams {
105        period: param_usize(params, "period", 14)?,
106        ..Default::default()
107    })))
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    fn candles(closes: &[f64]) -> Vec<Candle> {
115        closes.iter().enumerate().map(|(i, &c)| Candle {
116            time: i as i64, open: c, high: c, low: c, close: c, volume: 1.0,
117        }).collect()
118    }
119
120    #[test]
121    fn lr_perfect_line_slope_one() {
122        // y = x → slope should be 1.0
123        let closes: Vec<f64> = (0..14).map(|x| x as f64).collect();
124        let out = LinearRegression::with_period(14).calculate(&candles(&closes)).unwrap();
125        let vals = out.get("LR_slope_14").unwrap();
126        assert!((vals[13] - 1.0).abs() < 1e-9, "got {}", vals[13]);
127    }
128
129    #[test]
130    fn lr_constant_slope_zero() {
131        let closes = vec![5.0f64; 14];
132        let out = LinearRegression::with_period(14).calculate(&candles(&closes)).unwrap();
133        let vals = out.get("LR_slope_14").unwrap();
134        assert!(vals[13].abs() < 1e-9);
135    }
136
137    #[test]
138    fn lr_leading_nans() {
139        let closes: Vec<f64> = (0..20).map(|x| x as f64).collect();
140        let out = LinearRegression::with_period(14).calculate(&candles(&closes)).unwrap();
141        let vals = out.get("LR_slope_14").unwrap();
142        assert!(vals[0].is_nan());
143        assert!(!vals[13].is_nan());
144    }
145
146    #[test]
147    fn factory_creates_lr() {
148        assert_eq!(factory(&HashMap::new()).unwrap().name(), "LinearRegression");
149    }
150}