1use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2, Axis, ScalarOperand};
30use num_traits::Float;
31
32use crate::linalg;
33
34#[derive(Debug, Clone)]
44pub struct LinearRegression<F> {
45 pub fit_intercept: bool,
47 _marker: std::marker::PhantomData<F>,
48}
49
50impl<F: Float> LinearRegression<F> {
51 #[must_use]
55 pub fn new() -> Self {
56 Self {
57 fit_intercept: true,
58 _marker: std::marker::PhantomData,
59 }
60 }
61
62 #[must_use]
64 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
65 self.fit_intercept = fit_intercept;
66 self
67 }
68}
69
70impl<F: Float> Default for LinearRegression<F> {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[derive(Debug, Clone)]
81pub struct FittedLinearRegression<F> {
82 coefficients: Array1<F>,
84 intercept: F,
86}
87
88impl<F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static>
89 Fit<Array2<F>, Array1<F>> for LinearRegression<F>
90{
91 type Fitted = FittedLinearRegression<F>;
92 type Error = FerroError;
93
94 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLinearRegression<F>, FerroError> {
108 let (n_samples, _n_features) = x.dim();
109
110 if n_samples != y.len() {
112 return Err(FerroError::ShapeMismatch {
113 expected: vec![n_samples],
114 actual: vec![y.len()],
115 context: "y length must match number of samples in X".into(),
116 });
117 }
118
119 if n_samples == 0 {
120 return Err(FerroError::InsufficientSamples {
121 required: 1,
122 actual: 0,
123 context: "LinearRegression requires at least one sample".into(),
124 });
125 }
126
127 if self.fit_intercept {
128 let n = F::from(n_samples).unwrap();
132 let x_mean = x.mean_axis(Axis(0)).unwrap();
133 let y_mean = y.sum() / n;
134
135 let x_centered = x - &x_mean;
136 let y_centered = y - y_mean;
137
138 let w = linalg::solve_normal_equations(&x_centered, &y_centered)
140 .or_else(|_| linalg::solve_lstsq(&x_centered, &y_centered))?;
141
142 let intercept = y_mean - x_mean.dot(&w);
143
144 Ok(FittedLinearRegression {
145 coefficients: w,
146 intercept,
147 })
148 } else {
149 let w = linalg::solve_normal_equations(x, y)
151 .or_else(|_| linalg::solve_lstsq(x, y))?;
152
153 Ok(FittedLinearRegression {
154 coefficients: w,
155 intercept: F::zero(),
156 })
157 }
158 }
159}
160
161impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
162 for FittedLinearRegression<F>
163{
164 type Output = Array1<F>;
165 type Error = FerroError;
166
167 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
176 let n_features = x.ncols();
177 if n_features != self.coefficients.len() {
178 return Err(FerroError::ShapeMismatch {
179 expected: vec![self.coefficients.len()],
180 actual: vec![n_features],
181 context: "number of features must match fitted model".into(),
182 });
183 }
184
185 let preds = x.dot(&self.coefficients) + self.intercept;
186 Ok(preds)
187 }
188}
189
190impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
191 for FittedLinearRegression<F>
192{
193 fn coefficients(&self) -> &Array1<F> {
194 &self.coefficients
195 }
196
197 fn intercept(&self) -> F {
198 self.intercept
199 }
200}
201
202impl PipelineEstimator for LinearRegression<f64> {
204 fn fit_pipeline(
205 &self,
206 x: &Array2<f64>,
207 y: &Array1<f64>,
208 ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
209 let fitted = self.fit(x, y)?;
210 Ok(Box::new(fitted))
211 }
212}
213
214impl FittedPipelineEstimator for FittedLinearRegression<f64> {
215 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
216 self.predict(x)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use approx::assert_relative_eq;
224 use ndarray::array;
225
226 #[test]
227 fn test_simple_linear_regression() {
228 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
230 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
231
232 let model = LinearRegression::<f64>::new();
233 let fitted = model.fit(&x, &y).unwrap();
234
235 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
236 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-10);
237
238 let preds = fitted.predict(&x).unwrap();
239 for (p, &actual) in preds.iter().zip(y.iter()) {
240 assert_relative_eq!(*p, actual, epsilon = 1e-10);
241 }
242 }
243
244 #[test]
245 fn test_multiple_linear_regression() {
246 let x =
248 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 1.0, 3.0, 2.0, 4.0, 2.0]).unwrap();
249 let y = array![6.0, 7.0, 10.0, 11.0];
250
251 let model = LinearRegression::<f64>::new();
252 let fitted = model.fit(&x, &y).unwrap();
253
254 assert_relative_eq!(fitted.coefficients()[0], 1.0, epsilon = 1e-10);
255 assert_relative_eq!(fitted.coefficients()[1], 2.0, epsilon = 1e-10);
256 assert_relative_eq!(fitted.intercept(), 3.0, epsilon = 1e-10);
257 }
258
259 #[test]
260 fn test_no_intercept() {
261 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
263 let y = array![2.0, 4.0, 6.0, 8.0];
264
265 let model = LinearRegression::<f64>::new().with_fit_intercept(false);
266 let fitted = model.fit(&x, &y).unwrap();
267
268 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
269 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
270 }
271
272 #[test]
273 fn test_shape_mismatch_fit() {
274 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
275 let y = array![1.0, 2.0]; let model = LinearRegression::<f64>::new();
278 let result = model.fit(&x, &y);
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn test_shape_mismatch_predict() {
284 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
285 let y = array![1.0, 2.0, 3.0];
286
287 let model = LinearRegression::<f64>::new();
288 let fitted = model.fit(&x, &y).unwrap();
289
290 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
292 let result = fitted.predict(&x_bad);
293 assert!(result.is_err());
294 }
295
296 #[test]
297 fn test_has_coefficients() {
298 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
299 let y = array![2.0, 4.0, 6.0];
300
301 let model = LinearRegression::<f64>::new();
302 let fitted = model.fit(&x, &y).unwrap();
303
304 assert_eq!(fitted.coefficients().len(), 1);
305 }
306
307 #[test]
308 fn test_pipeline_integration() {
309 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
310 let y = array![3.0, 5.0, 7.0, 9.0];
311
312 let model = LinearRegression::<f64>::new();
313 let fitted = model.fit_pipeline(&x, &y).unwrap();
314 let preds = fitted.predict_pipeline(&x).unwrap();
315 assert_eq!(preds.len(), 4);
316 }
317
318 #[test]
319 fn test_f32_support() {
320 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
321 let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
322
323 let model = LinearRegression::<f32>::new();
324 let fitted = model.fit(&x, &y).unwrap();
325 let preds = fitted.predict(&x).unwrap();
326 assert_eq!(preds.len(), 4);
327 }
328}