Skip to main content

ferrolearn_linear/
linear_svc.rs

1//! Linear Support Vector Classifier.
2//!
3//! This module provides [`LinearSVC`], an optimized linear SVM that operates
4//! directly in the primal space without the overhead of a kernel function.
5//! It uses coordinate descent on the L2-regularized hinge or squared-hinge
6//! loss.
7//!
8//! Unlike [`SVC`](crate::svm::SVC) with a [`LinearKernel`](crate::svm::LinearKernel),
9//! `LinearSVC` avoids computing and caching the full kernel matrix, making it
10//! significantly faster for high-dimensional data.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_linear::linear_svc::LinearSVC;
16//! use ferrolearn_core::{Fit, Predict};
17//! use ndarray::{array, Array2};
18//!
19//! let x = Array2::from_shape_vec((6, 2), vec![
20//!     1.0, 1.0, 1.0, 2.0, 2.0, 1.0,
21//!     5.0, 5.0, 5.0, 6.0, 6.0, 5.0,
22//! ]).unwrap();
23//! let y = array![0usize, 0, 0, 1, 1, 1];
24//!
25//! let model = LinearSVC::<f64>::new();
26//! let fitted = model.fit(&x, &y).unwrap();
27//! let preds = fitted.predict(&x).unwrap();
28//! assert_eq!(preds.len(), 6);
29//! ```
30
31use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::{HasClasses, HasCoefficients};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2, ScalarOperand};
35use num_traits::Float;
36
37/// Loss function for [`LinearSVC`].
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum LinearSVCLoss {
40    /// Standard hinge loss: `max(0, 1 - y * f(x))`.
41    Hinge,
42    /// Squared hinge loss: `max(0, 1 - y * f(x))^2`.
43    SquaredHinge,
44}
45
46/// Linear Support Vector Classifier (primal formulation).
47///
48/// Solves the L2-regularized hinge or squared-hinge loss via coordinate
49/// descent in the primal. Supports binary and multiclass (one-vs-rest)
50/// classification.
51///
52/// # Type Parameters
53///
54/// - `F`: The floating-point type (`f32` or `f64`).
55#[derive(Debug, Clone)]
56pub struct LinearSVC<F> {
57    /// Inverse regularization strength. Larger values allow more
58    /// misclassification.
59    pub c: F,
60    /// Maximum number of coordinate descent iterations.
61    pub max_iter: usize,
62    /// Convergence tolerance on the change in weight vector.
63    pub tol: F,
64    /// Loss function to use.
65    pub loss: LinearSVCLoss,
66}
67
68impl<F: Float> LinearSVC<F> {
69    /// Create a new `LinearSVC` with default settings.
70    ///
71    /// Defaults: `C = 1.0`, `max_iter = 1000`, `tol = 1e-4`,
72    /// `loss = SquaredHinge`.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            c: F::one(),
77            max_iter: 1000,
78            tol: F::from(1e-4).unwrap(),
79            loss: LinearSVCLoss::SquaredHinge,
80        }
81    }
82
83    /// Set the regularization parameter C.
84    #[must_use]
85    pub fn with_c(mut self, c: F) -> Self {
86        self.c = c;
87        self
88    }
89
90    /// Set the maximum number of iterations.
91    #[must_use]
92    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
93        self.max_iter = max_iter;
94        self
95    }
96
97    /// Set the convergence tolerance.
98    #[must_use]
99    pub fn with_tol(mut self, tol: F) -> Self {
100        self.tol = tol;
101        self
102    }
103
104    /// Set the loss function.
105    #[must_use]
106    pub fn with_loss(mut self, loss: LinearSVCLoss) -> Self {
107        self.loss = loss;
108        self
109    }
110}
111
112impl<F: Float> Default for LinearSVC<F> {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118/// Fitted Linear Support Vector Classifier.
119///
120/// Stores the learned weight vectors, intercepts, and class labels.
121/// For binary classification a single weight vector is stored; for
122/// multiclass, one per class (one-vs-rest).
123#[derive(Debug, Clone)]
124pub struct FittedLinearSVC<F> {
125    /// Weight vectors: one per binary sub-problem.
126    /// Binary: `[w]`, Multiclass: `[w_0, w_1, ..., w_{k-1}]`.
127    weight_vectors: Vec<Array1<F>>,
128    /// Intercept for each sub-problem.
129    intercepts: Vec<F>,
130    /// Sorted unique class labels.
131    classes: Vec<usize>,
132    /// Whether this is a binary problem.
133    is_binary: bool,
134    /// Number of features.
135    n_features: usize,
136}
137
138impl<F: Float> FittedLinearSVC<F> {
139    /// Returns the weight vectors (one per binary sub-problem).
140    #[must_use]
141    pub fn weight_vectors(&self) -> &[Array1<F>] {
142        &self.weight_vectors
143    }
144
145    /// Returns the intercepts (one per binary sub-problem).
146    #[must_use]
147    pub fn intercepts(&self) -> &[F] {
148        &self.intercepts
149    }
150}
151
152impl<F: Float + ScalarOperand + Send + Sync + 'static> FittedLinearSVC<F> {
153    /// Raw signed distance from the decision boundary. Mirrors sklearn
154    /// `LinearSVC.decision_function`.
155    ///
156    /// Binary: shape `(n_samples, 1)` containing `X @ w + b`.
157    /// Multiclass: shape `(n_samples, n_classes)` of one-vs-rest scores.
158    /// argmax of each row agrees with [`Predict`].
159    ///
160    /// # Errors
161    ///
162    /// Returns [`FerroError::ShapeMismatch`] if the number of features
163    /// does not match the fitted model.
164    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
165        let n_features = x.ncols();
166        if n_features != self.n_features {
167            return Err(FerroError::ShapeMismatch {
168                expected: vec![self.n_features],
169                actual: vec![n_features],
170                context: "number of features must match fitted model".into(),
171            });
172        }
173        let n_samples = x.nrows();
174        if self.is_binary {
175            let scores = x.dot(&self.weight_vectors[0]) + self.intercepts[0];
176            let mut out = Array2::<F>::zeros((n_samples, 1));
177            for i in 0..n_samples {
178                out[[i, 0]] = scores[i];
179            }
180            Ok(out)
181        } else {
182            let n_classes = self.classes.len();
183            let mut out = Array2::<F>::zeros((n_samples, n_classes));
184            for c in 0..n_classes {
185                for i in 0..n_samples {
186                    out[[i, c]] = x.row(i).dot(&self.weight_vectors[c]) + self.intercepts[c];
187                }
188            }
189            Ok(out)
190        }
191    }
192}
193
194/// Solve a single binary L2-SVM via coordinate descent on the primal.
195///
196/// Minimises `0.5 * ||w||^2 + C * sum_i loss(y_i, w^T x_i + b) / n` where
197/// `y_i ∈ {-1, +1}`. For squared-hinge loss, performs coordinate-wise
198/// Newton updates `w[j] -= f'(w[j]) / f''(w[j])`, which dramatically
199/// outperforms the previous fixed-step (LR=0.01) approach — the prior code
200/// was undertrained by ~30× on 100-D inputs because the Hessian diagonal at
201/// unit-variance features is `1 + 2C` (not 100).
202///
203/// For hinge loss (non-differentiable at the kink) we use a clipped Newton
204/// step with the squared-hinge Hessian as a smooth majorant.
205///
206/// We maintain `decision = X w + b` incrementally rather than recomputing
207/// it on every coordinate update; this is what makes the loop O(n_features
208/// × n_samples) per outer iteration instead of O(n_features^2 × n_samples).
209fn solve_binary_primal<F: Float + 'static>(
210    x: &Array2<F>,
211    y_signed: &Array1<F>,
212    c: F,
213    max_iter: usize,
214    tol: F,
215    loss: LinearSVCLoss,
216) -> (Array1<F>, F) {
217    let (n_samples, n_features) = x.dim();
218    let mut w = Array1::<F>::zeros(n_features);
219    let mut b = F::zero();
220
221    let n_f = F::from(n_samples).unwrap();
222    let two = F::from(2.0).unwrap();
223
224    // decision[i] = X[i, :] @ w + b — maintained incrementally.
225    let mut decision = Array1::<F>::zeros(n_samples);
226
227    for _iter in 0..max_iter {
228        let mut max_change = F::zero();
229
230        // Coordinate-Newton update for each w[j].
231        for j in 0..n_features {
232            // Gradient and Hessian-diagonal contributions.
233            let mut grad = w[j]; // regularizer gradient
234            let mut hess = F::one(); // regularizer hessian diagonal
235
236            for i in 0..n_samples {
237                let margin = y_signed[i] * decision[i];
238                if margin < F::one() {
239                    let xij = x[[i, j]];
240                    match loss {
241                        LinearSVCLoss::Hinge => {
242                            // Use squared-hinge Hessian as smooth majorant; the
243                            // hinge gradient is the subgradient -y_i x_{i,j}.
244                            grad = grad - c / n_f * y_signed[i] * xij;
245                            hess = hess + c / n_f * xij * xij;
246                        }
247                        LinearSVCLoss::SquaredHinge => {
248                            grad = grad - two * c / n_f
249                                * (F::one() - margin) * y_signed[i] * xij;
250                            hess = hess + two * c / n_f * xij * xij;
251                        }
252                    }
253                }
254            }
255
256            // Newton step: dw = -grad / hess. hess >= 1 since regularizer
257            // contributes 1, so it can never be zero.
258            let dw = -(grad / hess);
259            let new_w = w[j] + dw;
260            let change = dw.abs();
261            if change > max_change {
262                max_change = change;
263            }
264
265            // Apply update and refresh decision values: decision += dw * X[:, j].
266            w[j] = new_w;
267            for i in 0..n_samples {
268                decision[i] = decision[i] + dw * x[[i, j]];
269            }
270        }
271
272        // Coordinate-Newton update for the intercept (not regularized).
273        {
274            let mut grad_b = F::zero();
275            let mut hess_b = F::from(1e-12).unwrap(); // tiny ridge for stability
276            for i in 0..n_samples {
277                let margin = y_signed[i] * decision[i];
278                if margin < F::one() {
279                    match loss {
280                        LinearSVCLoss::Hinge => {
281                            grad_b = grad_b - c / n_f * y_signed[i];
282                            hess_b = hess_b + c / n_f;
283                        }
284                        LinearSVCLoss::SquaredHinge => {
285                            grad_b = grad_b - two * c / n_f
286                                * (F::one() - margin) * y_signed[i];
287                            hess_b = hess_b + two * c / n_f;
288                        }
289                    }
290                }
291            }
292            let db = -(grad_b / hess_b);
293            let change = db.abs();
294            if change > max_change {
295                max_change = change;
296            }
297            b = b + db;
298            for i in 0..n_samples {
299                decision[i] = decision[i] + db;
300            }
301        }
302
303        if max_change < tol {
304            break;
305        }
306    }
307
308    (w, b)
309}
310
311impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
312    for LinearSVC<F>
313{
314    type Fitted = FittedLinearSVC<F>;
315    type Error = FerroError;
316
317    /// Fit the linear SVC model using coordinate descent.
318    ///
319    /// # Errors
320    ///
321    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
322    /// - [`FerroError::InvalidParameter`] — `C` not positive.
323    /// - [`FerroError::InsufficientSamples`] — fewer than 2 distinct classes.
324    fn fit(
325        &self,
326        x: &Array2<F>,
327        y: &Array1<usize>,
328    ) -> Result<FittedLinearSVC<F>, FerroError> {
329        let (n_samples, n_features) = x.dim();
330
331        if n_samples != y.len() {
332            return Err(FerroError::ShapeMismatch {
333                expected: vec![n_samples],
334                actual: vec![y.len()],
335                context: "y length must match number of samples in X".into(),
336            });
337        }
338
339        if self.c <= F::zero() {
340            return Err(FerroError::InvalidParameter {
341                name: "C".into(),
342                reason: "must be positive".into(),
343            });
344        }
345
346        let mut classes: Vec<usize> = y.to_vec();
347        classes.sort_unstable();
348        classes.dedup();
349
350        if classes.len() < 2 {
351            return Err(FerroError::InsufficientSamples {
352                required: 2,
353                actual: classes.len(),
354                context: "LinearSVC requires at least 2 distinct classes".into(),
355            });
356        }
357
358        if classes.len() == 2 {
359            // Binary classification.
360            let y_signed: Array1<F> = y.mapv(|label| {
361                if label == classes[1] {
362                    F::one()
363                } else {
364                    -F::one()
365                }
366            });
367
368            let (w, b) = solve_binary_primal(x, &y_signed, self.c, self.max_iter, self.tol, self.loss);
369
370            Ok(FittedLinearSVC {
371                weight_vectors: vec![w],
372                intercepts: vec![b],
373                classes,
374                is_binary: true,
375                n_features,
376            })
377        } else {
378            // Multiclass: one-vs-rest.
379            let mut weight_vectors = Vec::with_capacity(classes.len());
380            let mut intercepts = Vec::with_capacity(classes.len());
381
382            for &cls in &classes {
383                let y_signed: Array1<F> = y.mapv(|label| {
384                    if label == cls {
385                        F::one()
386                    } else {
387                        -F::one()
388                    }
389                });
390
391                let (w, b) =
392                    solve_binary_primal(x, &y_signed, self.c, self.max_iter, self.tol, self.loss);
393                weight_vectors.push(w);
394                intercepts.push(b);
395            }
396
397            Ok(FittedLinearSVC {
398                weight_vectors,
399                intercepts,
400                classes,
401                is_binary: false,
402                n_features,
403            })
404        }
405    }
406}
407
408impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
409    for FittedLinearSVC<F>
410{
411    type Output = Array1<usize>;
412    type Error = FerroError;
413
414    /// Predict class labels for the given feature matrix.
415    ///
416    /// Binary: `sign(X @ w + b)` mapped to class labels.
417    /// Multiclass: argmax of decision values across one-vs-rest classifiers.
418    ///
419    /// # Errors
420    ///
421    /// Returns [`FerroError::ShapeMismatch`] if the number of features
422    /// does not match the fitted model.
423    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
424        let n_features = x.ncols();
425        if n_features != self.n_features {
426            return Err(FerroError::ShapeMismatch {
427                expected: vec![self.n_features],
428                actual: vec![n_features],
429                context: "number of features must match fitted model".into(),
430            });
431        }
432
433        let n_samples = x.nrows();
434        let mut predictions = Array1::<usize>::zeros(n_samples);
435
436        if self.is_binary {
437            let scores = x.dot(&self.weight_vectors[0]) + self.intercepts[0];
438            for i in 0..n_samples {
439                predictions[i] = if scores[i] >= F::zero() {
440                    self.classes[1]
441                } else {
442                    self.classes[0]
443                };
444            }
445        } else {
446            // Multiclass: pick class with highest decision value.
447            for i in 0..n_samples {
448                let mut best_class = 0;
449                let mut best_score = F::neg_infinity();
450                for (c, w) in self.weight_vectors.iter().enumerate() {
451                    let score = x.row(i).dot(w) + self.intercepts[c];
452                    if score > best_score {
453                        best_score = score;
454                        best_class = c;
455                    }
456                }
457                predictions[i] = self.classes[best_class];
458            }
459        }
460
461        Ok(predictions)
462    }
463}
464
465impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
466    for FittedLinearSVC<F>
467{
468    /// Returns the coefficient vector of the first (or only) binary sub-problem.
469    fn coefficients(&self) -> &Array1<F> {
470        &self.weight_vectors[0]
471    }
472
473    /// Returns the intercept of the first (or only) binary sub-problem.
474    fn intercept(&self) -> F {
475        self.intercepts[0]
476    }
477}
478
479impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedLinearSVC<F> {
480    fn classes(&self) -> &[usize] {
481        &self.classes
482    }
483
484    fn n_classes(&self) -> usize {
485        self.classes.len()
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use ndarray::array;
493
494    #[test]
495    fn test_default_constructor() {
496        let m = LinearSVC::<f64>::new();
497        assert_eq!(m.max_iter, 1000);
498        assert!(m.c == 1.0);
499        assert_eq!(m.loss, LinearSVCLoss::SquaredHinge);
500    }
501
502    #[test]
503    fn test_builder_setters() {
504        let m = LinearSVC::<f64>::new()
505            .with_c(10.0)
506            .with_max_iter(500)
507            .with_tol(1e-6)
508            .with_loss(LinearSVCLoss::Hinge);
509        assert!(m.c == 10.0);
510        assert_eq!(m.max_iter, 500);
511        assert_eq!(m.loss, LinearSVCLoss::Hinge);
512    }
513
514    #[test]
515    fn test_binary_classification() {
516        let x = Array2::from_shape_vec(
517            (8, 2),
518            vec![
519                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
520                8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
521            ],
522        )
523        .unwrap();
524        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
525
526        let model = LinearSVC::<f64>::new().with_c(1.0).with_max_iter(5000);
527        let fitted = model.fit(&x, &y).unwrap();
528        let preds = fitted.predict(&x).unwrap();
529
530        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
531        assert!(correct >= 6, "expected at least 6 correct, got {correct}");
532    }
533
534    #[test]
535    fn test_binary_hinge_loss() {
536        let x = Array2::from_shape_vec(
537            (6, 2),
538            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
539        )
540        .unwrap();
541        let y = array![0, 0, 0, 1, 1, 1];
542
543        let model = LinearSVC::<f64>::new()
544            .with_loss(LinearSVCLoss::Hinge)
545            .with_max_iter(5000);
546        let fitted = model.fit(&x, &y).unwrap();
547        let preds = fitted.predict(&x).unwrap();
548
549        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
550        assert!(correct >= 4, "expected at least 4 correct, got {correct}");
551    }
552
553    #[test]
554    fn test_multiclass_classification() {
555        let x = Array2::from_shape_vec(
556            (9, 2),
557            vec![
558                0.0, 0.0, 0.5, 0.0, 0.0, 0.5,
559                10.0, 0.0, 10.5, 0.0, 10.0, 0.5,
560                0.0, 10.0, 0.5, 10.0, 0.0, 10.5,
561            ],
562        )
563        .unwrap();
564        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
565
566        let model = LinearSVC::<f64>::new().with_c(10.0).with_max_iter(5000);
567        let fitted = model.fit(&x, &y).unwrap();
568
569        assert_eq!(fitted.n_classes(), 3);
570        assert_eq!(fitted.classes(), &[0, 1, 2]);
571
572        let preds = fitted.predict(&x).unwrap();
573        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
574        assert!(correct >= 7, "expected at least 7 correct, got {correct}");
575    }
576
577    #[test]
578    fn test_shape_mismatch_fit() {
579        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
580        let y = array![0, 1]; // Wrong length
581
582        let model = LinearSVC::<f64>::new();
583        assert!(model.fit(&x, &y).is_err());
584    }
585
586    #[test]
587    fn test_invalid_c() {
588        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
589        let y = array![0, 0, 1, 1];
590
591        let model = LinearSVC::<f64>::new().with_c(0.0);
592        assert!(model.fit(&x, &y).is_err());
593
594        let model_neg = LinearSVC::<f64>::new().with_c(-1.0);
595        assert!(model_neg.fit(&x, &y).is_err());
596    }
597
598    #[test]
599    fn test_single_class_error() {
600        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
601        let y = array![0, 0, 0];
602
603        let model = LinearSVC::<f64>::new();
604        assert!(model.fit(&x, &y).is_err());
605    }
606
607    #[test]
608    fn test_has_coefficients() {
609        let x = Array2::from_shape_vec(
610            (6, 2),
611            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
612        )
613        .unwrap();
614        let y = array![0, 0, 0, 1, 1, 1];
615
616        let model = LinearSVC::<f64>::new().with_max_iter(5000);
617        let fitted = model.fit(&x, &y).unwrap();
618        assert_eq!(fitted.coefficients().len(), 2);
619    }
620
621    #[test]
622    fn test_predict_feature_mismatch() {
623        let x = Array2::from_shape_vec(
624            (6, 2),
625            vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
626        )
627        .unwrap();
628        let y = array![0, 0, 0, 1, 1, 1];
629
630        let fitted = LinearSVC::<f64>::new().with_max_iter(5000).fit(&x, &y).unwrap();
631
632        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
633        assert!(fitted.predict(&x_bad).is_err());
634    }
635}