Skip to main content

ferrolearn_linear/
ridge_classifier.rs

1//! Ridge Classifier.
2//!
3//! This module provides [`RidgeClassifier`], which applies Ridge regression
4//! to classification tasks by converting class labels into a binary indicator
5//! matrix and fitting a multivariate Ridge regression.
6//!
7//! For binary classification, the indicator matrix has a single column
8//! (`{-1, +1}`). For multiclass, it has one column per class (one-hot
9//! encoding). The predicted class is the one with the highest decision
10//! value (`argmax(X @ coef + intercept)`).
11//!
12//! This approach is significantly faster than logistic regression for
13//! large datasets while often achieving competitive accuracy.
14//!
15//! # Examples
16//!
17//! ```
18//! use ferrolearn_linear::ridge_classifier::RidgeClassifier;
19//! use ferrolearn_core::{Fit, Predict};
20//! use ndarray::{array, Array2};
21//!
22//! let x = Array2::from_shape_vec((6, 2), vec![
23//!     1.0, 1.0, 1.0, 2.0, 2.0, 1.0,
24//!     5.0, 5.0, 5.0, 6.0, 6.0, 5.0,
25//! ]).unwrap();
26//! let y = array![0usize, 0, 0, 1, 1, 1];
27//!
28//! let model = RidgeClassifier::<f64>::new();
29//! let fitted = model.fit(&x, &y).unwrap();
30//! let preds = fitted.predict(&x).unwrap();
31//! assert_eq!(preds.len(), 6);
32//! ```
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::introspection::{HasClasses, HasCoefficients};
36use ferrolearn_core::traits::{Fit, Predict};
37use ndarray::{Array1, Array2, Axis, ScalarOperand};
38use num_traits::{Float, FromPrimitive};
39
40use crate::linalg;
41
42/// Ridge Classifier.
43///
44/// Applies Ridge regression (L2-regularized least squares) to classification
45/// by converting labels to a binary indicator matrix.
46///
47/// # Type Parameters
48///
49/// - `F`: The floating-point type (`f32` or `f64`).
50#[derive(Debug, Clone)]
51pub struct RidgeClassifier<F> {
52    /// Regularization strength. Larger values specify stronger regularization.
53    pub alpha: F,
54    /// Whether to fit an intercept (bias) term.
55    pub fit_intercept: bool,
56}
57
58impl<F: Float> RidgeClassifier<F> {
59    /// Create a new `RidgeClassifier` with default settings.
60    ///
61    /// Defaults: `alpha = 1.0`, `fit_intercept = true`.
62    #[must_use]
63    pub fn new() -> Self {
64        Self {
65            alpha: F::one(),
66            fit_intercept: true,
67        }
68    }
69
70    /// Set the regularization strength.
71    #[must_use]
72    pub fn with_alpha(mut self, alpha: F) -> Self {
73        self.alpha = alpha;
74        self
75    }
76
77    /// Set whether to fit an intercept term.
78    #[must_use]
79    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
80        self.fit_intercept = fit_intercept;
81        self
82    }
83}
84
85impl<F: Float> Default for RidgeClassifier<F> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91/// Fitted Ridge Classifier.
92///
93/// Stores the learned coefficient matrix, intercept vector, and class labels.
94#[derive(Debug, Clone)]
95pub struct FittedRidgeClassifier<F> {
96    /// Coefficient matrix, shape `(n_features, n_targets)`.
97    /// For binary, `n_targets = 1`.
98    coef_matrix: Array2<F>,
99    /// Intercept vector, one per target.
100    intercept_vec: Array1<F>,
101    /// For HasCoefficients: first column of coef_matrix.
102    coefficients: Array1<F>,
103    /// For HasCoefficients: first element of intercept_vec.
104    intercept: F,
105    /// Sorted unique class labels.
106    classes: Vec<usize>,
107    /// Whether this is a binary problem.
108    is_binary: bool,
109    /// Number of features.
110    n_features: usize,
111}
112
113impl<F: Float> FittedRidgeClassifier<F> {
114    /// Returns the full coefficient matrix, shape `(n_features, n_targets)`.
115    #[must_use]
116    pub fn coef_matrix(&self) -> &Array2<F> {
117        &self.coef_matrix
118    }
119
120    /// Returns the intercept vector.
121    #[must_use]
122    pub fn intercept_vec(&self) -> &Array1<F> {
123        &self.intercept_vec
124    }
125}
126
127impl<F: Float + ndarray::ScalarOperand + Send + Sync + 'static> FittedRidgeClassifier<F> {
128    /// Raw `X @ coef + intercept` per class. Mirrors sklearn
129    /// `RidgeClassifier.decision_function`.
130    ///
131    /// Returns shape `(n_samples, n_classes)`. argmax of each row agrees
132    /// with [`Predict`].
133    ///
134    /// # Errors
135    ///
136    /// Returns [`FerroError::ShapeMismatch`] if the number of features
137    /// does not match the fitted model.
138    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
139        let n_features = x.ncols();
140        if n_features != self.n_features {
141            return Err(FerroError::ShapeMismatch {
142                expected: vec![self.n_features],
143                actual: vec![n_features],
144                context: "number of features must match fitted model".into(),
145            });
146        }
147        Ok(x.dot(&self.coef_matrix) + &self.intercept_vec)
148    }
149}
150
151impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static>
152    Fit<Array2<F>, Array1<usize>> for RidgeClassifier<F>
153{
154    type Fitted = FittedRidgeClassifier<F>;
155    type Error = FerroError;
156
157    /// Fit the Ridge Classifier by converting labels to a binary indicator
158    /// matrix and solving multivariate Ridge regression.
159    ///
160    /// # Errors
161    ///
162    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
163    /// - [`FerroError::InvalidParameter`] — negative alpha.
164    /// - [`FerroError::InsufficientSamples`] — fewer than 2 classes.
165    fn fit(
166        &self,
167        x: &Array2<F>,
168        y: &Array1<usize>,
169    ) -> Result<FittedRidgeClassifier<F>, FerroError> {
170        let (n_samples, n_features) = x.dim();
171
172        if n_samples != y.len() {
173            return Err(FerroError::ShapeMismatch {
174                expected: vec![n_samples],
175                actual: vec![y.len()],
176                context: "y length must match number of samples in X".into(),
177            });
178        }
179
180        if self.alpha < F::zero() {
181            return Err(FerroError::InvalidParameter {
182                name: "alpha".into(),
183                reason: "must be non-negative".into(),
184            });
185        }
186
187        let mut classes: Vec<usize> = y.to_vec();
188        classes.sort_unstable();
189        classes.dedup();
190
191        if classes.len() < 2 {
192            return Err(FerroError::InsufficientSamples {
193                required: 2,
194                actual: classes.len(),
195                context: "RidgeClassifier requires at least 2 distinct classes".into(),
196            });
197        }
198
199        if n_samples == 0 {
200            return Err(FerroError::InsufficientSamples {
201                required: 1,
202                actual: 0,
203                context: "RidgeClassifier requires at least one sample".into(),
204            });
205        }
206
207        let is_binary = classes.len() == 2;
208
209        // Build indicator matrix Y.
210        let n_targets = if is_binary { 1 } else { classes.len() };
211        let mut y_indicator = Array2::<F>::zeros((n_samples, n_targets));
212
213        if is_binary {
214            // Binary: encode as {-1, +1}.
215            for i in 0..n_samples {
216                y_indicator[[i, 0]] = if y[i] == classes[1] {
217                    F::one()
218                } else {
219                    -F::one()
220                };
221            }
222        } else {
223            // Multiclass: one-hot.
224            for i in 0..n_samples {
225                let ci = classes.iter().position(|&c| c == y[i]).unwrap();
226                y_indicator[[i, ci]] = F::one();
227            }
228        }
229
230        // Center data if fit_intercept.
231        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
232            let x_mean = x
233                .mean_axis(Axis(0))
234                .ok_or_else(|| FerroError::NumericalInstability {
235                    message: "failed to compute column means".into(),
236                })?;
237            let y_mean = y_indicator
238                .mean_axis(Axis(0))
239                .ok_or_else(|| FerroError::NumericalInstability {
240                    message: "failed to compute target means".into(),
241                })?;
242            let x_c = x - &x_mean;
243            let y_c = &y_indicator - &y_mean;
244            (x_c, y_c, Some(x_mean), Some(y_mean))
245        } else {
246            (x.clone(), y_indicator.clone(), None, None)
247        };
248
249        // Solve Ridge for each target column.
250        let mut coef_matrix = Array2::<F>::zeros((n_features, n_targets));
251        for t in 0..n_targets {
252            let y_col = y_work.column(t).to_owned();
253            let w = linalg::solve_ridge(&x_work, &y_col, self.alpha)?;
254            for j in 0..n_features {
255                coef_matrix[[j, t]] = w[j];
256            }
257        }
258
259        // Compute intercepts.
260        let intercept_vec = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
261            let xm_dot = xm.dot(&coef_matrix);
262            ym - &xm_dot
263        } else {
264            Array1::<F>::zeros(n_targets)
265        };
266
267        let coefficients = coef_matrix.column(0).to_owned();
268        let intercept = intercept_vec[0];
269
270        Ok(FittedRidgeClassifier {
271            coef_matrix,
272            intercept_vec,
273            coefficients,
274            intercept,
275            classes,
276            is_binary,
277            n_features,
278        })
279    }
280}
281
282impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
283    for FittedRidgeClassifier<F>
284{
285    type Output = Array1<usize>;
286    type Error = FerroError;
287
288    /// Predict class labels for the given feature matrix.
289    ///
290    /// Computes `X @ coef_matrix + intercept_vec` and takes `argmax` per row.
291    ///
292    /// # Errors
293    ///
294    /// Returns [`FerroError::ShapeMismatch`] if the number of features
295    /// does not match the fitted model.
296    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
297        let n_features = x.ncols();
298        if n_features != self.n_features {
299            return Err(FerroError::ShapeMismatch {
300                expected: vec![self.n_features],
301                actual: vec![n_features],
302                context: "number of features must match fitted model".into(),
303            });
304        }
305
306        let n_samples = x.nrows();
307        let mut predictions = Array1::<usize>::zeros(n_samples);
308
309        // Compute decision values: X @ coef_matrix + intercept_vec.
310        let scores = x.dot(&self.coef_matrix) + &self.intercept_vec;
311
312        if self.is_binary {
313            for i in 0..n_samples {
314                predictions[i] = if scores[[i, 0]] >= F::zero() {
315                    self.classes[1]
316                } else {
317                    self.classes[0]
318                };
319            }
320        } else {
321            for i in 0..n_samples {
322                let mut best_class = 0;
323                let mut best_score = scores[[i, 0]];
324                for c in 1..self.classes.len() {
325                    if scores[[i, c]] > best_score {
326                        best_score = scores[[i, c]];
327                        best_class = c;
328                    }
329                }
330                predictions[i] = self.classes[best_class];
331            }
332        }
333
334        Ok(predictions)
335    }
336}
337
338impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
339    for FittedRidgeClassifier<F>
340{
341    fn coefficients(&self) -> &Array1<F> {
342        &self.coefficients
343    }
344
345    fn intercept(&self) -> F {
346        self.intercept
347    }
348}
349
350impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedRidgeClassifier<F> {
351    fn classes(&self) -> &[usize] {
352        &self.classes
353    }
354
355    fn n_classes(&self) -> usize {
356        self.classes.len()
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use ndarray::array;
364
365    #[test]
366    fn test_default_constructor() {
367        let m = RidgeClassifier::<f64>::new();
368        assert!(m.alpha == 1.0);
369        assert!(m.fit_intercept);
370    }
371
372    #[test]
373    fn test_builder() {
374        let m = RidgeClassifier::<f64>::new()
375            .with_alpha(0.5)
376            .with_fit_intercept(false);
377        assert!(m.alpha == 0.5);
378        assert!(!m.fit_intercept);
379    }
380
381    #[test]
382    fn test_binary_classification() {
383        let x = Array2::from_shape_vec(
384            (8, 2),
385            vec![
386                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
387                8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
388            ],
389        )
390        .unwrap();
391        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
392
393        let model = RidgeClassifier::<f64>::new();
394        let fitted = model.fit(&x, &y).unwrap();
395        let preds = fitted.predict(&x).unwrap();
396
397        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
398        assert!(correct >= 6, "expected at least 6 correct, got {correct}");
399    }
400
401    #[test]
402    fn test_multiclass_classification() {
403        let x = Array2::from_shape_vec(
404            (9, 2),
405            vec![
406                0.0, 0.0, 0.5, 0.0, 0.0, 0.5,
407                10.0, 0.0, 10.5, 0.0, 10.0, 0.5,
408                0.0, 10.0, 0.5, 10.0, 0.0, 10.5,
409            ],
410        )
411        .unwrap();
412        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
413
414        let model = RidgeClassifier::<f64>::new().with_alpha(0.1);
415        let fitted = model.fit(&x, &y).unwrap();
416
417        assert_eq!(fitted.n_classes(), 3);
418        assert_eq!(fitted.classes(), &[0, 1, 2]);
419
420        let preds = fitted.predict(&x).unwrap();
421        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
422        assert!(correct >= 7, "expected at least 7 correct, got {correct}");
423    }
424
425    #[test]
426    fn test_shape_mismatch() {
427        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
428        let y = array![0, 1]; // Wrong length
429
430        let model = RidgeClassifier::<f64>::new();
431        assert!(model.fit(&x, &y).is_err());
432    }
433
434    #[test]
435    fn test_negative_alpha() {
436        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
437        let y = array![0, 0, 1, 1];
438
439        let model = RidgeClassifier::<f64>::new().with_alpha(-1.0);
440        assert!(model.fit(&x, &y).is_err());
441    }
442
443    #[test]
444    fn test_single_class_error() {
445        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
446        let y = array![0, 0, 0];
447
448        let model = RidgeClassifier::<f64>::new();
449        assert!(model.fit(&x, &y).is_err());
450    }
451
452    #[test]
453    fn test_has_coefficients() {
454        let x = Array2::from_shape_vec(
455            (6, 2),
456            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
457        )
458        .unwrap();
459        let y = array![0, 0, 0, 1, 1, 1];
460
461        let fitted = RidgeClassifier::<f64>::new().fit(&x, &y).unwrap();
462        assert_eq!(fitted.coefficients().len(), 2);
463    }
464
465    #[test]
466    fn test_has_classes() {
467        let x = Array2::from_shape_vec(
468            (6, 2),
469            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
470        )
471        .unwrap();
472        let y = array![0, 0, 0, 1, 1, 1];
473
474        let fitted = RidgeClassifier::<f64>::new().fit(&x, &y).unwrap();
475        assert_eq!(fitted.classes(), &[0, 1]);
476        assert_eq!(fitted.n_classes(), 2);
477    }
478
479    #[test]
480    fn test_predict_feature_mismatch() {
481        let x = Array2::from_shape_vec(
482            (6, 2),
483            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
484        )
485        .unwrap();
486        let y = array![0, 0, 0, 1, 1, 1];
487
488        let fitted = RidgeClassifier::<f64>::new().fit(&x, &y).unwrap();
489
490        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
491        assert!(fitted.predict(&x_bad).is_err());
492    }
493
494    #[test]
495    fn test_alpha_zero() {
496        let x = Array2::from_shape_vec(
497            (6, 2),
498            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
499        )
500        .unwrap();
501        let y = array![0, 0, 0, 1, 1, 1];
502
503        let model = RidgeClassifier::<f64>::new().with_alpha(0.0);
504        let fitted = model.fit(&x, &y).unwrap();
505        let preds = fitted.predict(&x).unwrap();
506        assert_eq!(preds.len(), 6);
507    }
508}