1use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2, Axis, ScalarOperand};
30use num_traits::Float;
31
32use crate::linalg;
33
34#[derive(Debug, Clone)]
44pub struct LinearRegression<F> {
45 pub fit_intercept: bool,
47 _marker: std::marker::PhantomData<F>,
48}
49
50impl<F: Float> LinearRegression<F> {
51 #[must_use]
55 pub fn new() -> Self {
56 Self {
57 fit_intercept: true,
58 _marker: std::marker::PhantomData,
59 }
60 }
61
62 #[must_use]
64 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
65 self.fit_intercept = fit_intercept;
66 self
67 }
68}
69
70impl<F: Float> Default for LinearRegression<F> {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[derive(Debug, Clone)]
81pub struct FittedLinearRegression<F> {
82 coefficients: Array1<F>,
84 intercept: F,
86}
87
88impl<F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static>
89 Fit<Array2<F>, Array1<F>> for LinearRegression<F>
90{
91 type Fitted = FittedLinearRegression<F>;
92 type Error = FerroError;
93
94 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLinearRegression<F>, FerroError> {
108 let (n_samples, _n_features) = x.dim();
109
110 if n_samples != y.len() {
112 return Err(FerroError::ShapeMismatch {
113 expected: vec![n_samples],
114 actual: vec![y.len()],
115 context: "y length must match number of samples in X".into(),
116 });
117 }
118
119 if n_samples == 0 {
120 return Err(FerroError::InsufficientSamples {
121 required: 1,
122 actual: 0,
123 context: "LinearRegression requires at least one sample".into(),
124 });
125 }
126
127 if self.fit_intercept {
128 let n = F::from(n_samples).unwrap();
132 let x_mean = x.mean_axis(Axis(0)).unwrap();
133 let y_mean = y.sum() / n;
134
135 let x_centered = x - &x_mean;
136 let y_centered = y - y_mean;
137
138 let w = linalg::solve_normal_equations(&x_centered, &y_centered)
140 .or_else(|_| linalg::solve_lstsq(&x_centered, &y_centered))?;
141
142 let intercept = y_mean - x_mean.dot(&w);
143
144 Ok(FittedLinearRegression {
145 coefficients: w,
146 intercept,
147 })
148 } else {
149 let w = linalg::solve_normal_equations(x, y).or_else(|_| linalg::solve_lstsq(x, y))?;
151
152 Ok(FittedLinearRegression {
153 coefficients: w,
154 intercept: F::zero(),
155 })
156 }
157 }
158}
159
160impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
161 for FittedLinearRegression<F>
162{
163 type Output = Array1<F>;
164 type Error = FerroError;
165
166 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
175 let n_features = x.ncols();
176 if n_features != self.coefficients.len() {
177 return Err(FerroError::ShapeMismatch {
178 expected: vec![self.coefficients.len()],
179 actual: vec![n_features],
180 context: "number of features must match fitted model".into(),
181 });
182 }
183
184 let preds = x.dot(&self.coefficients) + self.intercept;
185 Ok(preds)
186 }
187}
188
189impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
190 for FittedLinearRegression<F>
191{
192 fn coefficients(&self) -> &Array1<F> {
193 &self.coefficients
194 }
195
196 fn intercept(&self) -> F {
197 self.intercept
198 }
199}
200
201impl PipelineEstimator<f64> for LinearRegression<f64> {
203 fn fit_pipeline(
204 &self,
205 x: &Array2<f64>,
206 y: &Array1<f64>,
207 ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
208 let fitted = self.fit(x, y)?;
209 Ok(Box::new(fitted))
210 }
211}
212
213impl FittedPipelineEstimator<f64> for FittedLinearRegression<f64> {
214 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
215 self.predict(x)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use approx::assert_relative_eq;
223 use ndarray::array;
224
225 #[test]
226 fn test_simple_linear_regression() {
227 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
229 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
230
231 let model = LinearRegression::<f64>::new();
232 let fitted = model.fit(&x, &y).unwrap();
233
234 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
235 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-10);
236
237 let preds = fitted.predict(&x).unwrap();
238 for (p, &actual) in preds.iter().zip(y.iter()) {
239 assert_relative_eq!(*p, actual, epsilon = 1e-10);
240 }
241 }
242
243 #[test]
244 fn test_multiple_linear_regression() {
245 let x =
247 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 1.0, 3.0, 2.0, 4.0, 2.0]).unwrap();
248 let y = array![6.0, 7.0, 10.0, 11.0];
249
250 let model = LinearRegression::<f64>::new();
251 let fitted = model.fit(&x, &y).unwrap();
252
253 assert_relative_eq!(fitted.coefficients()[0], 1.0, epsilon = 1e-10);
254 assert_relative_eq!(fitted.coefficients()[1], 2.0, epsilon = 1e-10);
255 assert_relative_eq!(fitted.intercept(), 3.0, epsilon = 1e-10);
256 }
257
258 #[test]
259 fn test_no_intercept() {
260 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
262 let y = array![2.0, 4.0, 6.0, 8.0];
263
264 let model = LinearRegression::<f64>::new().with_fit_intercept(false);
265 let fitted = model.fit(&x, &y).unwrap();
266
267 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
268 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
269 }
270
271 #[test]
272 fn test_shape_mismatch_fit() {
273 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
274 let y = array![1.0, 2.0]; let model = LinearRegression::<f64>::new();
277 let result = model.fit(&x, &y);
278 assert!(result.is_err());
279 }
280
281 #[test]
282 fn test_shape_mismatch_predict() {
283 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
284 let y = array![1.0, 2.0, 3.0];
285
286 let model = LinearRegression::<f64>::new();
287 let fitted = model.fit(&x, &y).unwrap();
288
289 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
291 let result = fitted.predict(&x_bad);
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn test_has_coefficients() {
297 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
298 let y = array![2.0, 4.0, 6.0];
299
300 let model = LinearRegression::<f64>::new();
301 let fitted = model.fit(&x, &y).unwrap();
302
303 assert_eq!(fitted.coefficients().len(), 1);
304 }
305
306 #[test]
307 fn test_pipeline_integration() {
308 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
309 let y = array![3.0, 5.0, 7.0, 9.0];
310
311 let model = LinearRegression::<f64>::new();
312 let fitted = model.fit_pipeline(&x, &y).unwrap();
313 let preds = fitted.predict_pipeline(&x).unwrap();
314 assert_eq!(preds.len(), 4);
315 }
316
317 #[test]
318 fn test_f32_support() {
319 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
320 let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
321
322 let model = LinearRegression::<f32>::new();
323 let fitted = model.fit(&x, &y).unwrap();
324 let preds = fitted.predict(&x).unwrap();
325 assert_eq!(preds.len(), 4);
326 }
327}