pub struct TemperatureScaler { /* private fields */ }Expand description
Temperature Scaling - simple and effective multi-class calibration Scales logits by a single learned temperature parameter Particularly effective for neural network outputs
Implementations§
Source§impl TemperatureScaler
impl TemperatureScaler
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new temperature scaler
Examples found in repository?
examples/calibration_demo.rs (line 178)
146fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
147 // Generate multi-class logits (4 classes, 8 samples)
148 let logits = array![
149 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
150 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
151 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
152 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
153 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
154 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
155 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
156 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
157 ];
158 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
159
160 println!(" Input: 4-class classification with 8 samples");
161 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
162
163 // Compute uncalibrated softmax for comparison
164 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
165 for i in 0..logits.nrows() {
166 let max_logit = logits
167 .row(i)
168 .iter()
169 .copied()
170 .fold(f64::NEG_INFINITY, f64::max);
171 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
172 for j in 0..logits.ncols() {
173 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
174 }
175 }
176
177 // Fit temperature scaler
178 let mut scaler = TemperatureScaler::new();
179 scaler.fit(&logits, &labels)?;
180
181 // Get fitted temperature
182 if let Some(temp) = scaler.temperature() {
183 println!(" Fitted temperature: T = {temp:.4}");
184 println!(
185 " Interpretation: {}",
186 if temp > 1.0 {
187 "Model is overconfident (T > 1 reduces confidence)"
188 } else if temp < 1.0 {
189 "Model is underconfident (T < 1 increases confidence)"
190 } else {
191 "Model is well-calibrated (T ≈ 1)"
192 }
193 );
194 }
195
196 // Transform to calibrated probabilities
197 let calibrated_probs = scaler.transform(&logits)?;
198
199 println!("\n Comparison (first 4 samples):");
200 println!(
201 " {:<8} | {:<20} | {:<20}",
202 "Sample", "Uncalibrated Max P", "Calibrated Max P"
203 );
204 println!(" {}", "-".repeat(60));
205
206 for i in 0..4 {
207 let uncal_max = uncalibrated_probs
208 .row(i)
209 .iter()
210 .copied()
211 .fold(f64::NEG_INFINITY, f64::max);
212 let cal_max = calibrated_probs
213 .row(i)
214 .iter()
215 .copied()
216 .fold(f64::NEG_INFINITY, f64::max);
217 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
218 }
219
220 // Compute predictions
221 let mut correct = 0;
222 for i in 0..calibrated_probs.nrows() {
223 let pred = calibrated_probs
224 .row(i)
225 .iter()
226 .enumerate()
227 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
228 .map(|(idx, _)| idx)
229 .unwrap();
230 if pred == labels[i] {
231 correct += 1;
232 }
233 }
234
235 let accuracy = correct as f64 / labels.len() as f64;
236 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
237
238 Ok(())
239}Sourcepub fn fit(
&mut self,
logits: &Array2<f64>,
labels: &Array1<usize>,
) -> Result<()>
pub fn fit( &mut self, logits: &Array2<f64>, labels: &Array1<usize>, ) -> Result<()>
Fit the temperature scaler to logits and true labels Uses negative log-likelihood minimization
Examples found in repository?
examples/calibration_demo.rs (line 179)
146fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
147 // Generate multi-class logits (4 classes, 8 samples)
148 let logits = array![
149 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
150 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
151 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
152 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
153 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
154 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
155 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
156 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
157 ];
158 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
159
160 println!(" Input: 4-class classification with 8 samples");
161 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
162
163 // Compute uncalibrated softmax for comparison
164 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
165 for i in 0..logits.nrows() {
166 let max_logit = logits
167 .row(i)
168 .iter()
169 .copied()
170 .fold(f64::NEG_INFINITY, f64::max);
171 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
172 for j in 0..logits.ncols() {
173 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
174 }
175 }
176
177 // Fit temperature scaler
178 let mut scaler = TemperatureScaler::new();
179 scaler.fit(&logits, &labels)?;
180
181 // Get fitted temperature
182 if let Some(temp) = scaler.temperature() {
183 println!(" Fitted temperature: T = {temp:.4}");
184 println!(
185 " Interpretation: {}",
186 if temp > 1.0 {
187 "Model is overconfident (T > 1 reduces confidence)"
188 } else if temp < 1.0 {
189 "Model is underconfident (T < 1 increases confidence)"
190 } else {
191 "Model is well-calibrated (T ≈ 1)"
192 }
193 );
194 }
195
196 // Transform to calibrated probabilities
197 let calibrated_probs = scaler.transform(&logits)?;
198
199 println!("\n Comparison (first 4 samples):");
200 println!(
201 " {:<8} | {:<20} | {:<20}",
202 "Sample", "Uncalibrated Max P", "Calibrated Max P"
203 );
204 println!(" {}", "-".repeat(60));
205
206 for i in 0..4 {
207 let uncal_max = uncalibrated_probs
208 .row(i)
209 .iter()
210 .copied()
211 .fold(f64::NEG_INFINITY, f64::max);
212 let cal_max = calibrated_probs
213 .row(i)
214 .iter()
215 .copied()
216 .fold(f64::NEG_INFINITY, f64::max);
217 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
218 }
219
220 // Compute predictions
221 let mut correct = 0;
222 for i in 0..calibrated_probs.nrows() {
223 let pred = calibrated_probs
224 .row(i)
225 .iter()
226 .enumerate()
227 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
228 .map(|(idx, _)| idx)
229 .unwrap();
230 if pred == labels[i] {
231 correct += 1;
232 }
233 }
234
235 let accuracy = correct as f64 / labels.len() as f64;
236 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
237
238 Ok(())
239}Sourcepub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>>
pub fn transform(&self, logits: &Array2<f64>) -> Result<Array2<f64>>
Transform logits to calibrated probabilities using temperature scaling
Examples found in repository?
examples/calibration_demo.rs (line 197)
146fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
147 // Generate multi-class logits (4 classes, 8 samples)
148 let logits = array![
149 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
150 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
151 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
152 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
153 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
154 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
155 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
156 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
157 ];
158 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
159
160 println!(" Input: 4-class classification with 8 samples");
161 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
162
163 // Compute uncalibrated softmax for comparison
164 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
165 for i in 0..logits.nrows() {
166 let max_logit = logits
167 .row(i)
168 .iter()
169 .copied()
170 .fold(f64::NEG_INFINITY, f64::max);
171 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
172 for j in 0..logits.ncols() {
173 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
174 }
175 }
176
177 // Fit temperature scaler
178 let mut scaler = TemperatureScaler::new();
179 scaler.fit(&logits, &labels)?;
180
181 // Get fitted temperature
182 if let Some(temp) = scaler.temperature() {
183 println!(" Fitted temperature: T = {temp:.4}");
184 println!(
185 " Interpretation: {}",
186 if temp > 1.0 {
187 "Model is overconfident (T > 1 reduces confidence)"
188 } else if temp < 1.0 {
189 "Model is underconfident (T < 1 increases confidence)"
190 } else {
191 "Model is well-calibrated (T ≈ 1)"
192 }
193 );
194 }
195
196 // Transform to calibrated probabilities
197 let calibrated_probs = scaler.transform(&logits)?;
198
199 println!("\n Comparison (first 4 samples):");
200 println!(
201 " {:<8} | {:<20} | {:<20}",
202 "Sample", "Uncalibrated Max P", "Calibrated Max P"
203 );
204 println!(" {}", "-".repeat(60));
205
206 for i in 0..4 {
207 let uncal_max = uncalibrated_probs
208 .row(i)
209 .iter()
210 .copied()
211 .fold(f64::NEG_INFINITY, f64::max);
212 let cal_max = calibrated_probs
213 .row(i)
214 .iter()
215 .copied()
216 .fold(f64::NEG_INFINITY, f64::max);
217 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
218 }
219
220 // Compute predictions
221 let mut correct = 0;
222 for i in 0..calibrated_probs.nrows() {
223 let pred = calibrated_probs
224 .row(i)
225 .iter()
226 .enumerate()
227 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
228 .map(|(idx, _)| idx)
229 .unwrap();
230 if pred == labels[i] {
231 correct += 1;
232 }
233 }
234
235 let accuracy = correct as f64 / labels.len() as f64;
236 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
237
238 Ok(())
239}Sourcepub fn fit_transform(
&mut self,
logits: &Array2<f64>,
labels: &Array1<usize>,
) -> Result<Array2<f64>>
pub fn fit_transform( &mut self, logits: &Array2<f64>, labels: &Array1<usize>, ) -> Result<Array2<f64>>
Fit and transform in one step
Sourcepub fn temperature(&self) -> Option<f64>
pub fn temperature(&self) -> Option<f64>
Get the fitted temperature parameter
Examples found in repository?
examples/calibration_demo.rs (line 182)
146fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
147 // Generate multi-class logits (4 classes, 8 samples)
148 let logits = array![
149 [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
150 [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
151 [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
152 [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
153 [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
154 [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
155 [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
156 [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
157 ];
158 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
159
160 println!(" Input: 4-class classification with 8 samples");
161 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
162
163 // Compute uncalibrated softmax for comparison
164 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
165 for i in 0..logits.nrows() {
166 let max_logit = logits
167 .row(i)
168 .iter()
169 .copied()
170 .fold(f64::NEG_INFINITY, f64::max);
171 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
172 for j in 0..logits.ncols() {
173 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
174 }
175 }
176
177 // Fit temperature scaler
178 let mut scaler = TemperatureScaler::new();
179 scaler.fit(&logits, &labels)?;
180
181 // Get fitted temperature
182 if let Some(temp) = scaler.temperature() {
183 println!(" Fitted temperature: T = {temp:.4}");
184 println!(
185 " Interpretation: {}",
186 if temp > 1.0 {
187 "Model is overconfident (T > 1 reduces confidence)"
188 } else if temp < 1.0 {
189 "Model is underconfident (T < 1 increases confidence)"
190 } else {
191 "Model is well-calibrated (T ≈ 1)"
192 }
193 );
194 }
195
196 // Transform to calibrated probabilities
197 let calibrated_probs = scaler.transform(&logits)?;
198
199 println!("\n Comparison (first 4 samples):");
200 println!(
201 " {:<8} | {:<20} | {:<20}",
202 "Sample", "Uncalibrated Max P", "Calibrated Max P"
203 );
204 println!(" {}", "-".repeat(60));
205
206 for i in 0..4 {
207 let uncal_max = uncalibrated_probs
208 .row(i)
209 .iter()
210 .copied()
211 .fold(f64::NEG_INFINITY, f64::max);
212 let cal_max = calibrated_probs
213 .row(i)
214 .iter()
215 .copied()
216 .fold(f64::NEG_INFINITY, f64::max);
217 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
218 }
219
220 // Compute predictions
221 let mut correct = 0;
222 for i in 0..calibrated_probs.nrows() {
223 let pred = calibrated_probs
224 .row(i)
225 .iter()
226 .enumerate()
227 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
228 .map(|(idx, _)| idx)
229 .unwrap();
230 if pred == labels[i] {
231 correct += 1;
232 }
233 }
234
235 let accuracy = correct as f64 / labels.len() as f64;
236 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
237
238 Ok(())
239}Trait Implementations§
Source§impl Clone for TemperatureScaler
impl Clone for TemperatureScaler
Source§fn clone(&self) -> TemperatureScaler
fn clone(&self) -> TemperatureScaler
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreSource§impl Debug for TemperatureScaler
impl Debug for TemperatureScaler
Auto Trait Implementations§
impl Freeze for TemperatureScaler
impl RefUnwindSafe for TemperatureScaler
impl Send for TemperatureScaler
impl Sync for TemperatureScaler
impl Unpin for TemperatureScaler
impl UnwindSafe for TemperatureScaler
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
The inverse inclusion map: attempts to construct
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
Checks if
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
Use with care! Same as
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
The inclusion map: converts
self to the equivalent element of its superset.