Skip to main content

libsvm_rs/
probability.rs

1//! Probability estimation functions for SVM models.
2//!
3//! Provides Platt scaling (`sigmoid_train`/`sigmoid_predict`), multiclass
4//! probability estimation, and density-based probability for one-class SVM.
5//! Matches the original LIBSVM's probability routines (svm.cpp:1714–2096).
6
7use crate::predict::predict_values;
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmParameter, SvmProblem};
10use std::sync::{Mutex, OnceLock};
11
12// ─── RNG helper ──────────────────────────────────────────────────────
13
14#[cfg(target_os = "macos")]
15fn c_rand() -> usize {
16    static STATE: OnceLock<Mutex<u32>> = OnceLock::new();
17    let state = STATE.get_or_init(|| Mutex::new(1));
18    let mut guard = state.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
19
20    // macOS/BSD libc rand() uses Park-Miller (MINSTD) with 31-bit modulus.
21    let hi = *guard / 127_773;
22    let lo = *guard % 127_773;
23    let test = 16_807_i64 * lo as i64 - 2_836_i64 * hi as i64;
24    *guard = if test > 0 {
25        test as u32
26    } else {
27        (test + 2_147_483_647) as u32
28    };
29    *guard as usize
30}
31
32#[cfg(not(target_os = "macos"))]
33fn c_rand() -> usize {
34    static STATE: OnceLock<Mutex<u64>> = OnceLock::new();
35    let state = STATE.get_or_init(|| Mutex::new(1));
36    let mut guard = state.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
37
38    // Deterministic fallback for non-macOS builds.
39    *guard = guard
40        .wrapping_mul(6364136223846793005)
41        .wrapping_add(1442695040888963407);
42    (*guard >> 33) as usize
43}
44
45// ─── Platt scaling ───────────────────────────────────────────────────
46
47/// Train Platt scaling parameters (A, B) via Newton's method.
48///
49/// Given decision values and labels (+1/−1), fits the sigmoid
50/// P(y=1|f) = 1/(1+exp(A*f+B)) using the algorithm of Lin, Lin &
51/// Weng (2007). Matches LIBSVM's `sigmoid_train`.
52pub fn sigmoid_train(dec_values: &[f64], labels: &[f64]) -> (f64, f64) {
53    let l = dec_values.len();
54
55    let mut prior1: f64 = 0.0;
56    let mut prior0: f64 = 0.0;
57    for &y in labels {
58        if y > 0.0 {
59            prior1 += 1.0;
60        } else {
61            prior0 += 1.0;
62        }
63    }
64
65    let max_iter = 100;
66    let min_step = 1e-10;
67    let sigma = 1e-12;
68    let eps = 1e-5;
69
70    let hi_target = (prior1 + 1.0) / (prior1 + 2.0);
71    let lo_target = 1.0 / (prior0 + 2.0);
72
73    let t: Vec<f64> = labels
74        .iter()
75        .map(|&y| if y > 0.0 { hi_target } else { lo_target })
76        .collect();
77
78    // Initial point
79    let mut a = 0.0;
80    let mut b = ((prior0 + 1.0) / (prior1 + 1.0)).ln();
81
82    // Initial objective
83    let mut fval = 0.0;
84    for i in 0..l {
85        let f_apb = dec_values[i] * a + b;
86        if f_apb >= 0.0 {
87            fval += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
88        } else {
89            fval += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
90        }
91    }
92
93    for _iter in 0..max_iter {
94        // Gradient and Hessian (H' = H + σI)
95        let mut h11 = sigma;
96        let mut h22 = sigma;
97        let mut h21 = 0.0;
98        let mut g1 = 0.0;
99        let mut g2 = 0.0;
100
101        for i in 0..l {
102            let f_apb = dec_values[i] * a + b;
103            let (p, q) = if f_apb >= 0.0 {
104                let e = (-f_apb).exp();
105                (e / (1.0 + e), 1.0 / (1.0 + e))
106            } else {
107                let e = f_apb.exp();
108                (1.0 / (1.0 + e), e / (1.0 + e))
109            };
110            let d2 = p * q;
111            h11 += dec_values[i] * dec_values[i] * d2;
112            h22 += d2;
113            h21 += dec_values[i] * d2;
114            let d1 = t[i] - p;
115            g1 += dec_values[i] * d1;
116            g2 += d1;
117        }
118
119        if g1.abs() < eps && g2.abs() < eps {
120            break;
121        }
122
123        // Newton direction: −H'⁻¹ g
124        let det = h11 * h22 - h21 * h21;
125        let da = -(h22 * g1 - h21 * g2) / det;
126        let db = -(-h21 * g1 + h11 * g2) / det;
127        let gd = g1 * da + g2 * db;
128
129        // Line search with step-size halving
130        let mut stepsize = 1.0;
131        while stepsize >= min_step {
132            let new_a = a + stepsize * da;
133            let new_b = b + stepsize * db;
134
135            let mut newf = 0.0;
136            for i in 0..l {
137                let f_apb = dec_values[i] * new_a + new_b;
138                if f_apb >= 0.0 {
139                    newf += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
140                } else {
141                    newf += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
142                }
143            }
144
145            if newf < fval + 0.0001 * stepsize * gd {
146                a = new_a;
147                b = new_b;
148                fval = newf;
149                break;
150            }
151            stepsize /= 2.0;
152        }
153
154        if stepsize < min_step {
155            break;
156        }
157    }
158
159    (a, b)
160}
161
162/// Numerically stable sigmoid prediction.
163///
164/// Returns P(y=1|f) = 1/(1+exp(A*f+B)), branching on sign of A*f+B
165/// to avoid overflow. Matches LIBSVM's `sigmoid_predict`.
166pub fn sigmoid_predict(decision_value: f64, a: f64, b: f64) -> f64 {
167    let f_apb = decision_value * a + b;
168    if f_apb >= 0.0 {
169        (-f_apb).exp() / (1.0 + (-f_apb).exp())
170    } else {
171        1.0 / (1.0 + f_apb.exp())
172    }
173}
174
175// ─── Multiclass probability ──────────────────────────────────────────
176
177/// Solve multiclass probabilities from pairwise estimates.
178///
179/// Given `k` classes and a k×k matrix `r` of pairwise probabilities
180/// (r\[i\]\[j\] = P(class i | class i or j)), fills `p` with class
181/// probabilities using the Wu-Lin-Weng iterative method.
182///
183/// Matches LIBSVM's `multiclass_probability`.
184#[allow(clippy::needless_range_loop)]
185pub fn multiclass_probability(k: usize, r: &[Vec<f64>], p: &mut [f64]) {
186    let max_iter = 100.max(k);
187    let eps = 0.005 / k as f64;
188
189    // Build Q matrix
190    let mut q_mat = vec![vec![0.0; k]; k];
191    for t in 0..k {
192        q_mat[t][t] = 0.0;
193        for j in 0..t {
194            q_mat[t][t] += r[j][t] * r[j][t];
195            q_mat[t][j] = q_mat[j][t];
196        }
197        for j in (t + 1)..k {
198            q_mat[t][t] += r[j][t] * r[j][t];
199            q_mat[t][j] = -r[j][t] * r[t][j];
200        }
201    }
202
203    for t in 0..k {
204        p[t] = 1.0 / k as f64;
205    }
206
207    let mut qp = vec![0.0; k];
208
209    for _iter in 0..max_iter {
210        let mut p_qp = 0.0;
211        for t in 0..k {
212            qp[t] = 0.0;
213            for j in 0..k {
214                qp[t] += q_mat[t][j] * p[j];
215            }
216            p_qp += p[t] * qp[t];
217        }
218
219        let mut max_error = 0.0;
220        for t in 0..k {
221            let error = (qp[t] - p_qp).abs();
222            if error > max_error {
223                max_error = error;
224            }
225        }
226        if max_error < eps {
227            break;
228        }
229
230        for t in 0..k {
231            let diff = (-qp[t] + p_qp) / q_mat[t][t];
232            p[t] += diff;
233            p_qp = (p_qp + diff * (diff * q_mat[t][t] + 2.0 * qp[t])) / (1.0 + diff) / (1.0 + diff);
234            for j in 0..k {
235                qp[j] = (qp[j] + diff * q_mat[t][j]) / (1.0 + diff);
236                p[j] /= 1.0 + diff;
237            }
238        }
239    }
240}
241
242// ─── Binary SVC probability via internal CV ──────────────────────────
243
244/// Estimate Platt scaling parameters for a binary sub-problem.
245///
246/// Performs 5-fold CV internally: trains on 4 folds with class weights
247/// (cp, cn), collects decision values on the held-out fold, then fits
248/// a sigmoid via `sigmoid_train`.
249///
250/// Matches LIBSVM's `svm_binary_svc_probability`.
251pub fn svm_binary_svc_probability(
252    prob: &SvmProblem,
253    param: &SvmParameter,
254    cp: f64,
255    cn: f64,
256) -> (f64, f64) {
257    let l = prob.labels.len();
258    let nr_fold = 5.min(l.max(1));
259    let mut perm: Vec<usize> = (0..l).collect();
260    let mut dec_values = vec![0.0; l];
261
262    // Random shuffle (Fisher-Yates)
263    for i in 0..l {
264        let j = i + c_rand() % (l - i);
265        perm.swap(i, j);
266    }
267
268    for fold in 0..nr_fold {
269        let begin = fold * l / nr_fold;
270        let end = (fold + 1) * l / nr_fold;
271
272        // Build training sub-problem (exclude held-out fold)
273        let mut sub_instances = Vec::with_capacity(l - (end - begin));
274        let mut sub_labels = Vec::with_capacity(l - (end - begin));
275
276        for &pi in &perm[..begin] {
277            sub_instances.push(prob.instances[pi].clone());
278            sub_labels.push(prob.labels[pi]);
279        }
280        for &pi in &perm[end..l] {
281            sub_instances.push(prob.instances[pi].clone());
282            sub_labels.push(prob.labels[pi]);
283        }
284
285        // Count classes in training set
286        let p_count = sub_labels.iter().filter(|&&y| y > 0.0).count();
287        let n_count = sub_labels.len() - p_count;
288
289        if p_count == 0 && n_count == 0 {
290            for j in begin..end {
291                dec_values[perm[j]] = 0.0;
292            }
293        } else if p_count > 0 && n_count == 0 {
294            for j in begin..end {
295                dec_values[perm[j]] = 1.0;
296            }
297        } else if p_count == 0 && n_count > 0 {
298            for j in begin..end {
299                dec_values[perm[j]] = -1.0;
300            }
301        } else {
302            let mut subparam = param.clone();
303            subparam.probability = false;
304            subparam.c = 1.0;
305            subparam.weight = vec![(1, cp), (-1, cn)];
306
307            let subprob = SvmProblem {
308                labels: sub_labels,
309                instances: sub_instances,
310            };
311            let submodel = svm_train(&subprob, &subparam);
312
313            for j in begin..end {
314                let mut dv = [0.0];
315                predict_values(&submodel, &prob.instances[perm[j]], &mut dv);
316                // Sign correction: ensure +1/−1 ordering
317                dec_values[perm[j]] = dv[0] * submodel.label[0] as f64;
318            }
319        }
320    }
321
322    sigmoid_train(&dec_values, &prob.labels)
323}
324
325// ─── One-class probability ───────────────────────────────────────────
326
327/// Predict probability for one-class SVM from density marks.
328///
329/// Bin-lookup in precomputed density marks (10 entries). Returns a
330/// probability estimate in (0, 1).
331///
332/// Matches LIBSVM's `predict_one_class_probability`.
333pub fn predict_one_class_probability(prob_density_marks: &[f64], dec_value: f64) -> f64 {
334    let nr_marks = prob_density_marks.len();
335    if nr_marks == 0 {
336        return 0.5;
337    }
338
339    if dec_value < prob_density_marks[0] {
340        return 0.001;
341    }
342    if dec_value > prob_density_marks[nr_marks - 1] {
343        return 0.999;
344    }
345
346    for (i, &mark) in prob_density_marks
347        .iter()
348        .enumerate()
349        .skip(1)
350        .take(nr_marks - 1)
351    {
352        if dec_value < mark {
353            return i as f64 / nr_marks as f64;
354        }
355    }
356
357    0.999
358}
359
360/// Estimate probability density marks for one-class SVM.
361///
362/// Predicts all training instances, sorts decision values, bins into
363/// 10 density marks. Returns `None` if fewer than 5 positive or 5
364/// negative decision values.
365///
366/// Matches LIBSVM's `svm_one_class_probability`.
367pub fn svm_one_class_probability(prob: &SvmProblem, model: &SvmModel) -> Option<Vec<f64>> {
368    let l = prob.labels.len();
369    let mut dec_values = vec![0.0; l];
370
371    for (dv_slot, instance) in dec_values.iter_mut().zip(prob.instances.iter()) {
372        let mut dv = [0.0];
373        predict_values(model, instance, &mut dv);
374        *dv_slot = dv[0];
375    }
376
377    dec_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
378
379    // Find first index with dec_value >= 0  (= neg_counter in C++)
380    let mut neg_counter = 0usize;
381    for (i, &dv) in dec_values.iter().enumerate() {
382        if dv >= 0.0 {
383            neg_counter = i;
384            break;
385        }
386    }
387    let pos_counter = l - neg_counter;
388
389    let nr_marks: usize = 10;
390    let mid = nr_marks / 2; // 5
391
392    if neg_counter < mid || pos_counter < mid {
393        crate::info(&format!(
394            "WARNING: number of positive or negative decision values <{}; \
395             too few to do a probability estimation.\n",
396            mid
397        ));
398        return None;
399    }
400
401    let mut tmp_marks = vec![0.0; nr_marks + 1];
402
403    for i in 0..mid {
404        tmp_marks[i] = dec_values[i * neg_counter / mid];
405    }
406    tmp_marks[mid] = 0.0;
407    for i in (mid + 1)..=nr_marks {
408        tmp_marks[i] = dec_values[neg_counter - 1 + (i - mid) * pos_counter / mid];
409    }
410
411    let mut marks = vec![0.0; nr_marks];
412    for i in 0..nr_marks {
413        marks[i] = (tmp_marks[i] + tmp_marks[i + 1]) / 2.0;
414    }
415
416    Some(marks)
417}
418
419// ─── SVR probability ─────────────────────────────────────────────────
420
421/// Estimate Laplace scale parameter for SVR probability.
422///
423/// Performs 5-fold CV to get residuals, computes MAE, then applies
424/// outlier rejection (exclude |residual| > 5·√(2·mae²)) and
425/// recomputes. Returns the final MAE (= σ of Laplace distribution).
426///
427/// Matches LIBSVM's `svm_svr_probability`.
428pub fn svm_svr_probability(prob: &SvmProblem, param: &SvmParameter) -> f64 {
429    let l = prob.labels.len();
430    let nr_fold = 5;
431
432    let mut newparam = param.clone();
433    newparam.probability = false;
434    let ymv = crate::cross_validation::svm_cross_validation(prob, &newparam, nr_fold);
435
436    // Compute residuals and initial MAE
437    let mut ymv_residuals: Vec<f64> = Vec::with_capacity(l);
438    let mut mae = 0.0;
439    for (&label, &pred) in prob.labels.iter().zip(ymv.iter()) {
440        let r = label - pred;
441        ymv_residuals.push(r);
442        mae += r.abs();
443    }
444    mae /= l as f64;
445
446    // Outlier rejection
447    let std_val = (2.0 * mae * mae).sqrt();
448    let mut count = 0usize;
449    mae = 0.0;
450    for &residual in &ymv_residuals {
451        if residual.abs() > 5.0 * std_val {
452            count += 1;
453        } else {
454            mae += residual.abs();
455        }
456    }
457    mae /= (l - count) as f64;
458
459    crate::info(&format!(
460        "Prob. model for test data: target value = predicted value + z,\n\
461         z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= {:.6}\n",
462        mae
463    ));
464
465    mae
466}
467
468// ─── Tests ───────────────────────────────────────────────────────────
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::io::load_problem;
474    use crate::train::svm_train;
475    use crate::types::{KernelType, SvmNode, SvmParameter, SvmProblem, SvmType};
476    use std::path::PathBuf;
477
478    fn data_dir() -> PathBuf {
479        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
480            .join("..")
481            .join("..")
482            .join("data")
483    }
484
485    #[test]
486    fn sigmoid_predict_symmetric() {
487        let p = sigmoid_predict(0.0, 0.0, 0.0);
488        assert!((p - 0.5).abs() < 1e-10);
489    }
490
491    #[test]
492    fn sigmoid_predict_stable() {
493        let p1 = sigmoid_predict(1000.0, 1.0, 0.0);
494        assert!(p1.is_finite() && (0.0..=1.0).contains(&p1));
495
496        let p2 = sigmoid_predict(-1000.0, 1.0, 0.0);
497        assert!(p2.is_finite() && (0.0..=1.0).contains(&p2));
498    }
499
500    #[test]
501    fn sigmoid_train_basic() {
502        let dec = vec![1.0, 2.0, -1.0, -2.0, 0.5];
503        let lab = vec![1.0, 1.0, -1.0, -1.0, 1.0];
504        let (a, b) = sigmoid_train(&dec, &lab);
505        assert!(a.is_finite());
506        assert!(b.is_finite());
507    }
508
509    #[test]
510    fn multiclass_prob_sums_to_one() {
511        let k = 3;
512        let r = vec![
513            vec![0.0, 0.6, 0.5],
514            vec![0.4, 0.0, 0.7],
515            vec![0.5, 0.3, 0.0],
516        ];
517        let mut p = vec![0.0; k];
518        multiclass_probability(k, &r, &mut p);
519
520        let sum: f64 = p.iter().sum();
521        assert!(
522            (sum - 1.0).abs() < 1e-6,
523            "probabilities sum to {}, expected ~1.0",
524            sum
525        );
526        for &pi in &p {
527            assert!(pi > 0.0, "probability should be positive, got {}", pi);
528        }
529    }
530
531    #[test]
532    fn predict_one_class_prob_boundaries() {
533        let marks = vec![-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9];
534        assert!((predict_one_class_probability(&marks, -1.0) - 0.001).abs() < 1e-10);
535        assert!((predict_one_class_probability(&marks, 1.0) - 0.999).abs() < 1e-10);
536        let mid = predict_one_class_probability(&marks, 0.0);
537        assert!(mid > 0.0 && mid < 1.0);
538    }
539
540    #[test]
541    fn predict_one_class_probability_empty_marks_is_half() {
542        let p = predict_one_class_probability(&[], 0.25);
543        assert!((p - 0.5).abs() < 1e-12);
544    }
545
546    #[test]
547    fn binary_svc_probability_is_finite() {
548        let prob = load_problem(&data_dir().join("heart_scale")).unwrap();
549        let param = SvmParameter {
550            svm_type: SvmType::CSvc,
551            kernel_type: KernelType::Rbf,
552            gamma: 1.0 / 13.0,
553            ..Default::default()
554        };
555
556        let (a, b) = svm_binary_svc_probability(&prob, &param, 1.0, 1.0);
557        assert!(a.is_finite());
558        assert!(b.is_finite());
559    }
560
561    #[test]
562    fn one_class_probability_marks_generated_for_large_problem() {
563        let prob = load_problem(&data_dir().join("heart_scale")).unwrap();
564        let param = SvmParameter {
565            svm_type: SvmType::OneClass,
566            kernel_type: KernelType::Rbf,
567            gamma: 1.0 / 13.0,
568            nu: 0.5,
569            ..Default::default()
570        };
571        let model = svm_train(&prob, &param);
572
573        let marks = svm_one_class_probability(&prob, &model).expect("expected marks");
574        assert_eq!(marks.len(), 10);
575        for pair in marks.windows(2) {
576            assert!(pair[0] <= pair[1]);
577        }
578    }
579
580    #[test]
581    fn one_class_probability_returns_none_when_too_few_samples() {
582        let prob = SvmProblem {
583            labels: vec![1.0, 1.0, 1.0, 1.0],
584            instances: vec![
585                vec![SvmNode {
586                    index: 1,
587                    value: 0.0,
588                }],
589                vec![SvmNode {
590                    index: 1,
591                    value: 1.0,
592                }],
593                vec![SvmNode {
594                    index: 1,
595                    value: 2.0,
596                }],
597                vec![SvmNode {
598                    index: 1,
599                    value: 3.0,
600                }],
601            ],
602        };
603        let param = SvmParameter {
604            svm_type: SvmType::OneClass,
605            kernel_type: KernelType::Linear,
606            nu: 0.5,
607            ..Default::default()
608        };
609        let model = svm_train(&prob, &param);
610        assert!(svm_one_class_probability(&prob, &model).is_none());
611    }
612
613    #[test]
614    fn svr_probability_is_positive_and_finite() {
615        let prob = load_problem(&data_dir().join("housing_scale")).unwrap();
616        let param = SvmParameter {
617            svm_type: SvmType::EpsilonSvr,
618            kernel_type: KernelType::Rbf,
619            gamma: 1.0 / 13.0,
620            c: 1.0,
621            p: 0.1,
622            ..Default::default()
623        };
624
625        let sigma = svm_svr_probability(&prob, &param);
626        assert!(sigma.is_finite());
627        assert!(sigma > 0.0);
628    }
629}