quantrs2_ml/utils/calibration/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::error::{MLError, Result};
6use scirs2_core::ndarray::{Array1, Array2};
7
8/// Matrix Scaling - full affine transformation for maximum calibration flexibility
9/// Uses full weight matrix W and bias vector b: calibrated = softmax(W @ logits + b)
10/// More expressive than vector scaling but requires more data to avoid overfitting
11#[derive(Debug, Clone)]
12pub struct MatrixScaler {
13    /// Weight matrix (n_classes × n_classes)
14    weight_matrix: Option<Array2<f64>>,
15    /// Bias vector (n_classes)
16    bias_vector: Option<Array1<f64>>,
17    /// Whether the scaler has been fitted
18    fitted: bool,
19    /// Regularization strength (L2 penalty on off-diagonal elements)
20    regularization: f64,
21}
22impl MatrixScaler {
23    /// Create a new matrix scaler
24    pub fn new() -> Self {
25        Self {
26            weight_matrix: None,
27            bias_vector: None,
28            fitted: false,
29            regularization: 0.01,
30        }
31    }
32    /// Create matrix scaler with custom regularization
33    pub fn with_regularization(regularization: f64) -> Self {
34        Self {
35            weight_matrix: None,
36            bias_vector: None,
37            fitted: false,
38            regularization,
39        }
40    }
41    /// Fit the matrix scaler to logits and true labels
42    /// Uses gradient descent with L2 regularization on off-diagonal weights
43    pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
44        if logits.nrows() != labels.len() {
45            return Err(MLError::InvalidInput(
46                "Logits and labels must have same number of samples".to_string(),
47            ));
48        }
49        let n_samples = logits.nrows();
50        let n_classes = logits.ncols();
51        if n_samples < n_classes * 2 {
52            return Err(MLError::InvalidInput(format!(
53                "Need at least {} samples for {} classes (matrix calibration)",
54                n_classes * 2,
55                n_classes
56            )));
57        }
58        let mut weight_matrix = Array2::eye(n_classes);
59        let mut bias_vector = Array1::zeros(n_classes);
60        let learning_rate = 0.001;
61        let max_iter = 300;
62        let tolerance = 1e-7;
63        let mut prev_nll = f64::INFINITY;
64        for _iter in 0..max_iter {
65            let (nll, reg_term) =
66                self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_vector)?;
67            let total_loss = nll + reg_term;
68            if (prev_nll - total_loss).abs() < tolerance {
69                break;
70            }
71            prev_nll = total_loss;
72            let epsilon = 1e-6;
73            let mut weight_grads = Array2::zeros((n_classes, n_classes));
74            let mut bias_grads = Array1::zeros(n_classes);
75            for i in 0..n_classes {
76                for j in 0..n_classes {
77                    let mut weight_plus = weight_matrix.clone();
78                    weight_plus[(i, j)] += epsilon;
79                    let (nll_plus, reg_plus) =
80                        self.compute_nll_with_reg(logits, labels, &weight_plus, &bias_vector)?;
81                    weight_grads[(i, j)] = (nll_plus + reg_plus - total_loss) / epsilon;
82                }
83            }
84            for j in 0..n_classes {
85                let mut bias_plus = bias_vector.clone();
86                bias_plus[j] += epsilon;
87                let (nll_plus, reg_plus) =
88                    self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_plus)?;
89                bias_grads[j] = (nll_plus + reg_plus - total_loss) / epsilon;
90            }
91            weight_matrix = &weight_matrix - &weight_grads.mapv(|g| learning_rate * g);
92            bias_vector = &bias_vector - &bias_grads.mapv(|g| learning_rate * g);
93            for i in 0..n_classes {
94                weight_matrix[(i, i)] = weight_matrix[(i, i)].max(0.01);
95            }
96            let grad_norm = weight_grads.iter().map(|&g| g * g).sum::<f64>().sqrt()
97                + bias_grads.iter().map(|&g| g * g).sum::<f64>().sqrt();
98            if grad_norm < tolerance {
99                break;
100            }
101        }
102        self.weight_matrix = Some(weight_matrix);
103        self.bias_vector = Some(bias_vector);
104        self.fitted = true;
105        Ok(())
106    }
107    /// Compute NLL with L2 regularization on off-diagonal weights
108    fn compute_nll_with_reg(
109        &self,
110        logits: &Array2<f64>,
111        labels: &Array1<usize>,
112        weight_matrix: &Array2<f64>,
113        bias_vector: &Array1<f64>,
114    ) -> Result<(f64, f64)> {
115        let mut nll = 0.0;
116        let n_samples = logits.nrows();
117        let n_classes = logits.ncols();
118        for i in 0..n_samples {
119            let logits_row = logits.row(i);
120            let mut scaled_logits = Array1::zeros(n_classes);
121            for j in 0..n_classes {
122                let mut val = bias_vector[j];
123                for k in 0..n_classes {
124                    val += weight_matrix[(j, k)] * logits_row[k];
125                }
126                scaled_logits[j] = val;
127            }
128            let max_logit = scaled_logits
129                .iter()
130                .cloned()
131                .fold(f64::NEG_INFINITY, f64::max);
132            let exp_logits: Vec<f64> = scaled_logits
133                .iter()
134                .map(|&x| (x - max_logit).exp())
135                .collect();
136            let sum_exp: f64 = exp_logits.iter().sum();
137            let true_label = labels[i];
138            if true_label >= exp_logits.len() {
139                return Err(MLError::InvalidInput(format!(
140                    "Label {} out of bounds for {} classes",
141                    true_label,
142                    exp_logits.len()
143                )));
144            }
145            let prob = exp_logits[true_label] / sum_exp;
146            nll -= prob.max(1e-10).ln();
147        }
148        nll /= n_samples as f64;
149        let mut reg_term = 0.0;
150        for i in 0..n_classes {
151            for j in 0..n_classes {
152                if i != j {
153                    reg_term += weight_matrix[(i, j)].powi(2);
154                }
155            }
156        }
157        reg_term *= self.regularization;
158        Ok((nll, reg_term))
159    }
160    /// Transform logits to calibrated probabilities using matrix scaling
161    pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
162        if !self.fitted {
163            return Err(MLError::InvalidInput(
164                "Scaler must be fitted before transform".to_string(),
165            ));
166        }
167        let weight_matrix = self
168            .weight_matrix
169            .as_ref()
170            .ok_or_else(|| MLError::InvalidInput("Weight matrix not initialized".to_string()))?;
171        let bias_vector = self
172            .bias_vector
173            .as_ref()
174            .ok_or_else(|| MLError::InvalidInput("Bias vector not initialized".to_string()))?;
175        let n_classes = logits.ncols();
176        let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
177        for i in 0..logits.nrows() {
178            let logits_row = logits.row(i);
179            let mut scaled_logits = Array1::zeros(n_classes);
180            for j in 0..n_classes {
181                let mut val = bias_vector[j];
182                for k in 0..n_classes {
183                    val += weight_matrix[(j, k)] * logits_row[k];
184                }
185                scaled_logits[j] = val;
186            }
187            let max_logit = scaled_logits
188                .iter()
189                .cloned()
190                .fold(f64::NEG_INFINITY, f64::max);
191            let exp_logits: Vec<f64> = scaled_logits
192                .iter()
193                .map(|&x| (x - max_logit).exp())
194                .collect();
195            let sum_exp: f64 = exp_logits.iter().sum();
196            for j in 0..logits.ncols() {
197                calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
198            }
199        }
200        Ok(calibrated_probs)
201    }
202    /// Fit and transform in one step
203    pub fn fit_transform(
204        &mut self,
205        logits: &Array2<f64>,
206        labels: &Array1<usize>,
207    ) -> Result<Array2<f64>> {
208        self.fit(logits, labels)?;
209        self.transform(logits)
210    }
211    /// Get the fitted parameters (weight_matrix, bias_vector)
212    pub fn parameters(&self) -> Option<(Array2<f64>, Array1<f64>)> {
213        if self.fitted {
214            Some((
215                self.weight_matrix.as_ref()?.clone(),
216                self.bias_vector.as_ref()?.clone(),
217            ))
218        } else {
219            None
220        }
221    }
222    /// Get the condition number of the weight matrix (for diagnostics)
223    /// Higher values indicate potential numerical instability
224    pub fn condition_number(&self) -> Option<f64> {
225        if !self.fitted {
226            return None;
227        }
228        let w = self.weight_matrix.as_ref()?;
229        let norm = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
230        Some(norm)
231    }
232}
233/// Isotonic Regression - non-parametric calibration using monotonic transformation
234/// More flexible than Platt scaling but requires more data
235#[derive(Debug, Clone)]
236pub struct IsotonicRegression {
237    /// X values (decision scores)
238    x_thresholds: Vec<f64>,
239    /// Y values (calibrated probabilities)
240    y_thresholds: Vec<f64>,
241    /// Whether the regressor has been fitted
242    fitted: bool,
243}
244impl IsotonicRegression {
245    /// Create a new isotonic regression calibrator
246    pub fn new() -> Self {
247        Self {
248            x_thresholds: Vec::new(),
249            y_thresholds: Vec::new(),
250            fitted: false,
251        }
252    }
253    /// Fit isotonic regression to scores and labels
254    pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
255        if scores.len() != labels.len() {
256            return Err(MLError::InvalidInput(
257                "Scores and labels must have same length".to_string(),
258            ));
259        }
260        let n = scores.len();
261        if n < 2 {
262            return Err(MLError::InvalidInput(
263                "Need at least 2 samples for calibration".to_string(),
264            ));
265        }
266        let mut pairs: Vec<(f64, f64)> = scores
267            .iter()
268            .zip(labels.iter())
269            .map(|(&s, &l)| (s, l as f64))
270            .collect();
271        pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
272        let mut x = Vec::new();
273        let mut y = Vec::new();
274        let mut weights = Vec::new();
275        for (score, label) in pairs {
276            x.push(score);
277            y.push(label);
278            weights.push(1.0);
279        }
280        let mut i = 0;
281        while i < y.len() - 1 {
282            if y[i] > y[i + 1] {
283                let w1 = weights[i];
284                let w2 = weights[i + 1];
285                let total_weight = w1 + w2;
286                y[i] = (y[i] * w1 + y[i + 1] * w2) / total_weight;
287                weights[i] = total_weight;
288                y.remove(i + 1);
289                x.remove(i + 1);
290                weights.remove(i + 1);
291                if i > 0 {
292                    i -= 1;
293                }
294            } else {
295                i += 1;
296            }
297        }
298        self.x_thresholds = x;
299        self.y_thresholds = y;
300        self.fitted = true;
301        Ok(())
302    }
303    /// Transform decision scores to calibrated probabilities
304    pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
305        if !self.fitted {
306            return Err(MLError::InvalidInput(
307                "Regressor must be fitted before transform".to_string(),
308            ));
309        }
310        let mut calibrated = Array1::zeros(scores.len());
311        for (i, &score) in scores.iter().enumerate() {
312            let pos = self
313                .x_thresholds
314                .binary_search_by(|&x| x.partial_cmp(&score).unwrap_or(std::cmp::Ordering::Less))
315                .unwrap_or_else(|e| e);
316            if pos == 0 {
317                calibrated[i] = self.y_thresholds[0];
318            } else if pos >= self.x_thresholds.len() {
319                calibrated[i] = self.y_thresholds.last().copied().unwrap_or(0.0);
320            } else {
321                let x0 = self.x_thresholds[pos - 1];
322                let x1 = self.x_thresholds[pos];
323                let y0 = self.y_thresholds[pos - 1];
324                let y1 = self.y_thresholds[pos];
325                if (x1 - x0).abs() < 1e-10 {
326                    calibrated[i] = (y0 + y1) / 2.0;
327                } else {
328                    let alpha = (score - x0) / (x1 - x0);
329                    calibrated[i] = y0 + alpha * (y1 - y0);
330                }
331            }
332        }
333        Ok(calibrated)
334    }
335    /// Fit and transform in one step
336    pub fn fit_transform(
337        &mut self,
338        scores: &Array1<f64>,
339        labels: &Array1<usize>,
340    ) -> Result<Array1<f64>> {
341        self.fit(scores, labels)?;
342        self.transform(scores)
343    }
344}
345/// Bayesian Binning into Quantiles (BBQ) - sophisticated histogram-based calibration
346/// Bins predictions into quantiles and learns Bayesian posterior for each bin
347/// Uses Beta distribution for robust probability estimation with uncertainty quantification
348#[derive(Debug, Clone)]
349pub struct BayesianBinningQuantiles {
350    /// Number of bins
351    n_bins: usize,
352    /// Bin edges (quantile thresholds)
353    bin_edges: Option<Vec<f64>>,
354    /// Alpha parameters for Beta distribution in each bin
355    alphas: Option<Array1<f64>>,
356    /// Beta parameters for Beta distribution in each bin
357    betas: Option<Array1<f64>>,
358    /// Whether the calibrator has been fitted
359    fitted: bool,
360}
361impl BayesianBinningQuantiles {
362    /// Create a new BBQ calibrator with specified number of bins
363    pub fn new(n_bins: usize) -> Self {
364        Self {
365            n_bins,
366            bin_edges: None,
367            alphas: None,
368            betas: None,
369            fitted: false,
370        }
371    }
372    /// Fit the BBQ calibrator to probabilities and true labels
373    pub fn fit(&mut self, probabilities: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
374        if probabilities.len() != labels.len() {
375            return Err(MLError::InvalidInput(
376                "Probabilities and labels must have same length".to_string(),
377            ));
378        }
379        let n_samples = probabilities.len();
380        if n_samples < self.n_bins {
381            return Err(MLError::InvalidInput(format!(
382                "Need at least {} samples for {} bins, got {}",
383                self.n_bins, self.n_bins, n_samples
384            )));
385        }
386        let mut sorted_probs = probabilities.to_vec();
387        sorted_probs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
388        let mut bin_edges = vec![0.0];
389        for i in 1..self.n_bins {
390            let quantile_idx = (i as f64 / self.n_bins as f64 * n_samples as f64) as usize;
391            let quantile_idx = quantile_idx.min(sorted_probs.len() - 1);
392            bin_edges.push(sorted_probs[quantile_idx]);
393        }
394        bin_edges.push(1.0);
395        let mut bin_positives = vec![0.0; self.n_bins];
396        let mut bin_negatives = vec![0.0; self.n_bins];
397        for (i, &prob) in probabilities.iter().enumerate() {
398            let bin_idx = self.find_bin(&bin_edges, prob);
399            let label = labels[i];
400            if label == 1 {
401                bin_positives[bin_idx] += 1.0;
402            } else {
403                bin_negatives[bin_idx] += 1.0;
404            }
405        }
406        let prior_alpha = 0.5;
407        let prior_beta = 0.5;
408        let mut alphas = Array1::zeros(self.n_bins);
409        let mut betas = Array1::zeros(self.n_bins);
410        for i in 0..self.n_bins {
411            alphas[i] = prior_alpha + bin_positives[i];
412            betas[i] = prior_beta + bin_negatives[i];
413        }
414        self.bin_edges = Some(bin_edges);
415        self.alphas = Some(alphas);
416        self.betas = Some(betas);
417        self.fitted = true;
418        Ok(())
419    }
420    /// Find which bin a probability belongs to
421    fn find_bin(&self, bin_edges: &[f64], prob: f64) -> usize {
422        for i in 0..bin_edges.len() - 1 {
423            if prob >= bin_edges[i] && prob < bin_edges[i + 1] {
424                return i;
425            }
426        }
427        bin_edges.len() - 2
428    }
429    /// Transform probabilities to calibrated probabilities using BBQ
430    pub fn transform(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
431        if !self.fitted {
432            return Err(MLError::InvalidInput(
433                "Calibrator must be fitted before transform".to_string(),
434            ));
435        }
436        let bin_edges = self
437            .bin_edges
438            .as_ref()
439            .ok_or_else(|| MLError::InvalidInput("Bin edges not initialized".to_string()))?;
440        let alphas = self
441            .alphas
442            .as_ref()
443            .ok_or_else(|| MLError::InvalidInput("Alphas not initialized".to_string()))?;
444        let betas = self
445            .betas
446            .as_ref()
447            .ok_or_else(|| MLError::InvalidInput("Betas not initialized".to_string()))?;
448        let mut calibrated = Array1::zeros(probabilities.len());
449        for (i, &prob) in probabilities.iter().enumerate() {
450            let bin_idx = self.find_bin(bin_edges, prob);
451            let alpha = alphas[bin_idx];
452            let beta = betas[bin_idx];
453            calibrated[i] = alpha / (alpha + beta);
454        }
455        Ok(calibrated)
456    }
457    /// Fit and transform in one step
458    pub fn fit_transform(
459        &mut self,
460        probabilities: &Array1<f64>,
461        labels: &Array1<usize>,
462    ) -> Result<Array1<f64>> {
463        self.fit(probabilities, labels)?;
464        self.transform(probabilities)
465    }
466    /// Get calibrated probability with uncertainty bounds (credible interval)
467    /// Returns (mean, lower_bound, upper_bound) for given confidence level
468    pub fn predict_with_uncertainty(
469        &self,
470        probabilities: &Array1<f64>,
471        confidence: f64,
472    ) -> Result<Vec<(f64, f64, f64)>> {
473        if !self.fitted {
474            return Err(MLError::InvalidInput(
475                "Calibrator must be fitted before prediction".to_string(),
476            ));
477        }
478        if confidence <= 0.0 || confidence >= 1.0 {
479            return Err(MLError::InvalidInput(
480                "Confidence must be between 0 and 1".to_string(),
481            ));
482        }
483        let bin_edges = self
484            .bin_edges
485            .as_ref()
486            .ok_or_else(|| MLError::InvalidInput("Bin edges not initialized".to_string()))?;
487        let alphas = self
488            .alphas
489            .as_ref()
490            .ok_or_else(|| MLError::InvalidInput("Alphas not initialized".to_string()))?;
491        let betas = self
492            .betas
493            .as_ref()
494            .ok_or_else(|| MLError::InvalidInput("Betas not initialized".to_string()))?;
495        let lower_quantile = (1.0 - confidence) / 2.0;
496        let upper_quantile = 1.0 - lower_quantile;
497        let mut results = Vec::new();
498        for &prob in probabilities.iter() {
499            let bin_idx = self.find_bin(bin_edges, prob);
500            let alpha = alphas[bin_idx];
501            let beta = betas[bin_idx];
502            let mean = alpha / (alpha + beta);
503            let n = alpha + beta - 1.0;
504            let p = alpha / (alpha + beta);
505            if n > 0.0 {
506                let z = 1.96;
507                let denominator = 1.0 + z * z / n;
508                let center = (p + z * z / (2.0 * n)) / denominator;
509                let margin = z * (p * (1.0 - p) / n + z * z / (4.0 * n * n)).sqrt() / denominator;
510                let lower = (center - margin).max(0.0);
511                let upper = (center + margin).min(1.0);
512                results.push((mean, lower, upper));
513            } else {
514                results.push((mean, 0.0, 1.0));
515            }
516        }
517        Ok(results)
518    }
519    /// Get the number of bins
520    pub fn n_bins(&self) -> usize {
521        self.n_bins
522    }
523    /// Get bin statistics (edges, alphas, betas)
524    pub fn bin_statistics(&self) -> Option<(Vec<f64>, Array1<f64>, Array1<f64>)> {
525        if self.fitted {
526            Some((
527                self.bin_edges.as_ref()?.clone(),
528                self.alphas.as_ref()?.clone(),
529                self.betas.as_ref()?.clone(),
530            ))
531        } else {
532            None
533        }
534    }
535}
536/// Platt Scaling - fits a logistic regression on decision scores
537/// Calibrates binary classifier outputs to produce better probability estimates
538#[derive(Debug, Clone)]
539pub struct PlattScaler {
540    /// Slope parameter of logistic function
541    a: f64,
542    /// Intercept parameter of logistic function
543    b: f64,
544    /// Whether the scaler has been fitted
545    fitted: bool,
546}
547impl PlattScaler {
548    /// Create a new Platt scaler
549    pub fn new() -> Self {
550        Self {
551            a: 1.0,
552            b: 0.0,
553            fitted: false,
554        }
555    }
556    /// Fit the Platt scaler to decision scores and true labels
557    /// Uses maximum likelihood estimation to find optimal sigmoid parameters
558    pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
559        if scores.len() != labels.len() {
560            return Err(MLError::InvalidInput(
561                "Scores and labels must have same length".to_string(),
562            ));
563        }
564        let n = scores.len();
565        if n < 2 {
566            return Err(MLError::InvalidInput(
567                "Need at least 2 samples for calibration".to_string(),
568            ));
569        }
570        let y: Array1<f64> = labels
571            .iter()
572            .map(|&l| if l == 1 { 1.0 } else { -1.0 })
573            .collect();
574        let mut a = 0.0;
575        let mut b = 0.0;
576        let n_pos = labels.iter().filter(|&&l| l == 1).count() as f64;
577        let n_neg = n as f64 - n_pos;
578        let prior_pos = (n_pos + 1.0) / (n as f64 + 2.0);
579        b = (prior_pos / (1.0 - prior_pos)).ln();
580        for _ in 0..100 {
581            let mut fval = 0.0;
582            let mut fpp = 0.0;
583            for i in 0..n {
584                let fapb = scores[i] * a + b;
585                let p = 1.0 / (1.0 + (-fapb).exp());
586                let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
587                fval += scores[i] * (t - p);
588                fpp += scores[i] * scores[i] * p * (1.0 - p);
589            }
590            if fpp.abs() < 1e-12 {
591                break;
592            }
593            let delta = fval / fpp;
594            a += delta;
595            if delta.abs() < 1e-7 {
596                break;
597            }
598        }
599        for _ in 0..100 {
600            let mut fval = 0.0;
601            let mut fpp = 0.0;
602            for i in 0..n {
603                let fapb = scores[i] * a + b;
604                let p = 1.0 / (1.0 + (-fapb).exp());
605                let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
606                fval += t - p;
607                fpp += p * (1.0 - p);
608            }
609            if fpp.abs() < 1e-12 {
610                break;
611            }
612            let delta = fval / fpp;
613            b += delta;
614            if delta.abs() < 1e-7 {
615                break;
616            }
617        }
618        self.a = a;
619        self.b = b;
620        self.fitted = true;
621        Ok(())
622    }
623    /// Transform decision scores to calibrated probabilities
624    pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
625        if !self.fitted {
626            return Err(MLError::InvalidInput(
627                "Scaler must be fitted before transform".to_string(),
628            ));
629        }
630        let probs = scores.mapv(|s| {
631            let fapb = s * self.a + self.b;
632            1.0 / (1.0 + (-fapb).exp())
633        });
634        Ok(probs)
635    }
636    /// Fit and transform in one step
637    pub fn fit_transform(
638        &mut self,
639        scores: &Array1<f64>,
640        labels: &Array1<usize>,
641    ) -> Result<Array1<f64>> {
642        self.fit(scores, labels)?;
643        self.transform(scores)
644    }
645    /// Get the fitted parameters
646    pub fn parameters(&self) -> Option<(f64, f64)> {
647        if self.fitted {
648            Some((self.a, self.b))
649        } else {
650            None
651        }
652    }
653}
654/// Temperature Scaling - simple and effective multi-class calibration
655/// Scales logits by a single learned temperature parameter
656/// Particularly effective for neural network outputs
657#[derive(Debug, Clone)]
658pub struct TemperatureScaler {
659    /// Temperature parameter (T > 0)
660    temperature: f64,
661    /// Whether the scaler has been fitted
662    fitted: bool,
663}
664impl TemperatureScaler {
665    /// Create a new temperature scaler
666    pub fn new() -> Self {
667        Self {
668            temperature: 1.0,
669            fitted: false,
670        }
671    }
672    /// Fit the temperature scaler to logits and true labels
673    /// Uses negative log-likelihood minimization
674    pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
675        if logits.nrows() != labels.len() {
676            return Err(MLError::InvalidInput(
677                "Logits and labels must have same number of samples".to_string(),
678            ));
679        }
680        let n_samples = logits.nrows();
681        if n_samples < 2 {
682            return Err(MLError::InvalidInput(
683                "Need at least 2 samples for calibration".to_string(),
684            ));
685        }
686        let mut best_temp = 1.0;
687        let mut best_nll = f64::INFINITY;
688        for t_candidate in [0.1, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0] {
689            let nll = self.compute_nll(logits, labels, t_candidate)?;
690            if nll < best_nll {
691                best_nll = nll;
692                best_temp = t_candidate;
693            }
694        }
695        let mut temperature = best_temp;
696        let learning_rate = 0.01;
697        for _ in 0..100 {
698            let nll_current = self.compute_nll(logits, labels, temperature)?;
699            let nll_plus = self.compute_nll(logits, labels, temperature + 0.01)?;
700            let gradient = (nll_plus - nll_current) / 0.01;
701            let new_temp = temperature - learning_rate * gradient;
702            if new_temp <= 0.01 {
703                break;
704            }
705            temperature = new_temp;
706            if gradient.abs() < 1e-5 {
707                break;
708            }
709        }
710        self.temperature = temperature.max(0.01);
711        self.fitted = true;
712        Ok(())
713    }
714    /// Compute negative log-likelihood for given temperature
715    fn compute_nll(
716        &self,
717        logits: &Array2<f64>,
718        labels: &Array1<usize>,
719        temperature: f64,
720    ) -> Result<f64> {
721        let mut nll = 0.0;
722        let n_samples = logits.nrows();
723        for i in 0..n_samples {
724            let scaled_logits = logits.row(i).mapv(|x| x / temperature);
725            let max_logit = scaled_logits
726                .iter()
727                .cloned()
728                .fold(f64::NEG_INFINITY, f64::max);
729            let exp_logits: Vec<f64> = scaled_logits
730                .iter()
731                .map(|&x| (x - max_logit).exp())
732                .collect();
733            let sum_exp: f64 = exp_logits.iter().sum();
734            let true_label = labels[i];
735            if true_label >= exp_logits.len() {
736                return Err(MLError::InvalidInput(format!(
737                    "Label {} out of bounds for {} classes",
738                    true_label,
739                    exp_logits.len()
740                )));
741            }
742            let prob = exp_logits[true_label] / sum_exp;
743            nll -= prob.max(1e-10).ln();
744        }
745        Ok(nll / n_samples as f64)
746    }
747    /// Transform logits to calibrated probabilities using temperature scaling
748    pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
749        if !self.fitted {
750            return Err(MLError::InvalidInput(
751                "Scaler must be fitted before transform".to_string(),
752            ));
753        }
754        let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
755        for i in 0..logits.nrows() {
756            let scaled_logits = logits.row(i).mapv(|x| x / self.temperature);
757            let max_logit = scaled_logits
758                .iter()
759                .cloned()
760                .fold(f64::NEG_INFINITY, f64::max);
761            let exp_logits: Vec<f64> = scaled_logits
762                .iter()
763                .map(|&x| (x - max_logit).exp())
764                .collect();
765            let sum_exp: f64 = exp_logits.iter().sum();
766            for j in 0..logits.ncols() {
767                calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
768            }
769        }
770        Ok(calibrated_probs)
771    }
772    /// Fit and transform in one step
773    pub fn fit_transform(
774        &mut self,
775        logits: &Array2<f64>,
776        labels: &Array1<usize>,
777    ) -> Result<Array2<f64>> {
778        self.fit(logits, labels)?;
779        self.transform(logits)
780    }
781    /// Get the fitted temperature parameter
782    pub fn temperature(&self) -> Option<f64> {
783        if self.fitted {
784            Some(self.temperature)
785        } else {
786            None
787        }
788    }
789}
790/// Vector Scaling - extension of temperature scaling with class-specific parameters
791/// Uses diagonal weight matrix and bias vector for more flexible calibration
792/// Particularly effective when different classes have different calibration needs
793#[derive(Debug, Clone)]
794pub struct VectorScaler {
795    /// Diagonal weight matrix (one parameter per class)
796    weights: Option<Array1<f64>>,
797    /// Bias vector (one parameter per class)
798    biases: Option<Array1<f64>>,
799    /// Whether the scaler has been fitted
800    fitted: bool,
801}
802impl VectorScaler {
803    /// Create a new vector scaler
804    pub fn new() -> Self {
805        Self {
806            weights: None,
807            biases: None,
808            fitted: false,
809        }
810    }
811    /// Fit the vector scaler to logits and true labels
812    /// Uses negative log-likelihood minimization with L-BFGS-B optimization
813    pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
814        if logits.nrows() != labels.len() {
815            return Err(MLError::InvalidInput(
816                "Logits and labels must have same number of samples".to_string(),
817            ));
818        }
819        let n_samples = logits.nrows();
820        let n_classes = logits.ncols();
821        if n_samples < 2 {
822            return Err(MLError::InvalidInput(
823                "Need at least 2 samples for calibration".to_string(),
824            ));
825        }
826        let mut weights = Array1::ones(n_classes);
827        let mut biases = Array1::zeros(n_classes);
828        let learning_rate = 0.01;
829        let max_iter = 200;
830        let tolerance = 1e-6;
831        let mut prev_nll = f64::INFINITY;
832        for iter in 0..max_iter {
833            let nll = self.compute_nll_vec(logits, labels, &weights, &biases)?;
834            if (prev_nll - nll).abs() < tolerance {
835                break;
836            }
837            prev_nll = nll;
838            let epsilon = 1e-6;
839            let mut weight_grads = Array1::zeros(n_classes);
840            let mut bias_grads = Array1::zeros(n_classes);
841            for j in 0..n_classes {
842                let mut weights_plus = weights.clone();
843                weights_plus[j] += epsilon;
844                let nll_plus = self.compute_nll_vec(logits, labels, &weights_plus, &biases)?;
845                weight_grads[j] = (nll_plus - nll) / epsilon;
846                let mut biases_plus = biases.clone();
847                biases_plus[j] += epsilon;
848                let nll_plus = self.compute_nll_vec(logits, labels, &weights, &biases_plus)?;
849                bias_grads[j] = (nll_plus - nll) / epsilon;
850            }
851            weights = &weights - &weight_grads.mapv(|g| learning_rate * g);
852            biases = &biases - &bias_grads.mapv(|g| learning_rate * g);
853            weights.mapv_inplace(|w| w.max(0.01));
854            if weight_grads.iter().all(|&g| g.abs() < tolerance)
855                && bias_grads.iter().all(|&g| g.abs() < tolerance)
856            {
857                break;
858            }
859        }
860        self.weights = Some(weights);
861        self.biases = Some(biases);
862        self.fitted = true;
863        Ok(())
864    }
865    /// Compute negative log-likelihood for given weights and biases
866    fn compute_nll_vec(
867        &self,
868        logits: &Array2<f64>,
869        labels: &Array1<usize>,
870        weights: &Array1<f64>,
871        biases: &Array1<f64>,
872    ) -> Result<f64> {
873        let mut nll = 0.0;
874        let n_samples = logits.nrows();
875        for i in 0..n_samples {
876            let scaled_logits = logits.row(i).to_owned() * weights + biases;
877            let max_logit = scaled_logits
878                .iter()
879                .cloned()
880                .fold(f64::NEG_INFINITY, f64::max);
881            let exp_logits: Vec<f64> = scaled_logits
882                .iter()
883                .map(|&x| (x - max_logit).exp())
884                .collect();
885            let sum_exp: f64 = exp_logits.iter().sum();
886            let true_label = labels[i];
887            if true_label >= exp_logits.len() {
888                return Err(MLError::InvalidInput(format!(
889                    "Label {} out of bounds for {} classes",
890                    true_label,
891                    exp_logits.len()
892                )));
893            }
894            let prob = exp_logits[true_label] / sum_exp;
895            nll -= prob.max(1e-10).ln();
896        }
897        Ok(nll / n_samples as f64)
898    }
899    /// Transform logits to calibrated probabilities using vector scaling
900    pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
901        if !self.fitted {
902            return Err(MLError::InvalidInput(
903                "Scaler must be fitted before transform".to_string(),
904            ));
905        }
906        let weights = self
907            .weights
908            .as_ref()
909            .ok_or_else(|| MLError::InvalidInput("Weights not initialized".to_string()))?;
910        let biases = self
911            .biases
912            .as_ref()
913            .ok_or_else(|| MLError::InvalidInput("Biases not initialized".to_string()))?;
914        let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
915        for i in 0..logits.nrows() {
916            let scaled_logits = logits.row(i).to_owned() * weights + biases;
917            let max_logit = scaled_logits
918                .iter()
919                .cloned()
920                .fold(f64::NEG_INFINITY, f64::max);
921            let exp_logits: Vec<f64> = scaled_logits
922                .iter()
923                .map(|&x| (x - max_logit).exp())
924                .collect();
925            let sum_exp: f64 = exp_logits.iter().sum();
926            for j in 0..logits.ncols() {
927                calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
928            }
929        }
930        Ok(calibrated_probs)
931    }
932    /// Fit and transform in one step
933    pub fn fit_transform(
934        &mut self,
935        logits: &Array2<f64>,
936        labels: &Array1<usize>,
937    ) -> Result<Array2<f64>> {
938        self.fit(logits, labels)?;
939        self.transform(logits)
940    }
941    /// Get the fitted parameters (weights, biases)
942    pub fn parameters(&self) -> Option<(Array1<f64>, Array1<f64>)> {
943        if self.fitted {
944            Some((
945                self.weights.as_ref()?.clone(),
946                self.biases.as_ref()?.clone(),
947            ))
948        } else {
949            None
950        }
951    }
952}