1use ferrolearn_core::error::FerroError;
32use ferrolearn_core::introspection::{HasClasses, HasCoefficients};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2, ScalarOperand};
35use num_traits::Float;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum LinearSVCLoss {
40 Hinge,
42 SquaredHinge,
44}
45
46#[derive(Debug, Clone)]
56pub struct LinearSVC<F> {
57 pub c: F,
60 pub max_iter: usize,
62 pub tol: F,
64 pub loss: LinearSVCLoss,
66}
67
68impl<F: Float> LinearSVC<F> {
69 #[must_use]
74 pub fn new() -> Self {
75 Self {
76 c: F::one(),
77 max_iter: 1000,
78 tol: F::from(1e-4).unwrap(),
79 loss: LinearSVCLoss::SquaredHinge,
80 }
81 }
82
83 #[must_use]
85 pub fn with_c(mut self, c: F) -> Self {
86 self.c = c;
87 self
88 }
89
90 #[must_use]
92 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
93 self.max_iter = max_iter;
94 self
95 }
96
97 #[must_use]
99 pub fn with_tol(mut self, tol: F) -> Self {
100 self.tol = tol;
101 self
102 }
103
104 #[must_use]
106 pub fn with_loss(mut self, loss: LinearSVCLoss) -> Self {
107 self.loss = loss;
108 self
109 }
110}
111
112impl<F: Float> Default for LinearSVC<F> {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118#[derive(Debug, Clone)]
124pub struct FittedLinearSVC<F> {
125 weight_vectors: Vec<Array1<F>>,
128 intercepts: Vec<F>,
130 classes: Vec<usize>,
132 is_binary: bool,
134 n_features: usize,
136}
137
138impl<F: Float> FittedLinearSVC<F> {
139 #[must_use]
141 pub fn weight_vectors(&self) -> &[Array1<F>] {
142 &self.weight_vectors
143 }
144
145 #[must_use]
147 pub fn intercepts(&self) -> &[F] {
148 &self.intercepts
149 }
150}
151
152impl<F: Float + ScalarOperand + Send + Sync + 'static> FittedLinearSVC<F> {
153 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
165 let n_features = x.ncols();
166 if n_features != self.n_features {
167 return Err(FerroError::ShapeMismatch {
168 expected: vec![self.n_features],
169 actual: vec![n_features],
170 context: "number of features must match fitted model".into(),
171 });
172 }
173 let n_samples = x.nrows();
174 if self.is_binary {
175 let scores = x.dot(&self.weight_vectors[0]) + self.intercepts[0];
176 let mut out = Array2::<F>::zeros((n_samples, 1));
177 for i in 0..n_samples {
178 out[[i, 0]] = scores[i];
179 }
180 Ok(out)
181 } else {
182 let n_classes = self.classes.len();
183 let mut out = Array2::<F>::zeros((n_samples, n_classes));
184 for c in 0..n_classes {
185 for i in 0..n_samples {
186 out[[i, c]] = x.row(i).dot(&self.weight_vectors[c]) + self.intercepts[c];
187 }
188 }
189 Ok(out)
190 }
191 }
192}
193
194fn solve_binary_primal<F: Float + 'static>(
210 x: &Array2<F>,
211 y_signed: &Array1<F>,
212 c: F,
213 max_iter: usize,
214 tol: F,
215 loss: LinearSVCLoss,
216) -> (Array1<F>, F) {
217 let (n_samples, n_features) = x.dim();
218 let mut w = Array1::<F>::zeros(n_features);
219 let mut b = F::zero();
220
221 let n_f = F::from(n_samples).unwrap();
222 let two = F::from(2.0).unwrap();
223
224 let mut decision = Array1::<F>::zeros(n_samples);
226
227 for _iter in 0..max_iter {
228 let mut max_change = F::zero();
229
230 for j in 0..n_features {
232 let mut grad = w[j]; let mut hess = F::one(); for i in 0..n_samples {
237 let margin = y_signed[i] * decision[i];
238 if margin < F::one() {
239 let xij = x[[i, j]];
240 match loss {
241 LinearSVCLoss::Hinge => {
242 grad = grad - c / n_f * y_signed[i] * xij;
245 hess = hess + c / n_f * xij * xij;
246 }
247 LinearSVCLoss::SquaredHinge => {
248 grad = grad - two * c / n_f
249 * (F::one() - margin) * y_signed[i] * xij;
250 hess = hess + two * c / n_f * xij * xij;
251 }
252 }
253 }
254 }
255
256 let dw = -(grad / hess);
259 let new_w = w[j] + dw;
260 let change = dw.abs();
261 if change > max_change {
262 max_change = change;
263 }
264
265 w[j] = new_w;
267 for i in 0..n_samples {
268 decision[i] = decision[i] + dw * x[[i, j]];
269 }
270 }
271
272 {
274 let mut grad_b = F::zero();
275 let mut hess_b = F::from(1e-12).unwrap(); for i in 0..n_samples {
277 let margin = y_signed[i] * decision[i];
278 if margin < F::one() {
279 match loss {
280 LinearSVCLoss::Hinge => {
281 grad_b = grad_b - c / n_f * y_signed[i];
282 hess_b = hess_b + c / n_f;
283 }
284 LinearSVCLoss::SquaredHinge => {
285 grad_b = grad_b - two * c / n_f
286 * (F::one() - margin) * y_signed[i];
287 hess_b = hess_b + two * c / n_f;
288 }
289 }
290 }
291 }
292 let db = -(grad_b / hess_b);
293 let change = db.abs();
294 if change > max_change {
295 max_change = change;
296 }
297 b = b + db;
298 for i in 0..n_samples {
299 decision[i] = decision[i] + db;
300 }
301 }
302
303 if max_change < tol {
304 break;
305 }
306 }
307
308 (w, b)
309}
310
311impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
312 for LinearSVC<F>
313{
314 type Fitted = FittedLinearSVC<F>;
315 type Error = FerroError;
316
317 fn fit(
325 &self,
326 x: &Array2<F>,
327 y: &Array1<usize>,
328 ) -> Result<FittedLinearSVC<F>, FerroError> {
329 let (n_samples, n_features) = x.dim();
330
331 if n_samples != y.len() {
332 return Err(FerroError::ShapeMismatch {
333 expected: vec![n_samples],
334 actual: vec![y.len()],
335 context: "y length must match number of samples in X".into(),
336 });
337 }
338
339 if self.c <= F::zero() {
340 return Err(FerroError::InvalidParameter {
341 name: "C".into(),
342 reason: "must be positive".into(),
343 });
344 }
345
346 let mut classes: Vec<usize> = y.to_vec();
347 classes.sort_unstable();
348 classes.dedup();
349
350 if classes.len() < 2 {
351 return Err(FerroError::InsufficientSamples {
352 required: 2,
353 actual: classes.len(),
354 context: "LinearSVC requires at least 2 distinct classes".into(),
355 });
356 }
357
358 if classes.len() == 2 {
359 let y_signed: Array1<F> = y.mapv(|label| {
361 if label == classes[1] {
362 F::one()
363 } else {
364 -F::one()
365 }
366 });
367
368 let (w, b) = solve_binary_primal(x, &y_signed, self.c, self.max_iter, self.tol, self.loss);
369
370 Ok(FittedLinearSVC {
371 weight_vectors: vec![w],
372 intercepts: vec![b],
373 classes,
374 is_binary: true,
375 n_features,
376 })
377 } else {
378 let mut weight_vectors = Vec::with_capacity(classes.len());
380 let mut intercepts = Vec::with_capacity(classes.len());
381
382 for &cls in &classes {
383 let y_signed: Array1<F> = y.mapv(|label| {
384 if label == cls {
385 F::one()
386 } else {
387 -F::one()
388 }
389 });
390
391 let (w, b) =
392 solve_binary_primal(x, &y_signed, self.c, self.max_iter, self.tol, self.loss);
393 weight_vectors.push(w);
394 intercepts.push(b);
395 }
396
397 Ok(FittedLinearSVC {
398 weight_vectors,
399 intercepts,
400 classes,
401 is_binary: false,
402 n_features,
403 })
404 }
405 }
406}
407
408impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
409 for FittedLinearSVC<F>
410{
411 type Output = Array1<usize>;
412 type Error = FerroError;
413
414 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
424 let n_features = x.ncols();
425 if n_features != self.n_features {
426 return Err(FerroError::ShapeMismatch {
427 expected: vec![self.n_features],
428 actual: vec![n_features],
429 context: "number of features must match fitted model".into(),
430 });
431 }
432
433 let n_samples = x.nrows();
434 let mut predictions = Array1::<usize>::zeros(n_samples);
435
436 if self.is_binary {
437 let scores = x.dot(&self.weight_vectors[0]) + self.intercepts[0];
438 for i in 0..n_samples {
439 predictions[i] = if scores[i] >= F::zero() {
440 self.classes[1]
441 } else {
442 self.classes[0]
443 };
444 }
445 } else {
446 for i in 0..n_samples {
448 let mut best_class = 0;
449 let mut best_score = F::neg_infinity();
450 for (c, w) in self.weight_vectors.iter().enumerate() {
451 let score = x.row(i).dot(w) + self.intercepts[c];
452 if score > best_score {
453 best_score = score;
454 best_class = c;
455 }
456 }
457 predictions[i] = self.classes[best_class];
458 }
459 }
460
461 Ok(predictions)
462 }
463}
464
465impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
466 for FittedLinearSVC<F>
467{
468 fn coefficients(&self) -> &Array1<F> {
470 &self.weight_vectors[0]
471 }
472
473 fn intercept(&self) -> F {
475 self.intercepts[0]
476 }
477}
478
479impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedLinearSVC<F> {
480 fn classes(&self) -> &[usize] {
481 &self.classes
482 }
483
484 fn n_classes(&self) -> usize {
485 self.classes.len()
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use ndarray::array;
493
494 #[test]
495 fn test_default_constructor() {
496 let m = LinearSVC::<f64>::new();
497 assert_eq!(m.max_iter, 1000);
498 assert!(m.c == 1.0);
499 assert_eq!(m.loss, LinearSVCLoss::SquaredHinge);
500 }
501
502 #[test]
503 fn test_builder_setters() {
504 let m = LinearSVC::<f64>::new()
505 .with_c(10.0)
506 .with_max_iter(500)
507 .with_tol(1e-6)
508 .with_loss(LinearSVCLoss::Hinge);
509 assert!(m.c == 10.0);
510 assert_eq!(m.max_iter, 500);
511 assert_eq!(m.loss, LinearSVCLoss::Hinge);
512 }
513
514 #[test]
515 fn test_binary_classification() {
516 let x = Array2::from_shape_vec(
517 (8, 2),
518 vec![
519 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
520 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
521 ],
522 )
523 .unwrap();
524 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
525
526 let model = LinearSVC::<f64>::new().with_c(1.0).with_max_iter(5000);
527 let fitted = model.fit(&x, &y).unwrap();
528 let preds = fitted.predict(&x).unwrap();
529
530 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
531 assert!(correct >= 6, "expected at least 6 correct, got {correct}");
532 }
533
534 #[test]
535 fn test_binary_hinge_loss() {
536 let x = Array2::from_shape_vec(
537 (6, 2),
538 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
539 )
540 .unwrap();
541 let y = array![0, 0, 0, 1, 1, 1];
542
543 let model = LinearSVC::<f64>::new()
544 .with_loss(LinearSVCLoss::Hinge)
545 .with_max_iter(5000);
546 let fitted = model.fit(&x, &y).unwrap();
547 let preds = fitted.predict(&x).unwrap();
548
549 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
550 assert!(correct >= 4, "expected at least 4 correct, got {correct}");
551 }
552
553 #[test]
554 fn test_multiclass_classification() {
555 let x = Array2::from_shape_vec(
556 (9, 2),
557 vec![
558 0.0, 0.0, 0.5, 0.0, 0.0, 0.5,
559 10.0, 0.0, 10.5, 0.0, 10.0, 0.5,
560 0.0, 10.0, 0.5, 10.0, 0.0, 10.5,
561 ],
562 )
563 .unwrap();
564 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
565
566 let model = LinearSVC::<f64>::new().with_c(10.0).with_max_iter(5000);
567 let fitted = model.fit(&x, &y).unwrap();
568
569 assert_eq!(fitted.n_classes(), 3);
570 assert_eq!(fitted.classes(), &[0, 1, 2]);
571
572 let preds = fitted.predict(&x).unwrap();
573 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
574 assert!(correct >= 7, "expected at least 7 correct, got {correct}");
575 }
576
577 #[test]
578 fn test_shape_mismatch_fit() {
579 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
580 let y = array![0, 1]; let model = LinearSVC::<f64>::new();
583 assert!(model.fit(&x, &y).is_err());
584 }
585
586 #[test]
587 fn test_invalid_c() {
588 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
589 let y = array![0, 0, 1, 1];
590
591 let model = LinearSVC::<f64>::new().with_c(0.0);
592 assert!(model.fit(&x, &y).is_err());
593
594 let model_neg = LinearSVC::<f64>::new().with_c(-1.0);
595 assert!(model_neg.fit(&x, &y).is_err());
596 }
597
598 #[test]
599 fn test_single_class_error() {
600 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
601 let y = array![0, 0, 0];
602
603 let model = LinearSVC::<f64>::new();
604 assert!(model.fit(&x, &y).is_err());
605 }
606
607 #[test]
608 fn test_has_coefficients() {
609 let x = Array2::from_shape_vec(
610 (6, 2),
611 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
612 )
613 .unwrap();
614 let y = array![0, 0, 0, 1, 1, 1];
615
616 let model = LinearSVC::<f64>::new().with_max_iter(5000);
617 let fitted = model.fit(&x, &y).unwrap();
618 assert_eq!(fitted.coefficients().len(), 2);
619 }
620
621 #[test]
622 fn test_predict_feature_mismatch() {
623 let x = Array2::from_shape_vec(
624 (6, 2),
625 vec![1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0],
626 )
627 .unwrap();
628 let y = array![0, 0, 0, 1, 1, 1];
629
630 let fitted = LinearSVC::<f64>::new().with_max_iter(5000).fit(&x, &y).unwrap();
631
632 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
633 assert!(fitted.predict(&x_bad).is_err());
634 }
635}