pub struct TemperatureScaler { /* private fields */ }Expand description
Temperature Scaling - simple and effective multi-class calibration Scales logits by a single learned temperature parameter Particularly effective for neural network outputs
Implementations§
Source§impl TemperatureScaler
impl TemperatureScaler
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new temperature scaler
Examples found in repository?
examples/calibration_demo.rs (line 171)
139fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
140 // Generate multi-class logits (4 classes, 8 samples)
141 let logits = array![
142 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
143 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
144 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
145 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
146 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
147 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
148 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
149 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
150 ];
151 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
152
153 println!(" Input: 4-class classification with 8 samples");
154 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
155
156 // Compute uncalibrated softmax for comparison
157 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
158 for i in 0..logits.nrows() {
159 let max_logit = logits
160 .row(i)
161 .iter()
162 .copied()
163 .fold(f64::NEG_INFINITY, f64::max);
164 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
165 for j in 0..logits.ncols() {
166 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
167 }
168 }
169
170 // Fit temperature scaler
171 let mut scaler = TemperatureScaler::new();
172 scaler.fit(&logits, &labels)?;
173
174 // Get fitted temperature
175 if let Some(temp) = scaler.temperature() {
176 println!(" Fitted temperature: T = {temp:.4}");
177 println!(
178 " Interpretation: {}",
179 if temp > 1.0 {
180 "Model is overconfident (T > 1 reduces confidence)"
181 } else if temp < 1.0 {
182 "Model is underconfident (T < 1 increases confidence)"
183 } else {
184 "Model is well-calibrated (T ≈ 1)"
185 }
186 );
187 }
188
189 // Transform to calibrated probabilities
190 let calibrated_probs = scaler.transform(&logits)?;
191
192 println!("\n Comparison (first 4 samples):");
193 println!(
194 " {:<8} | {:<20} | {:<20}",
195 "Sample", "Uncalibrated Max P", "Calibrated Max P"
196 );
197 println!(" {}", "-".repeat(60));
198
199 for i in 0..4 {
200 let uncal_max = uncalibrated_probs
201 .row(i)
202 .iter()
203 .copied()
204 .fold(f64::NEG_INFINITY, f64::max);
205 let cal_max = calibrated_probs
206 .row(i)
207 .iter()
208 .copied()
209 .fold(f64::NEG_INFINITY, f64::max);
210 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
211 }
212
213 // Compute predictions
214 let mut correct = 0;
215 for i in 0..calibrated_probs.nrows() {
216 let pred = calibrated_probs
217 .row(i)
218 .iter()
219 .enumerate()
220 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
221 .map(|(idx, _)| idx)
222 .unwrap();
223 if pred == labels[i] {
224 correct += 1;
225 }
226 }
227
228 let accuracy = correct as f64 / labels.len() as f64;
229 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
230
231 Ok(())
232}Sourcepub fn fit(
&mut self,
logits: &Array2<f64>,
labels: &Array1<usize>,
) -> Result<()>
pub fn fit( &mut self, logits: &Array2<f64>, labels: &Array1<usize>, ) -> Result<()>
Fit the temperature scaler to logits and true labels Uses negative log-likelihood minimization
Examples found in repository?
examples/calibration_demo.rs (line 172)
139fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
140 // Generate multi-class logits (4 classes, 8 samples)
141 let logits = array![
142 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
143 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
144 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
145 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
146 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
147 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
148 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
149 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
150 ];
151 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
152
153 println!(" Input: 4-class classification with 8 samples");
154 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
155
156 // Compute uncalibrated softmax for comparison
157 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
158 for i in 0..logits.nrows() {
159 let max_logit = logits
160 .row(i)
161 .iter()
162 .copied()
163 .fold(f64::NEG_INFINITY, f64::max);
164 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
165 for j in 0..logits.ncols() {
166 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
167 }
168 }
169
170 // Fit temperature scaler
171 let mut scaler = TemperatureScaler::new();
172 scaler.fit(&logits, &labels)?;
173
174 // Get fitted temperature
175 if let Some(temp) = scaler.temperature() {
176 println!(" Fitted temperature: T = {temp:.4}");
177 println!(
178 " Interpretation: {}",
179 if temp > 1.0 {
180 "Model is overconfident (T > 1 reduces confidence)"
181 } else if temp < 1.0 {
182 "Model is underconfident (T < 1 increases confidence)"
183 } else {
184 "Model is well-calibrated (T ≈ 1)"
185 }
186 );
187 }
188
189 // Transform to calibrated probabilities
190 let calibrated_probs = scaler.transform(&logits)?;
191
192 println!("\n Comparison (first 4 samples):");
193 println!(
194 " {:<8} | {:<20} | {:<20}",
195 "Sample", "Uncalibrated Max P", "Calibrated Max P"
196 );
197 println!(" {}", "-".repeat(60));
198
199 for i in 0..4 {
200 let uncal_max = uncalibrated_probs
201 .row(i)
202 .iter()
203 .copied()
204 .fold(f64::NEG_INFINITY, f64::max);
205 let cal_max = calibrated_probs
206 .row(i)
207 .iter()
208 .copied()
209 .fold(f64::NEG_INFINITY, f64::max);
210 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
211 }
212
213 // Compute predictions
214 let mut correct = 0;
215 for i in 0..calibrated_probs.nrows() {
216 let pred = calibrated_probs
217 .row(i)
218 .iter()
219 .enumerate()
220 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
221 .map(|(idx, _)| idx)
222 .unwrap();
223 if pred == labels[i] {
224 correct += 1;
225 }
226 }
227
228 let accuracy = correct as f64 / labels.len() as f64;
229 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
230
231 Ok(())
232}Sourcepub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>>
pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>>
Transform logits to calibrated probabilities using temperature scaling
Examples found in repository?
examples/calibration_demo.rs (line 190)
139fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
140 // Generate multi-class logits (4 classes, 8 samples)
141 let logits = array![
142 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
143 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
144 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
145 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
146 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
147 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
148 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
149 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
150 ];
151 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
152
153 println!(" Input: 4-class classification with 8 samples");
154 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
155
156 // Compute uncalibrated softmax for comparison
157 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
158 for i in 0..logits.nrows() {
159 let max_logit = logits
160 .row(i)
161 .iter()
162 .copied()
163 .fold(f64::NEG_INFINITY, f64::max);
164 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
165 for j in 0..logits.ncols() {
166 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
167 }
168 }
169
170 // Fit temperature scaler
171 let mut scaler = TemperatureScaler::new();
172 scaler.fit(&logits, &labels)?;
173
174 // Get fitted temperature
175 if let Some(temp) = scaler.temperature() {
176 println!(" Fitted temperature: T = {temp:.4}");
177 println!(
178 " Interpretation: {}",
179 if temp > 1.0 {
180 "Model is overconfident (T > 1 reduces confidence)"
181 } else if temp < 1.0 {
182 "Model is underconfident (T < 1 increases confidence)"
183 } else {
184 "Model is well-calibrated (T ≈ 1)"
185 }
186 );
187 }
188
189 // Transform to calibrated probabilities
190 let calibrated_probs = scaler.transform(&logits)?;
191
192 println!("\n Comparison (first 4 samples):");
193 println!(
194 " {:<8} | {:<20} | {:<20}",
195 "Sample", "Uncalibrated Max P", "Calibrated Max P"
196 );
197 println!(" {}", "-".repeat(60));
198
199 for i in 0..4 {
200 let uncal_max = uncalibrated_probs
201 .row(i)
202 .iter()
203 .copied()
204 .fold(f64::NEG_INFINITY, f64::max);
205 let cal_max = calibrated_probs
206 .row(i)
207 .iter()
208 .copied()
209 .fold(f64::NEG_INFINITY, f64::max);
210 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
211 }
212
213 // Compute predictions
214 let mut correct = 0;
215 for i in 0..calibrated_probs.nrows() {
216 let pred = calibrated_probs
217 .row(i)
218 .iter()
219 .enumerate()
220 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
221 .map(|(idx, _)| idx)
222 .unwrap();
223 if pred == labels[i] {
224 correct += 1;
225 }
226 }
227
228 let accuracy = correct as f64 / labels.len() as f64;
229 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
230
231 Ok(())
232}Sourcepub fn fit_transform(
&mut self,
logits: &Array2<f64>,
labels: &Array1<usize>,
) -> Result<Array2<f64>>
pub fn fit_transform( &mut self, logits: &Array2<f64>, labels: &Array1<usize>, ) -> Result<Array2<f64>>
Fit and transform in one step
Sourcepub fn temperature(&self) -> Option<f64>
pub fn temperature(&self) -> Option<f64>
Get the fitted temperature parameter
Examples found in repository?
examples/calibration_demo.rs (line 175)
139fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
140 // Generate multi-class logits (4 classes, 8 samples)
141 let logits = array![
142 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
143 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
144 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
145 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
146 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
147 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
148 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
149 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
150 ];
151 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
152
153 println!(" Input: 4-class classification with 8 samples");
154 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
155
156 // Compute uncalibrated softmax for comparison
157 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
158 for i in 0..logits.nrows() {
159 let max_logit = logits
160 .row(i)
161 .iter()
162 .copied()
163 .fold(f64::NEG_INFINITY, f64::max);
164 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
165 for j in 0..logits.ncols() {
166 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
167 }
168 }
169
170 // Fit temperature scaler
171 let mut scaler = TemperatureScaler::new();
172 scaler.fit(&logits, &labels)?;
173
174 // Get fitted temperature
175 if let Some(temp) = scaler.temperature() {
176 println!(" Fitted temperature: T = {temp:.4}");
177 println!(
178 " Interpretation: {}",
179 if temp > 1.0 {
180 "Model is overconfident (T > 1 reduces confidence)"
181 } else if temp < 1.0 {
182 "Model is underconfident (T < 1 increases confidence)"
183 } else {
184 "Model is well-calibrated (T ≈ 1)"
185 }
186 );
187 }
188
189 // Transform to calibrated probabilities
190 let calibrated_probs = scaler.transform(&logits)?;
191
192 println!("\n Comparison (first 4 samples):");
193 println!(
194 " {:<8} | {:<20} | {:<20}",
195 "Sample", "Uncalibrated Max P", "Calibrated Max P"
196 );
197 println!(" {}", "-".repeat(60));
198
199 for i in 0..4 {
200 let uncal_max = uncalibrated_probs
201 .row(i)
202 .iter()
203 .copied()
204 .fold(f64::NEG_INFINITY, f64::max);
205 let cal_max = calibrated_probs
206 .row(i)
207 .iter()
208 .copied()
209 .fold(f64::NEG_INFINITY, f64::max);
210 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
211 }
212
213 // Compute predictions
214 let mut correct = 0;
215 for i in 0..calibrated_probs.nrows() {
216 let pred = calibrated_probs
217 .row(i)
218 .iter()
219 .enumerate()
220 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
221 .map(|(idx, _)| idx)
222 .unwrap();
223 if pred == labels[i] {
224 correct += 1;
225 }
226 }
227
228 let accuracy = correct as f64 / labels.len() as f64;
229 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
230
231 Ok(())
232}Trait Implementations§
Source§impl Clone for TemperatureScaler
impl Clone for TemperatureScaler
Source§fn clone(&self) -> TemperatureScaler
fn clone(&self) -> TemperatureScaler
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreSource§impl Debug for TemperatureScaler
impl Debug for TemperatureScaler
Auto Trait Implementations§
impl Freeze for TemperatureScaler
impl RefUnwindSafe for TemperatureScaler
impl Send for TemperatureScaler
impl Sync for TemperatureScaler
impl Unpin for TemperatureScaler
impl UnwindSafe for TemperatureScaler
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
The inverse inclusion map: attempts to construct
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
Checks if
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
Use with care! Same as
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
The inclusion map: converts
self to the equivalent element of its superset.