use super::Metric;
use scirs2_core::ndarray::{Array, Axis, Ix1, Ix2, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct LossMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> {
total_loss: F,
num_batches: usize,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> Default
for LossMetric<F>
{
fn default() -> Self {
Self::new()
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> LossMetric<F> {
pub fn new() -> Self {
Self {
total_loss: F::zero(),
num_batches: 0,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> Metric<F>
fn update(
&mut self_predictions: &Array<F, IxDyn>, _targets: &Array<F, IxDyn>,
loss: Option<F>,
) {
if let Some(loss) = loss {
self.total_loss = self.total_loss + loss;
self.num_batches += 1;
fn reset(&mut self) {
self.total_loss = F::zero();
self.num_batches = 0;
fn result(&self) -> F {
if self.num_batches > 0 {
self.total_loss / F::from(self.num_batches).expect("Failed to convert to float")
} else {
F::zero()
fn name(&self) -> &str {
"loss"
pub struct AccuracyMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync>
correct: usize,
total: usize,
_phantom: PhantomData<F>,
for AccuracyMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> AccuracyMetric<F> {
correct: 0,
total: 0, phantom: PhantomData,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>, _loss: Option<F>,
let preds = predictions.clone();
let targets = targets.clone();
let preds_2d = if preds.ndim() > 2 {
let batch_size = preds.shape()[0];
let total_classes = preds.len() / batch_size;
preds
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
} else if preds.ndim() == 1 {
.clone()
.into_shape_with_order(IxDyn(&[preds.len(), 1]))
preds.into_dimensionality::<Ix2>().expect("Operation failed")
};
let targets_2d = if targets.ndim() > 2 {
let batch_size = targets.shape()[0];
let total_classes = targets.len() / batch_size;
targets
} else if targets.ndim() == 1 {
.into_shape_with_order(IxDyn(&[targets.len(), 1]))
targets.into_dimensionality::<Ix2>().expect("Operation failed")
let pred_classes = preds_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
F::from(max_idx).expect("Failed to convert to float")
});
let target_classes = if targets_2d.shape()[1] > 1 {
targets_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
F::from(max_idx).expect("Failed to convert to float")
})
targets_2d.index_axis(Axis(1), 0).to_owned()
for (pred, target) in pred_classes.iter().zip(target_classes.iter()) {
if (*pred - *target).abs() < F::from(1e-6).expect("Failed to convert constant to float") {
self.correct += 1;
self.total += pred_classes.len();
self.correct = 0;
self.total = 0;
if self.total > 0 {
F::from(self.correct as f64 / self.total as f64).expect("Failed to convert to float")
"accuracy"
pub struct PrecisionMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync>
tp: usize,
fp: usize,
threshold: F,
for PrecisionMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> PrecisionMetric<F> {
tp: 0,
fp: 0,
threshold: F::from(0.5).expect("Failed to convert constant to float"),
pub fn with_threshold(threshold: F) -> Self {
threshold,
if predictions.shape()[predictions.ndim() - 1] == 1 || predictions.ndim() == 1 {
let preds = predictions
.unwrap_or_else(|_| {
predictions
.clone()
.into_shape_with_order(IxDyn(&[predictions.len(), 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
});
let targets = targets
targets
.into_shape_with_order(IxDyn(&[targets.len(), 1]))
for (pred, target) in preds.iter().zip(targets.iter()) {
let pred_class = if *pred >= self.threshold { 1 } else { 0 };
let target_class = if *target >= F::from(0.5).expect("Failed to convert constant to float") {
1
} else {
0
};
if pred_class == 1 && target_class == 1 {
self.tp += 1;
} else if pred_class == 1 && target_class == 0 {
self.fp += 1;
let preds = predictions.clone();
let targets = targets.clone();
let preds_2d = if preds.ndim() > 2 {
let batch_size = preds.shape()[0];
let total_classes = preds.len() / batch_size;
preds
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
} else {
preds.into_dimensionality::<Ix2>().expect("Operation failed")
};
let targets_2d = if targets.ndim() > 2 {
let batch_size = targets.shape()[0];
let total_classes = targets.len() / batch_size;
targets
targets.into_dimensionality::<Ix2>().expect("Operation failed")
let pred_classes = preds_2d.map_axis(Axis(1), |row| {
max_idx
});
let target_classes = if targets_2d.shape()[1] > 1 {
targets_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
max_idx
})
targets_2d
.index_axis(Axis(1), 0)
.mapv(|x| x.to_usize().unwrap_or(0))
let num_classes = preds_2d.shape()[1];
for c in 0..num_classes {
let class_preds = pred_classes.mapv(|x| if x == c { 1 } else { 0 });
let class_targets = target_classes.mapv(|x| if x == c { 1 } else { 0 });
for (pred, target) in class_preds.iter().zip(class_targets.iter()) {
if *pred == 1 && *target == 1 {
self.tp += 1;
} else if *pred == 1 && *target == 0 {
self.fp += 1;
self.tp = 0;
self.fp = 0;
if self.tp + self.fp > 0 {
F::from(self.tp as f64 / (self.tp + self.fp) as f64).expect("Operation failed")
"precision"
pub struct RecallMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> {
fn_: usize,
for RecallMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> RecallMetric<F> {
fn_: 0,
} else if pred_class == 0 && target_class == 1 {
self.fn_ += 1;
} else if *pred == 0 && *target == 1 {
self.fn_ += 1;
self.fn_ = 0;
if self.tp + self.fn_ > 0 {
F::from(self.tp as f64 / (self.tp + self.fn_) as f64).expect("Operation failed")
"recall"
pub struct F1ScoreMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> {
precision: PrecisionMetric<F>,
recall: RecallMetric<F>,
for F1ScoreMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> F1ScoreMetric<F> {
precision: PrecisionMetric::new(),
recall: RecallMetric::new(),
precision: PrecisionMetric::with_threshold(threshold),
recall: RecallMetric::with_threshold(threshold),
self.precision.update(predictions, targets, None);
self.recall.update(predictions, targets, None);
self.precision.reset();
self.recall.reset();
let precision = self.precision.result();
let recall = self.recall.result();
if precision + recall > F::zero() {
let two = F::from(2.0).expect("Failed to convert constant to float");
(two * precision * recall) / (precision + recall)
"f1_score"
pub struct MeanSquaredErrorMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync,
> {
sum_squared_error: F,
count: usize,
for MeanSquaredErrorMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync>
MeanSquaredErrorMetric<F>
sum_squared_error: F::zero(),
count: 0,
let preds_flat = predictions
.clone()
.into_shape_with_order(IxDyn(&[predictions.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
let targets_flat = targets
.into_shape_with_order(IxDyn(&[targets.len()]))
for (pred, target) in preds_flat.iter().zip(targets_flat.iter()) {
let error = *pred - *target;
self.sum_squared_error = self.sum_squared_error + error * error;
self.count += preds_flat.len();
self.sum_squared_error = F::zero();
self.count = 0;
if self.count > 0 {
self.sum_squared_error / F::from(self.count).expect("Failed to convert to float")
"mean_squared_error"
pub struct MeanAbsoluteErrorMetric<
sum_absolute_error: F,
for MeanAbsoluteErrorMetric<F>
MeanAbsoluteErrorMetric<F>
sum_absolute_error: F::zero(),
let error = (*pred - *target).abs();
self.sum_absolute_error = self.sum_absolute_error + error;
self.sum_absolute_error = F::zero();
self.sum_absolute_error / F::from(self.count).expect("Failed to convert to float")
"mean_absolute_error"
pub struct RSquaredMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync>
sum_squared_total: F,
mean: F,
first_update: bool,
for RSquaredMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> RSquaredMetric<F> {
sum_squared_total: F::zero(),
mean: F::zero(),
first_update: true,
if self.first_update {
let mut sum = F::zero();
for &target in targets_flat.iter() {
sum = sum + target;
self.mean = sum / F::from(targets_flat.len()).expect("Operation failed");
self.first_update = false;
let diff_from_mean = *target - self.mean;
self.sum_squared_total = self.sum_squared_total + diff_from_mean * diff_from_mean;
self.sum_squared_total = F::zero();
self.mean = F::zero();
self.first_update = true;
if self.count > 0 && self.sum_squared_total > F::zero() {
F::one() - (self.sum_squared_error / self.sum_squared_total)
"r_squared"
pub struct AUCMetric<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> {
scores: Vec<F>,
labels: Vec<F>,
for AUCMetric<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> AUCMetric<F> {
scores: Vec::new(),
labels: Vec::new(),
fn compute_auc(&self) -> F {
if self.scores.is_empty() || self.labels.is_empty() {
return F::zero();
let mut pairs: Vec<(F, F)> = self
.scores
.iter()
.cloned()
.zip(self.labels.iter().cloned())
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
let num_pos = self.labels.iter().filter(|&&l| l > F::zero()).count();
let num_neg = self.labels.len() - num_pos;
if num_pos == 0 || num_neg == 0 {
let mut sum_ranks = F::zero();
let mut pos_count = 0;
for (i, (_, label)) in pairs.iter().enumerate() {
if *label > F::zero() {
sum_ranks = sum_ranks + F::from(i + 1).expect("Failed to convert to float");
pos_count += 1;
let pos_count = F::from(pos_count).expect("Failed to convert to float");
let num_pos = F::from(num_pos).expect("Failed to convert to float");
let num_neg = F::from(num_neg).expect("Failed to convert to float");
(sum_ranks - (pos_count * (pos_count + F::one())) / F::from(2.0).expect("Failed to convert constant to float"))
/ (num_pos * num_neg)
let preds = if predictions.ndim() == 2 && predictions.shape()[1] == 2 {
let mut probs = Vec::with_capacity(predictions.shape()[0]);
for i in 0..predictions.shape()[0] {
probs.push(predictions[[i, 1]]);
probs
} else if (predictions.ndim() == 2 && predictions.shape()[1] == 1)
|| predictions.ndim() == 1
{
predictions.iter().cloned().collect()
return;
let labels = if targets.ndim() == 2 && targets.shape()[1] == 2 {
let mut labs = Vec::with_capacity(targets.shape()[0]);
for i in 0..targets.shape()[0] {
labs.push(targets[[i, 1]]);
labs
} else if (targets.ndim() == 2 && targets.shape()[1] == 1) || targets.ndim() == 1 {
targets.iter().cloned().collect()
self.scores.extend(preds);
self.labels.extend(labels);
self.scores.clear();
self.labels.clear();
self.compute_auc()
"auc"