use super::Metric;
use scirs2_core::ndarray::{Array, Axis, Ix1, Ix2, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct LossMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
total_loss: F,
num_batches: usize,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for LossMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
LossMetric<F>
{
pub fn new() -> Self {
Self {
total_loss: F::zero(),
num_batches: 0,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for LossMetric<F>
{
fn update(
&mut self,
_predictions: &Array<F, IxDyn>,
_targets: &Array<F, IxDyn>,
loss: Option<F>,
) {
if let Some(loss) = loss {
self.total_loss += loss;
self.num_batches += 1;
}
}
fn reset(&mut self) {
self.total_loss = F::zero();
self.num_batches = 0;
}
fn result(&self) -> F {
if self.num_batches > 0 {
self.total_loss / F::from(self.num_batches).expect("Failed to convert to float")
} else {
F::zero()
}
}
fn name(&self) -> &str {
"loss"
}
}
#[derive(Debug, Clone)]
pub struct AccuracyMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
correct: usize,
total: usize,
_phantom: PhantomData<F>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for AccuracyMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
AccuracyMetric<F>
{
pub fn new() -> Self {
Self {
correct: 0,
total: 0,
_phantom: PhantomData,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for AccuracyMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
let preds = predictions.clone();
let targets = targets.clone();
let preds_2d: Array<F, Ix2> = if preds.ndim() > 2 {
let batch_size = preds.shape()[0];
let total_classes = preds.len() / batch_size;
preds
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else if preds.ndim() == 1 {
let n = preds.len();
preds
.into_shape_with_order(IxDyn(&[n, 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else {
preds
.into_dimensionality::<Ix2>()
.expect("Operation failed")
};
let targets_2d: Array<F, Ix2> = if targets.ndim() > 2 {
let batch_size = targets.shape()[0];
let total_classes = targets.len() / batch_size;
targets
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else if targets.ndim() == 1 {
let n = targets.len();
targets
.into_shape_with_order(IxDyn(&[n, 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else {
targets
.into_dimensionality::<Ix2>()
.expect("Operation failed")
};
let pred_classes = preds_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
F::from(max_idx).expect("Failed to convert to float")
});
let target_classes = if targets_2d.shape()[1] > 1 {
targets_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
F::from(max_idx).expect("Failed to convert to float")
})
} else {
targets_2d.index_axis(Axis(1), 0).to_owned()
};
for (pred, target) in pred_classes.iter().zip(target_classes.iter()) {
if (*pred - *target).abs() < F::from(1e-6).expect("Failed to convert constant to float")
{
self.correct += 1;
}
}
self.total += pred_classes.len();
}
fn reset(&mut self) {
self.correct = 0;
self.total = 0;
}
fn result(&self) -> F {
if self.total > 0 {
F::from(self.correct as f64 / self.total as f64).expect("Failed to convert to float")
} else {
F::zero()
}
}
fn name(&self) -> &str {
"accuracy"
}
}
#[derive(Debug, Clone)]
pub struct PrecisionMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
tp: usize,
fp: usize,
threshold: F,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for PrecisionMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
PrecisionMetric<F>
{
pub fn new() -> Self {
Self {
tp: 0,
fp: 0,
threshold: F::from(0.5).expect("Failed to convert constant to float"),
}
}
pub fn with_threshold(threshold: F) -> Self {
Self {
tp: 0,
fp: 0,
threshold,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for PrecisionMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
if predictions.shape()[predictions.ndim() - 1] == 1 || predictions.ndim() == 1 {
let n = predictions.len();
let preds = predictions
.clone()
.into_shape_with_order(IxDyn(&[n, 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed");
let m = targets.len();
let tgts = targets
.clone()
.into_shape_with_order(IxDyn(&[m, 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed");
for (pred, target) in preds.iter().zip(tgts.iter()) {
let pred_class = if *pred >= self.threshold { 1 } else { 0 };
let target_class =
if *target >= F::from(0.5).expect("Failed to convert constant to float") {
1
} else {
0
};
if pred_class == 1 && target_class == 1 {
self.tp += 1;
} else if pred_class == 1 && target_class == 0 {
self.fp += 1;
}
}
} else {
let preds = predictions.clone();
let targets = targets.clone();
let preds_2d: Array<F, Ix2> = if preds.ndim() > 2 {
let batch_size = preds.shape()[0];
let total_classes = preds.len() / batch_size;
preds
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else {
preds
.into_dimensionality::<Ix2>()
.expect("Operation failed")
};
let targets_2d: Array<F, Ix2> = if targets.ndim() > 2 {
let batch_size = targets.shape()[0];
let total_classes = targets.len() / batch_size;
targets
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else {
targets
.into_dimensionality::<Ix2>()
.expect("Operation failed")
};
let pred_classes = preds_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0usize;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
max_idx
});
let target_classes: Array<usize, _> = if targets_2d.shape()[1] > 1 {
targets_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0usize;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
max_idx
})
} else {
targets_2d
.index_axis(Axis(1), 0)
.mapv(|x| x.to_usize().unwrap_or(0))
};
let num_classes = preds_2d.shape()[1];
for c in 0..num_classes {
let class_preds = pred_classes.mapv(|x| if x == c { 1usize } else { 0 });
let class_targets = target_classes.mapv(|x| if x == c { 1usize } else { 0 });
for (pred, target) in class_preds.iter().zip(class_targets.iter()) {
if *pred == 1 && *target == 1 {
self.tp += 1;
} else if *pred == 1 && *target == 0 {
self.fp += 1;
}
}
}
}
}
fn reset(&mut self) {
self.tp = 0;
self.fp = 0;
}
fn result(&self) -> F {
if self.tp + self.fp > 0 {
F::from(self.tp as f64 / (self.tp + self.fp) as f64).expect("Operation failed")
} else {
F::zero()
}
}
fn name(&self) -> &str {
"precision"
}
}
#[derive(Debug, Clone)]
pub struct RecallMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
tp: usize,
fn_: usize,
threshold: F,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for RecallMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
RecallMetric<F>
{
pub fn new() -> Self {
Self {
tp: 0,
fn_: 0,
threshold: F::from(0.5).expect("Failed to convert constant to float"),
}
}
pub fn with_threshold(threshold: F) -> Self {
Self {
tp: 0,
fn_: 0,
threshold,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for RecallMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
if predictions.shape()[predictions.ndim() - 1] == 1 || predictions.ndim() == 1 {
let n = predictions.len();
let preds = predictions
.clone()
.into_shape_with_order(IxDyn(&[n, 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed");
let m = targets.len();
let tgts = targets
.clone()
.into_shape_with_order(IxDyn(&[m, 1]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed");
for (pred, target) in preds.iter().zip(tgts.iter()) {
let pred_class = if *pred >= self.threshold { 1 } else { 0 };
let target_class =
if *target >= F::from(0.5).expect("Failed to convert constant to float") {
1
} else {
0
};
if pred_class == 1 && target_class == 1 {
self.tp += 1;
} else if pred_class == 0 && target_class == 1 {
self.fn_ += 1;
}
}
} else {
let preds = predictions.clone();
let targets_arr = targets.clone();
let preds_2d: Array<F, Ix2> = if preds.ndim() > 2 {
let batch_size = preds.shape()[0];
let total_classes = preds.len() / batch_size;
preds
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else {
preds
.into_dimensionality::<Ix2>()
.expect("Operation failed")
};
let targets_2d: Array<F, Ix2> = if targets_arr.ndim() > 2 {
let batch_size = targets_arr.shape()[0];
let total_classes = targets_arr.len() / batch_size;
targets_arr
.into_shape_with_order(IxDyn(&[batch_size, total_classes]))
.expect("Operation failed")
.into_dimensionality::<Ix2>()
.expect("Operation failed")
} else {
targets_arr
.into_dimensionality::<Ix2>()
.expect("Operation failed")
};
let pred_classes = preds_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0usize;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
max_idx
});
let target_classes: Array<usize, _> = if targets_2d.shape()[1] > 1 {
targets_2d.map_axis(Axis(1), |row| {
let mut max_idx = 0usize;
let mut max_val = row[0];
for (i, &val) in row.iter().enumerate().skip(1) {
if val > max_val {
max_idx = i;
max_val = val;
}
}
max_idx
})
} else {
targets_2d
.index_axis(Axis(1), 0)
.mapv(|x| x.to_usize().unwrap_or(0))
};
let num_classes = preds_2d.shape()[1];
for c in 0..num_classes {
let class_preds = pred_classes.mapv(|x| if x == c { 1usize } else { 0 });
let class_targets = target_classes.mapv(|x| if x == c { 1usize } else { 0 });
for (pred, target) in class_preds.iter().zip(class_targets.iter()) {
if *pred == 1 && *target == 1 {
self.tp += 1;
} else if *pred == 0 && *target == 1 {
self.fn_ += 1;
}
}
}
}
}
fn reset(&mut self) {
self.tp = 0;
self.fn_ = 0;
}
fn result(&self) -> F {
if self.tp + self.fn_ > 0 {
F::from(self.tp as f64 / (self.tp + self.fn_) as f64).expect("Operation failed")
} else {
F::zero()
}
}
fn name(&self) -> &str {
"recall"
}
}
#[derive(Debug, Clone)]
pub struct F1ScoreMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
precision: PrecisionMetric<F>,
recall: RecallMetric<F>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for F1ScoreMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
F1ScoreMetric<F>
{
pub fn new() -> Self {
Self {
precision: PrecisionMetric::new(),
recall: RecallMetric::new(),
}
}
pub fn with_threshold(threshold: F) -> Self {
Self {
precision: PrecisionMetric::with_threshold(threshold),
recall: RecallMetric::with_threshold(threshold),
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for F1ScoreMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
self.precision.update(predictions, targets, None);
self.recall.update(predictions, targets, None);
}
fn reset(&mut self) {
self.precision.reset();
self.recall.reset();
}
fn result(&self) -> F {
let precision = self.precision.result();
let recall = self.recall.result();
if precision + recall > F::zero() {
let two = F::from(2.0).expect("Failed to convert constant to float");
(two * precision * recall) / (precision + recall)
} else {
F::zero()
}
}
fn name(&self) -> &str {
"f1_score"
}
}
#[derive(Debug, Clone)]
pub struct MeanSquaredErrorMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
sum_squared_error: F,
count: usize,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for MeanSquaredErrorMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
MeanSquaredErrorMetric<F>
{
pub fn new() -> Self {
Self {
sum_squared_error: F::zero(),
count: 0,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for MeanSquaredErrorMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
let preds_flat = predictions
.clone()
.into_shape_with_order(IxDyn(&[predictions.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
let targets_flat = targets
.clone()
.into_shape_with_order(IxDyn(&[targets.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
for (pred, target) in preds_flat.iter().zip(targets_flat.iter()) {
let error = *pred - *target;
self.sum_squared_error += error * error;
}
self.count += preds_flat.len();
}
fn reset(&mut self) {
self.sum_squared_error = F::zero();
self.count = 0;
}
fn result(&self) -> F {
if self.count > 0 {
self.sum_squared_error / F::from(self.count).expect("Failed to convert to float")
} else {
F::zero()
}
}
fn name(&self) -> &str {
"mean_squared_error"
}
}
#[derive(Debug, Clone)]
pub struct MeanAbsoluteErrorMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
sum_absolute_error: F,
count: usize,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for MeanAbsoluteErrorMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
MeanAbsoluteErrorMetric<F>
{
pub fn new() -> Self {
Self {
sum_absolute_error: F::zero(),
count: 0,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for MeanAbsoluteErrorMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
let preds_flat = predictions
.clone()
.into_shape_with_order(IxDyn(&[predictions.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
let targets_flat = targets
.clone()
.into_shape_with_order(IxDyn(&[targets.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
for (pred, target) in preds_flat.iter().zip(targets_flat.iter()) {
let error = (*pred - *target).abs();
self.sum_absolute_error += error;
}
self.count += preds_flat.len();
}
fn reset(&mut self) {
self.sum_absolute_error = F::zero();
self.count = 0;
}
fn result(&self) -> F {
if self.count > 0 {
self.sum_absolute_error / F::from(self.count).expect("Failed to convert to float")
} else {
F::zero()
}
}
fn name(&self) -> &str {
"mean_absolute_error"
}
}
#[derive(Debug, Clone)]
pub struct RSquaredMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
sum_squared_error: F,
sum_squared_total: F,
count: usize,
mean: F,
first_update: bool,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for RSquaredMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
RSquaredMetric<F>
{
pub fn new() -> Self {
Self {
sum_squared_error: F::zero(),
sum_squared_total: F::zero(),
count: 0,
mean: F::zero(),
first_update: true,
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for RSquaredMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
let preds_flat = predictions
.clone()
.into_shape_with_order(IxDyn(&[predictions.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
let targets_flat = targets
.clone()
.into_shape_with_order(IxDyn(&[targets.len()]))
.expect("Operation failed")
.into_dimensionality::<Ix1>()
.expect("Operation failed");
if self.first_update {
let mut sum = F::zero();
for &target in targets_flat.iter() {
sum += target;
}
self.mean = sum / F::from(targets_flat.len()).expect("Operation failed");
self.first_update = false;
}
for (pred, target) in preds_flat.iter().zip(targets_flat.iter()) {
let residual = *pred - *target;
self.sum_squared_error += residual * residual;
let diff_from_mean = *target - self.mean;
self.sum_squared_total += diff_from_mean * diff_from_mean;
}
self.count += preds_flat.len();
}
fn reset(&mut self) {
self.sum_squared_error = F::zero();
self.sum_squared_total = F::zero();
self.count = 0;
self.mean = F::zero();
self.first_update = true;
}
fn result(&self) -> F {
if self.count > 0 && self.sum_squared_total > F::zero() {
F::one() - (self.sum_squared_error / self.sum_squared_total)
} else {
F::zero()
}
}
fn name(&self) -> &str {
"r_squared"
}
}
#[derive(Debug, Clone)]
pub struct AUCMetric<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync,
> {
scores: Vec<F>,
labels: Vec<F>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Default
for AUCMetric<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync>
AUCMetric<F>
{
pub fn new() -> Self {
Self {
scores: Vec::new(),
labels: Vec::new(),
}
}
fn compute_auc(&self) -> F {
if self.scores.is_empty() || self.labels.is_empty() {
return F::zero();
}
let mut pairs: Vec<(F, F)> = self
.scores
.iter()
.cloned()
.zip(self.labels.iter().cloned())
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
let num_pos = self.labels.iter().filter(|&&l| l > F::zero()).count();
let num_neg = self.labels.len() - num_pos;
if num_pos == 0 || num_neg == 0 {
return F::zero();
}
let mut sum_ranks = F::zero();
let mut pos_count = 0usize;
for (i, (_, label)) in pairs.iter().enumerate() {
if *label > F::zero() {
sum_ranks += F::from(i + 1).expect("Failed to convert to float");
pos_count += 1;
}
}
let pos_count_f = F::from(pos_count).expect("Failed to convert to float");
let num_pos_f = F::from(num_pos).expect("Failed to convert to float");
let num_neg_f = F::from(num_neg).expect("Failed to convert to float");
(sum_ranks
- (pos_count_f * (pos_count_f + F::one()))
/ F::from(2.0).expect("Failed to convert constant to float"))
/ (num_pos_f * num_neg_f)
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Display + Send + Sync> Metric<F>
for AUCMetric<F>
{
fn update(
&mut self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
_loss: Option<F>,
) {
let preds = if predictions.ndim() == 2 && predictions.shape()[1] == 2 {
let mut probs = Vec::with_capacity(predictions.shape()[0]);
for i in 0..predictions.shape()[0] {
probs.push(predictions[[i, 1]]);
}
probs
} else if (predictions.ndim() == 2 && predictions.shape()[1] == 1)
|| predictions.ndim() == 1
{
predictions.iter().cloned().collect()
} else {
return;
};
let labels = if targets.ndim() == 2 && targets.shape()[1] == 2 {
let mut labs = Vec::with_capacity(targets.shape()[0]);
for i in 0..targets.shape()[0] {
labs.push(targets[[i, 1]]);
}
labs
} else if (targets.ndim() == 2 && targets.shape()[1] == 1) || targets.ndim() == 1 {
targets.iter().cloned().collect()
} else {
return;
};
self.scores.extend(preds);
self.labels.extend(labels);
}
fn reset(&mut self) {
self.scores.clear();
self.labels.clear();
}
fn result(&self) -> F {
self.compute_auc()
}
fn name(&self) -> &str {
"auc"
}
}