quantrs2_ml/utils/calibration/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::error::{MLError, Result};
6use scirs2_core::ndarray::{Array1, Array2};
7
8/// Matrix Scaling - full affine transformation for maximum calibration flexibility
9/// Uses full weight matrix W and bias vector b: calibrated = softmax(W @ logits + b)
10/// More expressive than vector scaling but requires more data to avoid overfitting
11#[derive(Debug, Clone)]
12pub struct MatrixScaler {
13    /// Weight matrix (n_classes × n_classes)
14    weight_matrix: Option<Array2<f64>>,
15    /// Bias vector (n_classes)
16    bias_vector: Option<Array1<f64>>,
17    /// Whether the scaler has been fitted
18    fitted: bool,
19    /// Regularization strength (L2 penalty on off-diagonal elements)
20    regularization: f64,
21}
22impl MatrixScaler {
23    /// Create a new matrix scaler
24    pub fn new() -> Self {
25        Self {
26            weight_matrix: None,
27            bias_vector: None,
28            fitted: false,
29            regularization: 0.01,
30        }
31    }
32    /// Create matrix scaler with custom regularization
33    pub fn with_regularization(regularization: f64) -> Self {
34        Self {
35            weight_matrix: None,
36            bias_vector: None,
37            fitted: false,
38            regularization,
39        }
40    }
41    /// Fit the matrix scaler to logits and true labels
42    /// Uses gradient descent with L2 regularization on off-diagonal weights
43    pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
44        if logits.nrows() != labels.len() {
45            return Err(MLError::InvalidInput(
46                "Logits and labels must have same number of samples".to_string(),
47            ));
48        }
49        let n_samples = logits.nrows();
50        let n_classes = logits.ncols();
51        if n_samples < n_classes * 2 {
52            return Err(MLError::InvalidInput(format!(
53                "Need at least {} samples for {} classes (matrix calibration)",
54                n_classes * 2,
55                n_classes
56            )));
57        }
58        let mut weight_matrix = Array2::eye(n_classes);
59        let mut bias_vector = Array1::zeros(n_classes);
60        let learning_rate = 0.001;
61        let max_iter = 300;
62        let tolerance = 1e-7;
63        let mut prev_nll = f64::INFINITY;
64        for _iter in 0..max_iter {
65            let (nll, reg_term) =
66                self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_vector)?;
67            let total_loss = nll + reg_term;
68            if (prev_nll - total_loss).abs() < tolerance {
69                break;
70            }
71            prev_nll = total_loss;
72            let epsilon = 1e-6;
73            let mut weight_grads = Array2::zeros((n_classes, n_classes));
74            let mut bias_grads = Array1::zeros(n_classes);
75            for i in 0..n_classes {
76                for j in 0..n_classes {
77                    let mut weight_plus = weight_matrix.clone();
78                    weight_plus[(i, j)] += epsilon;
79                    let (nll_plus, reg_plus) =
80                        self.compute_nll_with_reg(logits, labels, &weight_plus, &bias_vector)?;
81                    weight_grads[(i, j)] = (nll_plus + reg_plus - total_loss) / epsilon;
82                }
83            }
84            for j in 0..n_classes {
85                let mut bias_plus = bias_vector.clone();
86                bias_plus[j] += epsilon;
87                let (nll_plus, reg_plus) =
88                    self.compute_nll_with_reg(logits, labels, &weight_matrix, &bias_plus)?;
89                bias_grads[j] = (nll_plus + reg_plus - total_loss) / epsilon;
90            }
91            weight_matrix = &weight_matrix - &weight_grads.mapv(|g| learning_rate * g);
92            bias_vector = &bias_vector - &bias_grads.mapv(|g| learning_rate * g);
93            for i in 0..n_classes {
94                weight_matrix[(i, i)] = weight_matrix[(i, i)].max(0.01);
95            }
96            let grad_norm = weight_grads.iter().map(|&g| g * g).sum::<f64>().sqrt()
97                + bias_grads.iter().map(|&g| g * g).sum::<f64>().sqrt();
98            if grad_norm < tolerance {
99                break;
100            }
101        }
102        self.weight_matrix = Some(weight_matrix);
103        self.bias_vector = Some(bias_vector);
104        self.fitted = true;
105        Ok(())
106    }
107    /// Compute NLL with L2 regularization on off-diagonal weights
108    fn compute_nll_with_reg(
109        &self,
110        logits: &Array2<f64>,
111        labels: &Array1<usize>,
112        weight_matrix: &Array2<f64>,
113        bias_vector: &Array1<f64>,
114    ) -> Result<(f64, f64)> {
115        let mut nll = 0.0;
116        let n_samples = logits.nrows();
117        let n_classes = logits.ncols();
118        for i in 0..n_samples {
119            let logits_row = logits.row(i);
120            let mut scaled_logits = Array1::zeros(n_classes);
121            for j in 0..n_classes {
122                let mut val = bias_vector[j];
123                for k in 0..n_classes {
124                    val += weight_matrix[(j, k)] * logits_row[k];
125                }
126                scaled_logits[j] = val;
127            }
128            let max_logit = scaled_logits
129                .iter()
130                .cloned()
131                .fold(f64::NEG_INFINITY, f64::max);
132            let exp_logits: Vec<f64> = scaled_logits
133                .iter()
134                .map(|&x| (x - max_logit).exp())
135                .collect();
136            let sum_exp: f64 = exp_logits.iter().sum();
137            let true_label = labels[i];
138            if true_label >= exp_logits.len() {
139                return Err(MLError::InvalidInput(format!(
140                    "Label {} out of bounds for {} classes",
141                    true_label,
142                    exp_logits.len()
143                )));
144            }
145            let prob = exp_logits[true_label] / sum_exp;
146            nll -= prob.max(1e-10).ln();
147        }
148        nll /= n_samples as f64;
149        let mut reg_term = 0.0;
150        for i in 0..n_classes {
151            for j in 0..n_classes {
152                if i != j {
153                    reg_term += weight_matrix[(i, j)].powi(2);
154                }
155            }
156        }
157        reg_term *= self.regularization;
158        Ok((nll, reg_term))
159    }
160    /// Transform logits to calibrated probabilities using matrix scaling
161    pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
162        if !self.fitted {
163            return Err(MLError::InvalidInput(
164                "Scaler must be fitted before transform".to_string(),
165            ));
166        }
167        let weight_matrix = self.weight_matrix.as_ref().unwrap();
168        let bias_vector = self.bias_vector.as_ref().unwrap();
169        let n_classes = logits.ncols();
170        let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
171        for i in 0..logits.nrows() {
172            let logits_row = logits.row(i);
173            let mut scaled_logits = Array1::zeros(n_classes);
174            for j in 0..n_classes {
175                let mut val = bias_vector[j];
176                for k in 0..n_classes {
177                    val += weight_matrix[(j, k)] * logits_row[k];
178                }
179                scaled_logits[j] = val;
180            }
181            let max_logit = scaled_logits
182                .iter()
183                .cloned()
184                .fold(f64::NEG_INFINITY, f64::max);
185            let exp_logits: Vec<f64> = scaled_logits
186                .iter()
187                .map(|&x| (x - max_logit).exp())
188                .collect();
189            let sum_exp: f64 = exp_logits.iter().sum();
190            for j in 0..logits.ncols() {
191                calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
192            }
193        }
194        Ok(calibrated_probs)
195    }
196    /// Fit and transform in one step
197    pub fn fit_transform(
198        &mut self,
199        logits: &Array2<f64>,
200        labels: &Array1<usize>,
201    ) -> Result<Array2<f64>> {
202        self.fit(logits, labels)?;
203        self.transform(logits)
204    }
205    /// Get the fitted parameters (weight_matrix, bias_vector)
206    pub fn parameters(&self) -> Option<(Array2<f64>, Array1<f64>)> {
207        if self.fitted {
208            Some((
209                self.weight_matrix.as_ref().unwrap().clone(),
210                self.bias_vector.as_ref().unwrap().clone(),
211            ))
212        } else {
213            None
214        }
215    }
216    /// Get the condition number of the weight matrix (for diagnostics)
217    /// Higher values indicate potential numerical instability
218    pub fn condition_number(&self) -> Option<f64> {
219        if !self.fitted {
220            return None;
221        }
222        let w = self.weight_matrix.as_ref().unwrap();
223        let norm = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
224        Some(norm)
225    }
226}
227/// Isotonic Regression - non-parametric calibration using monotonic transformation
228/// More flexible than Platt scaling but requires more data
229#[derive(Debug, Clone)]
230pub struct IsotonicRegression {
231    /// X values (decision scores)
232    x_thresholds: Vec<f64>,
233    /// Y values (calibrated probabilities)
234    y_thresholds: Vec<f64>,
235    /// Whether the regressor has been fitted
236    fitted: bool,
237}
238impl IsotonicRegression {
239    /// Create a new isotonic regression calibrator
240    pub fn new() -> Self {
241        Self {
242            x_thresholds: Vec::new(),
243            y_thresholds: Vec::new(),
244            fitted: false,
245        }
246    }
247    /// Fit isotonic regression to scores and labels
248    pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
249        if scores.len() != labels.len() {
250            return Err(MLError::InvalidInput(
251                "Scores and labels must have same length".to_string(),
252            ));
253        }
254        let n = scores.len();
255        if n < 2 {
256            return Err(MLError::InvalidInput(
257                "Need at least 2 samples for calibration".to_string(),
258            ));
259        }
260        let mut pairs: Vec<(f64, f64)> = scores
261            .iter()
262            .zip(labels.iter())
263            .map(|(&s, &l)| (s, l as f64))
264            .collect();
265        pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
266        let mut x = Vec::new();
267        let mut y = Vec::new();
268        let mut weights = Vec::new();
269        for (score, label) in pairs {
270            x.push(score);
271            y.push(label);
272            weights.push(1.0);
273        }
274        let mut i = 0;
275        while i < y.len() - 1 {
276            if y[i] > y[i + 1] {
277                let w1 = weights[i];
278                let w2 = weights[i + 1];
279                let total_weight = w1 + w2;
280                y[i] = (y[i] * w1 + y[i + 1] * w2) / total_weight;
281                weights[i] = total_weight;
282                y.remove(i + 1);
283                x.remove(i + 1);
284                weights.remove(i + 1);
285                if i > 0 {
286                    i -= 1;
287                }
288            } else {
289                i += 1;
290            }
291        }
292        self.x_thresholds = x;
293        self.y_thresholds = y;
294        self.fitted = true;
295        Ok(())
296    }
297    /// Transform decision scores to calibrated probabilities
298    pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
299        if !self.fitted {
300            return Err(MLError::InvalidInput(
301                "Regressor must be fitted before transform".to_string(),
302            ));
303        }
304        let mut calibrated = Array1::zeros(scores.len());
305        for (i, &score) in scores.iter().enumerate() {
306            let pos = self
307                .x_thresholds
308                .binary_search_by(|&x| x.partial_cmp(&score).unwrap_or(std::cmp::Ordering::Less))
309                .unwrap_or_else(|e| e);
310            if pos == 0 {
311                calibrated[i] = self.y_thresholds[0];
312            } else if pos >= self.x_thresholds.len() {
313                calibrated[i] = *self.y_thresholds.last().unwrap();
314            } else {
315                let x0 = self.x_thresholds[pos - 1];
316                let x1 = self.x_thresholds[pos];
317                let y0 = self.y_thresholds[pos - 1];
318                let y1 = self.y_thresholds[pos];
319                if (x1 - x0).abs() < 1e-10 {
320                    calibrated[i] = (y0 + y1) / 2.0;
321                } else {
322                    let alpha = (score - x0) / (x1 - x0);
323                    calibrated[i] = y0 + alpha * (y1 - y0);
324                }
325            }
326        }
327        Ok(calibrated)
328    }
329    /// Fit and transform in one step
330    pub fn fit_transform(
331        &mut self,
332        scores: &Array1<f64>,
333        labels: &Array1<usize>,
334    ) -> Result<Array1<f64>> {
335        self.fit(scores, labels)?;
336        self.transform(scores)
337    }
338}
339/// Bayesian Binning into Quantiles (BBQ) - sophisticated histogram-based calibration
340/// Bins predictions into quantiles and learns Bayesian posterior for each bin
341/// Uses Beta distribution for robust probability estimation with uncertainty quantification
342#[derive(Debug, Clone)]
343pub struct BayesianBinningQuantiles {
344    /// Number of bins
345    n_bins: usize,
346    /// Bin edges (quantile thresholds)
347    bin_edges: Option<Vec<f64>>,
348    /// Alpha parameters for Beta distribution in each bin
349    alphas: Option<Array1<f64>>,
350    /// Beta parameters for Beta distribution in each bin
351    betas: Option<Array1<f64>>,
352    /// Whether the calibrator has been fitted
353    fitted: bool,
354}
355impl BayesianBinningQuantiles {
356    /// Create a new BBQ calibrator with specified number of bins
357    pub fn new(n_bins: usize) -> Self {
358        Self {
359            n_bins,
360            bin_edges: None,
361            alphas: None,
362            betas: None,
363            fitted: false,
364        }
365    }
366    /// Fit the BBQ calibrator to probabilities and true labels
367    pub fn fit(&mut self, probabilities: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
368        if probabilities.len() != labels.len() {
369            return Err(MLError::InvalidInput(
370                "Probabilities and labels must have same length".to_string(),
371            ));
372        }
373        let n_samples = probabilities.len();
374        if n_samples < self.n_bins {
375            return Err(MLError::InvalidInput(format!(
376                "Need at least {} samples for {} bins, got {}",
377                self.n_bins, self.n_bins, n_samples
378            )));
379        }
380        let mut sorted_probs = probabilities.to_vec();
381        sorted_probs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
382        let mut bin_edges = vec![0.0];
383        for i in 1..self.n_bins {
384            let quantile_idx = (i as f64 / self.n_bins as f64 * n_samples as f64) as usize;
385            let quantile_idx = quantile_idx.min(sorted_probs.len() - 1);
386            bin_edges.push(sorted_probs[quantile_idx]);
387        }
388        bin_edges.push(1.0);
389        let mut bin_positives = vec![0.0; self.n_bins];
390        let mut bin_negatives = vec![0.0; self.n_bins];
391        for (i, &prob) in probabilities.iter().enumerate() {
392            let bin_idx = self.find_bin(&bin_edges, prob);
393            let label = labels[i];
394            if label == 1 {
395                bin_positives[bin_idx] += 1.0;
396            } else {
397                bin_negatives[bin_idx] += 1.0;
398            }
399        }
400        let prior_alpha = 0.5;
401        let prior_beta = 0.5;
402        let mut alphas = Array1::zeros(self.n_bins);
403        let mut betas = Array1::zeros(self.n_bins);
404        for i in 0..self.n_bins {
405            alphas[i] = prior_alpha + bin_positives[i];
406            betas[i] = prior_beta + bin_negatives[i];
407        }
408        self.bin_edges = Some(bin_edges);
409        self.alphas = Some(alphas);
410        self.betas = Some(betas);
411        self.fitted = true;
412        Ok(())
413    }
414    /// Find which bin a probability belongs to
415    fn find_bin(&self, bin_edges: &[f64], prob: f64) -> usize {
416        for i in 0..bin_edges.len() - 1 {
417            if prob >= bin_edges[i] && prob < bin_edges[i + 1] {
418                return i;
419            }
420        }
421        bin_edges.len() - 2
422    }
423    /// Transform probabilities to calibrated probabilities using BBQ
424    pub fn transform(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
425        if !self.fitted {
426            return Err(MLError::InvalidInput(
427                "Calibrator must be fitted before transform".to_string(),
428            ));
429        }
430        let bin_edges = self.bin_edges.as_ref().unwrap();
431        let alphas = self.alphas.as_ref().unwrap();
432        let betas = self.betas.as_ref().unwrap();
433        let mut calibrated = Array1::zeros(probabilities.len());
434        for (i, &prob) in probabilities.iter().enumerate() {
435            let bin_idx = self.find_bin(bin_edges, prob);
436            let alpha = alphas[bin_idx];
437            let beta = betas[bin_idx];
438            calibrated[i] = alpha / (alpha + beta);
439        }
440        Ok(calibrated)
441    }
442    /// Fit and transform in one step
443    pub fn fit_transform(
444        &mut self,
445        probabilities: &Array1<f64>,
446        labels: &Array1<usize>,
447    ) -> Result<Array1<f64>> {
448        self.fit(probabilities, labels)?;
449        self.transform(probabilities)
450    }
451    /// Get calibrated probability with uncertainty bounds (credible interval)
452    /// Returns (mean, lower_bound, upper_bound) for given confidence level
453    pub fn predict_with_uncertainty(
454        &self,
455        probabilities: &Array1<f64>,
456        confidence: f64,
457    ) -> Result<Vec<(f64, f64, f64)>> {
458        if !self.fitted {
459            return Err(MLError::InvalidInput(
460                "Calibrator must be fitted before prediction".to_string(),
461            ));
462        }
463        if confidence <= 0.0 || confidence >= 1.0 {
464            return Err(MLError::InvalidInput(
465                "Confidence must be between 0 and 1".to_string(),
466            ));
467        }
468        let bin_edges = self.bin_edges.as_ref().unwrap();
469        let alphas = self.alphas.as_ref().unwrap();
470        let betas = self.betas.as_ref().unwrap();
471        let lower_quantile = (1.0 - confidence) / 2.0;
472        let upper_quantile = 1.0 - lower_quantile;
473        let mut results = Vec::new();
474        for &prob in probabilities.iter() {
475            let bin_idx = self.find_bin(bin_edges, prob);
476            let alpha = alphas[bin_idx];
477            let beta = betas[bin_idx];
478            let mean = alpha / (alpha + beta);
479            let n = alpha + beta - 1.0;
480            let p = alpha / (alpha + beta);
481            if n > 0.0 {
482                let z = 1.96;
483                let denominator = 1.0 + z * z / n;
484                let center = (p + z * z / (2.0 * n)) / denominator;
485                let margin = z * (p * (1.0 - p) / n + z * z / (4.0 * n * n)).sqrt() / denominator;
486                let lower = (center - margin).max(0.0);
487                let upper = (center + margin).min(1.0);
488                results.push((mean, lower, upper));
489            } else {
490                results.push((mean, 0.0, 1.0));
491            }
492        }
493        Ok(results)
494    }
495    /// Get the number of bins
496    pub fn n_bins(&self) -> usize {
497        self.n_bins
498    }
499    /// Get bin statistics (edges, alphas, betas)
500    pub fn bin_statistics(&self) -> Option<(Vec<f64>, Array1<f64>, Array1<f64>)> {
501        if self.fitted {
502            Some((
503                self.bin_edges.as_ref().unwrap().clone(),
504                self.alphas.as_ref().unwrap().clone(),
505                self.betas.as_ref().unwrap().clone(),
506            ))
507        } else {
508            None
509        }
510    }
511}
512/// Platt Scaling - fits a logistic regression on decision scores
513/// Calibrates binary classifier outputs to produce better probability estimates
514#[derive(Debug, Clone)]
515pub struct PlattScaler {
516    /// Slope parameter of logistic function
517    a: f64,
518    /// Intercept parameter of logistic function
519    b: f64,
520    /// Whether the scaler has been fitted
521    fitted: bool,
522}
523impl PlattScaler {
524    /// Create a new Platt scaler
525    pub fn new() -> Self {
526        Self {
527            a: 1.0,
528            b: 0.0,
529            fitted: false,
530        }
531    }
532    /// Fit the Platt scaler to decision scores and true labels
533    /// Uses maximum likelihood estimation to find optimal sigmoid parameters
534    pub fn fit(&mut self, scores: &Array1<f64>, labels: &Array1<usize>) -> Result<()> {
535        if scores.len() != labels.len() {
536            return Err(MLError::InvalidInput(
537                "Scores and labels must have same length".to_string(),
538            ));
539        }
540        let n = scores.len();
541        if n < 2 {
542            return Err(MLError::InvalidInput(
543                "Need at least 2 samples for calibration".to_string(),
544            ));
545        }
546        let y: Array1<f64> = labels
547            .iter()
548            .map(|&l| if l == 1 { 1.0 } else { -1.0 })
549            .collect();
550        let mut a = 0.0;
551        let mut b = 0.0;
552        let n_pos = labels.iter().filter(|&&l| l == 1).count() as f64;
553        let n_neg = n as f64 - n_pos;
554        let prior_pos = (n_pos + 1.0) / (n as f64 + 2.0);
555        b = (prior_pos / (1.0 - prior_pos)).ln();
556        for _ in 0..100 {
557            let mut fval = 0.0;
558            let mut fpp = 0.0;
559            for i in 0..n {
560                let fapb = scores[i] * a + b;
561                let p = 1.0 / (1.0 + (-fapb).exp());
562                let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
563                fval += scores[i] * (t - p);
564                fpp += scores[i] * scores[i] * p * (1.0 - p);
565            }
566            if fpp.abs() < 1e-12 {
567                break;
568            }
569            let delta = fval / fpp;
570            a += delta;
571            if delta.abs() < 1e-7 {
572                break;
573            }
574        }
575        for _ in 0..100 {
576            let mut fval = 0.0;
577            let mut fpp = 0.0;
578            for i in 0..n {
579                let fapb = scores[i] * a + b;
580                let p = 1.0 / (1.0 + (-fapb).exp());
581                let t = if y[i] > 0.0 { 1.0 } else { 0.0 };
582                fval += t - p;
583                fpp += p * (1.0 - p);
584            }
585            if fpp.abs() < 1e-12 {
586                break;
587            }
588            let delta = fval / fpp;
589            b += delta;
590            if delta.abs() < 1e-7 {
591                break;
592            }
593        }
594        self.a = a;
595        self.b = b;
596        self.fitted = true;
597        Ok(())
598    }
599    /// Transform decision scores to calibrated probabilities
600    pub fn transform(&self, scores: &Array1<f64>) -> Result<Array1<f64>> {
601        if !self.fitted {
602            return Err(MLError::InvalidInput(
603                "Scaler must be fitted before transform".to_string(),
604            ));
605        }
606        let probs = scores.mapv(|s| {
607            let fapb = s * self.a + self.b;
608            1.0 / (1.0 + (-fapb).exp())
609        });
610        Ok(probs)
611    }
612    /// Fit and transform in one step
613    pub fn fit_transform(
614        &mut self,
615        scores: &Array1<f64>,
616        labels: &Array1<usize>,
617    ) -> Result<Array1<f64>> {
618        self.fit(scores, labels)?;
619        self.transform(scores)
620    }
621    /// Get the fitted parameters
622    pub fn parameters(&self) -> Option<(f64, f64)> {
623        if self.fitted {
624            Some((self.a, self.b))
625        } else {
626            None
627        }
628    }
629}
630/// Temperature Scaling - simple and effective multi-class calibration
631/// Scales logits by a single learned temperature parameter
632/// Particularly effective for neural network outputs
633#[derive(Debug, Clone)]
634pub struct TemperatureScaler {
635    /// Temperature parameter (T > 0)
636    temperature: f64,
637    /// Whether the scaler has been fitted
638    fitted: bool,
639}
640impl TemperatureScaler {
641    /// Create a new temperature scaler
642    pub fn new() -> Self {
643        Self {
644            temperature: 1.0,
645            fitted: false,
646        }
647    }
648    /// Fit the temperature scaler to logits and true labels
649    /// Uses negative log-likelihood minimization
650    pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
651        if logits.nrows() != labels.len() {
652            return Err(MLError::InvalidInput(
653                "Logits and labels must have same number of samples".to_string(),
654            ));
655        }
656        let n_samples = logits.nrows();
657        if n_samples < 2 {
658            return Err(MLError::InvalidInput(
659                "Need at least 2 samples for calibration".to_string(),
660            ));
661        }
662        let mut best_temp = 1.0;
663        let mut best_nll = f64::INFINITY;
664        for t_candidate in [0.1, 0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0] {
665            let nll = self.compute_nll(logits, labels, t_candidate)?;
666            if nll < best_nll {
667                best_nll = nll;
668                best_temp = t_candidate;
669            }
670        }
671        let mut temperature = best_temp;
672        let learning_rate = 0.01;
673        for _ in 0..100 {
674            let nll_current = self.compute_nll(logits, labels, temperature)?;
675            let nll_plus = self.compute_nll(logits, labels, temperature + 0.01)?;
676            let gradient = (nll_plus - nll_current) / 0.01;
677            let new_temp = temperature - learning_rate * gradient;
678            if new_temp <= 0.01 {
679                break;
680            }
681            temperature = new_temp;
682            if gradient.abs() < 1e-5 {
683                break;
684            }
685        }
686        self.temperature = temperature.max(0.01);
687        self.fitted = true;
688        Ok(())
689    }
690    /// Compute negative log-likelihood for given temperature
691    fn compute_nll(
692        &self,
693        logits: &Array2<f64>,
694        labels: &Array1<usize>,
695        temperature: f64,
696    ) -> Result<f64> {
697        let mut nll = 0.0;
698        let n_samples = logits.nrows();
699        for i in 0..n_samples {
700            let scaled_logits = logits.row(i).mapv(|x| x / temperature);
701            let max_logit = scaled_logits
702                .iter()
703                .cloned()
704                .fold(f64::NEG_INFINITY, f64::max);
705            let exp_logits: Vec<f64> = scaled_logits
706                .iter()
707                .map(|&x| (x - max_logit).exp())
708                .collect();
709            let sum_exp: f64 = exp_logits.iter().sum();
710            let true_label = labels[i];
711            if true_label >= exp_logits.len() {
712                return Err(MLError::InvalidInput(format!(
713                    "Label {} out of bounds for {} classes",
714                    true_label,
715                    exp_logits.len()
716                )));
717            }
718            let prob = exp_logits[true_label] / sum_exp;
719            nll -= prob.max(1e-10).ln();
720        }
721        Ok(nll / n_samples as f64)
722    }
723    /// Transform logits to calibrated probabilities using temperature scaling
724    pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
725        if !self.fitted {
726            return Err(MLError::InvalidInput(
727                "Scaler must be fitted before transform".to_string(),
728            ));
729        }
730        let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
731        for i in 0..logits.nrows() {
732            let scaled_logits = logits.row(i).mapv(|x| x / self.temperature);
733            let max_logit = scaled_logits
734                .iter()
735                .cloned()
736                .fold(f64::NEG_INFINITY, f64::max);
737            let exp_logits: Vec<f64> = scaled_logits
738                .iter()
739                .map(|&x| (x - max_logit).exp())
740                .collect();
741            let sum_exp: f64 = exp_logits.iter().sum();
742            for j in 0..logits.ncols() {
743                calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
744            }
745        }
746        Ok(calibrated_probs)
747    }
748    /// Fit and transform in one step
749    pub fn fit_transform(
750        &mut self,
751        logits: &Array2<f64>,
752        labels: &Array1<usize>,
753    ) -> Result<Array2<f64>> {
754        self.fit(logits, labels)?;
755        self.transform(logits)
756    }
757    /// Get the fitted temperature parameter
758    pub fn temperature(&self) -> Option<f64> {
759        if self.fitted {
760            Some(self.temperature)
761        } else {
762            None
763        }
764    }
765}
766/// Vector Scaling - extension of temperature scaling with class-specific parameters
767/// Uses diagonal weight matrix and bias vector for more flexible calibration
768/// Particularly effective when different classes have different calibration needs
769#[derive(Debug, Clone)]
770pub struct VectorScaler {
771    /// Diagonal weight matrix (one parameter per class)
772    weights: Option<Array1<f64>>,
773    /// Bias vector (one parameter per class)
774    biases: Option<Array1<f64>>,
775    /// Whether the scaler has been fitted
776    fitted: bool,
777}
778impl VectorScaler {
779    /// Create a new vector scaler
780    pub fn new() -> Self {
781        Self {
782            weights: None,
783            biases: None,
784            fitted: false,
785        }
786    }
787    /// Fit the vector scaler to logits and true labels
788    /// Uses negative log-likelihood minimization with L-BFGS-B optimization
789    pub fn fit(&mut self, logits: &Array2<f64>, labels: &Array1<usize>) -> Result<()> {
790        if logits.nrows() != labels.len() {
791            return Err(MLError::InvalidInput(
792                "Logits and labels must have same number of samples".to_string(),
793            ));
794        }
795        let n_samples = logits.nrows();
796        let n_classes = logits.ncols();
797        if n_samples < 2 {
798            return Err(MLError::InvalidInput(
799                "Need at least 2 samples for calibration".to_string(),
800            ));
801        }
802        let mut weights = Array1::ones(n_classes);
803        let mut biases = Array1::zeros(n_classes);
804        let learning_rate = 0.01;
805        let max_iter = 200;
806        let tolerance = 1e-6;
807        let mut prev_nll = f64::INFINITY;
808        for iter in 0..max_iter {
809            let nll = self.compute_nll_vec(logits, labels, &weights, &biases)?;
810            if (prev_nll - nll).abs() < tolerance {
811                break;
812            }
813            prev_nll = nll;
814            let epsilon = 1e-6;
815            let mut weight_grads = Array1::zeros(n_classes);
816            let mut bias_grads = Array1::zeros(n_classes);
817            for j in 0..n_classes {
818                let mut weights_plus = weights.clone();
819                weights_plus[j] += epsilon;
820                let nll_plus = self.compute_nll_vec(logits, labels, &weights_plus, &biases)?;
821                weight_grads[j] = (nll_plus - nll) / epsilon;
822                let mut biases_plus = biases.clone();
823                biases_plus[j] += epsilon;
824                let nll_plus = self.compute_nll_vec(logits, labels, &weights, &biases_plus)?;
825                bias_grads[j] = (nll_plus - nll) / epsilon;
826            }
827            weights = &weights - &weight_grads.mapv(|g| learning_rate * g);
828            biases = &biases - &bias_grads.mapv(|g| learning_rate * g);
829            weights.mapv_inplace(|w| w.max(0.01));
830            if weight_grads.iter().all(|&g| g.abs() < tolerance)
831                && bias_grads.iter().all(|&g| g.abs() < tolerance)
832            {
833                break;
834            }
835        }
836        self.weights = Some(weights);
837        self.biases = Some(biases);
838        self.fitted = true;
839        Ok(())
840    }
841    /// Compute negative log-likelihood for given weights and biases
842    fn compute_nll_vec(
843        &self,
844        logits: &Array2<f64>,
845        labels: &Array1<usize>,
846        weights: &Array1<f64>,
847        biases: &Array1<f64>,
848    ) -> Result<f64> {
849        let mut nll = 0.0;
850        let n_samples = logits.nrows();
851        for i in 0..n_samples {
852            let scaled_logits = logits.row(i).to_owned() * weights + biases;
853            let max_logit = scaled_logits
854                .iter()
855                .cloned()
856                .fold(f64::NEG_INFINITY, f64::max);
857            let exp_logits: Vec<f64> = scaled_logits
858                .iter()
859                .map(|&x| (x - max_logit).exp())
860                .collect();
861            let sum_exp: f64 = exp_logits.iter().sum();
862            let true_label = labels[i];
863            if true_label >= exp_logits.len() {
864                return Err(MLError::InvalidInput(format!(
865                    "Label {} out of bounds for {} classes",
866                    true_label,
867                    exp_logits.len()
868                )));
869            }
870            let prob = exp_logits[true_label] / sum_exp;
871            nll -= prob.max(1e-10).ln();
872        }
873        Ok(nll / n_samples as f64)
874    }
875    /// Transform logits to calibrated probabilities using vector scaling
876    pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
877        if !self.fitted {
878            return Err(MLError::InvalidInput(
879                "Scaler must be fitted before transform".to_string(),
880            ));
881        }
882        let weights = self.weights.as_ref().unwrap();
883        let biases = self.biases.as_ref().unwrap();
884        let mut calibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
885        for i in 0..logits.nrows() {
886            let scaled_logits = logits.row(i).to_owned() * weights + biases;
887            let max_logit = scaled_logits
888                .iter()
889                .cloned()
890                .fold(f64::NEG_INFINITY, f64::max);
891            let exp_logits: Vec<f64> = scaled_logits
892                .iter()
893                .map(|&x| (x - max_logit).exp())
894                .collect();
895            let sum_exp: f64 = exp_logits.iter().sum();
896            for j in 0..logits.ncols() {
897                calibrated_probs[(i, j)] = exp_logits[j] / sum_exp;
898            }
899        }
900        Ok(calibrated_probs)
901    }
902    /// Fit and transform in one step
903    pub fn fit_transform(
904        &mut self,
905        logits: &Array2<f64>,
906        labels: &Array1<usize>,
907    ) -> Result<Array2<f64>> {
908        self.fit(logits, labels)?;
909        self.transform(logits)
910    }
911    /// Get the fitted parameters (weights, biases)
912    pub fn parameters(&self) -> Option<(Array1<f64>, Array1<f64>)> {
913        if self.fitted {
914            Some((
915                self.weights.as_ref().unwrap().clone(),
916                self.biases.as_ref().unwrap().clone(),
917            ))
918        } else {
919            None
920        }
921    }
922}