Skip to main content

ferrolearn_linear/
linear_regression.rs

1//! Ordinary Least Squares linear regression.
2//!
3//! This module provides [`LinearRegression`], which fits a linear model
4//! using QR decomposition (via `faer`) to solve the least squares problem:
5//!
6//! ```text
7//! minimize ||X @ w - y||^2
8//! ```
9//!
10//! # Examples
11//!
12//! ```
13//! use ferrolearn_linear::LinearRegression;
14//! use ferrolearn_core::{Fit, Predict};
15//! use ndarray::{array, Array1, Array2};
16//!
17//! let model = LinearRegression::<f64>::new();
18//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
19//! let y = array![2.0, 4.0, 6.0, 8.0];
20//!
21//! let fitted = model.fit(&x, &y).unwrap();
22//! let preds = fitted.predict(&x).unwrap();
23//! ```
24
25use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2, Axis, ScalarOperand};
30use num_traits::Float;
31
32use crate::linalg;
33
34/// Ordinary least squares linear regression.
35///
36/// Solves the normal equations using QR decomposition for numerical
37/// stability. The `fit_intercept` option controls whether a bias
38/// (intercept) term is included.
39///
40/// # Type Parameters
41///
42/// - `F`: The floating-point type (`f32` or `f64`).
43#[derive(Debug, Clone)]
44pub struct LinearRegression<F> {
45    /// Whether to fit an intercept (bias) term.
46    pub fit_intercept: bool,
47    _marker: std::marker::PhantomData<F>,
48}
49
50impl<F: Float> LinearRegression<F> {
51    /// Create a new `LinearRegression` with default settings.
52    ///
53    /// Defaults: `fit_intercept = true`.
54    #[must_use]
55    pub fn new() -> Self {
56        Self {
57            fit_intercept: true,
58            _marker: std::marker::PhantomData,
59        }
60    }
61
62    /// Set whether to fit an intercept term.
63    #[must_use]
64    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
65        self.fit_intercept = fit_intercept;
66        self
67    }
68}
69
70impl<F: Float> Default for LinearRegression<F> {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76/// Fitted ordinary least squares linear regression model.
77///
78/// Stores the learned coefficients and intercept. Implements [`Predict`]
79/// to generate predictions and [`HasCoefficients`] for introspection.
80#[derive(Debug, Clone)]
81pub struct FittedLinearRegression<F> {
82    /// Learned coefficient vector (one per feature).
83    coefficients: Array1<F>,
84    /// Learned intercept (bias) term.
85    intercept: F,
86}
87
88impl<F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static>
89    Fit<Array2<F>, Array1<F>> for LinearRegression<F>
90{
91    type Fitted = FittedLinearRegression<F>;
92    type Error = FerroError;
93
94    /// Fit the linear regression model.
95    ///
96    /// Uses the centering trick with Cholesky normal equations for speed.
97    /// Falls back to QR decomposition via faer if the normal equations are
98    /// ill-conditioned.
99    ///
100    /// # Errors
101    ///
102    /// Returns [`FerroError::ShapeMismatch`] if the number of samples in `x`
103    /// and `y` differ.
104    /// Returns [`FerroError::InsufficientSamples`] if there are fewer samples
105    /// than features.
106    /// Returns [`FerroError::NumericalInstability`] if the system is singular.
107    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLinearRegression<F>, FerroError> {
108        let (n_samples, _n_features) = x.dim();
109
110        // Validate input shapes.
111        if n_samples != y.len() {
112            return Err(FerroError::ShapeMismatch {
113                expected: vec![n_samples],
114                actual: vec![y.len()],
115                context: "y length must match number of samples in X".into(),
116            });
117        }
118
119        if n_samples == 0 {
120            return Err(FerroError::InsufficientSamples {
121                required: 1,
122                actual: 0,
123                context: "LinearRegression requires at least one sample".into(),
124            });
125        }
126
127        if self.fit_intercept {
128            // Centering trick: center X and y, solve without intercept column,
129            // then recover intercept as y_mean - x_mean . w.
130            // This avoids the expensive matrix augmentation + QR path.
131            let n = F::from(n_samples).unwrap();
132            let x_mean = x.mean_axis(Axis(0)).unwrap();
133            let y_mean = y.sum() / n;
134
135            let x_centered = x - &x_mean;
136            let y_centered = y - y_mean;
137
138            // Try fast Cholesky normal equations first, fall back to QR.
139            let w = linalg::solve_normal_equations(&x_centered, &y_centered)
140                .or_else(|_| linalg::solve_lstsq(&x_centered, &y_centered))?;
141
142            let intercept = y_mean - x_mean.dot(&w);
143
144            Ok(FittedLinearRegression {
145                coefficients: w,
146                intercept,
147            })
148        } else {
149            // Try fast Cholesky normal equations first, fall back to QR.
150            let w = linalg::solve_normal_equations(x, y)
151                .or_else(|_| linalg::solve_lstsq(x, y))?;
152
153            Ok(FittedLinearRegression {
154                coefficients: w,
155                intercept: F::zero(),
156            })
157        }
158    }
159}
160
161impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
162    for FittedLinearRegression<F>
163{
164    type Output = Array1<F>;
165    type Error = FerroError;
166
167    /// Predict target values for the given feature matrix.
168    ///
169    /// Computes `X @ coefficients + intercept`.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`FerroError::ShapeMismatch`] if the number of features
174    /// does not match the fitted model.
175    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
176        let n_features = x.ncols();
177        if n_features != self.coefficients.len() {
178            return Err(FerroError::ShapeMismatch {
179                expected: vec![self.coefficients.len()],
180                actual: vec![n_features],
181                context: "number of features must match fitted model".into(),
182            });
183        }
184
185        let preds = x.dot(&self.coefficients) + self.intercept;
186        Ok(preds)
187    }
188}
189
190impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
191    for FittedLinearRegression<F>
192{
193    fn coefficients(&self) -> &Array1<F> {
194        &self.coefficients
195    }
196
197    fn intercept(&self) -> F {
198        self.intercept
199    }
200}
201
202// Pipeline integration for f64.
203impl PipelineEstimator for LinearRegression<f64> {
204    fn fit_pipeline(
205        &self,
206        x: &Array2<f64>,
207        y: &Array1<f64>,
208    ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
209        let fitted = self.fit(x, y)?;
210        Ok(Box::new(fitted))
211    }
212}
213
214impl FittedPipelineEstimator for FittedLinearRegression<f64> {
215    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
216        self.predict(x)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use approx::assert_relative_eq;
224    use ndarray::array;
225
226    #[test]
227    fn test_simple_linear_regression() {
228        // y = 2*x + 1
229        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
230        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
231
232        let model = LinearRegression::<f64>::new();
233        let fitted = model.fit(&x, &y).unwrap();
234
235        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
236        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-10);
237
238        let preds = fitted.predict(&x).unwrap();
239        for (p, &actual) in preds.iter().zip(y.iter()) {
240            assert_relative_eq!(*p, actual, epsilon = 1e-10);
241        }
242    }
243
244    #[test]
245    fn test_multiple_linear_regression() {
246        // y = 1*x1 + 2*x2 + 3
247        let x =
248            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 1.0, 3.0, 2.0, 4.0, 2.0]).unwrap();
249        let y = array![6.0, 7.0, 10.0, 11.0];
250
251        let model = LinearRegression::<f64>::new();
252        let fitted = model.fit(&x, &y).unwrap();
253
254        assert_relative_eq!(fitted.coefficients()[0], 1.0, epsilon = 1e-10);
255        assert_relative_eq!(fitted.coefficients()[1], 2.0, epsilon = 1e-10);
256        assert_relative_eq!(fitted.intercept(), 3.0, epsilon = 1e-10);
257    }
258
259    #[test]
260    fn test_no_intercept() {
261        // y = 2*x (through origin)
262        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
263        let y = array![2.0, 4.0, 6.0, 8.0];
264
265        let model = LinearRegression::<f64>::new().with_fit_intercept(false);
266        let fitted = model.fit(&x, &y).unwrap();
267
268        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
269        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
270    }
271
272    #[test]
273    fn test_shape_mismatch_fit() {
274        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
275        let y = array![1.0, 2.0]; // Wrong length
276
277        let model = LinearRegression::<f64>::new();
278        let result = model.fit(&x, &y);
279        assert!(result.is_err());
280    }
281
282    #[test]
283    fn test_shape_mismatch_predict() {
284        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
285        let y = array![1.0, 2.0, 3.0];
286
287        let model = LinearRegression::<f64>::new();
288        let fitted = model.fit(&x, &y).unwrap();
289
290        // Wrong number of features
291        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
292        let result = fitted.predict(&x_bad);
293        assert!(result.is_err());
294    }
295
296    #[test]
297    fn test_has_coefficients() {
298        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
299        let y = array![2.0, 4.0, 6.0];
300
301        let model = LinearRegression::<f64>::new();
302        let fitted = model.fit(&x, &y).unwrap();
303
304        assert_eq!(fitted.coefficients().len(), 1);
305    }
306
307    #[test]
308    fn test_pipeline_integration() {
309        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
310        let y = array![3.0, 5.0, 7.0, 9.0];
311
312        let model = LinearRegression::<f64>::new();
313        let fitted = model.fit_pipeline(&x, &y).unwrap();
314        let preds = fitted.predict_pipeline(&x).unwrap();
315        assert_eq!(preds.len(), 4);
316    }
317
318    #[test]
319    fn test_f32_support() {
320        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
321        let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
322
323        let model = LinearRegression::<f32>::new();
324        let fitted = model.fit(&x, &y).unwrap();
325        let preds = fitted.predict(&x).unwrap();
326        assert_eq!(preds.len(), 4);
327    }
328}