pub mod auc;
pub mod classification;
pub mod conformal;
pub mod conformal_pid;
pub mod ewma;
pub mod importance;
pub mod kappa;
pub mod platt_scaling;
pub mod quantile_tracker;
pub mod regression;
pub mod rolling;
pub mod streaming_metric;
pub mod strongly_adaptive;
pub mod temperature_scaling;
pub use classification::ClassificationMetrics;
pub use conformal::StepSchedule;
pub use conformal_pid::ConformalPID;
pub use importance::FeatureImportance;
pub use platt_scaling::OnlinePlattScaling;
pub use quantile_tracker::StreamingQuantileTracker;
pub use regression::RegressionMetrics;
pub use streaming_metric::{
Accuracy, LogLoss, MetricUnion, Pinball, StreamingMetric, MAE, MSE, R2, RMSE,
};
pub use strongly_adaptive::StronglyAdaptiveConformal;
pub use temperature_scaling::OnlineTemperatureScaling;
#[derive(Debug, Clone)]
pub struct MetricSet {
reg: RegressionMetrics,
cls: ClassificationMetrics,
}
impl MetricSet {
pub fn new() -> Self {
Self {
reg: RegressionMetrics::new(),
cls: ClassificationMetrics::new(),
}
}
pub fn update_regression(&mut self, target: f64, prediction: f64) {
self.reg.update(target, prediction);
}
pub fn update_classification(&mut self, target: usize, predicted: usize, predicted_proba: f64) {
self.cls.update(target, predicted, predicted_proba);
}
pub fn regression(&self) -> &RegressionMetrics {
&self.reg
}
pub fn classification(&self) -> &ClassificationMetrics {
&self.cls
}
pub fn reset(&mut self) {
self.reg.reset();
self.cls.reset();
}
}
impl Default for MetricSet {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn metric_set_new_is_empty() {
let ms = MetricSet::new();
assert_eq!(ms.regression().n_samples(), 0);
assert_eq!(ms.classification().n_samples(), 0);
}
#[test]
fn metric_set_update_regression() {
let mut ms = MetricSet::new();
ms.update_regression(5.0, 3.0);
ms.update_regression(2.0, 2.0);
assert_eq!(ms.regression().n_samples(), 2);
assert!(ms.regression().mae() > 0.0);
assert_eq!(ms.classification().n_samples(), 0);
}
#[test]
fn metric_set_update_classification() {
let mut ms = MetricSet::new();
ms.update_classification(1, 1, 0.9);
ms.update_classification(0, 0, 0.1);
assert_eq!(ms.classification().n_samples(), 2);
assert_eq!(ms.classification().accuracy(), 1.0);
assert_eq!(ms.regression().n_samples(), 0);
}
#[test]
fn metric_set_reset() {
let mut ms = MetricSet::new();
ms.update_regression(1.0, 2.0);
ms.update_classification(1, 0, 0.3);
ms.reset();
assert_eq!(ms.regression().n_samples(), 0);
assert_eq!(ms.classification().n_samples(), 0);
}
#[test]
fn metric_set_default() {
let ms = MetricSet::default();
assert_eq!(ms.regression().n_samples(), 0);
assert_eq!(ms.classification().n_samples(), 0);
}
}