1use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2, Axis, ScalarOperand};
30use num_traits::{Float, FromPrimitive};
31
32use crate::linalg;
33
34#[derive(Debug, Clone)]
44pub struct LinearRegression<F> {
45 pub fit_intercept: bool,
47 _marker: std::marker::PhantomData<F>,
48}
49
50impl<F: Float> LinearRegression<F> {
51 #[must_use]
55 pub fn new() -> Self {
56 Self {
57 fit_intercept: true,
58 _marker: std::marker::PhantomData,
59 }
60 }
61
62 #[must_use]
64 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
65 self.fit_intercept = fit_intercept;
66 self
67 }
68}
69
70impl<F: Float> Default for LinearRegression<F> {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[derive(Debug, Clone)]
81pub struct FittedLinearRegression<F> {
82 coefficients: Array1<F>,
84 intercept: F,
86}
87
88impl<F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static>
89 Fit<Array2<F>, Array1<F>> for LinearRegression<F>
90{
91 type Fitted = FittedLinearRegression<F>;
92 type Error = FerroError;
93
94 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLinearRegression<F>, FerroError> {
108 let (n_samples, _n_features) = x.dim();
109
110 if n_samples != y.len() {
112 return Err(FerroError::ShapeMismatch {
113 expected: vec![n_samples],
114 actual: vec![y.len()],
115 context: "y length must match number of samples in X".into(),
116 });
117 }
118
119 if n_samples == 0 {
120 return Err(FerroError::InsufficientSamples {
121 required: 1,
122 actual: 0,
123 context: "LinearRegression requires at least one sample".into(),
124 });
125 }
126
127 if self.fit_intercept {
128 let n = F::from(n_samples).unwrap();
132 let x_mean = x.mean_axis(Axis(0)).unwrap();
133 let y_mean = y.sum() / n;
134
135 let x_centered = x - &x_mean;
136 let y_centered = y - y_mean;
137
138 let w = linalg::solve_normal_equations(&x_centered, &y_centered)
140 .or_else(|_| linalg::solve_lstsq(&x_centered, &y_centered))?;
141
142 let intercept = y_mean - x_mean.dot(&w);
143
144 Ok(FittedLinearRegression {
145 coefficients: w,
146 intercept,
147 })
148 } else {
149 let w = linalg::solve_normal_equations(x, y).or_else(|_| linalg::solve_lstsq(x, y))?;
151
152 Ok(FittedLinearRegression {
153 coefficients: w,
154 intercept: F::zero(),
155 })
156 }
157 }
158}
159
160impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
161 for FittedLinearRegression<F>
162{
163 type Output = Array1<F>;
164 type Error = FerroError;
165
166 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
175 let n_features = x.ncols();
176 if n_features != self.coefficients.len() {
177 return Err(FerroError::ShapeMismatch {
178 expected: vec![self.coefficients.len()],
179 actual: vec![n_features],
180 context: "number of features must match fitted model".into(),
181 });
182 }
183
184 let preds = x.dot(&self.coefficients) + self.intercept;
185 Ok(preds)
186 }
187}
188
189impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
190 for FittedLinearRegression<F>
191{
192 fn coefficients(&self) -> &Array1<F> {
193 &self.coefficients
194 }
195
196 fn intercept(&self) -> F {
197 self.intercept
198 }
199}
200
201impl<F> PipelineEstimator<F> for LinearRegression<F>
203where
204 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
205{
206 fn fit_pipeline(
207 &self,
208 x: &Array2<F>,
209 y: &Array1<F>,
210 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
211 let fitted = self.fit(x, y)?;
212 Ok(Box::new(fitted))
213 }
214}
215
216impl<F> FittedPipelineEstimator<F> for FittedLinearRegression<F>
217where
218 F: Float + ScalarOperand + Send + Sync + 'static,
219{
220 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
221 self.predict(x)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use approx::assert_relative_eq;
229 use ndarray::array;
230
231 #[test]
232 fn test_simple_linear_regression() {
233 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
235 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
236
237 let model = LinearRegression::<f64>::new();
238 let fitted = model.fit(&x, &y).unwrap();
239
240 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
241 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-10);
242
243 let preds = fitted.predict(&x).unwrap();
244 for (p, &actual) in preds.iter().zip(y.iter()) {
245 assert_relative_eq!(*p, actual, epsilon = 1e-10);
246 }
247 }
248
249 #[test]
250 fn test_multiple_linear_regression() {
251 let x =
253 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 1.0, 3.0, 2.0, 4.0, 2.0]).unwrap();
254 let y = array![6.0, 7.0, 10.0, 11.0];
255
256 let model = LinearRegression::<f64>::new();
257 let fitted = model.fit(&x, &y).unwrap();
258
259 assert_relative_eq!(fitted.coefficients()[0], 1.0, epsilon = 1e-10);
260 assert_relative_eq!(fitted.coefficients()[1], 2.0, epsilon = 1e-10);
261 assert_relative_eq!(fitted.intercept(), 3.0, epsilon = 1e-10);
262 }
263
264 #[test]
265 fn test_no_intercept() {
266 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
268 let y = array![2.0, 4.0, 6.0, 8.0];
269
270 let model = LinearRegression::<f64>::new().with_fit_intercept(false);
271 let fitted = model.fit(&x, &y).unwrap();
272
273 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
274 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
275 }
276
277 #[test]
278 fn test_shape_mismatch_fit() {
279 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
280 let y = array![1.0, 2.0]; let model = LinearRegression::<f64>::new();
283 let result = model.fit(&x, &y);
284 assert!(result.is_err());
285 }
286
287 #[test]
288 fn test_shape_mismatch_predict() {
289 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
290 let y = array![1.0, 2.0, 3.0];
291
292 let model = LinearRegression::<f64>::new();
293 let fitted = model.fit(&x, &y).unwrap();
294
295 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
297 let result = fitted.predict(&x_bad);
298 assert!(result.is_err());
299 }
300
301 #[test]
302 fn test_has_coefficients() {
303 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
304 let y = array![2.0, 4.0, 6.0];
305
306 let model = LinearRegression::<f64>::new();
307 let fitted = model.fit(&x, &y).unwrap();
308
309 assert_eq!(fitted.coefficients().len(), 1);
310 }
311
312 #[test]
313 fn test_pipeline_integration() {
314 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
315 let y = array![3.0, 5.0, 7.0, 9.0];
316
317 let model = LinearRegression::<f64>::new();
318 let fitted = model.fit_pipeline(&x, &y).unwrap();
319 let preds = fitted.predict_pipeline(&x).unwrap();
320 assert_eq!(preds.len(), 4);
321 }
322
323 #[test]
324 fn test_f32_support() {
325 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
326 let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
327
328 let model = LinearRegression::<f32>::new();
329 let fitted = model.fit(&x, &y).unwrap();
330 let preds = fitted.predict(&x).unwrap();
331 assert_eq!(preds.len(), 4);
332 }
333}