Skip to main content

libsvm_rs/
probability.rs

1//! Probability estimation functions for SVM models.
2//!
3//! Provides Platt scaling (`sigmoid_train`/`sigmoid_predict`), multiclass
4//! probability estimation, and density-based probability for one-class SVM.
5//! Matches the original LIBSVM's probability routines (svm.cpp:1714–2096).
6
7use crate::predict::predict_values;
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmParameter, SvmProblem};
10use crate::util::c_rand;
11#[cfg(feature = "rayon")]
12use rayon::prelude::*;
13
14// ─── Platt scaling ───────────────────────────────────────────────────
15
16/// Train Platt scaling parameters (A, B) via Newton's method.
17///
18/// Given decision values and labels (+1/−1), fits the sigmoid
19/// P(y=1|f) = 1/(1+exp(A*f+B)) using the algorithm of Lin, Lin &
20/// Weng (2007). Matches LIBSVM's `sigmoid_train`.
21pub fn sigmoid_train(dec_values: &[f64], labels: &[f64]) -> (f64, f64) {
22    let l = dec_values.len();
23
24    let mut prior1: f64 = 0.0;
25    let mut prior0: f64 = 0.0;
26    for &y in labels {
27        if y > 0.0 {
28            prior1 += 1.0;
29        } else {
30            prior0 += 1.0;
31        }
32    }
33
34    let max_iter = 100;
35    let min_step = 1e-10;
36    let sigma = 1e-12;
37    let eps = 1e-5;
38
39    let hi_target = (prior1 + 1.0) / (prior1 + 2.0);
40    let lo_target = 1.0 / (prior0 + 2.0);
41
42    let t: Vec<f64> = labels
43        .iter()
44        .map(|&y| if y > 0.0 { hi_target } else { lo_target })
45        .collect();
46
47    // Initial point
48    let mut a = 0.0;
49    let mut b = ((prior0 + 1.0) / (prior1 + 1.0)).ln();
50
51    // Initial objective
52    let mut fval = 0.0;
53    for i in 0..l {
54        let f_apb = dec_values[i] * a + b;
55        if f_apb >= 0.0 {
56            fval += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
57        } else {
58            fval += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
59        }
60    }
61
62    for _iter in 0..max_iter {
63        // Gradient and Hessian (H' = H + σI)
64        let mut h11 = sigma;
65        let mut h22 = sigma;
66        let mut h21 = 0.0;
67        let mut g1 = 0.0;
68        let mut g2 = 0.0;
69
70        for i in 0..l {
71            let f_apb = dec_values[i] * a + b;
72            let (p, q) = if f_apb >= 0.0 {
73                let e = (-f_apb).exp();
74                (e / (1.0 + e), 1.0 / (1.0 + e))
75            } else {
76                let e = f_apb.exp();
77                (1.0 / (1.0 + e), e / (1.0 + e))
78            };
79            let d2 = p * q;
80            h11 += dec_values[i] * dec_values[i] * d2;
81            h22 += d2;
82            h21 += dec_values[i] * d2;
83            let d1 = t[i] - p;
84            g1 += dec_values[i] * d1;
85            g2 += d1;
86        }
87
88        if g1.abs() < eps && g2.abs() < eps {
89            break;
90        }
91
92        // Newton direction: −H'⁻¹ g
93        let det = h11 * h22 - h21 * h21;
94        let da = -(h22 * g1 - h21 * g2) / det;
95        let db = -(-h21 * g1 + h11 * g2) / det;
96        let gd = g1 * da + g2 * db;
97
98        // Line search with step-size halving
99        let mut stepsize = 1.0;
100        while stepsize >= min_step {
101            let new_a = a + stepsize * da;
102            let new_b = b + stepsize * db;
103
104            let mut newf = 0.0;
105            for i in 0..l {
106                let f_apb = dec_values[i] * new_a + new_b;
107                if f_apb >= 0.0 {
108                    newf += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
109                } else {
110                    newf += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
111                }
112            }
113
114            if newf < fval + 0.0001 * stepsize * gd {
115                a = new_a;
116                b = new_b;
117                fval = newf;
118                break;
119            }
120            stepsize /= 2.0;
121        }
122
123        if stepsize < min_step {
124            break;
125        }
126    }
127
128    (a, b)
129}
130
131/// Numerically stable sigmoid prediction.
132///
133/// Returns P(y=1|f) = 1/(1+exp(A*f+B)), branching on sign of A*f+B
134/// to avoid overflow. Matches LIBSVM's `sigmoid_predict`.
135pub fn sigmoid_predict(decision_value: f64, a: f64, b: f64) -> f64 {
136    let f_apb = decision_value * a + b;
137    if f_apb >= 0.0 {
138        (-f_apb).exp() / (1.0 + (-f_apb).exp())
139    } else {
140        1.0 / (1.0 + f_apb.exp())
141    }
142}
143
144// ─── Multiclass probability ──────────────────────────────────────────
145
146/// Solve multiclass probabilities from pairwise estimates.
147///
148/// Given `k` classes and a k×k matrix `r` of pairwise probabilities
149/// (r\[i\]\[j\] = P(class i | class i or j)), fills `p` with class
150/// probabilities using the Wu-Lin-Weng iterative method.
151///
152/// Matches LIBSVM's `multiclass_probability`.
153#[allow(clippy::needless_range_loop)]
154pub fn multiclass_probability(k: usize, r: &[Vec<f64>], p: &mut [f64]) {
155    let max_iter = 100.max(k);
156    let eps = 0.005 / k as f64;
157
158    // Build Q matrix
159    let mut q_mat = vec![vec![0.0; k]; k];
160    for t in 0..k {
161        q_mat[t][t] = 0.0;
162        for j in 0..t {
163            q_mat[t][t] += r[j][t] * r[j][t];
164            q_mat[t][j] = q_mat[j][t];
165        }
166        for j in (t + 1)..k {
167            q_mat[t][t] += r[j][t] * r[j][t];
168            q_mat[t][j] = -r[j][t] * r[t][j];
169        }
170    }
171
172    for t in 0..k {
173        p[t] = 1.0 / k as f64;
174    }
175
176    let mut qp = vec![0.0; k];
177
178    for _iter in 0..max_iter {
179        let mut p_qp = 0.0;
180        for t in 0..k {
181            qp[t] = 0.0;
182            for j in 0..k {
183                qp[t] += q_mat[t][j] * p[j];
184            }
185            p_qp += p[t] * qp[t];
186        }
187
188        let mut max_error = 0.0;
189        for t in 0..k {
190            let error = (qp[t] - p_qp).abs();
191            if error > max_error {
192                max_error = error;
193            }
194        }
195        if max_error < eps {
196            break;
197        }
198
199        for t in 0..k {
200            let diff = (-qp[t] + p_qp) / q_mat[t][t];
201            p[t] += diff;
202            p_qp = (p_qp + diff * (diff * q_mat[t][t] + 2.0 * qp[t])) / (1.0 + diff) / (1.0 + diff);
203            for j in 0..k {
204                qp[j] = (qp[j] + diff * q_mat[t][j]) / (1.0 + diff);
205                p[j] /= 1.0 + diff;
206            }
207        }
208    }
209}
210
211// ─── Binary SVC probability via internal CV ──────────────────────────
212
213fn evaluate_binary_svc_probability_fold(
214    prob: &SvmProblem,
215    param: &SvmParameter,
216    cp: f64,
217    cn: f64,
218    perm: &[usize],
219    begin: usize,
220    end: usize,
221) -> Vec<f64> {
222    let l = prob.labels.len();
223
224    // Build training sub-problem (exclude held-out fold)
225    let mut sub_instances = Vec::with_capacity(l - (end - begin));
226    let mut sub_labels = Vec::with_capacity(l - (end - begin));
227
228    for &pi in &perm[..begin] {
229        sub_instances.push(prob.instances[pi].clone());
230        sub_labels.push(prob.labels[pi]);
231    }
232    for &pi in &perm[end..l] {
233        sub_instances.push(prob.instances[pi].clone());
234        sub_labels.push(prob.labels[pi]);
235    }
236
237    // Count classes in training set
238    let p_count = sub_labels.iter().filter(|&&y| y > 0.0).count();
239    let n_count = sub_labels.len() - p_count;
240
241    if p_count == 0 && n_count == 0 {
242        vec![0.0; end - begin]
243    } else if p_count > 0 && n_count == 0 {
244        vec![1.0; end - begin]
245    } else if p_count == 0 && n_count > 0 {
246        vec![-1.0; end - begin]
247    } else {
248        let mut subparam = param.clone();
249        subparam.probability = false;
250        subparam.c = 1.0;
251        subparam.weight = vec![(1, cp), (-1, cn)];
252
253        let subprob = SvmProblem {
254            labels: sub_labels,
255            instances: sub_instances,
256        };
257        #[cfg(feature = "rayon")]
258        let submodel = crate::with_suppressed_info(|| svm_train(&subprob, &subparam));
259        #[cfg(not(feature = "rayon"))]
260        let submodel = svm_train(&subprob, &subparam);
261
262        (begin..end)
263            .map(|j| {
264                let mut dv = [0.0];
265                predict_values(&submodel, &prob.instances[perm[j]], &mut dv);
266                // Sign correction: ensure +1/−1 ordering
267                dv[0] * submodel.label[0] as f64
268            })
269            .collect()
270    }
271}
272
273/// Estimate Platt scaling parameters for a binary sub-problem.
274///
275/// Performs 5-fold CV internally: trains on 4 folds with class weights
276/// (cp, cn), collects decision values on the held-out fold, then fits
277/// a sigmoid via `sigmoid_train`.
278///
279/// Matches LIBSVM's `svm_binary_svc_probability`.
280pub fn svm_binary_svc_probability(
281    prob: &SvmProblem,
282    param: &SvmParameter,
283    cp: f64,
284    cn: f64,
285) -> (f64, f64) {
286    let l = prob.labels.len();
287    let nr_fold = 5.min(l.max(1));
288    let mut perm: Vec<usize> = (0..l).collect();
289    let mut dec_values = vec![0.0; l];
290
291    // Random shuffle (Fisher-Yates)
292    for i in 0..l {
293        let j = i + c_rand() % (l - i);
294        perm.swap(i, j);
295    }
296
297    #[cfg(feature = "rayon")]
298    {
299        let fold_predictions: Vec<Vec<f64>> = (0..nr_fold)
300            .into_par_iter()
301            .map(|fold| {
302                let begin = fold * l / nr_fold;
303                let end = (fold + 1) * l / nr_fold;
304                evaluate_binary_svc_probability_fold(prob, param, cp, cn, &perm, begin, end)
305            })
306            .collect();
307
308        for (fold, predictions) in fold_predictions.into_iter().enumerate() {
309            let begin = fold * l / nr_fold;
310            let end = (fold + 1) * l / nr_fold;
311            for (j, prediction) in (begin..end).zip(predictions.into_iter()) {
312                dec_values[perm[j]] = prediction;
313            }
314        }
315    }
316
317    #[cfg(not(feature = "rayon"))]
318    {
319        for fold in 0..nr_fold {
320            let begin = fold * l / nr_fold;
321            let end = (fold + 1) * l / nr_fold;
322            let predictions =
323                evaluate_binary_svc_probability_fold(prob, param, cp, cn, &perm, begin, end);
324
325            for (j, prediction) in (begin..end).zip(predictions.into_iter()) {
326                dec_values[perm[j]] = prediction;
327            }
328        }
329    }
330
331    sigmoid_train(&dec_values, &prob.labels)
332}
333
334// ─── One-class probability ───────────────────────────────────────────
335
336/// Predict probability for one-class SVM from density marks.
337///
338/// Bin-lookup in precomputed density marks (10 entries). Returns a
339/// probability estimate in (0, 1).
340///
341/// Matches LIBSVM's `predict_one_class_probability`.
342pub fn predict_one_class_probability(prob_density_marks: &[f64], dec_value: f64) -> f64 {
343    let nr_marks = prob_density_marks.len();
344    if nr_marks == 0 {
345        return 0.5;
346    }
347
348    if dec_value < prob_density_marks[0] {
349        return 0.001;
350    }
351    if dec_value > prob_density_marks[nr_marks - 1] {
352        return 0.999;
353    }
354
355    for (i, &mark) in prob_density_marks
356        .iter()
357        .enumerate()
358        .skip(1)
359        .take(nr_marks - 1)
360    {
361        if dec_value < mark {
362            return i as f64 / nr_marks as f64;
363        }
364    }
365
366    0.999
367}
368
369/// Estimate probability density marks for one-class SVM.
370///
371/// Predicts all training instances, sorts decision values, bins into
372/// 10 density marks. Returns `None` if fewer than 5 positive or 5
373/// negative decision values.
374///
375/// Matches LIBSVM's `svm_one_class_probability`.
376pub fn svm_one_class_probability(prob: &SvmProblem, model: &SvmModel) -> Option<Vec<f64>> {
377    let l = prob.labels.len();
378    let mut dec_values = vec![0.0; l];
379
380    for (dv_slot, instance) in dec_values.iter_mut().zip(prob.instances.iter()) {
381        let mut dv = [0.0];
382        predict_values(model, instance, &mut dv);
383        *dv_slot = dv[0];
384    }
385
386    dec_values.sort_by(f64::total_cmp);
387
388    // Find first index with dec_value >= 0  (= neg_counter in C++)
389    let mut neg_counter = 0usize;
390    for (i, &dv) in dec_values.iter().enumerate() {
391        if dv >= 0.0 {
392            neg_counter = i;
393            break;
394        }
395    }
396    let pos_counter = l - neg_counter;
397
398    let nr_marks: usize = 10;
399    let mid = nr_marks / 2; // 5
400
401    if neg_counter < mid || pos_counter < mid {
402        crate::info(&format!(
403            "WARNING: number of positive or negative decision values <{}; \
404             too few to do a probability estimation.\n",
405            mid
406        ));
407        return None;
408    }
409
410    let mut tmp_marks = vec![0.0; nr_marks + 1];
411
412    for i in 0..mid {
413        tmp_marks[i] = dec_values[i * neg_counter / mid];
414    }
415    tmp_marks[mid] = 0.0;
416    for i in (mid + 1)..=nr_marks {
417        tmp_marks[i] = dec_values[neg_counter - 1 + (i - mid) * pos_counter / mid];
418    }
419
420    let mut marks = vec![0.0; nr_marks];
421    for i in 0..nr_marks {
422        marks[i] = (tmp_marks[i] + tmp_marks[i + 1]) / 2.0;
423    }
424
425    Some(marks)
426}
427
428// ─── SVR probability ─────────────────────────────────────────────────
429
430/// Estimate Laplace scale parameter for SVR probability.
431///
432/// Performs 5-fold CV to get residuals, computes MAE, then applies
433/// outlier rejection (exclude |residual| > 5·√(2·mae²)) and
434/// recomputes. Returns the final MAE (= σ of Laplace distribution).
435///
436/// Matches LIBSVM's `svm_svr_probability`.
437pub fn svm_svr_probability(prob: &SvmProblem, param: &SvmParameter) -> f64 {
438    let l = prob.labels.len();
439    let nr_fold = 5;
440
441    let mut newparam = param.clone();
442    newparam.probability = false;
443    let ymv = crate::cross_validation::svm_cross_validation(prob, &newparam, nr_fold);
444
445    // Compute residuals and initial MAE
446    let mut ymv_residuals: Vec<f64> = Vec::with_capacity(l);
447    let mut mae = 0.0;
448    for (&label, &pred) in prob.labels.iter().zip(ymv.iter()) {
449        let r = label - pred;
450        ymv_residuals.push(r);
451        mae += r.abs();
452    }
453    mae /= l as f64;
454
455    // Outlier rejection
456    let std_val = (2.0 * mae * mae).sqrt();
457    let mut count = 0usize;
458    mae = 0.0;
459    for &residual in &ymv_residuals {
460        if residual.abs() > 5.0 * std_val {
461            count += 1;
462        } else {
463            mae += residual.abs();
464        }
465    }
466    mae /= (l - count) as f64;
467
468    crate::info(&format!(
469        "Prob. model for test data: target value = predicted value + z,\n\
470         z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= {:.6}\n",
471        mae
472    ));
473
474    mae
475}
476
477// ─── Tests ───────────────────────────────────────────────────────────
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::io::load_problem;
483    use crate::train::svm_train;
484    use crate::types::{KernelType, SvmNode, SvmParameter, SvmProblem, SvmType};
485    use std::path::PathBuf;
486
487    fn data_dir() -> PathBuf {
488        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
489            .join("..")
490            .join("..")
491            .join("data")
492    }
493
494    #[test]
495    fn sigmoid_predict_symmetric() {
496        let p = sigmoid_predict(0.0, 0.0, 0.0);
497        assert!((p - 0.5).abs() < 1e-10);
498    }
499
500    #[test]
501    fn sigmoid_predict_stable() {
502        let p1 = sigmoid_predict(1000.0, 1.0, 0.0);
503        assert!(p1.is_finite() && (0.0..=1.0).contains(&p1));
504
505        let p2 = sigmoid_predict(-1000.0, 1.0, 0.0);
506        assert!(p2.is_finite() && (0.0..=1.0).contains(&p2));
507    }
508
509    #[test]
510    fn sigmoid_train_basic() {
511        let dec = vec![1.0, 2.0, -1.0, -2.0, 0.5];
512        let lab = vec![1.0, 1.0, -1.0, -1.0, 1.0];
513        let (a, b) = sigmoid_train(&dec, &lab);
514        assert!(a.is_finite());
515        assert!(b.is_finite());
516    }
517
518    #[test]
519    fn multiclass_prob_sums_to_one() {
520        let k = 3;
521        let r = vec![
522            vec![0.0, 0.6, 0.5],
523            vec![0.4, 0.0, 0.7],
524            vec![0.5, 0.3, 0.0],
525        ];
526        let mut p = vec![0.0; k];
527        multiclass_probability(k, &r, &mut p);
528
529        let sum: f64 = p.iter().sum();
530        assert!(
531            (sum - 1.0).abs() < 1e-6,
532            "probabilities sum to {}, expected ~1.0",
533            sum
534        );
535        for &pi in &p {
536            assert!(pi > 0.0, "probability should be positive, got {}", pi);
537        }
538    }
539
540    #[test]
541    fn predict_one_class_prob_boundaries() {
542        let marks = vec![-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9];
543        assert!((predict_one_class_probability(&marks, -1.0) - 0.001).abs() < 1e-10);
544        assert!((predict_one_class_probability(&marks, 1.0) - 0.999).abs() < 1e-10);
545        let mid = predict_one_class_probability(&marks, 0.0);
546        assert!(mid > 0.0 && mid < 1.0);
547    }
548
549    #[test]
550    fn predict_one_class_probability_empty_marks_is_half() {
551        let p = predict_one_class_probability(&[], 0.25);
552        assert!((p - 0.5).abs() < 1e-12);
553    }
554
555    #[test]
556    fn binary_svc_probability_is_finite() {
557        let prob = load_problem(&data_dir().join("heart_scale")).unwrap();
558        let param = SvmParameter {
559            svm_type: SvmType::CSvc,
560            kernel_type: KernelType::Rbf,
561            gamma: 1.0 / 13.0,
562            ..Default::default()
563        };
564
565        let (a, b) = svm_binary_svc_probability(&prob, &param, 1.0, 1.0);
566        assert!(a.is_finite());
567        assert!(b.is_finite());
568    }
569
570    #[test]
571    fn one_class_probability_marks_generated_for_large_problem() {
572        let prob = load_problem(&data_dir().join("heart_scale")).unwrap();
573        let param = SvmParameter {
574            svm_type: SvmType::OneClass,
575            kernel_type: KernelType::Rbf,
576            gamma: 1.0 / 13.0,
577            nu: 0.5,
578            ..Default::default()
579        };
580        let model = svm_train(&prob, &param);
581
582        let marks = svm_one_class_probability(&prob, &model).expect("expected marks");
583        assert_eq!(marks.len(), 10);
584        for pair in marks.windows(2) {
585            assert!(pair[0] <= pair[1]);
586        }
587    }
588
589    #[test]
590    fn one_class_probability_returns_none_when_too_few_samples() {
591        let prob = SvmProblem {
592            labels: vec![1.0, 1.0, 1.0, 1.0],
593            instances: vec![
594                vec![SvmNode {
595                    index: 1,
596                    value: 0.0,
597                }],
598                vec![SvmNode {
599                    index: 1,
600                    value: 1.0,
601                }],
602                vec![SvmNode {
603                    index: 1,
604                    value: 2.0,
605                }],
606                vec![SvmNode {
607                    index: 1,
608                    value: 3.0,
609                }],
610            ],
611        };
612        let param = SvmParameter {
613            svm_type: SvmType::OneClass,
614            kernel_type: KernelType::Linear,
615            nu: 0.5,
616            ..Default::default()
617        };
618        let model = svm_train(&prob, &param);
619        assert!(svm_one_class_probability(&prob, &model).is_none());
620    }
621
622    #[test]
623    fn svr_probability_is_positive_and_finite() {
624        let prob = load_problem(&data_dir().join("housing_scale")).unwrap();
625        let param = SvmParameter {
626            svm_type: SvmType::EpsilonSvr,
627            kernel_type: KernelType::Rbf,
628            gamma: 1.0 / 13.0,
629            c: 1.0,
630            p: 0.1,
631            ..Default::default()
632        };
633
634        let sigma = svm_svr_probability(&prob, &param);
635        assert!(sigma.is_finite());
636        assert!(sigma > 0.0);
637    }
638}