Skip to main content

indicators/trend/
linear_regression.rs

1//! Linear Regression Slope.
2//!
3//! Python source: `indicators/other/linear_regression.py :: class LinearRegressionIndicator`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! X = np.arange(self.period)
8//! slopes = data["Close"].rolling(window=self.period).apply(
9//!     lambda y: np.polyfit(X, y, 1)[0], raw=True
10//! )
11//! ```
12//!
13//! Output column: `"LR_slope_{period}"`.
14
15use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
23pub struct LrParams {
24    /// Rolling window.  Python default: 14.
25    pub period: usize,
26    /// Price field.  Python default: close.
27    pub column: PriceColumn,
28}
29impl Default for LrParams {
30    fn default() -> Self {
31        Self {
32            period: 14,
33            column: PriceColumn::Close,
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct LinearRegression {
40    pub params: LrParams,
41}
42
43impl LinearRegression {
44    pub fn new(params: LrParams) -> Self {
45        Self { params }
46    }
47    pub fn with_period(period: usize) -> Self {
48        Self::new(LrParams {
49            period,
50            ..Default::default()
51        })
52    }
53    fn output_key(&self) -> String {
54        format!("LR_slope_{}", self.params.period)
55    }
56
57    /// OLS slope: `sum((x - x_mean)(y - y_mean)) / sum((x - x_mean)^2)`
58    /// where `x = 0..period`.
59    fn ols_slope(y: &[f64]) -> f64 {
60        let n = y.len() as f64;
61        let x_mean = (n - 1.0) / 2.0;
62        let y_mean: f64 = y.iter().sum::<f64>() / n;
63        let mut num = 0.0f64;
64        let mut den = 0.0f64;
65        for (i, &yi) in y.iter().enumerate() {
66            let xi = i as f64 - x_mean;
67            num += xi * (yi - y_mean);
68            den += xi * xi;
69        }
70        if den == 0.0 { 0.0 } else { num / den }
71    }
72}
73
74impl Indicator for LinearRegression {
75    fn name(&self) -> &'static str {
76        "LinearRegression"
77    }
78    fn required_len(&self) -> usize {
79        self.params.period
80    }
81    fn required_columns(&self) -> &[&'static str] {
82        &["close"]
83    }
84
85    /// Ports `rolling(window=period).apply(lambda y: np.polyfit(X, y, 1)[0])`.
86    ///
87    /// OLS slope = `Σ(xᵢ − x̄)(yᵢ − ȳ) / Σ(xᵢ − x̄)²` where `xᵢ = i`.
88    /// This is algebraically identical to `np.polyfit`'s degree-1 coefficient.
89    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
90        self.check_len(candles)?;
91
92        let prices = self.params.column.extract(candles);
93        let n = prices.len();
94        let p = self.params.period;
95        let mut values = vec![f64::NAN; n];
96
97        for i in (p - 1)..n {
98            values[i] = Self::ols_slope(&prices[(i + 1 - p)..=i]);
99        }
100
101        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
102    }
103}
104
105pub fn factory<S: ::std::hash::BuildHasher>(
106    params: &HashMap<String, String, S>,
107) -> Result<Box<dyn Indicator>, IndicatorError> {
108    Ok(Box::new(LinearRegression::new(LrParams {
109        period: param_usize(params, "period", 14)?,
110        ..Default::default()
111    })))
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    fn candles(closes: &[f64]) -> Vec<Candle> {
119        closes
120            .iter()
121            .enumerate()
122            .map(|(i, &c)| Candle {
123                time: i64::try_from(i).expect("time index fits i64"),
124                open: c,
125                high: c,
126                low: c,
127                close: c,
128                volume: 1.0,
129            })
130            .collect()
131    }
132
133    #[test]
134    fn lr_perfect_line_slope_one() {
135        // y = x → slope should be 1.0
136        let closes: Vec<f64> = (0..14).map(|x| x as f64).collect();
137        let out = LinearRegression::with_period(14)
138            .calculate(&candles(&closes))
139            .unwrap();
140        let vals = out.get("LR_slope_14").unwrap();
141        assert!((vals[13] - 1.0).abs() < 1e-9, "got {}", vals[13]);
142    }
143
144    #[test]
145    fn lr_constant_slope_zero() {
146        let closes = vec![5.0f64; 14];
147        let out = LinearRegression::with_period(14)
148            .calculate(&candles(&closes))
149            .unwrap();
150        let vals = out.get("LR_slope_14").unwrap();
151        assert!(vals[13].abs() < 1e-9);
152    }
153
154    #[test]
155    fn lr_leading_nans() {
156        let closes: Vec<f64> = (0..20).map(|x| x as f64).collect();
157        let out = LinearRegression::with_period(14)
158            .calculate(&candles(&closes))
159            .unwrap();
160        let vals = out.get("LR_slope_14").unwrap();
161        assert!(vals[0].is_nan());
162        assert!(!vals[13].is_nan());
163    }
164
165    #[test]
166    fn factory_creates_lr() {
167        assert_eq!(factory(&HashMap::new()).unwrap().name(), "LinearRegression");
168    }
169}