Skip to main content

libsvm_rs/
probability.rs

1//! Probability estimation functions for SVM models.
2//!
3//! Provides Platt scaling (`sigmoid_train`/`sigmoid_predict`), multiclass
4//! probability estimation, and density-based probability for one-class SVM.
5//! Matches the original LIBSVM's probability routines (svm.cpp:1714–2096).
6
7use crate::predict::predict_values;
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmParameter, SvmProblem};
10
11// ─── RNG helper ──────────────────────────────────────────────────────
12
13fn rng_next(state: &mut u64) -> usize {
14    *state = state
15        .wrapping_mul(6364136223846793005)
16        .wrapping_add(1442695040888963407);
17    (*state >> 33) as usize
18}
19
20// ─── Platt scaling ───────────────────────────────────────────────────
21
22/// Train Platt scaling parameters (A, B) via Newton's method.
23///
24/// Given decision values and labels (+1/−1), fits the sigmoid
25/// P(y=1|f) = 1/(1+exp(A*f+B)) using the algorithm of Lin, Lin &
26/// Weng (2007). Matches LIBSVM's `sigmoid_train`.
27pub fn sigmoid_train(dec_values: &[f64], labels: &[f64]) -> (f64, f64) {
28    let l = dec_values.len();
29
30    let mut prior1: f64 = 0.0;
31    let mut prior0: f64 = 0.0;
32    for &y in labels {
33        if y > 0.0 {
34            prior1 += 1.0;
35        } else {
36            prior0 += 1.0;
37        }
38    }
39
40    let max_iter = 100;
41    let min_step = 1e-10;
42    let sigma = 1e-12;
43    let eps = 1e-5;
44
45    let hi_target = (prior1 + 1.0) / (prior1 + 2.0);
46    let lo_target = 1.0 / (prior0 + 2.0);
47
48    let t: Vec<f64> = labels
49        .iter()
50        .map(|&y| if y > 0.0 { hi_target } else { lo_target })
51        .collect();
52
53    // Initial point
54    let mut a = 0.0;
55    let mut b = ((prior0 + 1.0) / (prior1 + 1.0)).ln();
56
57    // Initial objective
58    let mut fval = 0.0;
59    for i in 0..l {
60        let f_apb = dec_values[i] * a + b;
61        if f_apb >= 0.0 {
62            fval += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
63        } else {
64            fval += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
65        }
66    }
67
68    for _iter in 0..max_iter {
69        // Gradient and Hessian (H' = H + σI)
70        let mut h11 = sigma;
71        let mut h22 = sigma;
72        let mut h21 = 0.0;
73        let mut g1 = 0.0;
74        let mut g2 = 0.0;
75
76        for i in 0..l {
77            let f_apb = dec_values[i] * a + b;
78            let (p, q) = if f_apb >= 0.0 {
79                let e = (-f_apb).exp();
80                (e / (1.0 + e), 1.0 / (1.0 + e))
81            } else {
82                let e = f_apb.exp();
83                (1.0 / (1.0 + e), e / (1.0 + e))
84            };
85            let d2 = p * q;
86            h11 += dec_values[i] * dec_values[i] * d2;
87            h22 += d2;
88            h21 += dec_values[i] * d2;
89            let d1 = t[i] - p;
90            g1 += dec_values[i] * d1;
91            g2 += d1;
92        }
93
94        if g1.abs() < eps && g2.abs() < eps {
95            break;
96        }
97
98        // Newton direction: −H'⁻¹ g
99        let det = h11 * h22 - h21 * h21;
100        let da = -(h22 * g1 - h21 * g2) / det;
101        let db = -(-h21 * g1 + h11 * g2) / det;
102        let gd = g1 * da + g2 * db;
103
104        // Line search with step-size halving
105        let mut stepsize = 1.0;
106        while stepsize >= min_step {
107            let new_a = a + stepsize * da;
108            let new_b = b + stepsize * db;
109
110            let mut newf = 0.0;
111            for i in 0..l {
112                let f_apb = dec_values[i] * new_a + new_b;
113                if f_apb >= 0.0 {
114                    newf += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
115                } else {
116                    newf += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
117                }
118            }
119
120            if newf < fval + 0.0001 * stepsize * gd {
121                a = new_a;
122                b = new_b;
123                fval = newf;
124                break;
125            }
126            stepsize /= 2.0;
127        }
128
129        if stepsize < min_step {
130            break;
131        }
132    }
133
134    (a, b)
135}
136
137/// Numerically stable sigmoid prediction.
138///
139/// Returns P(y=1|f) = 1/(1+exp(A*f+B)), branching on sign of A*f+B
140/// to avoid overflow. Matches LIBSVM's `sigmoid_predict`.
141pub fn sigmoid_predict(decision_value: f64, a: f64, b: f64) -> f64 {
142    let f_apb = decision_value * a + b;
143    if f_apb >= 0.0 {
144        (-f_apb).exp() / (1.0 + (-f_apb).exp())
145    } else {
146        1.0 / (1.0 + f_apb.exp())
147    }
148}
149
150// ─── Multiclass probability ──────────────────────────────────────────
151
152/// Solve multiclass probabilities from pairwise estimates.
153///
154/// Given `k` classes and a k×k matrix `r` of pairwise probabilities
155/// (r\[i\]\[j\] = P(class i | class i or j)), fills `p` with class
156/// probabilities using the Wu-Lin-Weng iterative method.
157///
158/// Matches LIBSVM's `multiclass_probability`.
159pub fn multiclass_probability(k: usize, r: &[Vec<f64>], p: &mut [f64]) {
160    let max_iter = 100.max(k);
161    let eps = 0.005 / k as f64;
162
163    // Build Q matrix
164    let mut q_mat = vec![vec![0.0; k]; k];
165    for t in 0..k {
166        q_mat[t][t] = 0.0;
167        for j in 0..t {
168            q_mat[t][t] += r[j][t] * r[j][t];
169            q_mat[t][j] = q_mat[j][t];
170        }
171        for j in (t + 1)..k {
172            q_mat[t][t] += r[j][t] * r[j][t];
173            q_mat[t][j] = -r[j][t] * r[t][j];
174        }
175    }
176
177    for t in 0..k {
178        p[t] = 1.0 / k as f64;
179    }
180
181    let mut qp = vec![0.0; k];
182
183    for _iter in 0..max_iter {
184        let mut p_qp = 0.0;
185        for t in 0..k {
186            qp[t] = 0.0;
187            for j in 0..k {
188                qp[t] += q_mat[t][j] * p[j];
189            }
190            p_qp += p[t] * qp[t];
191        }
192
193        let mut max_error = 0.0;
194        for t in 0..k {
195            let error = (qp[t] - p_qp).abs();
196            if error > max_error {
197                max_error = error;
198            }
199        }
200        if max_error < eps {
201            break;
202        }
203
204        for t in 0..k {
205            let diff = (-qp[t] + p_qp) / q_mat[t][t];
206            p[t] += diff;
207            p_qp = (p_qp + diff * (diff * q_mat[t][t] + 2.0 * qp[t]))
208                / (1.0 + diff)
209                / (1.0 + diff);
210            for j in 0..k {
211                qp[j] = (qp[j] + diff * q_mat[t][j]) / (1.0 + diff);
212                p[j] /= 1.0 + diff;
213            }
214        }
215    }
216}
217
218// ─── Binary SVC probability via internal CV ──────────────────────────
219
220/// Estimate Platt scaling parameters for a binary sub-problem.
221///
222/// Performs 5-fold CV internally: trains on 4 folds with class weights
223/// (cp, cn), collects decision values on the held-out fold, then fits
224/// a sigmoid via `sigmoid_train`.
225///
226/// Matches LIBSVM's `svm_binary_svc_probability`.
227pub fn svm_binary_svc_probability(
228    prob: &SvmProblem,
229    param: &SvmParameter,
230    cp: f64,
231    cn: f64,
232) -> (f64, f64) {
233    let l = prob.labels.len();
234    let nr_fold = 5;
235    let mut perm: Vec<usize> = (0..l).collect();
236    let mut dec_values = vec![0.0; l];
237
238    // Random shuffle (Fisher-Yates)
239    let mut rng: u64 = 1;
240    for i in 0..l {
241        let j = i + rng_next(&mut rng) % (l - i);
242        perm.swap(i, j);
243    }
244
245    for fold in 0..nr_fold {
246        let begin = fold * l / nr_fold;
247        let end = (fold + 1) * l / nr_fold;
248
249        // Build training sub-problem (exclude held-out fold)
250        let mut sub_instances = Vec::with_capacity(l - (end - begin));
251        let mut sub_labels = Vec::with_capacity(l - (end - begin));
252
253        for j in 0..begin {
254            sub_instances.push(prob.instances[perm[j]].clone());
255            sub_labels.push(prob.labels[perm[j]]);
256        }
257        for j in end..l {
258            sub_instances.push(prob.instances[perm[j]].clone());
259            sub_labels.push(prob.labels[perm[j]]);
260        }
261
262        // Count classes in training set
263        let p_count = sub_labels.iter().filter(|&&y| y > 0.0).count();
264        let n_count = sub_labels.len() - p_count;
265
266        if p_count == 0 && n_count == 0 {
267            for j in begin..end {
268                dec_values[perm[j]] = 0.0;
269            }
270        } else if p_count > 0 && n_count == 0 {
271            for j in begin..end {
272                dec_values[perm[j]] = 1.0;
273            }
274        } else if p_count == 0 && n_count > 0 {
275            for j in begin..end {
276                dec_values[perm[j]] = -1.0;
277            }
278        } else {
279            let mut subparam = param.clone();
280            subparam.probability = false;
281            subparam.c = 1.0;
282            subparam.weight = vec![(1, cp), (-1, cn)];
283
284            let subprob = SvmProblem {
285                labels: sub_labels,
286                instances: sub_instances,
287            };
288            let submodel = svm_train(&subprob, &subparam);
289
290            for j in begin..end {
291                let mut dv = [0.0];
292                predict_values(&submodel, &prob.instances[perm[j]], &mut dv);
293                // Sign correction: ensure +1/−1 ordering
294                dec_values[perm[j]] = dv[0] * submodel.label[0] as f64;
295            }
296        }
297    }
298
299    sigmoid_train(&dec_values, &prob.labels)
300}
301
302// ─── One-class probability ───────────────────────────────────────────
303
304/// Predict probability for one-class SVM from density marks.
305///
306/// Bin-lookup in precomputed density marks (10 entries). Returns a
307/// probability estimate in (0, 1).
308///
309/// Matches LIBSVM's `predict_one_class_probability`.
310pub fn predict_one_class_probability(prob_density_marks: &[f64], dec_value: f64) -> f64 {
311    let nr_marks = prob_density_marks.len();
312    if nr_marks == 0 {
313        return 0.5;
314    }
315
316    if dec_value < prob_density_marks[0] {
317        return 0.001;
318    }
319    if dec_value > prob_density_marks[nr_marks - 1] {
320        return 0.999;
321    }
322
323    for i in 1..nr_marks {
324        if dec_value < prob_density_marks[i] {
325            return i as f64 / nr_marks as f64;
326        }
327    }
328
329    0.999
330}
331
332/// Estimate probability density marks for one-class SVM.
333///
334/// Predicts all training instances, sorts decision values, bins into
335/// 10 density marks. Returns `None` if fewer than 5 positive or 5
336/// negative decision values.
337///
338/// Matches LIBSVM's `svm_one_class_probability`.
339pub fn svm_one_class_probability(
340    prob: &SvmProblem,
341    model: &SvmModel,
342) -> Option<Vec<f64>> {
343    let l = prob.labels.len();
344    let mut dec_values = vec![0.0; l];
345
346    for (i, instance) in prob.instances.iter().enumerate() {
347        let mut dv = [0.0];
348        predict_values(model, instance, &mut dv);
349        dec_values[i] = dv[0];
350    }
351
352    dec_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
353
354    // Find first index with dec_value >= 0  (= neg_counter in C++)
355    let mut neg_counter = 0usize;
356    for i in 0..l {
357        if dec_values[i] >= 0.0 {
358            neg_counter = i;
359            break;
360        }
361    }
362    let pos_counter = l - neg_counter;
363
364    let nr_marks: usize = 10;
365    let mid = nr_marks / 2; // 5
366
367    if neg_counter < mid || pos_counter < mid {
368        crate::info(&format!(
369            "WARNING: number of positive or negative decision values <{}; \
370             too few to do a probability estimation.\n",
371            mid
372        ));
373        return None;
374    }
375
376    let mut tmp_marks = vec![0.0; nr_marks + 1];
377
378    for i in 0..mid {
379        tmp_marks[i] = dec_values[i * neg_counter / mid];
380    }
381    tmp_marks[mid] = 0.0;
382    for i in (mid + 1)..=nr_marks {
383        tmp_marks[i] = dec_values[neg_counter - 1 + (i - mid) * pos_counter / mid];
384    }
385
386    let mut marks = vec![0.0; nr_marks];
387    for i in 0..nr_marks {
388        marks[i] = (tmp_marks[i] + tmp_marks[i + 1]) / 2.0;
389    }
390
391    Some(marks)
392}
393
394// ─── SVR probability ─────────────────────────────────────────────────
395
396/// Estimate Laplace scale parameter for SVR probability.
397///
398/// Performs 5-fold CV to get residuals, computes MAE, then applies
399/// outlier rejection (exclude |residual| > 5·√(2·mae²)) and
400/// recomputes. Returns the final MAE (= σ of Laplace distribution).
401///
402/// Matches LIBSVM's `svm_svr_probability`.
403pub fn svm_svr_probability(prob: &SvmProblem, param: &SvmParameter) -> f64 {
404    let l = prob.labels.len();
405    let nr_fold = 5;
406
407    let mut newparam = param.clone();
408    newparam.probability = false;
409    let ymv = crate::cross_validation::svm_cross_validation(prob, &newparam, nr_fold);
410
411    // Compute residuals and initial MAE
412    let mut ymv_residuals: Vec<f64> = Vec::with_capacity(l);
413    let mut mae = 0.0;
414    for i in 0..l {
415        let r = prob.labels[i] - ymv[i];
416        ymv_residuals.push(r);
417        mae += r.abs();
418    }
419    mae /= l as f64;
420
421    // Outlier rejection
422    let std_val = (2.0 * mae * mae).sqrt();
423    let mut count = 0usize;
424    mae = 0.0;
425    for i in 0..l {
426        if ymv_residuals[i].abs() > 5.0 * std_val {
427            count += 1;
428        } else {
429            mae += ymv_residuals[i].abs();
430        }
431    }
432    mae /= (l - count) as f64;
433
434    crate::info(&format!(
435        "Prob. model for test data: target value = predicted value + z,\n\
436         z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= {:.6}\n",
437        mae
438    ));
439
440    mae
441}
442
443// ─── Tests ───────────────────────────────────────────────────────────
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn sigmoid_predict_symmetric() {
451        let p = sigmoid_predict(0.0, 0.0, 0.0);
452        assert!((p - 0.5).abs() < 1e-10);
453    }
454
455    #[test]
456    fn sigmoid_predict_stable() {
457        let p1 = sigmoid_predict(1000.0, 1.0, 0.0);
458        assert!(p1.is_finite() && (0.0..=1.0).contains(&p1));
459
460        let p2 = sigmoid_predict(-1000.0, 1.0, 0.0);
461        assert!(p2.is_finite() && (0.0..=1.0).contains(&p2));
462    }
463
464    #[test]
465    fn sigmoid_train_basic() {
466        let dec = vec![1.0, 2.0, -1.0, -2.0, 0.5];
467        let lab = vec![1.0, 1.0, -1.0, -1.0, 1.0];
468        let (a, b) = sigmoid_train(&dec, &lab);
469        assert!(a.is_finite());
470        assert!(b.is_finite());
471    }
472
473    #[test]
474    fn multiclass_prob_sums_to_one() {
475        let k = 3;
476        let r = vec![
477            vec![0.0, 0.6, 0.5],
478            vec![0.4, 0.0, 0.7],
479            vec![0.5, 0.3, 0.0],
480        ];
481        let mut p = vec![0.0; k];
482        multiclass_probability(k, &r, &mut p);
483
484        let sum: f64 = p.iter().sum();
485        assert!(
486            (sum - 1.0).abs() < 1e-6,
487            "probabilities sum to {}, expected ~1.0",
488            sum
489        );
490        for &pi in &p {
491            assert!(pi > 0.0, "probability should be positive, got {}", pi);
492        }
493    }
494
495    #[test]
496    fn predict_one_class_prob_boundaries() {
497        let marks = vec![-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9];
498        assert!((predict_one_class_probability(&marks, -1.0) - 0.001).abs() < 1e-10);
499        assert!((predict_one_class_probability(&marks, 1.0) - 0.999).abs() < 1e-10);
500        let mid = predict_one_class_probability(&marks, 0.0);
501        assert!(mid > 0.0 && mid < 1.0);
502    }
503}