Skip to main content

scry_learn/svm/
linear.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Linear SVM classifier and regressor via Pegasos SGD.
3//!
4//! [`LinearSVC`] uses hinge loss with L2 penalty for classification.
5//! [`LinearSVR`] uses ε-insensitive loss with L2 penalty for regression.
6//! Both solve the SVM objective via stochastic sub-gradient descent
7//! (Pegasos algorithm).
8
9use crate::dataset::Dataset;
10use crate::error::{Result, ScryLearnError};
11use crate::sparse::{CscMatrix, CsrMatrix};
12use crate::weights::{compute_sample_weights, ClassWeight};
13
14// ─────────────────────────────────────────────────────────────────
15// LinearSVC
16// ─────────────────────────────────────────────────────────────────
17
18/// Linear Support Vector Classifier.
19///
20/// Uses the Pegasos SGD algorithm to minimize hinge loss with L2
21/// regularisation. Binary problems use a single weight vector;
22/// multiclass problems use one-vs-rest (one weight vector per class,
23/// prediction = argmax of decision function scores).
24///
25/// # Example
26///
27/// ```
28/// use scry_learn::dataset::Dataset;
29/// use scry_learn::svm::LinearSVC;
30///
31/// let features = vec![
32///     vec![0.0, 0.0, 10.0, 10.0],
33///     vec![0.0, 0.0, 10.0, 10.0],
34/// ];
35/// let target = vec![0.0, 0.0, 1.0, 1.0];
36/// let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
37///
38/// let mut svc = LinearSVC::new().c(1.0).max_iter(500);
39/// svc.fit(&data).unwrap();
40///
41/// let preds = svc.predict(&[vec![1.0, 1.0]]).unwrap();
42/// assert_eq!(preds[0] as usize, 0);
43/// ```
44#[derive(Clone)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46#[non_exhaustive]
47pub struct LinearSVC {
48    c: f64,
49    max_iter: usize,
50    tol: f64,
51    class_weight: ClassWeight,
52    probability: bool,
53    /// One weight vector per class (OVR). Each vector has length
54    /// `n_features + 1` (last element is the bias).
55    weights: Vec<Vec<f64>>,
56    /// Platt scaling parameters (A, B) per OVR model.
57    platt_params: Vec<(f64, f64)>,
58    n_classes: usize,
59    fitted: bool,
60    #[cfg_attr(feature = "serde", serde(default))]
61    _schema_version: u32,
62}
63
64impl LinearSVC {
65    /// Create a new `LinearSVC` with default parameters.
66    ///
67    /// Defaults: `C = 1.0`, `max_iter = 1000`, `tol = 1e-4`.
68    pub fn new() -> Self {
69        Self {
70            c: 1.0,
71            max_iter: 1000,
72            tol: crate::constants::DEFAULT_TOL,
73            class_weight: ClassWeight::Uniform,
74            probability: false,
75            weights: Vec::new(),
76            platt_params: Vec::new(),
77            n_classes: 0,
78            fitted: false,
79            _schema_version: crate::version::SCHEMA_VERSION,
80        }
81    }
82
83    /// Set the regularisation parameter `C`.
84    ///
85    /// Larger values penalise misclassification more (tighter margin).
86    pub fn c(mut self, c: f64) -> Self {
87        self.c = c;
88        self
89    }
90
91    /// Set the maximum number of SGD epochs.
92    pub fn max_iter(mut self, n: usize) -> Self {
93        self.max_iter = n;
94        self
95    }
96
97    /// Set convergence tolerance on the max weight change per epoch.
98    pub fn tol(mut self, t: f64) -> Self {
99        self.tol = t;
100        self
101    }
102
103    /// Set class weighting strategy for imbalanced datasets.
104    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
105        self.class_weight = cw;
106        self
107    }
108
109    /// Enable Platt scaling for probability estimates.
110    ///
111    /// When `true`, [`predict_proba`](Self::predict_proba) returns
112    /// calibrated class probabilities after fitting.
113    pub fn probability(mut self, enable: bool) -> Self {
114        self.probability = enable;
115        self
116    }
117
118    /// Train the SVM on the given dataset.
119    ///
120    /// Uses Pegasos-style SGD with one-vs-rest decomposition for
121    /// multiclass problems (≥ 3 classes). Auto-dispatches to sparse
122    /// kernels when the dataset uses sparse storage.
123    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
124        data.validate_finite()?;
125        if let Some(csc) = data.sparse_csc() {
126            return self.fit_sparse(csc, &data.target);
127        }
128        let n = data.n_samples();
129        let m = data.n_features();
130        if n == 0 {
131            return Err(ScryLearnError::EmptyDataset);
132        }
133        if self.c <= 0.0 || !self.c.is_finite() {
134            return Err(ScryLearnError::InvalidParameter(
135                "C must be finite and positive".into(),
136            ));
137        }
138
139        self.n_classes = data.n_classes();
140        let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
141
142        // One-vs-rest: train one binary sub-problem per class.
143        // For each class k the binary target is +1 / -1.
144        self.weights = Vec::with_capacity(self.n_classes);
145        self.platt_params = Vec::with_capacity(self.n_classes);
146
147        for cls in 0..self.n_classes {
148            let binary_target: Vec<f64> = data
149                .target
150                .iter()
151                .map(|&t| if t as usize == cls { 1.0 } else { -1.0 })
152                .collect();
153
154            let w = pegasos_train(
155                &data.features,
156                &binary_target,
157                &sample_weights,
158                m,
159                n,
160                self.c,
161                self.max_iter,
162                self.tol,
163            );
164
165            // Platt scaling: fit sigmoid on decision values.
166            let ab = if self.probability {
167                let dvals: Vec<f64> = (0..n)
168                    .map(|i| {
169                        let mut score = w[m]; // bias
170                        for (j, feat_col) in data.features.iter().enumerate().take(m) {
171                            score += w[j] * feat_col[i];
172                        }
173                        score
174                    })
175                    .collect();
176                platt_fit(&dvals, &binary_target)
177            } else {
178                (0.0, 0.0)
179            };
180            self.platt_params.push(ab);
181            self.weights.push(w);
182        }
183
184        self.fitted = true;
185        Ok(())
186    }
187
188    /// Train on sparse data (CSC format).
189    fn fit_sparse(&mut self, csc: &CscMatrix, target: &[f64]) -> Result<()> {
190        let csr = csc.to_csr();
191        let n = csr.n_rows();
192        let m = csc.n_cols();
193        if n == 0 {
194            return Err(ScryLearnError::EmptyDataset);
195        }
196        if self.c <= 0.0 || !self.c.is_finite() {
197            return Err(ScryLearnError::InvalidParameter(
198                "C must be finite and positive".into(),
199            ));
200        }
201
202        self.n_classes = {
203            let mut max_class = 0usize;
204            for &t in target {
205                let c = t as usize;
206                if c > max_class {
207                    max_class = c;
208                }
209            }
210            max_class + 1
211        };
212        let sample_weights = compute_sample_weights(target, &self.class_weight);
213
214        self.weights = Vec::with_capacity(self.n_classes);
215        self.platt_params = Vec::with_capacity(self.n_classes);
216
217        for cls in 0..self.n_classes {
218            let binary_target: Vec<f64> = target
219                .iter()
220                .map(|&t| if t as usize == cls { 1.0 } else { -1.0 })
221                .collect();
222
223            let w = pegasos_train_sparse(
224                &csr,
225                &binary_target,
226                &sample_weights,
227                m,
228                n,
229                self.c,
230                self.max_iter,
231                self.tol,
232            );
233
234            let ab = if self.probability {
235                let dvals: Vec<f64> = (0..n)
236                    .map(|i| {
237                        let row = csr.row(i);
238                        let mut score = w[m]; // bias
239                        for (col, val) in row.iter() {
240                            score += w[col] * val;
241                        }
242                        score
243                    })
244                    .collect();
245                platt_fit(&dvals, &binary_target)
246            } else {
247                (0.0, 0.0)
248            };
249            self.platt_params.push(ab);
250            self.weights.push(w);
251        }
252
253        self.fitted = true;
254        Ok(())
255    }
256
257    /// Predict class labels from sparse input (CSR format).
258    pub fn predict_sparse(&self, csr: &CsrMatrix) -> Result<Vec<f64>> {
259        if !self.fitted {
260            return Err(ScryLearnError::NotFitted);
261        }
262        let n = csr.n_rows();
263        let mut preds = Vec::with_capacity(n);
264        for i in 0..n {
265            let row = csr.row(i);
266            let mut best_cls = 0usize;
267            let mut best_score = f64::NEG_INFINITY;
268            for (cls, w) in self.weights.iter().enumerate() {
269                let m = w.len() - 1;
270                let mut score = w[m]; // bias
271                for (col, val) in row.iter() {
272                    if col < m {
273                        score += w[col] * val;
274                    }
275                }
276                if score > best_score {
277                    best_score = score;
278                    best_cls = cls;
279                }
280            }
281            preds.push(best_cls as f64);
282        }
283        Ok(preds)
284    }
285
286    /// Predict class labels for the given row-major feature matrix.
287    ///
288    /// Returns the class whose OVR decision function is largest.
289    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
290        crate::version::check_schema_version(self._schema_version)?;
291        let scores = self.decision_function(features)?;
292        Ok(scores
293            .into_iter()
294            .map(|row| {
295                row.iter()
296                    .enumerate()
297                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
298                    .map_or(0.0, |(idx, _)| idx as f64)
299            })
300            .collect())
301    }
302
303    /// Compute the raw decision function score for each class.
304    ///
305    /// Returns `scores[sample][class]` = `w · x + b` for each OVR
306    /// sub-problem.
307    ///
308    /// # Example
309    ///
310    /// ```
311    /// use scry_learn::dataset::Dataset;
312    /// use scry_learn::svm::LinearSVC;
313    ///
314    /// let features = vec![
315    ///     vec![0.0, 0.0, 10.0, 10.0],
316    ///     vec![0.0, 0.0, 10.0, 10.0],
317    /// ];
318    /// let target = vec![0.0, 0.0, 1.0, 1.0];
319    /// let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
320    ///
321    /// let mut svc = LinearSVC::new();
322    /// svc.fit(&data).unwrap();
323    ///
324    /// let scores = svc.decision_function(&[vec![1.0, 1.0]]).unwrap();
325    /// assert_eq!(scores[0].len(), 2); // two classes
326    /// ```
327    pub fn decision_function(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
328        if !self.fitted {
329            return Err(ScryLearnError::NotFitted);
330        }
331        Ok(features
332            .iter()
333            .map(|row| {
334                self.weights
335                    .iter()
336                    .map(|w| {
337                        let m = w.len() - 1;
338                        let mut score = w[m]; // bias
339                        for (j, &x) in row.iter().enumerate().take(m) {
340                            score += w[j] * x;
341                        }
342                        score
343                    })
344                    .collect()
345            })
346            .collect())
347    }
348
349    /// Predict class probabilities using Platt scaling.
350    ///
351    /// Requires `.probability(true)` to have been set before fitting.
352    /// Returns `probabilities[sample][class]` normalised to sum to 1.
353    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
354        if !self.fitted {
355            return Err(ScryLearnError::NotFitted);
356        }
357        if !self.probability {
358            return Err(ScryLearnError::InvalidParameter(
359                "call .probability(true) before fit to enable predict_proba".into(),
360            ));
361        }
362        let scores = self.decision_function(features)?;
363        Ok(scores
364            .into_iter()
365            .map(|row| {
366                let raw: Vec<f64> = row
367                    .iter()
368                    .zip(self.platt_params.iter())
369                    .map(|(&dv, &(a, b))| platt_predict(dv, a, b))
370                    .collect();
371                let sum: f64 = raw.iter().sum();
372                if sum > f64::EPSILON {
373                    raw.iter().map(|&p| p / sum).collect()
374                } else {
375                    vec![1.0 / raw.len() as f64; raw.len()]
376                }
377            })
378            .collect())
379    }
380}
381
382impl Default for LinearSVC {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388// ─────────────────────────────────────────────────────────────────
389// LinearSVR
390// ─────────────────────────────────────────────────────────────────
391
392/// Linear Support Vector Regressor.
393///
394/// Uses ε-insensitive loss with L2 penalty, solved by SGD.
395/// Predictions within `epsilon` of the true value incur no loss.
396///
397/// # Example
398///
399/// ```
400/// use scry_learn::dataset::Dataset;
401/// use scry_learn::svm::LinearSVR;
402///
403/// let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
404/// let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
405/// let data = Dataset::new(features, target, vec!["x".into()], "y");
406///
407/// let mut svr = LinearSVR::new().c(1.0).epsilon(0.1);
408/// svr.fit(&data).unwrap();
409///
410/// let preds = svr.predict(&[vec![3.0]]).unwrap();
411/// assert!((preds[0] - 6.0).abs() < 1.0);
412/// ```
413#[derive(Clone)]
414#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
415#[non_exhaustive]
416pub struct LinearSVR {
417    c: f64,
418    epsilon: f64,
419    max_iter: usize,
420    tol: f64,
421    /// `w[0..m]` = feature weights, `w[m]` = bias.
422    weights: Vec<f64>,
423    fitted: bool,
424    #[cfg_attr(feature = "serde", serde(default))]
425    _schema_version: u32,
426}
427
428impl LinearSVR {
429    /// Create a new `LinearSVR` with default parameters.
430    ///
431    /// Defaults: `C = 1.0`, `epsilon = 0.1`, `max_iter = 1000`, `tol = 1e-4`.
432    pub fn new() -> Self {
433        Self {
434            c: 1.0,
435            epsilon: 0.1,
436            max_iter: 1000,
437            tol: crate::constants::DEFAULT_TOL,
438            weights: Vec::new(),
439            fitted: false,
440            _schema_version: crate::version::SCHEMA_VERSION,
441        }
442    }
443
444    /// Set the regularisation parameter `C`.
445    pub fn c(mut self, c: f64) -> Self {
446        self.c = c;
447        self
448    }
449
450    /// Set the epsilon tube width.
451    ///
452    /// Predictions within `epsilon` of the true value incur zero loss.
453    pub fn epsilon(mut self, e: f64) -> Self {
454        self.epsilon = e;
455        self
456    }
457
458    /// Set the maximum number of SGD epochs.
459    pub fn max_iter(mut self, n: usize) -> Self {
460        self.max_iter = n;
461        self
462    }
463
464    /// Set convergence tolerance on the max weight change per epoch.
465    pub fn tol(mut self, t: f64) -> Self {
466        self.tol = t;
467        self
468    }
469
470    /// Train the SVR on the given dataset.
471    ///
472    /// Auto-dispatches to sparse kernels when the dataset uses sparse storage.
473    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
474        data.validate_finite()?;
475        if let Some(csc) = data.sparse_csc() {
476            return self.fit_sparse(csc, &data.target);
477        }
478        let n = data.n_samples();
479        let m = data.n_features();
480        if n == 0 {
481            return Err(ScryLearnError::EmptyDataset);
482        }
483        if self.c <= 0.0 || !self.c.is_finite() {
484            return Err(ScryLearnError::InvalidParameter(
485                "C must be finite and positive".into(),
486            ));
487        }
488
489        let lambda = 1.0 / (self.c * n as f64);
490        // w has m+1 elements: m feature weights + 1 bias.
491        let mut w = vec![0.0; m + 1];
492        let mut t = 1.0_f64;
493
494        let mut prev_w = w.clone();
495
496        for _epoch in 0..self.max_iter {
497            for i in 0..n {
498                let eta = 1.0 / (lambda * t);
499                t += 1.0;
500
501                // Compute prediction: w·x + b
502                let mut pred = w[m]; // bias
503                for (wj, feat_col) in w.iter().zip(data.features.iter()) {
504                    pred += wj * feat_col[i];
505                }
506
507                let residual = pred - data.target[i];
508
509                let sign = if residual > self.epsilon {
510                    1.0
511                } else if residual < -self.epsilon {
512                    -1.0
513                } else {
514                    0.0
515                };
516
517                for (wj, feat_col) in w.iter_mut().zip(data.features.iter()) {
518                    *wj = (1.0 - eta * lambda) * *wj - eta * sign * feat_col[i];
519                }
520                w[m] -= eta * sign;
521            }
522
523            let max_delta = w
524                .iter()
525                .zip(prev_w.iter())
526                .map(|(a, b)| (a - b).abs())
527                .fold(0.0_f64, f64::max);
528            if max_delta < self.tol {
529                break;
530            }
531            prev_w.copy_from_slice(&w);
532        }
533
534        self.weights = w;
535        self.fitted = true;
536        Ok(())
537    }
538
539    /// Train on sparse data (CSC format).
540    fn fit_sparse(&mut self, csc: &CscMatrix, target: &[f64]) -> Result<()> {
541        let csr = csc.to_csr();
542        let n = csr.n_rows();
543        let m = csc.n_cols();
544        if n == 0 {
545            return Err(ScryLearnError::EmptyDataset);
546        }
547        if self.c <= 0.0 || !self.c.is_finite() {
548            return Err(ScryLearnError::InvalidParameter(
549                "C must be finite and positive".into(),
550            ));
551        }
552
553        let lambda = 1.0 / (self.c * n as f64);
554        let mut w = vec![0.0; m + 1];
555        let mut t = 1.0_f64;
556        let mut prev_w = w.clone();
557
558        for _epoch in 0..self.max_iter {
559            for i in 0..n {
560                let eta = 1.0 / (lambda * t);
561                t += 1.0;
562
563                let row = csr.row(i);
564                let mut pred = w[m]; // bias
565                for (col, val) in row.iter() {
566                    pred += w[col] * val;
567                }
568
569                let residual = pred - target[i];
570                let sign = if residual > self.epsilon {
571                    1.0
572                } else if residual < -self.epsilon {
573                    -1.0
574                } else {
575                    0.0
576                };
577
578                // Regularise all weights, then update non-zero entries.
579                let decay = 1.0 - eta * lambda;
580                for wj in w.iter_mut().take(m) {
581                    *wj *= decay;
582                }
583                for (col, val) in row.iter() {
584                    w[col] -= eta * sign * val;
585                }
586                w[m] -= eta * sign;
587            }
588
589            let max_delta = w
590                .iter()
591                .zip(prev_w.iter())
592                .map(|(a, b)| (a - b).abs())
593                .fold(0.0_f64, f64::max);
594            if max_delta < self.tol {
595                break;
596            }
597            prev_w.copy_from_slice(&w);
598        }
599
600        self.weights = w;
601        self.fitted = true;
602        Ok(())
603    }
604
605    /// Predict continuous target values.
606    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
607        crate::version::check_schema_version(self._schema_version)?;
608        if !self.fitted {
609            return Err(ScryLearnError::NotFitted);
610        }
611        let m = self.weights.len() - 1;
612        Ok(features
613            .iter()
614            .map(|row| {
615                let mut pred = self.weights[m]; // bias
616                for (j, &x) in row.iter().enumerate().take(m) {
617                    pred += self.weights[j] * x;
618                }
619                pred
620            })
621            .collect())
622    }
623
624    /// Predict continuous target values from sparse input (CSR format).
625    pub fn predict_sparse(&self, csr: &CsrMatrix) -> Result<Vec<f64>> {
626        if !self.fitted {
627            return Err(ScryLearnError::NotFitted);
628        }
629        let m = self.weights.len() - 1;
630        let n = csr.n_rows();
631        let mut preds = Vec::with_capacity(n);
632        for i in 0..n {
633            let row = csr.row(i);
634            let mut pred = self.weights[m]; // bias
635            for (col, val) in row.iter() {
636                if col < m {
637                    pred += self.weights[col] * val;
638                }
639            }
640            preds.push(pred);
641        }
642        Ok(preds)
643    }
644}
645
646impl Default for LinearSVR {
647    fn default() -> Self {
648        Self::new()
649    }
650}
651
652// ─────────────────────────────────────────────────────────────────
653// Pegasos SGD helper (used by LinearSVC)
654// ─────────────────────────────────────────────────────────────────
655
656/// Train a single binary SVM via sub-gradient descent.
657///
658/// Uses full batch gradient with a fixed-then-decay learning rate.
659/// `binary_target` values are +1 / -1.
660/// Returns weight vector of length `m + 1` (last = bias).
661#[allow(clippy::too_many_arguments)]
662fn pegasos_train(
663    features: &[Vec<f64>],  // [n_features][n_samples] (column-major)
664    binary_target: &[f64],  // [n_samples], +1/-1
665    sample_weights: &[f64], // [n_samples]
666    m: usize,               // n_features
667    n: usize,               // n_samples
668    c: f64,
669    max_iter: usize,
670    tol: f64,
671) -> Vec<f64> {
672    let lambda = 1.0 / (c * n as f64);
673    let mut w = vec![0.0; m + 1]; // w[0..m] = features, w[m] = bias
674    let mut best_w = w.clone();
675    let mut best_loss = f64::INFINITY;
676
677    let mut prev_w = w.clone();
678
679    for epoch in 0..max_iter {
680        // Decaying learning rate with a floor to avoid stalling.
681        let eta = 1.0 / (1.0 + crate::constants::PEGASOS_LR_DECAY * epoch as f64);
682
683        // Batch sub-gradient.
684        let mut grad = vec![0.0; m + 1];
685        let mut hinge_loss = 0.0_f64;
686
687        for i in 0..n {
688            let mut score = w[m]; // bias
689            for j in 0..m {
690                score += w[j] * features[j][i];
691            }
692
693            let y = binary_target[i];
694            let sw = sample_weights[i];
695            let margin = y * score;
696
697            if margin < 1.0 {
698                let loss_contrib = sw * (1.0 - margin);
699                hinge_loss += loss_contrib;
700                for j in 0..m {
701                    grad[j] -= sw * y * features[j][i];
702                }
703                grad[m] -= sw * y;
704            }
705        }
706
707        // Average hinge gradient + L2 penalty on weights (not bias).
708        for j in 0..m {
709            grad[j] = grad[j] / n as f64 + lambda * w[j];
710        }
711        grad[m] /= n as f64;
712
713        // Update.
714        for j in 0..=m {
715            w[j] -= eta * grad[j];
716        }
717
718        // Track best weights by total loss.
719        let total_loss =
720            hinge_loss / n as f64 + 0.5 * lambda * w.iter().take(m).map(|x| x * x).sum::<f64>();
721        if total_loss < best_loss {
722            best_loss = total_loss;
723            best_w.copy_from_slice(&w);
724        }
725
726        // Convergence: max weight change.
727        let max_delta = w
728            .iter()
729            .zip(prev_w.iter())
730            .map(|(a, b)| (a - b).abs())
731            .fold(0.0_f64, f64::max);
732        if max_delta < tol {
733            break;
734        }
735        prev_w.copy_from_slice(&w);
736    }
737
738    best_w
739}
740
741// ─────────────────────────────────────────────────────────────────
742// Sparse Pegasos SGD helper (used by LinearSVC sparse path)
743// ─────────────────────────────────────────────────────────────────
744
745/// Train a single binary SVM on sparse data via batch sub-gradient descent.
746///
747/// Mirrors `pegasos_train` but operates on a CSR matrix for efficient
748/// row access. Returns weight vector of length `m + 1` (last = bias).
749#[allow(clippy::too_many_arguments)]
750fn pegasos_train_sparse(
751    csr: &CsrMatrix,
752    binary_target: &[f64],
753    sample_weights: &[f64],
754    m: usize,
755    n: usize,
756    c: f64,
757    max_iter: usize,
758    tol: f64,
759) -> Vec<f64> {
760    let lambda = 1.0 / (c * n as f64);
761    let mut w = vec![0.0; m + 1];
762    let mut best_w = w.clone();
763    let mut best_loss = f64::INFINITY;
764    let mut prev_w = w.clone();
765
766    for epoch in 0..max_iter {
767        let eta = 1.0 / (1.0 + crate::constants::PEGASOS_LR_DECAY * epoch as f64);
768
769        let mut grad = vec![0.0; m + 1];
770        let mut hinge_loss = 0.0_f64;
771
772        for i in 0..n {
773            let row = csr.row(i);
774            let mut score = w[m]; // bias
775            for (col, val) in row.iter() {
776                score += w[col] * val;
777            }
778
779            let y = binary_target[i];
780            let sw = sample_weights[i];
781            let margin = y * score;
782
783            if margin < 1.0 {
784                hinge_loss += sw * (1.0 - margin);
785                for (col, val) in row.iter() {
786                    grad[col] -= sw * y * val;
787                }
788                grad[m] -= sw * y;
789            }
790        }
791
792        // Average hinge gradient + L2 penalty.
793        for j in 0..m {
794            grad[j] = grad[j] / n as f64 + lambda * w[j];
795        }
796        grad[m] /= n as f64;
797
798        for j in 0..=m {
799            w[j] -= eta * grad[j];
800        }
801
802        let total_loss =
803            hinge_loss / n as f64 + 0.5 * lambda * w.iter().take(m).map(|x| x * x).sum::<f64>();
804        if total_loss < best_loss {
805            best_loss = total_loss;
806            best_w.copy_from_slice(&w);
807        }
808
809        let max_delta = w
810            .iter()
811            .zip(prev_w.iter())
812            .map(|(a, b)| (a - b).abs())
813            .fold(0.0_f64, f64::max);
814        if max_delta < tol {
815            break;
816        }
817        prev_w.copy_from_slice(&w);
818    }
819
820    best_w
821}
822
823// ─────────────────────────────────────────────────────────────────
824// Platt scaling (shared with kernel.rs)
825// ─────────────────────────────────────────────────────────────────
826
827/// Fit Platt sigmoid parameters (A, B) on decision values.
828fn platt_fit(decision_values: &[f64], labels: &[f64]) -> (f64, f64) {
829    let n = decision_values.len();
830    if n == 0 {
831        return (0.0, 0.0);
832    }
833
834    let n_pos = labels.iter().filter(|&&y| y > 0.0).count() as f64;
835    let n_neg = n as f64 - n_pos;
836
837    let t_pos = (n_pos + 1.0) / (n_pos + 2.0);
838    let t_neg = 1.0 / (n_neg + 2.0);
839    let targets: Vec<f64> = labels
840        .iter()
841        .map(|&y| if y > 0.0 { t_pos } else { t_neg })
842        .collect();
843
844    let mut a = 0.0_f64;
845    let mut b = ((n_neg + 1.0) / (n_pos + 1.0)).ln();
846    let sigma = crate::constants::PLATT_HESSIAN_REG;
847
848    for _ in 0..100 {
849        let mut g1 = 0.0_f64;
850        let mut g2 = 0.0_f64;
851        let mut h11 = sigma;
852        let mut h22 = sigma;
853        let mut h21 = 0.0_f64;
854
855        for i in 0..n {
856            let fval = decision_values[i] * a + b;
857            let p = 1.0 / (1.0 + (-fval).exp());
858            let d = p - targets[i];
859            let s = p * (1.0 - p);
860
861            g1 += d * decision_values[i];
862            g2 += d;
863            h11 += s * decision_values[i] * decision_values[i];
864            h22 += s;
865            h21 += s * decision_values[i];
866        }
867
868        let det = h11 * h22 - h21 * h21;
869        if det.abs() < crate::constants::PLATT_SINGULAR_DET {
870            break;
871        }
872        let da = -(h22 * g1 - h21 * g2) / det;
873        let db = -(h11 * g2 - h21 * g1) / det;
874
875        if da.abs() < crate::constants::PLATT_CONVERGENCE
876            && db.abs() < crate::constants::PLATT_CONVERGENCE
877        {
878            break;
879        }
880
881        a += da;
882        b += db;
883    }
884
885    (a, b)
886}
887
888/// Predict probability from a single decision value via Platt sigmoid.
889#[inline]
890fn platt_predict(dv: f64, a: f64, b: f64) -> f64 {
891    1.0 / (1.0 + (a * dv + b).exp())
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897
898    #[test]
899    fn test_linear_svc_binary() {
900        // Two linearly separable clusters.
901        let features = vec![
902            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
903            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
904        ];
905        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
906        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
907
908        let mut svc = LinearSVC::new().c(1.0).max_iter(500);
909        svc.fit(&data).unwrap();
910
911        let preds = svc.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
912        assert_eq!(preds[0] as usize, 0);
913        assert_eq!(preds[1] as usize, 1);
914    }
915
916    #[test]
917    fn test_linear_svc_decision_function() {
918        let features = vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]];
919        let target = vec![0.0, 0.0, 1.0, 1.0];
920        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
921
922        let mut svc = LinearSVC::new();
923        svc.fit(&data).unwrap();
924
925        let scores = svc.decision_function(&[vec![1.0, 1.0]]).unwrap();
926        assert_eq!(scores[0].len(), 2);
927    }
928
929    #[test]
930    fn test_linear_svc_not_fitted() {
931        let svc = LinearSVC::new();
932        assert!(svc.predict(&[vec![1.0]]).is_err());
933        assert!(svc.decision_function(&[vec![1.0]]).is_err());
934    }
935
936    #[test]
937    fn test_linear_svc_invalid_c() {
938        let features = vec![vec![1.0]];
939        let target = vec![0.0];
940        let data = Dataset::new(features, target, vec!["x".into()], "class");
941
942        let mut svc = LinearSVC::new().c(-1.0);
943        assert!(svc.fit(&data).is_err());
944    }
945
946    #[test]
947    fn test_linear_svr_simple() {
948        // y = 2x
949        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]];
950        let target = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
951        let data = Dataset::new(features, target, vec!["x".into()], "y");
952
953        let mut svr = LinearSVR::new().c(10.0).epsilon(0.1).max_iter(2000);
954        svr.fit(&data).unwrap();
955
956        let preds = svr.predict(&[vec![3.0], vec![5.0]]).unwrap();
957        assert!(
958            (preds[0] - 6.0).abs() < 2.0,
959            "Expected ~6.0, got {}",
960            preds[0]
961        );
962        assert!(
963            (preds[1] - 10.0).abs() < 2.0,
964            "Expected ~10.0, got {}",
965            preds[1]
966        );
967    }
968
969    #[test]
970    fn test_linear_svr_not_fitted() {
971        let svr = LinearSVR::new();
972        assert!(svr.predict(&[vec![1.0]]).is_err());
973    }
974
975    #[test]
976    fn test_linear_svc_predict_proba() {
977        let features = vec![
978            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
979            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
980        ];
981        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
982        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
983
984        let mut svc = LinearSVC::new().c(1.0).max_iter(500).probability(true);
985        svc.fit(&data).unwrap();
986
987        let proba = svc
988            .predict_proba(&[vec![1.0, 1.0], vec![9.0, 9.0]])
989            .unwrap();
990        for row in &proba {
991            let sum: f64 = row.iter().sum();
992            assert!(
993                (sum - 1.0).abs() < 1e-6,
994                "probabilities should sum to 1, got {sum}"
995            );
996            for &p in row {
997                assert!((0.0..=1.0).contains(&p), "probability out of range: {p}");
998            }
999        }
1000    }
1001
1002    #[test]
1003    fn test_linear_svc_predict_proba_not_enabled() {
1004        let features = vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]];
1005        let target = vec![0.0, 0.0, 1.0, 1.0];
1006        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1007
1008        let mut svc = LinearSVC::new();
1009        svc.fit(&data).unwrap();
1010        assert!(svc.predict_proba(&[vec![1.0, 1.0]]).is_err());
1011    }
1012}