Skip to main content

libsvm_rs/
probability.rs

1//! Probability estimation functions for SVM models.
2//!
3//! Provides Platt scaling (`sigmoid_train`/`sigmoid_predict`), multiclass
4//! probability estimation, and density-based probability for one-class SVM.
5//! Matches the original LIBSVM's probability routines (svm.cpp:1714–2096).
6
7use crate::predict::predict_values;
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmParameter, SvmProblem};
10use std::sync::{Mutex, OnceLock};
11
12// ─── RNG helper ──────────────────────────────────────────────────────
13
14#[cfg(target_os = "macos")]
15fn c_rand() -> usize {
16    static STATE: OnceLock<Mutex<u32>> = OnceLock::new();
17    let state = STATE.get_or_init(|| Mutex::new(1));
18    let mut guard = state
19        .lock()
20        .unwrap_or_else(|poisoned| poisoned.into_inner());
21
22    // macOS/BSD libc rand() uses Park-Miller (MINSTD) with 31-bit modulus.
23    let hi = *guard / 127_773;
24    let lo = *guard % 127_773;
25    let test = 16_807_i64 * lo as i64 - 2_836_i64 * hi as i64;
26    *guard = if test > 0 {
27        test as u32
28    } else {
29        (test + 2_147_483_647) as u32
30    };
31    *guard as usize
32}
33
34#[cfg(not(target_os = "macos"))]
35fn c_rand() -> usize {
36    static STATE: OnceLock<Mutex<u64>> = OnceLock::new();
37    let state = STATE.get_or_init(|| Mutex::new(1));
38    let mut guard = state
39        .lock()
40        .unwrap_or_else(|poisoned| poisoned.into_inner());
41
42    // Deterministic fallback for non-macOS builds.
43    *guard = guard
44        .wrapping_mul(6364136223846793005)
45        .wrapping_add(1442695040888963407);
46    (*guard >> 33) as usize
47}
48
49// ─── Platt scaling ───────────────────────────────────────────────────
50
51/// Train Platt scaling parameters (A, B) via Newton's method.
52///
53/// Given decision values and labels (+1/−1), fits the sigmoid
54/// P(y=1|f) = 1/(1+exp(A*f+B)) using the algorithm of Lin, Lin &
55/// Weng (2007). Matches LIBSVM's `sigmoid_train`.
56pub fn sigmoid_train(dec_values: &[f64], labels: &[f64]) -> (f64, f64) {
57    let l = dec_values.len();
58
59    let mut prior1: f64 = 0.0;
60    let mut prior0: f64 = 0.0;
61    for &y in labels {
62        if y > 0.0 {
63            prior1 += 1.0;
64        } else {
65            prior0 += 1.0;
66        }
67    }
68
69    let max_iter = 100;
70    let min_step = 1e-10;
71    let sigma = 1e-12;
72    let eps = 1e-5;
73
74    let hi_target = (prior1 + 1.0) / (prior1 + 2.0);
75    let lo_target = 1.0 / (prior0 + 2.0);
76
77    let t: Vec<f64> = labels
78        .iter()
79        .map(|&y| if y > 0.0 { hi_target } else { lo_target })
80        .collect();
81
82    // Initial point
83    let mut a = 0.0;
84    let mut b = ((prior0 + 1.0) / (prior1 + 1.0)).ln();
85
86    // Initial objective
87    let mut fval = 0.0;
88    for i in 0..l {
89        let f_apb = dec_values[i] * a + b;
90        if f_apb >= 0.0 {
91            fval += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
92        } else {
93            fval += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
94        }
95    }
96
97    for _iter in 0..max_iter {
98        // Gradient and Hessian (H' = H + σI)
99        let mut h11 = sigma;
100        let mut h22 = sigma;
101        let mut h21 = 0.0;
102        let mut g1 = 0.0;
103        let mut g2 = 0.0;
104
105        for i in 0..l {
106            let f_apb = dec_values[i] * a + b;
107            let (p, q) = if f_apb >= 0.0 {
108                let e = (-f_apb).exp();
109                (e / (1.0 + e), 1.0 / (1.0 + e))
110            } else {
111                let e = f_apb.exp();
112                (1.0 / (1.0 + e), e / (1.0 + e))
113            };
114            let d2 = p * q;
115            h11 += dec_values[i] * dec_values[i] * d2;
116            h22 += d2;
117            h21 += dec_values[i] * d2;
118            let d1 = t[i] - p;
119            g1 += dec_values[i] * d1;
120            g2 += d1;
121        }
122
123        if g1.abs() < eps && g2.abs() < eps {
124            break;
125        }
126
127        // Newton direction: −H'⁻¹ g
128        let det = h11 * h22 - h21 * h21;
129        let da = -(h22 * g1 - h21 * g2) / det;
130        let db = -(-h21 * g1 + h11 * g2) / det;
131        let gd = g1 * da + g2 * db;
132
133        // Line search with step-size halving
134        let mut stepsize = 1.0;
135        while stepsize >= min_step {
136            let new_a = a + stepsize * da;
137            let new_b = b + stepsize * db;
138
139            let mut newf = 0.0;
140            for i in 0..l {
141                let f_apb = dec_values[i] * new_a + new_b;
142                if f_apb >= 0.0 {
143                    newf += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
144                } else {
145                    newf += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
146                }
147            }
148
149            if newf < fval + 0.0001 * stepsize * gd {
150                a = new_a;
151                b = new_b;
152                fval = newf;
153                break;
154            }
155            stepsize /= 2.0;
156        }
157
158        if stepsize < min_step {
159            break;
160        }
161    }
162
163    (a, b)
164}
165
166/// Numerically stable sigmoid prediction.
167///
168/// Returns P(y=1|f) = 1/(1+exp(A*f+B)), branching on sign of A*f+B
169/// to avoid overflow. Matches LIBSVM's `sigmoid_predict`.
170pub fn sigmoid_predict(decision_value: f64, a: f64, b: f64) -> f64 {
171    let f_apb = decision_value * a + b;
172    if f_apb >= 0.0 {
173        (-f_apb).exp() / (1.0 + (-f_apb).exp())
174    } else {
175        1.0 / (1.0 + f_apb.exp())
176    }
177}
178
179// ─── Multiclass probability ──────────────────────────────────────────
180
181/// Solve multiclass probabilities from pairwise estimates.
182///
183/// Given `k` classes and a k×k matrix `r` of pairwise probabilities
184/// (r\[i\]\[j\] = P(class i | class i or j)), fills `p` with class
185/// probabilities using the Wu-Lin-Weng iterative method.
186///
187/// Matches LIBSVM's `multiclass_probability`.
188#[allow(clippy::needless_range_loop)]
189pub fn multiclass_probability(k: usize, r: &[Vec<f64>], p: &mut [f64]) {
190    let max_iter = 100.max(k);
191    let eps = 0.005 / k as f64;
192
193    // Build Q matrix
194    let mut q_mat = vec![vec![0.0; k]; k];
195    for t in 0..k {
196        q_mat[t][t] = 0.0;
197        for j in 0..t {
198            q_mat[t][t] += r[j][t] * r[j][t];
199            q_mat[t][j] = q_mat[j][t];
200        }
201        for j in (t + 1)..k {
202            q_mat[t][t] += r[j][t] * r[j][t];
203            q_mat[t][j] = -r[j][t] * r[t][j];
204        }
205    }
206
207    for t in 0..k {
208        p[t] = 1.0 / k as f64;
209    }
210
211    let mut qp = vec![0.0; k];
212
213    for _iter in 0..max_iter {
214        let mut p_qp = 0.0;
215        for t in 0..k {
216            qp[t] = 0.0;
217            for j in 0..k {
218                qp[t] += q_mat[t][j] * p[j];
219            }
220            p_qp += p[t] * qp[t];
221        }
222
223        let mut max_error = 0.0;
224        for t in 0..k {
225            let error = (qp[t] - p_qp).abs();
226            if error > max_error {
227                max_error = error;
228            }
229        }
230        if max_error < eps {
231            break;
232        }
233
234        for t in 0..k {
235            let diff = (-qp[t] + p_qp) / q_mat[t][t];
236            p[t] += diff;
237            p_qp = (p_qp + diff * (diff * q_mat[t][t] + 2.0 * qp[t])) / (1.0 + diff) / (1.0 + diff);
238            for j in 0..k {
239                qp[j] = (qp[j] + diff * q_mat[t][j]) / (1.0 + diff);
240                p[j] /= 1.0 + diff;
241            }
242        }
243    }
244}
245
246// ─── Binary SVC probability via internal CV ──────────────────────────
247
248/// Estimate Platt scaling parameters for a binary sub-problem.
249///
250/// Performs 5-fold CV internally: trains on 4 folds with class weights
251/// (cp, cn), collects decision values on the held-out fold, then fits
252/// a sigmoid via `sigmoid_train`.
253///
254/// Matches LIBSVM's `svm_binary_svc_probability`.
255pub fn svm_binary_svc_probability(
256    prob: &SvmProblem,
257    param: &SvmParameter,
258    cp: f64,
259    cn: f64,
260) -> (f64, f64) {
261    let l = prob.labels.len();
262    let nr_fold = 5.min(l.max(1));
263    let mut perm: Vec<usize> = (0..l).collect();
264    let mut dec_values = vec![0.0; l];
265
266    // Random shuffle (Fisher-Yates)
267    for i in 0..l {
268        let j = i + c_rand() % (l - i);
269        perm.swap(i, j);
270    }
271
272    for fold in 0..nr_fold {
273        let begin = fold * l / nr_fold;
274        let end = (fold + 1) * l / nr_fold;
275
276        // Build training sub-problem (exclude held-out fold)
277        let mut sub_instances = Vec::with_capacity(l - (end - begin));
278        let mut sub_labels = Vec::with_capacity(l - (end - begin));
279
280        for &pi in &perm[..begin] {
281            sub_instances.push(prob.instances[pi].clone());
282            sub_labels.push(prob.labels[pi]);
283        }
284        for &pi in &perm[end..l] {
285            sub_instances.push(prob.instances[pi].clone());
286            sub_labels.push(prob.labels[pi]);
287        }
288
289        // Count classes in training set
290        let p_count = sub_labels.iter().filter(|&&y| y > 0.0).count();
291        let n_count = sub_labels.len() - p_count;
292
293        if p_count == 0 && n_count == 0 {
294            for j in begin..end {
295                dec_values[perm[j]] = 0.0;
296            }
297        } else if p_count > 0 && n_count == 0 {
298            for j in begin..end {
299                dec_values[perm[j]] = 1.0;
300            }
301        } else if p_count == 0 && n_count > 0 {
302            for j in begin..end {
303                dec_values[perm[j]] = -1.0;
304            }
305        } else {
306            let mut subparam = param.clone();
307            subparam.probability = false;
308            subparam.c = 1.0;
309            subparam.weight = vec![(1, cp), (-1, cn)];
310
311            let subprob = SvmProblem {
312                labels: sub_labels,
313                instances: sub_instances,
314            };
315            let submodel = svm_train(&subprob, &subparam);
316
317            for j in begin..end {
318                let mut dv = [0.0];
319                predict_values(&submodel, &prob.instances[perm[j]], &mut dv);
320                // Sign correction: ensure +1/−1 ordering
321                dec_values[perm[j]] = dv[0] * submodel.label[0] as f64;
322            }
323        }
324    }
325
326    sigmoid_train(&dec_values, &prob.labels)
327}
328
329// ─── One-class probability ───────────────────────────────────────────
330
331/// Predict probability for one-class SVM from density marks.
332///
333/// Bin-lookup in precomputed density marks (10 entries). Returns a
334/// probability estimate in (0, 1).
335///
336/// Matches LIBSVM's `predict_one_class_probability`.
337pub fn predict_one_class_probability(prob_density_marks: &[f64], dec_value: f64) -> f64 {
338    let nr_marks = prob_density_marks.len();
339    if nr_marks == 0 {
340        return 0.5;
341    }
342
343    if dec_value < prob_density_marks[0] {
344        return 0.001;
345    }
346    if dec_value > prob_density_marks[nr_marks - 1] {
347        return 0.999;
348    }
349
350    for (i, &mark) in prob_density_marks
351        .iter()
352        .enumerate()
353        .skip(1)
354        .take(nr_marks - 1)
355    {
356        if dec_value < mark {
357            return i as f64 / nr_marks as f64;
358        }
359    }
360
361    0.999
362}
363
364/// Estimate probability density marks for one-class SVM.
365///
366/// Predicts all training instances, sorts decision values, bins into
367/// 10 density marks. Returns `None` if fewer than 5 positive or 5
368/// negative decision values.
369///
370/// Matches LIBSVM's `svm_one_class_probability`.
371pub fn svm_one_class_probability(prob: &SvmProblem, model: &SvmModel) -> Option<Vec<f64>> {
372    let l = prob.labels.len();
373    let mut dec_values = vec![0.0; l];
374
375    for (dv_slot, instance) in dec_values.iter_mut().zip(prob.instances.iter()) {
376        let mut dv = [0.0];
377        predict_values(model, instance, &mut dv);
378        *dv_slot = dv[0];
379    }
380
381    dec_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
382
383    // Find first index with dec_value >= 0  (= neg_counter in C++)
384    let mut neg_counter = 0usize;
385    for (i, &dv) in dec_values.iter().enumerate() {
386        if dv >= 0.0 {
387            neg_counter = i;
388            break;
389        }
390    }
391    let pos_counter = l - neg_counter;
392
393    let nr_marks: usize = 10;
394    let mid = nr_marks / 2; // 5
395
396    if neg_counter < mid || pos_counter < mid {
397        crate::info(&format!(
398            "WARNING: number of positive or negative decision values <{}; \
399             too few to do a probability estimation.\n",
400            mid
401        ));
402        return None;
403    }
404
405    let mut tmp_marks = vec![0.0; nr_marks + 1];
406
407    for i in 0..mid {
408        tmp_marks[i] = dec_values[i * neg_counter / mid];
409    }
410    tmp_marks[mid] = 0.0;
411    for i in (mid + 1)..=nr_marks {
412        tmp_marks[i] = dec_values[neg_counter - 1 + (i - mid) * pos_counter / mid];
413    }
414
415    let mut marks = vec![0.0; nr_marks];
416    for i in 0..nr_marks {
417        marks[i] = (tmp_marks[i] + tmp_marks[i + 1]) / 2.0;
418    }
419
420    Some(marks)
421}
422
423// ─── SVR probability ─────────────────────────────────────────────────
424
425/// Estimate Laplace scale parameter for SVR probability.
426///
427/// Performs 5-fold CV to get residuals, computes MAE, then applies
428/// outlier rejection (exclude |residual| > 5·√(2·mae²)) and
429/// recomputes. Returns the final MAE (= σ of Laplace distribution).
430///
431/// Matches LIBSVM's `svm_svr_probability`.
432pub fn svm_svr_probability(prob: &SvmProblem, param: &SvmParameter) -> f64 {
433    let l = prob.labels.len();
434    let nr_fold = 5;
435
436    let mut newparam = param.clone();
437    newparam.probability = false;
438    let ymv = crate::cross_validation::svm_cross_validation(prob, &newparam, nr_fold);
439
440    // Compute residuals and initial MAE
441    let mut ymv_residuals: Vec<f64> = Vec::with_capacity(l);
442    let mut mae = 0.0;
443    for (&label, &pred) in prob.labels.iter().zip(ymv.iter()) {
444        let r = label - pred;
445        ymv_residuals.push(r);
446        mae += r.abs();
447    }
448    mae /= l as f64;
449
450    // Outlier rejection
451    let std_val = (2.0 * mae * mae).sqrt();
452    let mut count = 0usize;
453    mae = 0.0;
454    for &residual in &ymv_residuals {
455        if residual.abs() > 5.0 * std_val {
456            count += 1;
457        } else {
458            mae += residual.abs();
459        }
460    }
461    mae /= (l - count) as f64;
462
463    crate::info(&format!(
464        "Prob. model for test data: target value = predicted value + z,\n\
465         z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= {:.6}\n",
466        mae
467    ));
468
469    mae
470}
471
472// ─── Tests ───────────────────────────────────────────────────────────
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::io::load_problem;
478    use crate::train::svm_train;
479    use crate::types::{KernelType, SvmNode, SvmParameter, SvmProblem, SvmType};
480    use std::path::PathBuf;
481
482    fn data_dir() -> PathBuf {
483        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
484            .join("..")
485            .join("..")
486            .join("data")
487    }
488
489    #[test]
490    fn sigmoid_predict_symmetric() {
491        let p = sigmoid_predict(0.0, 0.0, 0.0);
492        assert!((p - 0.5).abs() < 1e-10);
493    }
494
495    #[test]
496    fn sigmoid_predict_stable() {
497        let p1 = sigmoid_predict(1000.0, 1.0, 0.0);
498        assert!(p1.is_finite() && (0.0..=1.0).contains(&p1));
499
500        let p2 = sigmoid_predict(-1000.0, 1.0, 0.0);
501        assert!(p2.is_finite() && (0.0..=1.0).contains(&p2));
502    }
503
504    #[test]
505    fn sigmoid_train_basic() {
506        let dec = vec![1.0, 2.0, -1.0, -2.0, 0.5];
507        let lab = vec![1.0, 1.0, -1.0, -1.0, 1.0];
508        let (a, b) = sigmoid_train(&dec, &lab);
509        assert!(a.is_finite());
510        assert!(b.is_finite());
511    }
512
513    #[test]
514    fn multiclass_prob_sums_to_one() {
515        let k = 3;
516        let r = vec![
517            vec![0.0, 0.6, 0.5],
518            vec![0.4, 0.0, 0.7],
519            vec![0.5, 0.3, 0.0],
520        ];
521        let mut p = vec![0.0; k];
522        multiclass_probability(k, &r, &mut p);
523
524        let sum: f64 = p.iter().sum();
525        assert!(
526            (sum - 1.0).abs() < 1e-6,
527            "probabilities sum to {}, expected ~1.0",
528            sum
529        );
530        for &pi in &p {
531            assert!(pi > 0.0, "probability should be positive, got {}", pi);
532        }
533    }
534
535    #[test]
536    fn predict_one_class_prob_boundaries() {
537        let marks = vec![-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9];
538        assert!((predict_one_class_probability(&marks, -1.0) - 0.001).abs() < 1e-10);
539        assert!((predict_one_class_probability(&marks, 1.0) - 0.999).abs() < 1e-10);
540        let mid = predict_one_class_probability(&marks, 0.0);
541        assert!(mid > 0.0 && mid < 1.0);
542    }
543
544    #[test]
545    fn predict_one_class_probability_empty_marks_is_half() {
546        let p = predict_one_class_probability(&[], 0.25);
547        assert!((p - 0.5).abs() < 1e-12);
548    }
549
550    #[test]
551    fn binary_svc_probability_is_finite() {
552        let prob = load_problem(&data_dir().join("heart_scale")).unwrap();
553        let param = SvmParameter {
554            svm_type: SvmType::CSvc,
555            kernel_type: KernelType::Rbf,
556            gamma: 1.0 / 13.0,
557            ..Default::default()
558        };
559
560        let (a, b) = svm_binary_svc_probability(&prob, &param, 1.0, 1.0);
561        assert!(a.is_finite());
562        assert!(b.is_finite());
563    }
564
565    #[test]
566    fn one_class_probability_marks_generated_for_large_problem() {
567        let prob = load_problem(&data_dir().join("heart_scale")).unwrap();
568        let param = SvmParameter {
569            svm_type: SvmType::OneClass,
570            kernel_type: KernelType::Rbf,
571            gamma: 1.0 / 13.0,
572            nu: 0.5,
573            ..Default::default()
574        };
575        let model = svm_train(&prob, &param);
576
577        let marks = svm_one_class_probability(&prob, &model).expect("expected marks");
578        assert_eq!(marks.len(), 10);
579        for pair in marks.windows(2) {
580            assert!(pair[0] <= pair[1]);
581        }
582    }
583
584    #[test]
585    fn one_class_probability_returns_none_when_too_few_samples() {
586        let prob = SvmProblem {
587            labels: vec![1.0, 1.0, 1.0, 1.0],
588            instances: vec![
589                vec![SvmNode {
590                    index: 1,
591                    value: 0.0,
592                }],
593                vec![SvmNode {
594                    index: 1,
595                    value: 1.0,
596                }],
597                vec![SvmNode {
598                    index: 1,
599                    value: 2.0,
600                }],
601                vec![SvmNode {
602                    index: 1,
603                    value: 3.0,
604                }],
605            ],
606        };
607        let param = SvmParameter {
608            svm_type: SvmType::OneClass,
609            kernel_type: KernelType::Linear,
610            nu: 0.5,
611            ..Default::default()
612        };
613        let model = svm_train(&prob, &param);
614        assert!(svm_one_class_probability(&prob, &model).is_none());
615    }
616
617    #[test]
618    fn svr_probability_is_positive_and_finite() {
619        let prob = load_problem(&data_dir().join("housing_scale")).unwrap();
620        let param = SvmParameter {
621            svm_type: SvmType::EpsilonSvr,
622            kernel_type: KernelType::Rbf,
623            gamma: 1.0 / 13.0,
624            c: 1.0,
625            p: 0.1,
626            ..Default::default()
627        };
628
629        let sigma = svm_svr_probability(&prob, &param);
630        assert!(sigma.is_finite());
631        assert!(sigma > 0.0);
632    }
633}