1use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::HasCoefficients;
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2, Axis, ScalarOperand};
33use num_traits::{Float, FromPrimitive};
34
35use crate::linalg;
36
37#[derive(Debug, Clone)]
46pub struct Ridge<F> {
47 pub alpha: F,
50 pub fit_intercept: bool,
52}
53
54impl<F: Float> Ridge<F> {
55 #[must_use]
59 pub fn new() -> Self {
60 Self {
61 alpha: F::one(),
62 fit_intercept: true,
63 }
64 }
65
66 #[must_use]
68 pub fn with_alpha(mut self, alpha: F) -> Self {
69 self.alpha = alpha;
70 self
71 }
72
73 #[must_use]
75 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
76 self.fit_intercept = fit_intercept;
77 self
78 }
79}
80
81impl<F: Float> Default for Ridge<F> {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[derive(Debug, Clone)]
92pub struct FittedRidge<F> {
93 coefficients: Array1<F>,
95 intercept: F,
97}
98
99impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
100 for Ridge<F>
101{
102 type Fitted = FittedRidge<F>;
103 type Error = FerroError;
104
105 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedRidge<F>, FerroError> {
115 let (n_samples, _n_features) = x.dim();
116
117 if n_samples != y.len() {
118 return Err(FerroError::ShapeMismatch {
119 expected: vec![n_samples],
120 actual: vec![y.len()],
121 context: "y length must match number of samples in X".into(),
122 });
123 }
124
125 if self.alpha < F::zero() {
126 return Err(FerroError::InvalidParameter {
127 name: "alpha".into(),
128 reason: "must be non-negative".into(),
129 });
130 }
131
132 if n_samples == 0 {
133 return Err(FerroError::InsufficientSamples {
134 required: 1,
135 actual: 0,
136 context: "Ridge requires at least one sample".into(),
137 });
138 }
139
140 if self.fit_intercept {
141 let x_mean = x
143 .mean_axis(Axis(0))
144 .ok_or_else(|| FerroError::NumericalInstability {
145 message: "failed to compute column means".into(),
146 })?;
147 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
148 message: "failed to compute target mean".into(),
149 })?;
150
151 let x_centered = x - &x_mean;
152 let y_centered = y - y_mean;
153
154 let w = linalg::solve_ridge(&x_centered, &y_centered, self.alpha)?;
155 let intercept = y_mean - x_mean.dot(&w);
156
157 Ok(FittedRidge {
158 coefficients: w,
159 intercept,
160 })
161 } else {
162 let w = linalg::solve_ridge(x, y, self.alpha)?;
163
164 Ok(FittedRidge {
165 coefficients: w,
166 intercept: F::zero(),
167 })
168 }
169 }
170}
171
172impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedRidge<F> {
173 type Output = Array1<F>;
174 type Error = FerroError;
175
176 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
185 let n_features = x.ncols();
186 if n_features != self.coefficients.len() {
187 return Err(FerroError::ShapeMismatch {
188 expected: vec![self.coefficients.len()],
189 actual: vec![n_features],
190 context: "number of features must match fitted model".into(),
191 });
192 }
193
194 let preds = x.dot(&self.coefficients) + self.intercept;
195 Ok(preds)
196 }
197}
198
199impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedRidge<F> {
200 fn coefficients(&self) -> &Array1<F> {
201 &self.coefficients
202 }
203
204 fn intercept(&self) -> F {
205 self.intercept
206 }
207}
208
209impl<F> PipelineEstimator<F> for Ridge<F>
211where
212 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
213{
214 fn fit_pipeline(
215 &self,
216 x: &Array2<F>,
217 y: &Array1<F>,
218 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
219 let fitted = self.fit(x, y)?;
220 Ok(Box::new(fitted))
221 }
222}
223
224impl<F> FittedPipelineEstimator<F> for FittedRidge<F>
225where
226 F: Float + ScalarOperand + Send + Sync + 'static,
227{
228 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
229 self.predict(x)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use approx::assert_relative_eq;
237 use ndarray::array;
238
239 #[test]
240 fn test_ridge_no_regularization() {
241 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
243 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
244
245 let model = Ridge::<f64>::new().with_alpha(0.0);
246 let fitted = model.fit(&x, &y).unwrap();
247
248 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-8);
249 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-8);
250 }
251
252 #[test]
253 fn test_ridge_shrinks_coefficients() {
254 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
255 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
256
257 let model_low = Ridge::<f64>::new().with_alpha(0.01);
258 let model_high = Ridge::<f64>::new().with_alpha(100.0);
259
260 let fitted_low = model_low.fit(&x, &y).unwrap();
261 let fitted_high = model_high.fit(&x, &y).unwrap();
262
263 assert!(fitted_high.coefficients()[0].abs() < fitted_low.coefficients()[0].abs());
265 }
266
267 #[test]
268 fn test_ridge_no_intercept() {
269 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
270 let y = array![2.0, 4.0, 6.0, 8.0];
271
272 let model = Ridge::<f64>::new()
273 .with_alpha(0.0)
274 .with_fit_intercept(false);
275 let fitted = model.fit(&x, &y).unwrap();
276
277 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-10);
278 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
279 }
280
281 #[test]
282 fn test_ridge_negative_alpha() {
283 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
284 let y = array![1.0, 2.0, 3.0];
285
286 let model = Ridge::<f64>::new().with_alpha(-1.0);
287 let result = model.fit(&x, &y);
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_ridge_shape_mismatch() {
293 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
294 let y = array![1.0, 2.0];
295
296 let model = Ridge::<f64>::new();
297 let result = model.fit(&x, &y);
298 assert!(result.is_err());
299 }
300
301 #[test]
302 fn test_ridge_predict() {
303 let x =
304 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
305 let y = array![1.0, 2.0, 3.0, 6.0];
306
307 let model = Ridge::<f64>::new().with_alpha(0.01);
308 let fitted = model.fit(&x, &y).unwrap();
309
310 let preds = fitted.predict(&x).unwrap();
311 assert_eq!(preds.len(), 4);
312 }
313
314 #[test]
315 fn test_ridge_pipeline_integration() {
316 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
317 let y = array![3.0, 5.0, 7.0, 9.0];
318
319 let model = Ridge::<f64>::new();
320 let fitted = model.fit_pipeline(&x, &y).unwrap();
321 let preds = fitted.predict_pipeline(&x).unwrap();
322 assert_eq!(preds.len(), 4);
323 }
324
325 #[test]
326 fn test_ridge_has_coefficients() {
327 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
328 let y = array![1.0, 2.0, 3.0];
329
330 let model = Ridge::<f64>::new();
331 let fitted = model.fit(&x, &y).unwrap();
332
333 assert_eq!(fitted.coefficients().len(), 2);
334 }
335}