1use crate::error::{MLError, Result};
6use scirs2_core::ndarray::{Array1, Array2};
7
8#[derive(Debug, Clone)]
12pub struct MatrixScaler {
13 weight_matrix: Option<Array2<f64>>,
15 bias_vector: Option<Array1<f64>>,
17 fitted: bool,
19 regularization: f64,
21}
22impl MatrixScaler {
23 pub fn new() -> Self {
25 Self {
26 weight_matrix: None,
27 bias_vector: None,
28 fitted: false,
29 regularization: 0.01,
30 }
31 }
32 pub fn with_regularization(regularization: f64) -> Self {
34 Self {
35 weight_matrix: None,
36 bias_vector: None,
37 fitted: false,
38 regularization,
39 }
40 }
41 pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
44 if logits.nrows() != labels.len() {
45 return Err(MLError::InvalidInput(
46 "Logits and labels must have same number of samples".to_string(),
47 ));
48 }
49 let n_samples = logits.nrows();
50 let n_classes = logits.ncols();
51 if n_samples < n_classes * 2 {
52 return Err(MLError::InvalidInput(format!(
53 "Need at least {} samples for {} classes (matrix calibration)",
54 n_classes * 2,
55 n_classes
56 )));
57 }
58 let mut weight_matrix = Array2::eye(n_classes);
59 let mut bias_vector = Array1::zeros(n_classes);
60 let learning_rate = 0.001;
61 let max_iter = 300;
62 let tolerance = 1e-7;
63 let mut prev_nll = f64::INFINITY;
64 for _iter in 0..max_iter {
65 let (nll, reg_term) =
66 self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_vector)?;
67 let total_loss = nll + reg_term;
68 if (prev_nll - total_loss).abs() < tolerance {
69 break;
70 }
71 prev_nll = total_loss;
72 let epsilon = 1e-6;
73 let mut weight_grads = Array2::zeros((n_classes, n_classes));
74 let mut bias_grads = Array1::zeros(n_classes);
75 for i in 0..n_classes {
76 for j in 0..n_classes {
77 let mut weight_plus = weight_matrix.clone();
78 weight_plus[(i, j)] += epsilon;
79 let (nll_plus, reg_plus) =
80 self.compute_nll_with_reg(logits, labels, &weight_plus, &bias_vector)?;
81 weight_grads[(i, j)] = (nll_plus + reg_plus - total_loss) / epsilon;
82 }
83 }
84 for j in 0..n_classes {
85 let mut bias_plus = bias_vector.clone();
86 bias_plus[j] += epsilon;
87 let (nll_plus, reg_plus) =
88 self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_plus)?;
89 bias_grads[j] = (nll_plus + reg_plus - total_loss) / epsilon;
90 }
91 weight_matrix = &weight_matrix - &weight_grads.mapv(|g| learning_rate * g);
92 bias_vector = &bias_vector - &bias_grads.mapv(|g| learning_rate * g);
93 for i in 0..n_classes {
94 weight_matrix[(i, i)] = weight_matrix[(i, i)].max(0.01);
95 }
96 let grad_norm = weight_grads.iter().map(|&g| g * g).sum::<f64>().sqrt()
97 + bias_grads.iter().map(|&g| g * g).sum::<f64>().sqrt();
98 if grad_norm < tolerance {
99 break;
100 }
101 }
102 self.weight_matrix = Some(weight_matrix);
103 self.bias_vector = Some(bias_vector);
104 self.fitted = true;
105 Ok(())
106 }
107 fn compute_nll_with_reg(
109 &self,
110 logits: &Array2<f64>,
111 labels: &Array1<usize>,
112 weight_matrix: &Array2<f64>,
113 bias_vector: &Array1<f64>,
114 ) -> Result<(f64, f64)> {
115 let mut nll = 0.0;
116 let n_samples = logits.nrows();
117 let n_classes = logits.ncols();
118 for i in 0..n_samples {
119 let logits_row = logits.row(i);
120 let mut scaled_logits = Array1::zeros(n_classes);
121 for j in 0..n_classes {
122 let mut val = bias_vector[j];
123 for k in 0..n_classes {
124 val += weight_matrix[(j, k)] * logits_row[k];
125 }
126 scaled_logits[j] = val;
127 }
128 let max_logit = scaled_logits
129 .iter()
130 .cloned()
131 .fold(f64::NEG_INFINITY, f64::max);
132 let exp_logits: Vec<f64> = scaled_logits
133 .iter()
134 .map(|&x| (x - max_logit).exp())
135 .collect();
136 let sum_exp: f64 = exp_logits.iter().sum();
137 let true_label = labels[i];
138 if true_label >= exp_logits.len() {
139 return Err(MLError::InvalidInput(format!(
140 "Label {} out of bounds for {} classes",
141 true_label,
142 exp_logits.len()
143 )));
144 }
145 let prob = exp_logits[true_label] / sum_exp;
146 nll -= prob.max(1e-10).ln();
147 }
148 nll /= n_samples as f64;
149 let mut reg_term = 0.0;
150 for i in 0..n_classes {
151 for j in 0..n_classes {
152 if i != j {
153 reg_term += weight_matrix[(i, j)].powi(2);
154 }
155 }
156 }
157 reg_term *= self.regularization;
158 Ok((nll, reg_term))
159 }
160 pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
162 if !self.fitted {
163 return Err(MLError::InvalidInput(
164 "Scaler must be fitted before transform".to_string(),
165 ));
166 }
167 let weight_matrix = self
168 .weight_matrix
169 .as_ref()
170 .ok_or_else(|| MLError::InvalidInput("Weight matrix not initialized".to_string()))?;
171 let bias_vector = self
172 .bias_vector
173 .as_ref()
174 .ok_or_else(|| MLError::InvalidInput("Bias vector not initialized".to_string()))?;
175 let n_classes = logits.ncols();
176 let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
177 for i in 0..logits.nrows() {
178 let logits_row = logits.row(i);
179 let mut scaled_logits = Array1::zeros(n_classes);
180 for j in 0..n_classes {
181 let mut val = bias_vector[j];
182 for k in 0..n_classes {
183 val += weight_matrix[(j, k)] * logits_row[k];
184 }
185 scaled_logits[j] = val;
186 }
187 let max_logit = scaled_logits
188 .iter()
189 .cloned()
190 .fold(f64::NEG_INFINITY, f64::max);
191 let exp_logits: Vec<f64> = scaled_logits
192 .iter()
193 .map(|&x| (x - max_logit).exp())
194 .collect();
195 let sum_exp: f64 = exp_logits.iter().sum();
196 for j in 0..logits.ncols() {
197 calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
198 }
199 }
200 Ok(calibrated_probs)
201 }
202 pub fn fit_transform(
204 &mut self,
205 logits: &Array2<f64>,
206 labels: &Array1<usize>,
207 ) -> Result<Array2<f64>> {
208 self.fit(logits, labels)?;
209 self.transform(logits)
210 }
211 pub fn parameters(&self) -> Option<(Array2<f64>, Array1<f64>)> {
213 if self.fitted {
214 Some((
215 self.weight_matrix.as_ref()?.clone(),
216 self.bias_vector.as_ref()?.clone(),
217 ))
218 } else {
219 None
220 }
221 }
222 pub fn condition_number(&self) -> Option<f64> {
225 if !self.fitted {
226 return None;
227 }
228 let w = self.weight_matrix.as_ref()?;
229 let norm = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
230 Some(norm)
231 }
232}
233#[derive(Debug, Clone)]
236pub struct IsotonicRegression {
237 x_thresholds: Vec<f64>,
239 y_thresholds: Vec<f64>,
241 fitted: bool,
243}
244impl IsotonicRegression {
245 pub fn new() -> Self {
247 Self {
248 x_thresholds: Vec::new(),
249 y_thresholds: Vec::new(),
250 fitted: false,
251 }
252 }
253 pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
255 if scores.len() != labels.len() {
256 return Err(MLError::InvalidInput(
257 "Scores and labels must have same length".to_string(),
258 ));
259 }
260 let n = scores.len();
261 if n < 2 {
262 return Err(MLError::InvalidInput(
263 "Need at least 2 samples for calibration".to_string(),
264 ));
265 }
266 let mut pairs: Vec<(f64, f64)> = scores
267 .iter()
268 .zip(labels.iter())
269 .map(|(&s, &l)| (s, l as f64))
270 .collect();
271 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
272 let mut x = Vec::new();
273 let mut y = Vec::new();
274 let mut weights = Vec::new();
275 for (score, label) in pairs {
276 x.push(score);
277 y.push(label);
278 weights.push(1.0);
279 }
280 let mut i = 0;
281 while i < y.len() - 1 {
282 if y[i] > y[i + 1] {
283 let w1 = weights[i];
284 let w2 = weights[i + 1];
285 let total_weight = w1 + w2;
286 y[i] = (y[i] * w1 + y[i + 1] * w2) / total_weight;
287 weights[i] = total_weight;
288 y.remove(i + 1);
289 x.remove(i + 1);
290 weights.remove(i + 1);
291 if i > 0 {
292 i -= 1;
293 }
294 } else {
295 i += 1;
296 }
297 }
298 self.x_thresholds = x;
299 self.y_thresholds = y;
300 self.fitted = true;
301 Ok(())
302 }
303 pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
305 if !self.fitted {
306 return Err(MLError::InvalidInput(
307 "Regressor must be fitted before transform".to_string(),
308 ));
309 }
310 let mut calibrated = Array1::zeros(scores.len());
311 for (i, &score) in scores.iter().enumerate() {
312 let pos = self
313 .x_thresholds
314 .binary_search_by(|&x| x.partial_cmp(&score).unwrap_or(std::cmp::Ordering::Less))
315 .unwrap_or_else(|e| e);
316 if pos == 0 {
317 calibrated[i] = self.y_thresholds[0];
318 } else if pos >= self.x_thresholds.len() {
319 calibrated[i] = self.y_thresholds.last().copied().unwrap_or(0.0);
320 } else {
321 let x0 = self.x_thresholds[pos - 1];
322 let x1 = self.x_thresholds[pos];
323 let y0 = self.y_thresholds[pos - 1];
324 let y1 = self.y_thresholds[pos];
325 if (x1 - x0).abs() < 1e-10 {
326 calibrated[i] = (y0 + y1) / 2.0;
327 } else {
328 let alpha = (score - x0) / (x1 - x0);
329 calibrated[i] = y0 + alpha * (y1 - y0);
330 }
331 }
332 }
333 Ok(calibrated)
334 }
335 pub fn fit_transform(
337 &mut self,
338 scores: &Array1<f64>,
339 labels: &Array1<usize>,
340 ) -> Result<Array1<f64>> {
341 self.fit(scores, labels)?;
342 self.transform(scores)
343 }
344}
345#[derive(Debug, Clone)]
349pub struct BayesianBinningQuantiles {
350 n_bins: usize,
352 bin_edges: Option<Vec<f64>>,
354 alphas: Option<Array1<f64>>,
356 betas: Option<Array1<f64>>,
358 fitted: bool,
360}
361impl BayesianBinningQuantiles {
362 pub fn new(n_bins: usize) -> Self {
364 Self {
365 n_bins,
366 bin_edges: None,
367 alphas: None,
368 betas: None,
369 fitted: false,
370 }
371 }
372 pub fn fit(&mut self, probabilities: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
374 if probabilities.len() != labels.len() {
375 return Err(MLError::InvalidInput(
376 "Probabilities and labels must have same length".to_string(),
377 ));
378 }
379 let n_samples = probabilities.len();
380 if n_samples < self.n_bins {
381 return Err(MLError::InvalidInput(format!(
382 "Need at least {} samples for {} bins, got {}",
383 self.n_bins, self.n_bins, n_samples
384 )));
385 }
386 let mut sorted_probs = probabilities.to_vec();
387 sorted_probs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
388 let mut bin_edges = vec![0.0];
389 for i in 1..self.n_bins {
390 let quantile_idx = (i as f64 / self.n_bins as f64 * n_samples as f64) as usize;
391 let quantile_idx = quantile_idx.min(sorted_probs.len() - 1);
392 bin_edges.push(sorted_probs[quantile_idx]);
393 }
394 bin_edges.push(1.0);
395 let mut bin_positives = vec![0.0; self.n_bins];
396 let mut bin_negatives = vec![0.0; self.n_bins];
397 for (i, &prob) in probabilities.iter().enumerate() {
398 let bin_idx = self.find_bin(&bin_edges, prob);
399 let label = labels[i];
400 if label == 1 {
401 bin_positives[bin_idx] += 1.0;
402 } else {
403 bin_negatives[bin_idx] += 1.0;
404 }
405 }
406 let prior_alpha = 0.5;
407 let prior_beta = 0.5;
408 let mut alphas = Array1::zeros(self.n_bins);
409 let mut betas = Array1::zeros(self.n_bins);
410 for i in 0..self.n_bins {
411 alphas[i] = prior_alpha + bin_positives[i];
412 betas[i] = prior_beta + bin_negatives[i];
413 }
414 self.bin_edges = Some(bin_edges);
415 self.alphas = Some(alphas);
416 self.betas = Some(betas);
417 self.fitted = true;
418 Ok(())
419 }
420 fn find_bin(&self, bin_edges: &[f64], prob: f64) -> usize {
422 for i in 0..bin_edges.len() - 1 {
423 if prob >= bin_edges[i] && prob < bin_edges[i + 1] {
424 return i;
425 }
426 }
427 bin_edges.len() - 2
428 }
429 pub fn transform(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
431 if !self.fitted {
432 return Err(MLError::InvalidInput(
433 "Calibrator must be fitted before transform".to_string(),
434 ));
435 }
436 let bin_edges = self
437 .bin_edges
438 .as_ref()
439 .ok_or_else(|| MLError::InvalidInput("Bin edges not initialized".to_string()))?;
440 let alphas = self
441 .alphas
442 .as_ref()
443 .ok_or_else(|| MLError::InvalidInput("Alphas not initialized".to_string()))?;
444 let betas = self
445 .betas
446 .as_ref()
447 .ok_or_else(|| MLError::InvalidInput("Betas not initialized".to_string()))?;
448 let mut calibrated = Array1::zeros(probabilities.len());
449 for (i, &prob) in probabilities.iter().enumerate() {
450 let bin_idx = self.find_bin(bin_edges, prob);
451 let alpha = alphas[bin_idx];
452 let beta = betas[bin_idx];
453 calibrated[i] = alpha / (alpha + beta);
454 }
455 Ok(calibrated)
456 }
457 pub fn fit_transform(
459 &mut self,
460 probabilities: &Array1<f64>,
461 labels: &Array1<usize>,
462 ) -> Result<Array1<f64>> {
463 self.fit(probabilities, labels)?;
464 self.transform(probabilities)
465 }
466 pub fn predict_with_uncertainty(
469 &self,
470 probabilities: &Array1<f64>,
471 confidence: f64,
472 ) -> Result<Vec<(f64, f64, f64)>> {
473 if !self.fitted {
474 return Err(MLError::InvalidInput(
475 "Calibrator must be fitted before prediction".to_string(),
476 ));
477 }
478 if confidence <= 0.0 || confidence >= 1.0 {
479 return Err(MLError::InvalidInput(
480 "Confidence must be between 0 and 1".to_string(),
481 ));
482 }
483 let bin_edges = self
484 .bin_edges
485 .as_ref()
486 .ok_or_else(|| MLError::InvalidInput("Bin edges not initialized".to_string()))?;
487 let alphas = self
488 .alphas
489 .as_ref()
490 .ok_or_else(|| MLError::InvalidInput("Alphas not initialized".to_string()))?;
491 let betas = self
492 .betas
493 .as_ref()
494 .ok_or_else(|| MLError::InvalidInput("Betas not initialized".to_string()))?;
495 let lower_quantile = (1.0 - confidence) / 2.0;
496 let upper_quantile = 1.0 - lower_quantile;
497 let mut results = Vec::new();
498 for &prob in probabilities.iter() {
499 let bin_idx = self.find_bin(bin_edges, prob);
500 let alpha = alphas[bin_idx];
501 let beta = betas[bin_idx];
502 let mean = alpha / (alpha + beta);
503 let n = alpha + beta - 1.0;
504 let p = alpha / (alpha + beta);
505 if n > 0.0 {
506 let z = 1.96;
507 let denominator = 1.0 + z * z / n;
508 let center = (p + z * z / (2.0 * n)) / denominator;
509 let margin = z * (p * (1.0 - p) / n + z * z / (4.0 * n * n)).sqrt() / denominator;
510 let lower = (center - margin).max(0.0);
511 let upper = (center + margin).min(1.0);
512 results.push((mean, lower, upper));
513 } else {
514 results.push((mean, 0.0, 1.0));
515 }
516 }
517 Ok(results)
518 }
519 pub fn n_bins(&self) -> usize {
521 self.n_bins
522 }
523 pub fn bin_statistics(&self) -> Option<(Vec<f64>, Array1<f64>, Array1<f64>)> {
525 if self.fitted {
526 Some((
527 self.bin_edges.as_ref()?.clone(),
528 self.alphas.as_ref()?.clone(),
529 self.betas.as_ref()?.clone(),
530 ))
531 } else {
532 None
533 }
534 }
535}
536#[derive(Debug, Clone)]
539pub struct PlattScaler {
540 a: f64,
542 b: f64,
544 fitted: bool,
546}
547impl PlattScaler {
548 pub fn new() -> Self {
550 Self {
551 a: 1.0,
552 b: 0.0,
553 fitted: false,
554 }
555 }
556 pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
559 if scores.len() != labels.len() {
560 return Err(MLError::InvalidInput(
561 "Scores and labels must have same length".to_string(),
562 ));
563 }
564 let n = scores.len();
565 if n < 2 {
566 return Err(MLError::InvalidInput(
567 "Need at least 2 samples for calibration".to_string(),
568 ));
569 }
570 let y: Array1<f64> = labels
571 .iter()
572 .map(|&l| if l == 1 { 1.0 } else { -1.0 })
573 .collect();
574 let mut a = 0.0;
575 let mut b = 0.0;
576 let n_pos = labels.iter().filter(|&&l| l == 1).count() as f64;
577 let n_neg = n as f64 - n_pos;
578 let prior_pos = (n_pos + 1.0) / (n as f64 + 2.0);
579 b = (prior_pos / (1.0 - prior_pos)).ln();
580 for _ in 0..100 {
581 let mut fval = 0.0;
582 let mut fpp = 0.0;
583 for i in 0..n {
584 let fapb = scores[i] * a + b;
585 let p = 1.0 / (1.0 + (-fapb).exp());
586 let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
587 fval += scores[i] * (t - p);
588 fpp += scores[i] * scores[i] * p * (1.0 - p);
589 }
590 if fpp.abs() < 1e-12 {
591 break;
592 }
593 let delta = fval / fpp;
594 a += delta;
595 if delta.abs() < 1e-7 {
596 break;
597 }
598 }
599 for _ in 0..100 {
600 let mut fval = 0.0;
601 let mut fpp = 0.0;
602 for i in 0..n {
603 let fapb = scores[i] * a + b;
604 let p = 1.0 / (1.0 + (-fapb).exp());
605 let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
606 fval += t - p;
607 fpp += p * (1.0 - p);
608 }
609 if fpp.abs() < 1e-12 {
610 break;
611 }
612 let delta = fval / fpp;
613 b += delta;
614 if delta.abs() < 1e-7 {
615 break;
616 }
617 }
618 self.a = a;
619 self.b = b;
620 self.fitted = true;
621 Ok(())
622 }
623 pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
625 if !self.fitted {
626 return Err(MLError::InvalidInput(
627 "Scaler must be fitted before transform".to_string(),
628 ));
629 }
630 let probs = scores.mapv(|s| {
631 let fapb = s * self.a + self.b;
632 1.0 / (1.0 + (-fapb).exp())
633 });
634 Ok(probs)
635 }
636 pub fn fit_transform(
638 &mut self,
639 scores: &Array1<f64>,
640 labels: &Array1<usize>,
641 ) -> Result<Array1<f64>> {
642 self.fit(scores, labels)?;
643 self.transform(scores)
644 }
645 pub fn parameters(&self) -> Option<(f64, f64)> {
647 if self.fitted {
648 Some((self.a, self.b))
649 } else {
650 None
651 }
652 }
653}
654#[derive(Debug, Clone)]
658pub struct TemperatureScaler {
659 temperature: f64,
661 fitted: bool,
663}
664impl TemperatureScaler {
665 pub fn new() -> Self {
667 Self {
668 temperature: 1.0,
669 fitted: false,
670 }
671 }
672 pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
675 if logits.nrows() != labels.len() {
676 return Err(MLError::InvalidInput(
677 "Logits and labels must have same number of samples".to_string(),
678 ));
679 }
680 let n_samples = logits.nrows();
681 if n_samples < 2 {
682 return Err(MLError::InvalidInput(
683 "Need at least 2 samples for calibration".to_string(),
684 ));
685 }
686 let mut best_temp = 1.0;
687 let mut best_nll = f64::INFINITY;
688 for t_candidate in [0.1, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0] {
689 let nll = self.compute_nll(logits, labels, t_candidate)?;
690 if nll < best_nll {
691 best_nll = nll;
692 best_temp = t_candidate;
693 }
694 }
695 let mut temperature = best_temp;
696 let learning_rate = 0.01;
697 for _ in 0..100 {
698 let nll_current = self.compute_nll(logits, labels, temperature)?;
699 let nll_plus = self.compute_nll(logits, labels, temperature + 0.01)?;
700 let gradient = (nll_plus - nll_current) / 0.01;
701 let new_temp = temperature - learning_rate * gradient;
702 if new_temp <= 0.01 {
703 break;
704 }
705 temperature = new_temp;
706 if gradient.abs() < 1e-5 {
707 break;
708 }
709 }
710 self.temperature = temperature.max(0.01);
711 self.fitted = true;
712 Ok(())
713 }
714 fn compute_nll(
716 &self,
717 logits: &Array2<f64>,
718 labels: &Array1<usize>,
719 temperature: f64,
720 ) -> Result<f64> {
721 let mut nll = 0.0;
722 let n_samples = logits.nrows();
723 for i in 0..n_samples {
724 let scaled_logits = logits.row(i).mapv(|x| x / temperature);
725 let max_logit = scaled_logits
726 .iter()
727 .cloned()
728 .fold(f64::NEG_INFINITY, f64::max);
729 let exp_logits: Vec<f64> = scaled_logits
730 .iter()
731 .map(|&x| (x - max_logit).exp())
732 .collect();
733 let sum_exp: f64 = exp_logits.iter().sum();
734 let true_label = labels[i];
735 if true_label >= exp_logits.len() {
736 return Err(MLError::InvalidInput(format!(
737 "Label {} out of bounds for {} classes",
738 true_label,
739 exp_logits.len()
740 )));
741 }
742 let prob = exp_logits[true_label] / sum_exp;
743 nll -= prob.max(1e-10).ln();
744 }
745 Ok(nll / n_samples as f64)
746 }
747 pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
749 if !self.fitted {
750 return Err(MLError::InvalidInput(
751 "Scaler must be fitted before transform".to_string(),
752 ));
753 }
754 let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
755 for i in 0..logits.nrows() {
756 let scaled_logits = logits.row(i).mapv(|x| x / self.temperature);
757 let max_logit = scaled_logits
758 .iter()
759 .cloned()
760 .fold(f64::NEG_INFINITY, f64::max);
761 let exp_logits: Vec<f64> = scaled_logits
762 .iter()
763 .map(|&x| (x - max_logit).exp())
764 .collect();
765 let sum_exp: f64 = exp_logits.iter().sum();
766 for j in 0..logits.ncols() {
767 calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
768 }
769 }
770 Ok(calibrated_probs)
771 }
772 pub fn fit_transform(
774 &mut self,
775 logits: &Array2<f64>,
776 labels: &Array1<usize>,
777 ) -> Result<Array2<f64>> {
778 self.fit(logits, labels)?;
779 self.transform(logits)
780 }
781 pub fn temperature(&self) -> Option<f64> {
783 if self.fitted {
784 Some(self.temperature)
785 } else {
786 None
787 }
788 }
789}
790#[derive(Debug, Clone)]
794pub struct VectorScaler {
795 weights: Option<Array1<f64>>,
797 biases: Option<Array1<f64>>,
799 fitted: bool,
801}
802impl VectorScaler {
803 pub fn new() -> Self {
805 Self {
806 weights: None,
807 biases: None,
808 fitted: false,
809 }
810 }
811 pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
814 if logits.nrows() != labels.len() {
815 return Err(MLError::InvalidInput(
816 "Logits and labels must have same number of samples".to_string(),
817 ));
818 }
819 let n_samples = logits.nrows();
820 let n_classes = logits.ncols();
821 if n_samples < 2 {
822 return Err(MLError::InvalidInput(
823 "Need at least 2 samples for calibration".to_string(),
824 ));
825 }
826 let mut weights = Array1::ones(n_classes);
827 let mut biases = Array1::zeros(n_classes);
828 let learning_rate = 0.01;
829 let max_iter = 200;
830 let tolerance = 1e-6;
831 let mut prev_nll = f64::INFINITY;
832 for iter in 0..max_iter {
833 let nll = self.compute_nll_vec(logits, labels, &weights, &biases)?;
834 if (prev_nll - nll).abs() < tolerance {
835 break;
836 }
837 prev_nll = nll;
838 let epsilon = 1e-6;
839 let mut weight_grads = Array1::zeros(n_classes);
840 let mut bias_grads = Array1::zeros(n_classes);
841 for j in 0..n_classes {
842 let mut weights_plus = weights.clone();
843 weights_plus[j] += epsilon;
844 let nll_plus = self.compute_nll_vec(logits, labels, &weights_plus, &biases)?;
845 weight_grads[j] = (nll_plus - nll) / epsilon;
846 let mut biases_plus = biases.clone();
847 biases_plus[j] += epsilon;
848 let nll_plus = self.compute_nll_vec(logits, labels, &weights, &biases_plus)?;
849 bias_grads[j] = (nll_plus - nll) / epsilon;
850 }
851 weights = &weights - &weight_grads.mapv(|g| learning_rate * g);
852 biases = &biases - &bias_grads.mapv(|g| learning_rate * g);
853 weights.mapv_inplace(|w| w.max(0.01));
854 if weight_grads.iter().all(|&g| g.abs() < tolerance)
855 && bias_grads.iter().all(|&g| g.abs() < tolerance)
856 {
857 break;
858 }
859 }
860 self.weights = Some(weights);
861 self.biases = Some(biases);
862 self.fitted = true;
863 Ok(())
864 }
865 fn compute_nll_vec(
867 &self,
868 logits: &Array2<f64>,
869 labels: &Array1<usize>,
870 weights: &Array1<f64>,
871 biases: &Array1<f64>,
872 ) -> Result<f64> {
873 let mut nll = 0.0;
874 let n_samples = logits.nrows();
875 for i in 0..n_samples {
876 let scaled_logits = logits.row(i).to_owned() * weights + biases;
877 let max_logit = scaled_logits
878 .iter()
879 .cloned()
880 .fold(f64::NEG_INFINITY, f64::max);
881 let exp_logits: Vec<f64> = scaled_logits
882 .iter()
883 .map(|&x| (x - max_logit).exp())
884 .collect();
885 let sum_exp: f64 = exp_logits.iter().sum();
886 let true_label = labels[i];
887 if true_label >= exp_logits.len() {
888 return Err(MLError::InvalidInput(format!(
889 "Label {} out of bounds for {} classes",
890 true_label,
891 exp_logits.len()
892 )));
893 }
894 let prob = exp_logits[true_label] / sum_exp;
895 nll -= prob.max(1e-10).ln();
896 }
897 Ok(nll / n_samples as f64)
898 }
899 pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
901 if !self.fitted {
902 return Err(MLError::InvalidInput(
903 "Scaler must be fitted before transform".to_string(),
904 ));
905 }
906 let weights = self
907 .weights
908 .as_ref()
909 .ok_or_else(|| MLError::InvalidInput("Weights not initialized".to_string()))?;
910 let biases = self
911 .biases
912 .as_ref()
913 .ok_or_else(|| MLError::InvalidInput("Biases not initialized".to_string()))?;
914 let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
915 for i in 0..logits.nrows() {
916 let scaled_logits = logits.row(i).to_owned() * weights + biases;
917 let max_logit = scaled_logits
918 .iter()
919 .cloned()
920 .fold(f64::NEG_INFINITY, f64::max);
921 let exp_logits: Vec<f64> = scaled_logits
922 .iter()
923 .map(|&x| (x - max_logit).exp())
924 .collect();
925 let sum_exp: f64 = exp_logits.iter().sum();
926 for j in 0..logits.ncols() {
927 calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
928 }
929 }
930 Ok(calibrated_probs)
931 }
932 pub fn fit_transform(
934 &mut self,
935 logits: &Array2<f64>,
936 labels: &Array1<usize>,
937 ) -> Result<Array2<f64>> {
938 self.fit(logits, labels)?;
939 self.transform(logits)
940 }
941 pub fn parameters(&self) -> Option<(Array1<f64>, Array1<f64>)> {
943 if self.fitted {
944 Some((
945 self.weights.as_ref()?.clone(),
946 self.biases.as_ref()?.clone(),
947 ))
948 } else {
949 None
950 }
951 }
952}