Skip to main content

ferrolearn_linear/
linear_regression.rs

1//! Ordinary Least Squares linear regression.
2//!
3//! This module provides [`LinearRegression`], which fits a linear model
4//! using QR decomposition (via `faer`) to solve the least squares problem:
5//!
6//! ```text
7//! minimize ||X @ w - y||^2
8//! ```
9//!
10//! # Examples
11//!
12//! ```
13//! use ferrolearn_linear::LinearRegression;
14//! use ferrolearn_core::{Fit, Predict};
15//! use ndarray::{array, Array1, Array2};
16//!
17//! let model = LinearRegression::<f64>::new();
18//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
19//! let y = array![2.0, 4.0, 6.0, 8.0];
20//!
21//! let fitted = model.fit(&x, &y).unwrap();
22//! let preds = fitted.predict(&x).unwrap();
23//! ```
24
25use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2, Axis, ScalarOperand};
30use num_traits::Float;
31
32use crate::linalg;
33
34/// Ordinary least squares linear regression.
35///
36/// Solves the normal equations using QR decomposition for numerical
37/// stability. The `fit_intercept` option controls whether a bias
38/// (intercept) term is included.
39///
40/// # Type Parameters
41///
42/// - `F`: The floating-point type (`f32` or `f64`).
43#[derive(Debug, Clone)]
44pub struct LinearRegression<F> {
45    /// Whether to fit an intercept (bias) term.
46    pub fit_intercept: bool,
47    _marker: std::marker::PhantomData<F>,
48}
49
50impl<F: Float> LinearRegression<F> {
51    /// Create a new `LinearRegression` with default settings.
52    ///
53    /// Defaults: `fit_intercept = true`.
54    #[must_use]
55    pub fn new() -> Self {
56        Self {
57            fit_intercept: true,
58            _marker: std::marker::PhantomData,
59        }
60    }
61
62    /// Set whether to fit an intercept term.
63    #[must_use]
64    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
65        self.fit_intercept = fit_intercept;
66        self
67    }
68}
69
70impl<F: Float> Default for LinearRegression<F> {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76/// Fitted ordinary least squares linear regression model.
77///
78/// Stores the learned coefficients and intercept. Implements [`Predict`]
79/// to generate predictions and [`HasCoefficients`] for introspection.
80#[derive(Debug, Clone)]
81pub struct FittedLinearRegression<F> {
82    /// Learned coefficient vector (one per feature).
83    coefficients: Array1<F>,
84    /// Learned intercept (bias) term.
85    intercept: F,
86}
87
88impl<F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static>
89    Fit<Array2<F>, Array1<F>> for LinearRegression<F>
90{
91    type Fitted = FittedLinearRegression<F>;
92    type Error = FerroError;
93
94    /// Fit the linear regression model.
95    ///
96    /// Uses the centering trick with Cholesky normal equations for speed.
97    /// Falls back to QR decomposition via faer if the normal equations are
98    /// ill-conditioned.
99    ///
100    /// # Errors
101    ///
102    /// Returns [`FerroError::ShapeMismatch`] if the number of samples in `x`
103    /// and `y` differ.
104    /// Returns [`FerroError::InsufficientSamples`] if there are fewer samples
105    /// than features.
106    /// Returns [`FerroError::NumericalInstability`] if the system is singular.
107    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLinearRegression<F>, FerroError> {
108        let (n_samples, _n_features) = x.dim();
109
110        // Validate input shapes.
111        if n_samples != y.len() {
112            return Err(FerroError::ShapeMismatch {
113                expected: vec![n_samples],
114                actual: vec![y.len()],
115                context: "y length must match number of samples in X".into(),
116            });
117        }
118
119        if n_samples == 0 {
120            return Err(FerroError::InsufficientSamples {
121                required: 1,
122                actual: 0,
123                context: "LinearRegression requires at least one sample".into(),
124            });
125        }
126
127        if self.fit_intercept {
128            // Centering trick: center X and y, solve without intercept column,
129            // then recover intercept as y_mean - x_mean . w.
130            // This avoids the expensive matrix augmentation + QR path.
131            let n = F::from(n_samples).unwrap();
132            let x_mean = x.mean_axis(Axis(0)).unwrap();
133            let y_mean = y.sum() / n;
134
135            let x_centered = x - &x_mean;
136            let y_centered = y - y_mean;
137
138            // Try fast Cholesky normal equations first, fall back to QR.
139            let w = linalg::solve_normal_equations(&x_centered, &y_centered)
140                .or_else(|_| linalg::solve_lstsq(&x_centered, &y_centered))?;
141
142            let intercept = y_mean - x_mean.dot(&w);
143
144            Ok(FittedLinearRegression {
145                coefficients: w,
146                intercept,
147            })
148        } else {
149            // Try fast Cholesky normal equations first, fall back to QR.
150            let w = linalg::solve_normal_equations(x, y).or_else(|_| linalg::solve_lstsq(x, y))?;
151
152            Ok(FittedLinearRegression {
153                coefficients: w,
154                intercept: F::zero(),
155            })
156        }
157    }
158}
159
160impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
161    for FittedLinearRegression<F>
162{
163    type Output = Array1<F>;
164    type Error = FerroError;
165
166    /// Predict target values for the given feature matrix.
167    ///
168    /// Computes `X @ coefficients + intercept`.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`FerroError::ShapeMismatch`] if the number of features
173    /// does not match the fitted model.
174    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
175        let n_features = x.ncols();
176        if n_features != self.coefficients.len() {
177            return Err(FerroError::ShapeMismatch {
178                expected: vec![self.coefficients.len()],
179                actual: vec![n_features],
180                context: "number of features must match fitted model".into(),
181            });
182        }
183
184        let preds = x.dot(&self.coefficients) + self.intercept;
185        Ok(preds)
186    }
187}
188
189impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
190    for FittedLinearRegression<F>
191{
192    fn coefficients(&self) -> &Array1<F> {
193        &self.coefficients
194    }
195
196    fn intercept(&self) -> F {
197        self.intercept
198    }
199}
200
201// Pipeline integration for f64.
202impl PipelineEstimator<f64> for LinearRegression<f64> {
203    fn fit_pipeline(
204        &self,
205        x: &Array2<f64>,
206        y: &Array1<f64>,
207    ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
208        let fitted = self.fit(x, y)?;
209        Ok(Box::new(fitted))
210    }
211}
212
213impl FittedPipelineEstimator<f64> for FittedLinearRegression<f64> {
214    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
215        self.predict(x)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use approx::assert_relative_eq;
223    use ndarray::array;
224
225    #[test]
226    fn test_simple_linear_regression() {
227        // y = 2*x + 1
228        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
229        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
230
231        let model = LinearRegression::<f64>::new();
232        let fitted = model.fit(&x, &y).unwrap();
233
234        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
235        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-10);
236
237        let preds = fitted.predict(&x).unwrap();
238        for (p, &actual) in preds.iter().zip(y.iter()) {
239            assert_relative_eq!(*p, actual, epsilon = 1e-10);
240        }
241    }
242
243    #[test]
244    fn test_multiple_linear_regression() {
245        // y = 1*x1 + 2*x2 + 3
246        let x =
247            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 1.0, 3.0, 2.0, 4.0, 2.0]).unwrap();
248        let y = array![6.0, 7.0, 10.0, 11.0];
249
250        let model = LinearRegression::<f64>::new();
251        let fitted = model.fit(&x, &y).unwrap();
252
253        assert_relative_eq!(fitted.coefficients()[0], 1.0, epsilon = 1e-10);
254        assert_relative_eq!(fitted.coefficients()[1], 2.0, epsilon = 1e-10);
255        assert_relative_eq!(fitted.intercept(), 3.0, epsilon = 1e-10);
256    }
257
258    #[test]
259    fn test_no_intercept() {
260        // y = 2*x (through origin)
261        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
262        let y = array![2.0, 4.0, 6.0, 8.0];
263
264        let model = LinearRegression::<f64>::new().with_fit_intercept(false);
265        let fitted = model.fit(&x, &y).unwrap();
266
267        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
268        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
269    }
270
271    #[test]
272    fn test_shape_mismatch_fit() {
273        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
274        let y = array![1.0, 2.0]; // Wrong length
275
276        let model = LinearRegression::<f64>::new();
277        let result = model.fit(&x, &y);
278        assert!(result.is_err());
279    }
280
281    #[test]
282    fn test_shape_mismatch_predict() {
283        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
284        let y = array![1.0, 2.0, 3.0];
285
286        let model = LinearRegression::<f64>::new();
287        let fitted = model.fit(&x, &y).unwrap();
288
289        // Wrong number of features
290        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
291        let result = fitted.predict(&x_bad);
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_has_coefficients() {
297        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
298        let y = array![2.0, 4.0, 6.0];
299
300        let model = LinearRegression::<f64>::new();
301        let fitted = model.fit(&x, &y).unwrap();
302
303        assert_eq!(fitted.coefficients().len(), 1);
304    }
305
306    #[test]
307    fn test_pipeline_integration() {
308        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
309        let y = array![3.0, 5.0, 7.0, 9.0];
310
311        let model = LinearRegression::<f64>::new();
312        let fitted = model.fit_pipeline(&x, &y).unwrap();
313        let preds = fitted.predict_pipeline(&x).unwrap();
314        assert_eq!(preds.len(), 4);
315    }
316
317    #[test]
318    fn test_f32_support() {
319        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
320        let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
321
322        let model = LinearRegression::<f32>::new();
323        let fitted = model.fit(&x, &y).unwrap();
324        let preds = fitted.predict(&x).unwrap();
325        assert_eq!(preds.len(), 4);
326    }
327}