1use ferrolearn_core::error::FerroError;
35use ferrolearn_core::introspection::{HasClasses, HasCoefficients};
36use ferrolearn_core::traits::{Fit, Predict};
37use ndarray::{Array1, Array2, Axis, ScalarOperand};
38use num_traits::{Float, FromPrimitive};
39
40use crate::linalg;
41
42#[derive(Debug, Clone)]
51pub struct RidgeClassifier<F> {
52 pub alpha: F,
54 pub fit_intercept: bool,
56}
57
58impl<F: Float> RidgeClassifier<F> {
59 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 alpha: F::one(),
66 fit_intercept: true,
67 }
68 }
69
70 #[must_use]
72 pub fn with_alpha(mut self, alpha: F) -> Self {
73 self.alpha = alpha;
74 self
75 }
76
77 #[must_use]
79 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
80 self.fit_intercept = fit_intercept;
81 self
82 }
83}
84
85impl<F: Float> Default for RidgeClassifier<F> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91#[derive(Debug, Clone)]
95pub struct FittedRidgeClassifier<F> {
96 coef_matrix: Array2<F>,
99 intercept_vec: Array1<F>,
101 coefficients: Array1<F>,
103 intercept: F,
105 classes: Vec<usize>,
107 is_binary: bool,
109 n_features: usize,
111}
112
113impl<F: Float> FittedRidgeClassifier<F> {
114 #[must_use]
116 pub fn coef_matrix(&self) -> &Array2<F> {
117 &self.coef_matrix
118 }
119
120 #[must_use]
122 pub fn intercept_vec(&self) -> &Array1<F> {
123 &self.intercept_vec
124 }
125}
126
127impl<F: Float + ndarray::ScalarOperand + Send + Sync + 'static> FittedRidgeClassifier<F> {
128 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
139 let n_features = x.ncols();
140 if n_features != self.n_features {
141 return Err(FerroError::ShapeMismatch {
142 expected: vec![self.n_features],
143 actual: vec![n_features],
144 context: "number of features must match fitted model".into(),
145 });
146 }
147 Ok(x.dot(&self.coef_matrix) + &self.intercept_vec)
148 }
149}
150
151impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static>
152 Fit<Array2<F>, Array1<usize>> for RidgeClassifier<F>
153{
154 type Fitted = FittedRidgeClassifier<F>;
155 type Error = FerroError;
156
157 fn fit(
166 &self,
167 x: &Array2<F>,
168 y: &Array1<usize>,
169 ) -> Result<FittedRidgeClassifier<F>, FerroError> {
170 let (n_samples, n_features) = x.dim();
171
172 if n_samples != y.len() {
173 return Err(FerroError::ShapeMismatch {
174 expected: vec![n_samples],
175 actual: vec![y.len()],
176 context: "y length must match number of samples in X".into(),
177 });
178 }
179
180 if self.alpha < F::zero() {
181 return Err(FerroError::InvalidParameter {
182 name: "alpha".into(),
183 reason: "must be non-negative".into(),
184 });
185 }
186
187 let mut classes: Vec<usize> = y.to_vec();
188 classes.sort_unstable();
189 classes.dedup();
190
191 if classes.len() < 2 {
192 return Err(FerroError::InsufficientSamples {
193 required: 2,
194 actual: classes.len(),
195 context: "RidgeClassifier requires at least 2 distinct classes".into(),
196 });
197 }
198
199 if n_samples == 0 {
200 return Err(FerroError::InsufficientSamples {
201 required: 1,
202 actual: 0,
203 context: "RidgeClassifier requires at least one sample".into(),
204 });
205 }
206
207 let is_binary = classes.len() == 2;
208
209 let n_targets = if is_binary { 1 } else { classes.len() };
211 let mut y_indicator = Array2::<F>::zeros((n_samples, n_targets));
212
213 if is_binary {
214 for i in 0..n_samples {
216 y_indicator[[i, 0]] = if y[i] == classes[1] {
217 F::one()
218 } else {
219 -F::one()
220 };
221 }
222 } else {
223 for i in 0..n_samples {
225 let ci = classes.iter().position(|&c| c == y[i]).unwrap();
226 y_indicator[[i, ci]] = F::one();
227 }
228 }
229
230 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
232 let x_mean = x
233 .mean_axis(Axis(0))
234 .ok_or_else(|| FerroError::NumericalInstability {
235 message: "failed to compute column means".into(),
236 })?;
237 let y_mean = y_indicator
238 .mean_axis(Axis(0))
239 .ok_or_else(|| FerroError::NumericalInstability {
240 message: "failed to compute target means".into(),
241 })?;
242 let x_c = x - &x_mean;
243 let y_c = &y_indicator - &y_mean;
244 (x_c, y_c, Some(x_mean), Some(y_mean))
245 } else {
246 (x.clone(), y_indicator.clone(), None, None)
247 };
248
249 let mut coef_matrix = Array2::<F>::zeros((n_features, n_targets));
251 for t in 0..n_targets {
252 let y_col = y_work.column(t).to_owned();
253 let w = linalg::solve_ridge(&x_work, &y_col, self.alpha)?;
254 for j in 0..n_features {
255 coef_matrix[[j, t]] = w[j];
256 }
257 }
258
259 let intercept_vec = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
261 let xm_dot = xm.dot(&coef_matrix);
262 ym - &xm_dot
263 } else {
264 Array1::<F>::zeros(n_targets)
265 };
266
267 let coefficients = coef_matrix.column(0).to_owned();
268 let intercept = intercept_vec[0];
269
270 Ok(FittedRidgeClassifier {
271 coef_matrix,
272 intercept_vec,
273 coefficients,
274 intercept,
275 classes,
276 is_binary,
277 n_features,
278 })
279 }
280}
281
282impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
283 for FittedRidgeClassifier<F>
284{
285 type Output = Array1<usize>;
286 type Error = FerroError;
287
288 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
297 let n_features = x.ncols();
298 if n_features != self.n_features {
299 return Err(FerroError::ShapeMismatch {
300 expected: vec![self.n_features],
301 actual: vec![n_features],
302 context: "number of features must match fitted model".into(),
303 });
304 }
305
306 let n_samples = x.nrows();
307 let mut predictions = Array1::<usize>::zeros(n_samples);
308
309 let scores = x.dot(&self.coef_matrix) + &self.intercept_vec;
311
312 if self.is_binary {
313 for i in 0..n_samples {
314 predictions[i] = if scores[[i, 0]] >= F::zero() {
315 self.classes[1]
316 } else {
317 self.classes[0]
318 };
319 }
320 } else {
321 for i in 0..n_samples {
322 let mut best_class = 0;
323 let mut best_score = scores[[i, 0]];
324 for c in 1..self.classes.len() {
325 if scores[[i, c]] > best_score {
326 best_score = scores[[i, c]];
327 best_class = c;
328 }
329 }
330 predictions[i] = self.classes[best_class];
331 }
332 }
333
334 Ok(predictions)
335 }
336}
337
338impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
339 for FittedRidgeClassifier<F>
340{
341 fn coefficients(&self) -> &Array1<F> {
342 &self.coefficients
343 }
344
345 fn intercept(&self) -> F {
346 self.intercept
347 }
348}
349
350impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedRidgeClassifier<F> {
351 fn classes(&self) -> &[usize] {
352 &self.classes
353 }
354
355 fn n_classes(&self) -> usize {
356 self.classes.len()
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use ndarray::array;
364
365 #[test]
366 fn test_default_constructor() {
367 let m = RidgeClassifier::<f64>::new();
368 assert!(m.alpha == 1.0);
369 assert!(m.fit_intercept);
370 }
371
372 #[test]
373 fn test_builder() {
374 let m = RidgeClassifier::<f64>::new()
375 .with_alpha(0.5)
376 .with_fit_intercept(false);
377 assert!(m.alpha == 0.5);
378 assert!(!m.fit_intercept);
379 }
380
381 #[test]
382 fn test_binary_classification() {
383 let x = Array2::from_shape_vec(
384 (8, 2),
385 vec![
386 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
387 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
388 ],
389 )
390 .unwrap();
391 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
392
393 let model = RidgeClassifier::<f64>::new();
394 let fitted = model.fit(&x, &y).unwrap();
395 let preds = fitted.predict(&x).unwrap();
396
397 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
398 assert!(correct >= 6, "expected at least 6 correct, got {correct}");
399 }
400
401 #[test]
402 fn test_multiclass_classification() {
403 let x = Array2::from_shape_vec(
404 (9, 2),
405 vec![
406 0.0, 0.0, 0.5, 0.0, 0.0, 0.5,
407 10.0, 0.0, 10.5, 0.0, 10.0, 0.5,
408 0.0, 10.0, 0.5, 10.0, 0.0, 10.5,
409 ],
410 )
411 .unwrap();
412 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
413
414 let model = RidgeClassifier::<f64>::new().with_alpha(0.1);
415 let fitted = model.fit(&x, &y).unwrap();
416
417 assert_eq!(fitted.n_classes(), 3);
418 assert_eq!(fitted.classes(), &[0, 1, 2]);
419
420 let preds = fitted.predict(&x).unwrap();
421 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
422 assert!(correct >= 7, "expected at least 7 correct, got {correct}");
423 }
424
425 #[test]
426 fn test_shape_mismatch() {
427 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
428 let y = array![0, 1]; let model = RidgeClassifier::<f64>::new();
431 assert!(model.fit(&x, &y).is_err());
432 }
433
434 #[test]
435 fn test_negative_alpha() {
436 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
437 let y = array![0, 0, 1, 1];
438
439 let model = RidgeClassifier::<f64>::new().with_alpha(-1.0);
440 assert!(model.fit(&x, &y).is_err());
441 }
442
443 #[test]
444 fn test_single_class_error() {
445 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
446 let y = array![0, 0, 0];
447
448 let model = RidgeClassifier::<f64>::new();
449 assert!(model.fit(&x, &y).is_err());
450 }
451
452 #[test]
453 fn test_has_coefficients() {
454 let x = Array2::from_shape_vec(
455 (6, 2),
456 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
457 )
458 .unwrap();
459 let y = array![0, 0, 0, 1, 1, 1];
460
461 let fitted = RidgeClassifier::<f64>::new().fit(&x, &y).unwrap();
462 assert_eq!(fitted.coefficients().len(), 2);
463 }
464
465 #[test]
466 fn test_has_classes() {
467 let x = Array2::from_shape_vec(
468 (6, 2),
469 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
470 )
471 .unwrap();
472 let y = array![0, 0, 0, 1, 1, 1];
473
474 let fitted = RidgeClassifier::<f64>::new().fit(&x, &y).unwrap();
475 assert_eq!(fitted.classes(), &[0, 1]);
476 assert_eq!(fitted.n_classes(), 2);
477 }
478
479 #[test]
480 fn test_predict_feature_mismatch() {
481 let x = Array2::from_shape_vec(
482 (6, 2),
483 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
484 )
485 .unwrap();
486 let y = array![0, 0, 0, 1, 1, 1];
487
488 let fitted = RidgeClassifier::<f64>::new().fit(&x, &y).unwrap();
489
490 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
491 assert!(fitted.predict(&x_bad).is_err());
492 }
493
494 #[test]
495 fn test_alpha_zero() {
496 let x = Array2::from_shape_vec(
497 (6, 2),
498 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
499 )
500 .unwrap();
501 let y = array![0, 0, 0, 1, 1, 1];
502
503 let model = RidgeClassifier::<f64>::new().with_alpha(0.0);
504 let fitted = model.fit(&x, &y).unwrap();
505 let preds = fitted.predict(&x).unwrap();
506 assert_eq!(preds.len(), 6);
507 }
508}