1use crate::error::{MLError, Result};
6use scirs2_core::ndarray::{Array1, Array2};
7
8#[derive(Debug, Clone)]
12pub struct MatrixScaler {
13 weight_matrix: Option<Array2<f64>>,
15 bias_vector: Option<Array1<f64>>,
17 fitted: bool,
19 regularization: f64,
21}
22impl MatrixScaler {
23 pub fn new() -> Self {
25 Self {
26 weight_matrix: None,
27 bias_vector: None,
28 fitted: false,
29 regularization: 0.01,
30 }
31 }
32 pub fn with_regularization(regularization: f64) -> Self {
34 Self {
35 weight_matrix: None,
36 bias_vector: None,
37 fitted: false,
38 regularization,
39 }
40 }
41 pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
44 if logits.nrows() != labels.len() {
45 return Err(MLError::InvalidInput(
46 "Logits and labels must have same number of samples".to_string(),
47 ));
48 }
49 let n_samples = logits.nrows();
50 let n_classes = logits.ncols();
51 if n_samples < n_classes * 2 {
52 return Err(MLError::InvalidInput(format!(
53 "Need at least {} samples for {} classes (matrix calibration)",
54 n_classes * 2,
55 n_classes
56 )));
57 }
58 let mut weight_matrix = Array2::eye(n_classes);
59 let mut bias_vector = Array1::zeros(n_classes);
60 let learning_rate = 0.001;
61 let max_iter = 300;
62 let tolerance = 1e-7;
63 let mut prev_nll = f64::INFINITY;
64 for _iter in 0..max_iter {
65 let (nll, reg_term) =
66 self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_vector)?;
67 let total_loss = nll + reg_term;
68 if (prev_nll - total_loss).abs() < tolerance {
69 break;
70 }
71 prev_nll = total_loss;
72 let epsilon = 1e-6;
73 let mut weight_grads = Array2::zeros((n_classes, n_classes));
74 let mut bias_grads = Array1::zeros(n_classes);
75 for i in 0..n_classes {
76 for j in 0..n_classes {
77 let mut weight_plus = weight_matrix.clone();
78 weight_plus[(i, j)] += epsilon;
79 let (nll_plus, reg_plus) =
80 self.compute_nll_with_reg(logits, labels, &weight_plus, &bias_vector)?;
81 weight_grads[(i, j)] = (nll_plus + reg_plus - total_loss) / epsilon;
82 }
83 }
84 for j in 0..n_classes {
85 let mut bias_plus = bias_vector.clone();
86 bias_plus[j] += epsilon;
87 let (nll_plus, reg_plus) =
88 self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_plus)?;
89 bias_grads[j] = (nll_plus + reg_plus - total_loss) / epsilon;
90 }
91 weight_matrix = &weight_matrix - &weight_grads.mapv(|g| learning_rate * g);
92 bias_vector = &bias_vector - &bias_grads.mapv(|g| learning_rate * g);
93 for i in 0..n_classes {
94 weight_matrix[(i, i)] = weight_matrix[(i, i)].max(0.01);
95 }
96 let grad_norm = weight_grads.iter().map(|&g| g * g).sum::<f64>().sqrt()
97 + bias_grads.iter().map(|&g| g * g).sum::<f64>().sqrt();
98 if grad_norm < tolerance {
99 break;
100 }
101 }
102 self.weight_matrix = Some(weight_matrix);
103 self.bias_vector = Some(bias_vector);
104 self.fitted = true;
105 Ok(())
106 }
107 fn compute_nll_with_reg(
109 &self,
110 logits: &Array2<f64>,
111 labels: &Array1<usize>,
112 weight_matrix: &Array2<f64>,
113 bias_vector: &Array1<f64>,
114 ) -> Result<(f64, f64)> {
115 let mut nll = 0.0;
116 let n_samples = logits.nrows();
117 let n_classes = logits.ncols();
118 for i in 0..n_samples {
119 let logits_row = logits.row(i);
120 let mut scaled_logits = Array1::zeros(n_classes);
121 for j in 0..n_classes {
122 let mut val = bias_vector[j];
123 for k in 0..n_classes {
124 val += weight_matrix[(j, k)] * logits_row[k];
125 }
126 scaled_logits[j] = val;
127 }
128 let max_logit = scaled_logits
129 .iter()
130 .cloned()
131 .fold(f64::NEG_INFINITY, f64::max);
132 let exp_logits: Vec<f64> = scaled_logits
133 .iter()
134 .map(|&x| (x - max_logit).exp())
135 .collect();
136 let sum_exp: f64 = exp_logits.iter().sum();
137 let true_label = labels[i];
138 if true_label >= exp_logits.len() {
139 return Err(MLError::InvalidInput(format!(
140 "Label {} out of bounds for {} classes",
141 true_label,
142 exp_logits.len()
143 )));
144 }
145 let prob = exp_logits[true_label] / sum_exp;
146 nll -= prob.max(1e-10).ln();
147 }
148 nll /= n_samples as f64;
149 let mut reg_term = 0.0;
150 for i in 0..n_classes {
151 for j in 0..n_classes {
152 if i != j {
153 reg_term += weight_matrix[(i, j)].powi(2);
154 }
155 }
156 }
157 reg_term *= self.regularization;
158 Ok((nll, reg_term))
159 }
160 pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
162 if !self.fitted {
163 return Err(MLError::InvalidInput(
164 "Scaler must be fitted before transform".to_string(),
165 ));
166 }
167 let weight_matrix = self.weight_matrix.as_ref().unwrap();
168 let bias_vector = self.bias_vector.as_ref().unwrap();
169 let n_classes = logits.ncols();
170 let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
171 for i in 0..logits.nrows() {
172 let logits_row = logits.row(i);
173 let mut scaled_logits = Array1::zeros(n_classes);
174 for j in 0..n_classes {
175 let mut val = bias_vector[j];
176 for k in 0..n_classes {
177 val += weight_matrix[(j, k)] * logits_row[k];
178 }
179 scaled_logits[j] = val;
180 }
181 let max_logit = scaled_logits
182 .iter()
183 .cloned()
184 .fold(f64::NEG_INFINITY, f64::max);
185 let exp_logits: Vec<f64> = scaled_logits
186 .iter()
187 .map(|&x| (x - max_logit).exp())
188 .collect();
189 let sum_exp: f64 = exp_logits.iter().sum();
190 for j in 0..logits.ncols() {
191 calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
192 }
193 }
194 Ok(calibrated_probs)
195 }
196 pub fn fit_transform(
198 &mut self,
199 logits: &Array2<f64>,
200 labels: &Array1<usize>,
201 ) -> Result<Array2<f64>> {
202 self.fit(logits, labels)?;
203 self.transform(logits)
204 }
205 pub fn parameters(&self) -> Option<(Array2<f64>, Array1<f64>)> {
207 if self.fitted {
208 Some((
209 self.weight_matrix.as_ref().unwrap().clone(),
210 self.bias_vector.as_ref().unwrap().clone(),
211 ))
212 } else {
213 None
214 }
215 }
216 pub fn condition_number(&self) -> Option<f64> {
219 if !self.fitted {
220 return None;
221 }
222 let w = self.weight_matrix.as_ref().unwrap();
223 let norm = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
224 Some(norm)
225 }
226}
227#[derive(Debug, Clone)]
230pub struct IsotonicRegression {
231 x_thresholds: Vec<f64>,
233 y_thresholds: Vec<f64>,
235 fitted: bool,
237}
238impl IsotonicRegression {
239 pub fn new() -> Self {
241 Self {
242 x_thresholds: Vec::new(),
243 y_thresholds: Vec::new(),
244 fitted: false,
245 }
246 }
247 pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
249 if scores.len() != labels.len() {
250 return Err(MLError::InvalidInput(
251 "Scores and labels must have same length".to_string(),
252 ));
253 }
254 let n = scores.len();
255 if n < 2 {
256 return Err(MLError::InvalidInput(
257 "Need at least 2 samples for calibration".to_string(),
258 ));
259 }
260 let mut pairs: Vec<(f64, f64)> = scores
261 .iter()
262 .zip(labels.iter())
263 .map(|(&s, &l)| (s, l as f64))
264 .collect();
265 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
266 let mut x = Vec::new();
267 let mut y = Vec::new();
268 let mut weights = Vec::new();
269 for (score, label) in pairs {
270 x.push(score);
271 y.push(label);
272 weights.push(1.0);
273 }
274 let mut i = 0;
275 while i < y.len() - 1 {
276 if y[i] > y[i + 1] {
277 let w1 = weights[i];
278 let w2 = weights[i + 1];
279 let total_weight = w1 + w2;
280 y[i] = (y[i] * w1 + y[i + 1] * w2) / total_weight;
281 weights[i] = total_weight;
282 y.remove(i + 1);
283 x.remove(i + 1);
284 weights.remove(i + 1);
285 if i > 0 {
286 i -= 1;
287 }
288 } else {
289 i += 1;
290 }
291 }
292 self.x_thresholds = x;
293 self.y_thresholds = y;
294 self.fitted = true;
295 Ok(())
296 }
297 pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
299 if !self.fitted {
300 return Err(MLError::InvalidInput(
301 "Regressor must be fitted before transform".to_string(),
302 ));
303 }
304 let mut calibrated = Array1::zeros(scores.len());
305 for (i, &score) in scores.iter().enumerate() {
306 let pos = self
307 .x_thresholds
308 .binary_search_by(|&x| x.partial_cmp(&score).unwrap_or(std::cmp::Ordering::Less))
309 .unwrap_or_else(|e| e);
310 if pos == 0 {
311 calibrated[i] = self.y_thresholds[0];
312 } else if pos >= self.x_thresholds.len() {
313 calibrated[i] = *self.y_thresholds.last().unwrap();
314 } else {
315 let x0 = self.x_thresholds[pos - 1];
316 let x1 = self.x_thresholds[pos];
317 let y0 = self.y_thresholds[pos - 1];
318 let y1 = self.y_thresholds[pos];
319 if (x1 - x0).abs() < 1e-10 {
320 calibrated[i] = (y0 + y1) / 2.0;
321 } else {
322 let alpha = (score - x0) / (x1 - x0);
323 calibrated[i] = y0 + alpha * (y1 - y0);
324 }
325 }
326 }
327 Ok(calibrated)
328 }
329 pub fn fit_transform(
331 &mut self,
332 scores: &Array1<f64>,
333 labels: &Array1<usize>,
334 ) -> Result<Array1<f64>> {
335 self.fit(scores, labels)?;
336 self.transform(scores)
337 }
338}
339#[derive(Debug, Clone)]
343pub struct BayesianBinningQuantiles {
344 n_bins: usize,
346 bin_edges: Option<Vec<f64>>,
348 alphas: Option<Array1<f64>>,
350 betas: Option<Array1<f64>>,
352 fitted: bool,
354}
355impl BayesianBinningQuantiles {
356 pub fn new(n_bins: usize) -> Self {
358 Self {
359 n_bins,
360 bin_edges: None,
361 alphas: None,
362 betas: None,
363 fitted: false,
364 }
365 }
366 pub fn fit(&mut self, probabilities: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
368 if probabilities.len() != labels.len() {
369 return Err(MLError::InvalidInput(
370 "Probabilities and labels must have same length".to_string(),
371 ));
372 }
373 let n_samples = probabilities.len();
374 if n_samples < self.n_bins {
375 return Err(MLError::InvalidInput(format!(
376 "Need at least {} samples for {} bins, got {}",
377 self.n_bins, self.n_bins, n_samples
378 )));
379 }
380 let mut sorted_probs = probabilities.to_vec();
381 sorted_probs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
382 let mut bin_edges = vec![0.0];
383 for i in 1..self.n_bins {
384 let quantile_idx = (i as f64 / self.n_bins as f64 * n_samples as f64) as usize;
385 let quantile_idx = quantile_idx.min(sorted_probs.len() - 1);
386 bin_edges.push(sorted_probs[quantile_idx]);
387 }
388 bin_edges.push(1.0);
389 let mut bin_positives = vec![0.0; self.n_bins];
390 let mut bin_negatives = vec![0.0; self.n_bins];
391 for (i, &prob) in probabilities.iter().enumerate() {
392 let bin_idx = self.find_bin(&bin_edges, prob);
393 let label = labels[i];
394 if label == 1 {
395 bin_positives[bin_idx] += 1.0;
396 } else {
397 bin_negatives[bin_idx] += 1.0;
398 }
399 }
400 let prior_alpha = 0.5;
401 let prior_beta = 0.5;
402 let mut alphas = Array1::zeros(self.n_bins);
403 let mut betas = Array1::zeros(self.n_bins);
404 for i in 0..self.n_bins {
405 alphas[i] = prior_alpha + bin_positives[i];
406 betas[i] = prior_beta + bin_negatives[i];
407 }
408 self.bin_edges = Some(bin_edges);
409 self.alphas = Some(alphas);
410 self.betas = Some(betas);
411 self.fitted = true;
412 Ok(())
413 }
414 fn find_bin(&self, bin_edges: &[f64], prob: f64) -> usize {
416 for i in 0..bin_edges.len() - 1 {
417 if prob >= bin_edges[i] && prob < bin_edges[i + 1] {
418 return i;
419 }
420 }
421 bin_edges.len() - 2
422 }
423 pub fn transform(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
425 if !self.fitted {
426 return Err(MLError::InvalidInput(
427 "Calibrator must be fitted before transform".to_string(),
428 ));
429 }
430 let bin_edges = self.bin_edges.as_ref().unwrap();
431 let alphas = self.alphas.as_ref().unwrap();
432 let betas = self.betas.as_ref().unwrap();
433 let mut calibrated = Array1::zeros(probabilities.len());
434 for (i, &prob) in probabilities.iter().enumerate() {
435 let bin_idx = self.find_bin(bin_edges, prob);
436 let alpha = alphas[bin_idx];
437 let beta = betas[bin_idx];
438 calibrated[i] = alpha / (alpha + beta);
439 }
440 Ok(calibrated)
441 }
442 pub fn fit_transform(
444 &mut self,
445 probabilities: &Array1<f64>,
446 labels: &Array1<usize>,
447 ) -> Result<Array1<f64>> {
448 self.fit(probabilities, labels)?;
449 self.transform(probabilities)
450 }
451 pub fn predict_with_uncertainty(
454 &self,
455 probabilities: &Array1<f64>,
456 confidence: f64,
457 ) -> Result<Vec<(f64, f64, f64)>> {
458 if !self.fitted {
459 return Err(MLError::InvalidInput(
460 "Calibrator must be fitted before prediction".to_string(),
461 ));
462 }
463 if confidence <= 0.0 || confidence >= 1.0 {
464 return Err(MLError::InvalidInput(
465 "Confidence must be between 0 and 1".to_string(),
466 ));
467 }
468 let bin_edges = self.bin_edges.as_ref().unwrap();
469 let alphas = self.alphas.as_ref().unwrap();
470 let betas = self.betas.as_ref().unwrap();
471 let lower_quantile = (1.0 - confidence) / 2.0;
472 let upper_quantile = 1.0 - lower_quantile;
473 let mut results = Vec::new();
474 for &prob in probabilities.iter() {
475 let bin_idx = self.find_bin(bin_edges, prob);
476 let alpha = alphas[bin_idx];
477 let beta = betas[bin_idx];
478 let mean = alpha / (alpha + beta);
479 let n = alpha + beta - 1.0;
480 let p = alpha / (alpha + beta);
481 if n > 0.0 {
482 let z = 1.96;
483 let denominator = 1.0 + z * z / n;
484 let center = (p + z * z / (2.0 * n)) / denominator;
485 let margin = z * (p * (1.0 - p) / n + z * z / (4.0 * n * n)).sqrt() / denominator;
486 let lower = (center - margin).max(0.0);
487 let upper = (center + margin).min(1.0);
488 results.push((mean, lower, upper));
489 } else {
490 results.push((mean, 0.0, 1.0));
491 }
492 }
493 Ok(results)
494 }
495 pub fn n_bins(&self) -> usize {
497 self.n_bins
498 }
499 pub fn bin_statistics(&self) -> Option<(Vec<f64>, Array1<f64>, Array1<f64>)> {
501 if self.fitted {
502 Some((
503 self.bin_edges.as_ref().unwrap().clone(),
504 self.alphas.as_ref().unwrap().clone(),
505 self.betas.as_ref().unwrap().clone(),
506 ))
507 } else {
508 None
509 }
510 }
511}
512#[derive(Debug, Clone)]
515pub struct PlattScaler {
516 a: f64,
518 b: f64,
520 fitted: bool,
522}
523impl PlattScaler {
524 pub fn new() -> Self {
526 Self {
527 a: 1.0,
528 b: 0.0,
529 fitted: false,
530 }
531 }
532 pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
535 if scores.len() != labels.len() {
536 return Err(MLError::InvalidInput(
537 "Scores and labels must have same length".to_string(),
538 ));
539 }
540 let n = scores.len();
541 if n < 2 {
542 return Err(MLError::InvalidInput(
543 "Need at least 2 samples for calibration".to_string(),
544 ));
545 }
546 let y: Array1<f64> = labels
547 .iter()
548 .map(|&l| if l == 1 { 1.0 } else { -1.0 })
549 .collect();
550 let mut a = 0.0;
551 let mut b = 0.0;
552 let n_pos = labels.iter().filter(|&&l| l == 1).count() as f64;
553 let n_neg = n as f64 - n_pos;
554 let prior_pos = (n_pos + 1.0) / (n as f64 + 2.0);
555 b = (prior_pos / (1.0 - prior_pos)).ln();
556 for _ in 0..100 {
557 let mut fval = 0.0;
558 let mut fpp = 0.0;
559 for i in 0..n {
560 let fapb = scores[i] * a + b;
561 let p = 1.0 / (1.0 + (-fapb).exp());
562 let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
563 fval += scores[i] * (t - p);
564 fpp += scores[i] * scores[i] * p * (1.0 - p);
565 }
566 if fpp.abs() < 1e-12 {
567 break;
568 }
569 let delta = fval / fpp;
570 a += delta;
571 if delta.abs() < 1e-7 {
572 break;
573 }
574 }
575 for _ in 0..100 {
576 let mut fval = 0.0;
577 let mut fpp = 0.0;
578 for i in 0..n {
579 let fapb = scores[i] * a + b;
580 let p = 1.0 / (1.0 + (-fapb).exp());
581 let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
582 fval += t - p;
583 fpp += p * (1.0 - p);
584 }
585 if fpp.abs() < 1e-12 {
586 break;
587 }
588 let delta = fval / fpp;
589 b += delta;
590 if delta.abs() < 1e-7 {
591 break;
592 }
593 }
594 self.a = a;
595 self.b = b;
596 self.fitted = true;
597 Ok(())
598 }
599 pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
601 if !self.fitted {
602 return Err(MLError::InvalidInput(
603 "Scaler must be fitted before transform".to_string(),
604 ));
605 }
606 let probs = scores.mapv(|s| {
607 let fapb = s * self.a + self.b;
608 1.0 / (1.0 + (-fapb).exp())
609 });
610 Ok(probs)
611 }
612 pub fn fit_transform(
614 &mut self,
615 scores: &Array1<f64>,
616 labels: &Array1<usize>,
617 ) -> Result<Array1<f64>> {
618 self.fit(scores, labels)?;
619 self.transform(scores)
620 }
621 pub fn parameters(&self) -> Option<(f64, f64)> {
623 if self.fitted {
624 Some((self.a, self.b))
625 } else {
626 None
627 }
628 }
629}
630#[derive(Debug, Clone)]
634pub struct TemperatureScaler {
635 temperature: f64,
637 fitted: bool,
639}
640impl TemperatureScaler {
641 pub fn new() -> Self {
643 Self {
644 temperature: 1.0,
645 fitted: false,
646 }
647 }
648 pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
651 if logits.nrows() != labels.len() {
652 return Err(MLError::InvalidInput(
653 "Logits and labels must have same number of samples".to_string(),
654 ));
655 }
656 let n_samples = logits.nrows();
657 if n_samples < 2 {
658 return Err(MLError::InvalidInput(
659 "Need at least 2 samples for calibration".to_string(),
660 ));
661 }
662 let mut best_temp = 1.0;
663 let mut best_nll = f64::INFINITY;
664 for t_candidate in [0.1, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0] {
665 let nll = self.compute_nll(logits, labels, t_candidate)?;
666 if nll < best_nll {
667 best_nll = nll;
668 best_temp = t_candidate;
669 }
670 }
671 let mut temperature = best_temp;
672 let learning_rate = 0.01;
673 for _ in 0..100 {
674 let nll_current = self.compute_nll(logits, labels, temperature)?;
675 let nll_plus = self.compute_nll(logits, labels, temperature + 0.01)?;
676 let gradient = (nll_plus - nll_current) / 0.01;
677 let new_temp = temperature - learning_rate * gradient;
678 if new_temp <= 0.01 {
679 break;
680 }
681 temperature = new_temp;
682 if gradient.abs() < 1e-5 {
683 break;
684 }
685 }
686 self.temperature = temperature.max(0.01);
687 self.fitted = true;
688 Ok(())
689 }
690 fn compute_nll(
692 &self,
693 logits: &Array2<f64>,
694 labels: &Array1<usize>,
695 temperature: f64,
696 ) -> Result<f64> {
697 let mut nll = 0.0;
698 let n_samples = logits.nrows();
699 for i in 0..n_samples {
700 let scaled_logits = logits.row(i).mapv(|x| x / temperature);
701 let max_logit = scaled_logits
702 .iter()
703 .cloned()
704 .fold(f64::NEG_INFINITY, f64::max);
705 let exp_logits: Vec<f64> = scaled_logits
706 .iter()
707 .map(|&x| (x - max_logit).exp())
708 .collect();
709 let sum_exp: f64 = exp_logits.iter().sum();
710 let true_label = labels[i];
711 if true_label >= exp_logits.len() {
712 return Err(MLError::InvalidInput(format!(
713 "Label {} out of bounds for {} classes",
714 true_label,
715 exp_logits.len()
716 )));
717 }
718 let prob = exp_logits[true_label] / sum_exp;
719 nll -= prob.max(1e-10).ln();
720 }
721 Ok(nll / n_samples as f64)
722 }
723 pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
725 if !self.fitted {
726 return Err(MLError::InvalidInput(
727 "Scaler must be fitted before transform".to_string(),
728 ));
729 }
730 let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
731 for i in 0..logits.nrows() {
732 let scaled_logits = logits.row(i).mapv(|x| x / self.temperature);
733 let max_logit = scaled_logits
734 .iter()
735 .cloned()
736 .fold(f64::NEG_INFINITY, f64::max);
737 let exp_logits: Vec<f64> = scaled_logits
738 .iter()
739 .map(|&x| (x - max_logit).exp())
740 .collect();
741 let sum_exp: f64 = exp_logits.iter().sum();
742 for j in 0..logits.ncols() {
743 calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
744 }
745 }
746 Ok(calibrated_probs)
747 }
748 pub fn fit_transform(
750 &mut self,
751 logits: &Array2<f64>,
752 labels: &Array1<usize>,
753 ) -> Result<Array2<f64>> {
754 self.fit(logits, labels)?;
755 self.transform(logits)
756 }
757 pub fn temperature(&self) -> Option<f64> {
759 if self.fitted {
760 Some(self.temperature)
761 } else {
762 None
763 }
764 }
765}
766#[derive(Debug, Clone)]
770pub struct VectorScaler {
771 weights: Option<Array1<f64>>,
773 biases: Option<Array1<f64>>,
775 fitted: bool,
777}
778impl VectorScaler {
779 pub fn new() -> Self {
781 Self {
782 weights: None,
783 biases: None,
784 fitted: false,
785 }
786 }
787 pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
790 if logits.nrows() != labels.len() {
791 return Err(MLError::InvalidInput(
792 "Logits and labels must have same number of samples".to_string(),
793 ));
794 }
795 let n_samples = logits.nrows();
796 let n_classes = logits.ncols();
797 if n_samples < 2 {
798 return Err(MLError::InvalidInput(
799 "Need at least 2 samples for calibration".to_string(),
800 ));
801 }
802 let mut weights = Array1::ones(n_classes);
803 let mut biases = Array1::zeros(n_classes);
804 let learning_rate = 0.01;
805 let max_iter = 200;
806 let tolerance = 1e-6;
807 let mut prev_nll = f64::INFINITY;
808 for iter in 0..max_iter {
809 let nll = self.compute_nll_vec(logits, labels, &weights, &biases)?;
810 if (prev_nll - nll).abs() < tolerance {
811 break;
812 }
813 prev_nll = nll;
814 let epsilon = 1e-6;
815 let mut weight_grads = Array1::zeros(n_classes);
816 let mut bias_grads = Array1::zeros(n_classes);
817 for j in 0..n_classes {
818 let mut weights_plus = weights.clone();
819 weights_plus[j] += epsilon;
820 let nll_plus = self.compute_nll_vec(logits, labels, &weights_plus, &biases)?;
821 weight_grads[j] = (nll_plus - nll) / epsilon;
822 let mut biases_plus = biases.clone();
823 biases_plus[j] += epsilon;
824 let nll_plus = self.compute_nll_vec(logits, labels, &weights, &biases_plus)?;
825 bias_grads[j] = (nll_plus - nll) / epsilon;
826 }
827 weights = &weights - &weight_grads.mapv(|g| learning_rate * g);
828 biases = &biases - &bias_grads.mapv(|g| learning_rate * g);
829 weights.mapv_inplace(|w| w.max(0.01));
830 if weight_grads.iter().all(|&g| g.abs() < tolerance)
831 && bias_grads.iter().all(|&g| g.abs() < tolerance)
832 {
833 break;
834 }
835 }
836 self.weights = Some(weights);
837 self.biases = Some(biases);
838 self.fitted = true;
839 Ok(())
840 }
841 fn compute_nll_vec(
843 &self,
844 logits: &Array2<f64>,
845 labels: &Array1<usize>,
846 weights: &Array1<f64>,
847 biases: &Array1<f64>,
848 ) -> Result<f64> {
849 let mut nll = 0.0;
850 let n_samples = logits.nrows();
851 for i in 0..n_samples {
852 let scaled_logits = logits.row(i).to_owned() * weights + biases;
853 let max_logit = scaled_logits
854 .iter()
855 .cloned()
856 .fold(f64::NEG_INFINITY, f64::max);
857 let exp_logits: Vec<f64> = scaled_logits
858 .iter()
859 .map(|&x| (x - max_logit).exp())
860 .collect();
861 let sum_exp: f64 = exp_logits.iter().sum();
862 let true_label = labels[i];
863 if true_label >= exp_logits.len() {
864 return Err(MLError::InvalidInput(format!(
865 "Label {} out of bounds for {} classes",
866 true_label,
867 exp_logits.len()
868 )));
869 }
870 let prob = exp_logits[true_label] / sum_exp;
871 nll -= prob.max(1e-10).ln();
872 }
873 Ok(nll / n_samples as f64)
874 }
875 pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
877 if !self.fitted {
878 return Err(MLError::InvalidInput(
879 "Scaler must be fitted before transform".to_string(),
880 ));
881 }
882 let weights = self.weights.as_ref().unwrap();
883 let biases = self.biases.as_ref().unwrap();
884 let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
885 for i in 0..logits.nrows() {
886 let scaled_logits = logits.row(i).to_owned() * weights + biases;
887 let max_logit = scaled_logits
888 .iter()
889 .cloned()
890 .fold(f64::NEG_INFINITY, f64::max);
891 let exp_logits: Vec<f64> = scaled_logits
892 .iter()
893 .map(|&x| (x - max_logit).exp())
894 .collect();
895 let sum_exp: f64 = exp_logits.iter().sum();
896 for j in 0..logits.ncols() {
897 calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
898 }
899 }
900 Ok(calibrated_probs)
901 }
902 pub fn fit_transform(
904 &mut self,
905 logits: &Array2<f64>,
906 labels: &Array1<usize>,
907 ) -> Result<Array2<f64>> {
908 self.fit(logits, labels)?;
909 self.transform(logits)
910 }
911 pub fn parameters(&self) -> Option<(Array1<f64>, Array1<f64>)> {
913 if self.fitted {
914 Some((
915 self.weights.as_ref().unwrap().clone(),
916 self.biases.as_ref().unwrap().clone(),
917 ))
918 } else {
919 None
920 }
921 }
922}