1use crate::predict::predict_values;
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmParameter, SvmProblem};
10
11fn rng_next(state: &mut u64) -> usize {
14 *state = state
15 .wrapping_mul(6364136223846793005)
16 .wrapping_add(1442695040888963407);
17 (*state >> 33) as usize
18}
19
20pub fn sigmoid_train(dec_values: &[f64], labels: &[f64]) -> (f64, f64) {
28 let l = dec_values.len();
29
30 let mut prior1: f64 = 0.0;
31 let mut prior0: f64 = 0.0;
32 for &y in labels {
33 if y > 0.0 {
34 prior1 += 1.0;
35 } else {
36 prior0 += 1.0;
37 }
38 }
39
40 let max_iter = 100;
41 let min_step = 1e-10;
42 let sigma = 1e-12;
43 let eps = 1e-5;
44
45 let hi_target = (prior1 + 1.0) / (prior1 + 2.0);
46 let lo_target = 1.0 / (prior0 + 2.0);
47
48 let t: Vec<f64> = labels
49 .iter()
50 .map(|&y| if y > 0.0 { hi_target } else { lo_target })
51 .collect();
52
53 let mut a = 0.0;
55 let mut b = ((prior0 + 1.0) / (prior1 + 1.0)).ln();
56
57 let mut fval = 0.0;
59 for i in 0..l {
60 let f_apb = dec_values[i] * a + b;
61 if f_apb >= 0.0 {
62 fval += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
63 } else {
64 fval += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
65 }
66 }
67
68 for _iter in 0..max_iter {
69 let mut h11 = sigma;
71 let mut h22 = sigma;
72 let mut h21 = 0.0;
73 let mut g1 = 0.0;
74 let mut g2 = 0.0;
75
76 for i in 0..l {
77 let f_apb = dec_values[i] * a + b;
78 let (p, q) = if f_apb >= 0.0 {
79 let e = (-f_apb).exp();
80 (e / (1.0 + e), 1.0 / (1.0 + e))
81 } else {
82 let e = f_apb.exp();
83 (1.0 / (1.0 + e), e / (1.0 + e))
84 };
85 let d2 = p * q;
86 h11 += dec_values[i] * dec_values[i] * d2;
87 h22 += d2;
88 h21 += dec_values[i] * d2;
89 let d1 = t[i] - p;
90 g1 += dec_values[i] * d1;
91 g2 += d1;
92 }
93
94 if g1.abs() < eps && g2.abs() < eps {
95 break;
96 }
97
98 let det = h11 * h22 - h21 * h21;
100 let da = -(h22 * g1 - h21 * g2) / det;
101 let db = -(-h21 * g1 + h11 * g2) / det;
102 let gd = g1 * da + g2 * db;
103
104 let mut stepsize = 1.0;
106 while stepsize >= min_step {
107 let new_a = a + stepsize * da;
108 let new_b = b + stepsize * db;
109
110 let mut newf = 0.0;
111 for i in 0..l {
112 let f_apb = dec_values[i] * new_a + new_b;
113 if f_apb >= 0.0 {
114 newf += t[i] * f_apb + (1.0 + (-f_apb).exp()).ln();
115 } else {
116 newf += (t[i] - 1.0) * f_apb + (1.0 + f_apb.exp()).ln();
117 }
118 }
119
120 if newf < fval + 0.0001 * stepsize * gd {
121 a = new_a;
122 b = new_b;
123 fval = newf;
124 break;
125 }
126 stepsize /= 2.0;
127 }
128
129 if stepsize < min_step {
130 break;
131 }
132 }
133
134 (a, b)
135}
136
137pub fn sigmoid_predict(decision_value: f64, a: f64, b: f64) -> f64 {
142 let f_apb = decision_value * a + b;
143 if f_apb >= 0.0 {
144 (-f_apb).exp() / (1.0 + (-f_apb).exp())
145 } else {
146 1.0 / (1.0 + f_apb.exp())
147 }
148}
149
150pub fn multiclass_probability(k: usize, r: &[Vec<f64>], p: &mut [f64]) {
160 let max_iter = 100.max(k);
161 let eps = 0.005 / k as f64;
162
163 let mut q_mat = vec![vec![0.0; k]; k];
165 for t in 0..k {
166 q_mat[t][t] = 0.0;
167 for j in 0..t {
168 q_mat[t][t] += r[j][t] * r[j][t];
169 q_mat[t][j] = q_mat[j][t];
170 }
171 for j in (t + 1)..k {
172 q_mat[t][t] += r[j][t] * r[j][t];
173 q_mat[t][j] = -r[j][t] * r[t][j];
174 }
175 }
176
177 for t in 0..k {
178 p[t] = 1.0 / k as f64;
179 }
180
181 let mut qp = vec![0.0; k];
182
183 for _iter in 0..max_iter {
184 let mut p_qp = 0.0;
185 for t in 0..k {
186 qp[t] = 0.0;
187 for j in 0..k {
188 qp[t] += q_mat[t][j] * p[j];
189 }
190 p_qp += p[t] * qp[t];
191 }
192
193 let mut max_error = 0.0;
194 for t in 0..k {
195 let error = (qp[t] - p_qp).abs();
196 if error > max_error {
197 max_error = error;
198 }
199 }
200 if max_error < eps {
201 break;
202 }
203
204 for t in 0..k {
205 let diff = (-qp[t] + p_qp) / q_mat[t][t];
206 p[t] += diff;
207 p_qp = (p_qp + diff * (diff * q_mat[t][t] + 2.0 * qp[t]))
208 / (1.0 + diff)
209 / (1.0 + diff);
210 for j in 0..k {
211 qp[j] = (qp[j] + diff * q_mat[t][j]) / (1.0 + diff);
212 p[j] /= 1.0 + diff;
213 }
214 }
215 }
216}
217
218pub fn svm_binary_svc_probability(
228 prob: &SvmProblem,
229 param: &SvmParameter,
230 cp: f64,
231 cn: f64,
232) -> (f64, f64) {
233 let l = prob.labels.len();
234 let nr_fold = 5;
235 let mut perm: Vec<usize> = (0..l).collect();
236 let mut dec_values = vec![0.0; l];
237
238 let mut rng: u64 = 1;
240 for i in 0..l {
241 let j = i + rng_next(&mut rng) % (l - i);
242 perm.swap(i, j);
243 }
244
245 for fold in 0..nr_fold {
246 let begin = fold * l / nr_fold;
247 let end = (fold + 1) * l / nr_fold;
248
249 let mut sub_instances = Vec::with_capacity(l - (end - begin));
251 let mut sub_labels = Vec::with_capacity(l - (end - begin));
252
253 for j in 0..begin {
254 sub_instances.push(prob.instances[perm[j]].clone());
255 sub_labels.push(prob.labels[perm[j]]);
256 }
257 for j in end..l {
258 sub_instances.push(prob.instances[perm[j]].clone());
259 sub_labels.push(prob.labels[perm[j]]);
260 }
261
262 let p_count = sub_labels.iter().filter(|&&y| y > 0.0).count();
264 let n_count = sub_labels.len() - p_count;
265
266 if p_count == 0 && n_count == 0 {
267 for j in begin..end {
268 dec_values[perm[j]] = 0.0;
269 }
270 } else if p_count > 0 && n_count == 0 {
271 for j in begin..end {
272 dec_values[perm[j]] = 1.0;
273 }
274 } else if p_count == 0 && n_count > 0 {
275 for j in begin..end {
276 dec_values[perm[j]] = -1.0;
277 }
278 } else {
279 let mut subparam = param.clone();
280 subparam.probability = false;
281 subparam.c = 1.0;
282 subparam.weight = vec![(1, cp), (-1, cn)];
283
284 let subprob = SvmProblem {
285 labels: sub_labels,
286 instances: sub_instances,
287 };
288 let submodel = svm_train(&subprob, &subparam);
289
290 for j in begin..end {
291 let mut dv = [0.0];
292 predict_values(&submodel, &prob.instances[perm[j]], &mut dv);
293 dec_values[perm[j]] = dv[0] * submodel.label[0] as f64;
295 }
296 }
297 }
298
299 sigmoid_train(&dec_values, &prob.labels)
300}
301
302pub fn predict_one_class_probability(prob_density_marks: &[f64], dec_value: f64) -> f64 {
311 let nr_marks = prob_density_marks.len();
312 if nr_marks == 0 {
313 return 0.5;
314 }
315
316 if dec_value < prob_density_marks[0] {
317 return 0.001;
318 }
319 if dec_value > prob_density_marks[nr_marks - 1] {
320 return 0.999;
321 }
322
323 for i in 1..nr_marks {
324 if dec_value < prob_density_marks[i] {
325 return i as f64 / nr_marks as f64;
326 }
327 }
328
329 0.999
330}
331
332pub fn svm_one_class_probability(
340 prob: &SvmProblem,
341 model: &SvmModel,
342) -> Option<Vec<f64>> {
343 let l = prob.labels.len();
344 let mut dec_values = vec![0.0; l];
345
346 for (i, instance) in prob.instances.iter().enumerate() {
347 let mut dv = [0.0];
348 predict_values(model, instance, &mut dv);
349 dec_values[i] = dv[0];
350 }
351
352 dec_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
353
354 let mut neg_counter = 0usize;
356 for i in 0..l {
357 if dec_values[i] >= 0.0 {
358 neg_counter = i;
359 break;
360 }
361 }
362 let pos_counter = l - neg_counter;
363
364 let nr_marks: usize = 10;
365 let mid = nr_marks / 2; if neg_counter < mid || pos_counter < mid {
368 eprintln!(
369 "WARNING: number of positive or negative decision values <{}; \
370 too few to do a probability estimation.",
371 mid
372 );
373 return None;
374 }
375
376 let mut tmp_marks = vec![0.0; nr_marks + 1];
377
378 for i in 0..mid {
379 tmp_marks[i] = dec_values[i * neg_counter / mid];
380 }
381 tmp_marks[mid] = 0.0;
382 for i in (mid + 1)..=nr_marks {
383 tmp_marks[i] = dec_values[neg_counter - 1 + (i - mid) * pos_counter / mid];
384 }
385
386 let mut marks = vec![0.0; nr_marks];
387 for i in 0..nr_marks {
388 marks[i] = (tmp_marks[i] + tmp_marks[i + 1]) / 2.0;
389 }
390
391 Some(marks)
392}
393
394pub fn svm_svr_probability(prob: &SvmProblem, param: &SvmParameter) -> f64 {
404 let l = prob.labels.len();
405 let nr_fold = 5;
406
407 let mut newparam = param.clone();
408 newparam.probability = false;
409 let ymv = crate::cross_validation::svm_cross_validation(prob, &newparam, nr_fold);
410
411 let mut ymv_residuals: Vec<f64> = Vec::with_capacity(l);
413 let mut mae = 0.0;
414 for i in 0..l {
415 let r = prob.labels[i] - ymv[i];
416 ymv_residuals.push(r);
417 mae += r.abs();
418 }
419 mae /= l as f64;
420
421 let std_val = (2.0 * mae * mae).sqrt();
423 let mut count = 0usize;
424 mae = 0.0;
425 for i in 0..l {
426 if ymv_residuals[i].abs() > 5.0 * std_val {
427 count += 1;
428 } else {
429 mae += ymv_residuals[i].abs();
430 }
431 }
432 mae /= (l - count) as f64;
433
434 eprintln!(
435 "Prob. model for test data: target value = predicted value + z,\n\
436 z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= {}",
437 mae
438 );
439
440 mae
441}
442
443#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn sigmoid_predict_symmetric() {
451 let p = sigmoid_predict(0.0, 0.0, 0.0);
452 assert!((p - 0.5).abs() < 1e-10);
453 }
454
455 #[test]
456 fn sigmoid_predict_stable() {
457 let p1 = sigmoid_predict(1000.0, 1.0, 0.0);
458 assert!(p1.is_finite() && p1 >= 0.0 && p1 <= 1.0);
459
460 let p2 = sigmoid_predict(-1000.0, 1.0, 0.0);
461 assert!(p2.is_finite() && p2 >= 0.0 && p2 <= 1.0);
462 }
463
464 #[test]
465 fn sigmoid_train_basic() {
466 let dec = vec![1.0, 2.0, -1.0, -2.0, 0.5];
467 let lab = vec![1.0, 1.0, -1.0, -1.0, 1.0];
468 let (a, b) = sigmoid_train(&dec, &lab);
469 assert!(a.is_finite());
470 assert!(b.is_finite());
471 }
472
473 #[test]
474 fn multiclass_prob_sums_to_one() {
475 let k = 3;
476 let r = vec![
477 vec![0.0, 0.6, 0.5],
478 vec![0.4, 0.0, 0.7],
479 vec![0.5, 0.3, 0.0],
480 ];
481 let mut p = vec![0.0; k];
482 multiclass_probability(k, &r, &mut p);
483
484 let sum: f64 = p.iter().sum();
485 assert!(
486 (sum - 1.0).abs() < 1e-6,
487 "probabilities sum to {}, expected ~1.0",
488 sum
489 );
490 for &pi in &p {
491 assert!(pi > 0.0, "probability should be positive, got {}", pi);
492 }
493 }
494
495 #[test]
496 fn predict_one_class_prob_boundaries() {
497 let marks = vec![-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9];
498 assert!((predict_one_class_probability(&marks, -1.0) - 0.001).abs() < 1e-10);
499 assert!((predict_one_class_probability(&marks, 1.0) - 0.999).abs() < 1e-10);
500 let mid = predict_one_class_probability(&marks, 0.0);
501 assert!(mid > 0.0 && mid < 1.0);
502 }
503}