use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum TensorLossError {
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
InvalidTarget(String),
DivisionByZero,
EmptyInput,
InvalidConfig(String),
}
impl std::fmt::Display for TensorLossError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ShapeMismatch { expected, got } => {
write!(f, "shape mismatch: expected {:?}, got {:?}", expected, got)
}
Self::InvalidTarget(msg) => write!(f, "invalid target: {}", msg),
Self::DivisionByZero => write!(f, "division by zero encountered"),
Self::EmptyInput => write!(f, "input tensor is empty"),
Self::InvalidConfig(msg) => write!(f, "invalid configuration: {}", msg),
}
}
}
impl std::error::Error for TensorLossError {}
#[derive(Debug, Clone, PartialEq)]
pub enum LossReduction {
Mean,
Sum,
None,
}
#[derive(Debug, Clone)]
pub struct TensorLossOutput {
pub loss: f64,
pub loss_tensor: Option<ArrayD<f64>>,
pub grad: Option<ArrayD<f64>>,
}
pub trait TensorLoss: std::fmt::Debug {
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError>;
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct TensorLossConfig {
pub reduction: LossReduction,
pub compute_grad: bool,
pub epsilon: f64,
}
impl Default for TensorLossConfig {
fn default() -> Self {
Self {
reduction: LossReduction::Mean,
compute_grad: true,
epsilon: 1e-8,
}
}
}
fn validate_shapes(pred: &ArrayD<f64>, target: &ArrayD<f64>) -> Result<usize, TensorLossError> {
let n = pred.len();
if n == 0 {
return Err(TensorLossError::EmptyInput);
}
if pred.shape() != target.shape() {
return Err(TensorLossError::ShapeMismatch {
expected: pred.shape().to_vec(),
got: target.shape().to_vec(),
});
}
Ok(n)
}
fn apply_reduction(
loss_elem: ArrayD<f64>,
grad_elem: Option<ArrayD<f64>>,
reduction: &LossReduction,
n: usize,
) -> TensorLossOutput {
match reduction {
LossReduction::None => TensorLossOutput {
loss: 0.0,
loss_tensor: Some(loss_elem),
grad: grad_elem,
},
LossReduction::Sum => {
let loss = loss_elem.sum();
TensorLossOutput {
loss,
loss_tensor: None,
grad: grad_elem,
}
}
LossReduction::Mean => {
let loss = loss_elem.sum() / n as f64;
TensorLossOutput {
loss,
loss_tensor: None,
grad: grad_elem,
}
}
}
}
#[derive(Debug, Clone)]
pub struct TensorMseLoss {
pub config: TensorLossConfig,
}
impl TensorMseLoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
}
}
pub fn with_config(config: TensorLossConfig) -> Self {
Self { config }
}
}
impl Default for TensorMseLoss {
fn default() -> Self {
Self::new()
}
}
impl TensorLoss for TensorMseLoss {
fn name(&self) -> &'static str {
"mse"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let diff = pred - target;
let loss_elem = diff.mapv(|x| x * x);
let grad = if self.config.compute_grad {
let scale = match self.config.reduction {
LossReduction::Mean => 2.0 / n as f64,
LossReduction::Sum | LossReduction::None => 2.0,
};
Some(diff.mapv(|x| x * scale))
} else {
None
};
Ok(apply_reduction(loss_elem, grad, &self.config.reduction, n))
}
}
#[derive(Debug, Clone)]
pub struct TensorBCELoss {
pub config: TensorLossConfig,
}
impl TensorBCELoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
}
}
}
impl Default for TensorBCELoss {
fn default() -> Self {
Self::new()
}
}
impl TensorLoss for TensorBCELoss {
fn name(&self) -> &'static str {
"bce"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let eps = self.config.epsilon;
let p = pred.mapv(|x| x.clamp(eps, 1.0 - eps));
let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
let mut grad_elem = if self.config.compute_grad {
Some(ArrayD::zeros(IxDyn(pred.shape())))
} else {
None
};
Zip::from(&mut loss_elem)
.and(&p)
.and(target)
.for_each(|l, &pi, &ti| {
*l = -(ti * pi.ln() + (1.0 - ti) * (1.0 - pi).ln());
});
if let Some(ref mut g) = grad_elem {
Zip::from(g).and(&p).and(target).for_each(|gi, &pi, &ti| {
*gi = -(ti / pi - (1.0 - ti) / (1.0 - pi));
});
}
Ok(apply_reduction(
loss_elem,
grad_elem,
&self.config.reduction,
n,
))
}
}
#[derive(Debug, Clone)]
pub struct TensorCrossEntropyLoss {
pub config: TensorLossConfig,
pub label_smoothing: f64,
pub apply_softmax: bool,
}
impl TensorCrossEntropyLoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
label_smoothing: 0.0,
apply_softmax: false,
}
}
}
impl Default for TensorCrossEntropyLoss {
fn default() -> Self {
Self::new()
}
}
fn softmax_flat(logits: &ArrayD<f64>) -> ArrayD<f64> {
let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let shifted = logits.mapv(|x| (x - max_val).exp());
let sum = shifted.sum();
if sum == 0.0 {
shifted
} else {
shifted.mapv(|x| x / sum)
}
}
impl TensorLoss for TensorCrossEntropyLoss {
fn name(&self) -> &'static str {
"cross_entropy"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let eps = self.config.epsilon;
let k = n as f64;
let p = if self.apply_softmax {
softmax_flat(pred)
} else {
pred.clone()
};
let t_smooth = if self.label_smoothing > 0.0 {
let ls = self.label_smoothing;
target.mapv(|ti| ti * (1.0 - ls) + ls / k)
} else {
target.clone()
};
let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
Zip::from(&mut loss_elem)
.and(&p)
.and(&t_smooth)
.for_each(|l, &pi, &ti| {
*l = -(ti * (pi + eps).ln());
});
let grad = if self.config.compute_grad {
let mut g = ArrayD::zeros(IxDyn(pred.shape()));
Zip::from(&mut g)
.and(&p)
.and(&t_smooth)
.for_each(|gi, &pi, &ti| {
*gi = -ti / (pi + eps);
});
if self.apply_softmax {
Some((&p) - &t_smooth)
} else {
Some(g)
}
} else {
None
};
Ok(apply_reduction(loss_elem, grad, &self.config.reduction, n))
}
}
#[derive(Debug, Clone)]
pub struct TensorFocalLoss {
pub config: TensorLossConfig,
pub gamma: f64,
pub alpha: Option<f64>,
}
impl TensorFocalLoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
gamma: 2.0,
alpha: None,
}
}
pub fn with_gamma(gamma: f64) -> Self {
Self {
config: TensorLossConfig::default(),
gamma,
alpha: None,
}
}
}
impl Default for TensorFocalLoss {
fn default() -> Self {
Self::new()
}
}
impl TensorLoss for TensorFocalLoss {
fn name(&self) -> &'static str {
"focal"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let eps = self.config.epsilon;
let gamma = self.gamma;
let p = pred.mapv(|x| x.clamp(eps, 1.0 - eps));
let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
let mut grad_elem = if self.config.compute_grad {
Some(ArrayD::zeros(IxDyn(pred.shape())))
} else {
None
};
Zip::from(&mut loss_elem)
.and(&p)
.and(target)
.for_each(|l, &pi, &ti| {
let p_t = if ti > 0.5 { pi } else { 1.0 - pi };
let modulator = (1.0 - p_t).powf(gamma);
let weight = match self.alpha {
Some(a) => {
if ti > 0.5 {
a
} else {
1.0 - a
}
}
None => 1.0,
};
*l = -weight * modulator * (p_t + eps).ln();
});
if let Some(ref mut g) = grad_elem {
Zip::from(g).and(&p).and(target).for_each(|gi, &pi, &ti| {
let p_t = if ti > 0.5 { pi } else { 1.0 - pi };
let sign = if ti > 0.5 { 1.0_f64 } else { -1.0_f64 };
let modulator = (1.0 - p_t).powf(gamma);
let weight = match self.alpha {
Some(a) => {
if ti > 0.5 {
a
} else {
1.0 - a
}
}
None => 1.0,
};
let term1 = if gamma > 0.0 {
gamma * (1.0 - p_t).powf(gamma - 1.0) * (p_t + eps).ln()
} else {
0.0
};
let term2 = modulator / (p_t + eps);
*gi = -weight * (term1 - term2) * sign;
});
}
Ok(apply_reduction(
loss_elem,
grad_elem,
&self.config.reduction,
n,
))
}
}
#[derive(Debug, Clone)]
pub struct TensorHuberLoss {
pub config: TensorLossConfig,
pub delta: f64,
}
impl TensorHuberLoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
delta: 1.0,
}
}
pub fn with_delta(delta: f64) -> Self {
Self {
config: TensorLossConfig::default(),
delta,
}
}
}
impl Default for TensorHuberLoss {
fn default() -> Self {
Self::new()
}
}
impl TensorLoss for TensorHuberLoss {
fn name(&self) -> &'static str {
"huber"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let delta = self.delta;
if delta <= 0.0 {
return Err(TensorLossError::InvalidConfig(format!(
"delta must be positive, got {}",
delta
)));
}
let diff = pred - target;
let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
let mut grad_elem = if self.config.compute_grad {
Some(ArrayD::zeros(IxDyn(pred.shape())))
} else {
None
};
Zip::from(&mut loss_elem).and(&diff).for_each(|l, &d| {
let abs_d = d.abs();
if abs_d < delta {
*l = 0.5 * d * d / delta;
} else {
*l = abs_d - 0.5 * delta;
}
});
if let Some(ref mut g) = grad_elem {
Zip::from(g).and(&diff).for_each(|gi, &d| {
let abs_d = d.abs();
let sign = if d > 0.0 {
1.0
} else if d < 0.0 {
-1.0
} else {
0.0
};
*gi = sign * (abs_d / delta).min(1.0);
});
}
Ok(apply_reduction(
loss_elem,
grad_elem,
&self.config.reduction,
n,
))
}
}
#[derive(Debug, Clone)]
pub struct TensorKLDivLoss {
pub config: TensorLossConfig,
}
impl TensorKLDivLoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
}
}
}
impl Default for TensorKLDivLoss {
fn default() -> Self {
Self::new()
}
}
impl TensorLoss for TensorKLDivLoss {
fn name(&self) -> &'static str {
"kl_div"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let eps = self.config.epsilon;
let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
let mut grad_elem = if self.config.compute_grad {
Some(ArrayD::zeros(IxDyn(pred.shape())))
} else {
None
};
Zip::from(&mut loss_elem)
.and(pred)
.and(target)
.for_each(|l, &pi, &ti| {
if ti > eps {
let p_safe = pi.max(eps);
*l = ti * (ti.ln() - p_safe.ln());
}
});
if let Some(ref mut g) = grad_elem {
Zip::from(g).and(pred).and(target).for_each(|gi, &pi, &ti| {
if ti > eps {
*gi = -ti / (pi + eps);
}
});
}
Ok(apply_reduction(
loss_elem,
grad_elem,
&self.config.reduction,
n,
))
}
}
#[derive(Debug, Clone)]
pub struct TensorCosineEmbeddingLoss {
pub config: TensorLossConfig,
}
impl TensorCosineEmbeddingLoss {
pub fn new() -> Self {
Self {
config: TensorLossConfig::default(),
}
}
}
impl Default for TensorCosineEmbeddingLoss {
fn default() -> Self {
Self::new()
}
}
impl TensorLoss for TensorCosineEmbeddingLoss {
fn name(&self) -> &'static str {
"cosine_embedding"
}
fn compute(
&self,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let n = validate_shapes(pred, target)?;
let eps = self.config.epsilon;
let dot: f64 = pred.iter().zip(target.iter()).map(|(p, t)| p * t).sum();
let norm_p: f64 = pred.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_t: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
let denom = norm_p * norm_t + eps;
let similarity = dot / denom;
let scalar_loss = 1.0 - similarity;
let grad = if self.config.compute_grad {
let mut g = ArrayD::zeros(IxDyn(pred.shape()));
let norm_p_sq = norm_p * norm_p + eps;
Zip::from(&mut g)
.and(pred)
.and(target)
.for_each(|gi, &pi, &ti| {
let d_sim = ti / denom - dot * pi / (norm_p_sq * denom);
*gi = -d_sim;
});
Some(g)
} else {
None
};
match self.config.reduction {
LossReduction::None => {
let loss_tensor = ArrayD::from_elem(IxDyn(pred.shape()), scalar_loss / n as f64);
Ok(TensorLossOutput {
loss: 0.0,
loss_tensor: Some(loss_tensor),
grad,
})
}
LossReduction::Mean | LossReduction::Sum => Ok(TensorLossOutput {
loss: scalar_loss,
loss_tensor: None,
grad,
}),
}
}
}
#[derive(Debug)]
pub struct TensorLossRegistry {
losses: HashMap<String, Box<dyn TensorLoss>>,
}
impl TensorLossRegistry {
pub fn new() -> Self {
Self {
losses: HashMap::new(),
}
}
pub fn with_all_defaults() -> Self {
let mut reg = Self::new();
reg.register("mse", Box::new(TensorMseLoss::new()));
reg.register("bce", Box::new(TensorBCELoss::new()));
reg.register("cross_entropy", Box::new(TensorCrossEntropyLoss::new()));
reg.register("focal", Box::new(TensorFocalLoss::new()));
reg.register("huber", Box::new(TensorHuberLoss::new()));
reg.register("kl_div", Box::new(TensorKLDivLoss::new()));
reg.register(
"cosine_embedding",
Box::new(TensorCosineEmbeddingLoss::new()),
);
reg
}
pub fn register(&mut self, name: impl Into<String>, loss: Box<dyn TensorLoss>) {
self.losses.insert(name.into(), loss);
}
pub fn compute(
&self,
name: &str,
pred: &ArrayD<f64>,
target: &ArrayD<f64>,
) -> Result<TensorLossOutput, TensorLossError> {
let loss = self.losses.get(name).ok_or_else(|| {
TensorLossError::InvalidConfig(format!("no loss registered under name '{}'", name))
})?;
loss.compute(pred, target)
}
pub fn names(&self) -> Vec<&str> {
self.losses.keys().map(|s| s.as_str()).collect()
}
pub fn contains(&self, name: &str) -> bool {
self.losses.contains_key(name)
}
}
impl Default for TensorLossRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr1;
fn to_arrayd(v: Vec<f64>) -> ArrayD<f64> {
arr1(&v).into_dyn()
}
#[test]
fn test_mse_zero_loss_identical_arrays() {
let a = to_arrayd(vec![1.0, 2.0, 3.0]);
let loss = TensorMseLoss::new().compute(&a, &a).unwrap();
assert!(
(loss.loss).abs() < 1e-10,
"identical arrays should yield zero loss"
);
}
#[test]
fn test_mse_loss_value_correct() {
let pred = to_arrayd(vec![1.0, 2.0]);
let target = to_arrayd(vec![0.0, 0.0]);
let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
assert!((out.loss - 2.5).abs() < 1e-10);
}
#[test]
fn test_mse_gradient_shape() {
let pred = to_arrayd(vec![1.0, 2.0, 3.0]);
let target = to_arrayd(vec![0.0, 0.0, 0.0]);
let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
let grad = out.grad.unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_mse_gradient_direction() {
let pred = to_arrayd(vec![3.0, 2.0]);
let target = to_arrayd(vec![1.0, 1.0]);
let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
let grad = out.grad.unwrap();
for &g in grad.iter() {
assert!(g > 0.0, "gradient should be positive when pred > target");
}
}
#[test]
fn test_bce_perfect_prediction_near_zero() {
let pred = to_arrayd(vec![0.9999, 0.0001]);
let target = to_arrayd(vec![1.0, 0.0]);
let out = TensorBCELoss::new().compute(&pred, &target).unwrap();
assert!(out.loss < 1e-3, "near-perfect predictions → near-zero loss");
}
#[test]
fn test_bce_gradient_shape() {
let pred = to_arrayd(vec![0.5, 0.7]);
let target = to_arrayd(vec![1.0, 0.0]);
let out = TensorBCELoss::new().compute(&pred, &target).unwrap();
let grad = out.grad.unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_cross_entropy_uniform_target() {
let eps = 1e-8_f64;
let p = 1.0_f64 / 3.0;
let pred = to_arrayd(vec![p; 3]);
let target = to_arrayd(vec![p; 3]);
let out = TensorCrossEntropyLoss::new()
.compute(&pred, &target)
.unwrap();
let expected = -(p * (p + eps).ln());
assert!(
(out.loss - expected).abs() < 1e-6,
"expected {}, got {}",
expected,
out.loss
);
}
#[test]
fn test_cross_entropy_label_smoothing() {
let pred = to_arrayd(vec![0.9, 0.05, 0.05]);
let target = to_arrayd(vec![1.0, 0.0, 0.0]);
let no_smooth = TensorCrossEntropyLoss::new()
.compute(&pred, &target)
.unwrap();
let with_smooth = TensorCrossEntropyLoss {
label_smoothing: 0.1,
..TensorCrossEntropyLoss::new()
}
.compute(&pred, &target)
.unwrap();
assert!(
(no_smooth.loss - with_smooth.loss).abs() > 1e-6,
"label smoothing should change the loss"
);
}
#[test]
fn test_focal_gamma_zero_equals_bce() {
let pred = to_arrayd(vec![0.7, 0.3, 0.8]);
let target = to_arrayd(vec![1.0, 0.0, 1.0]);
let focal = TensorFocalLoss::with_gamma(0.0)
.compute(&pred, &target)
.unwrap();
let bce = TensorBCELoss::new().compute(&pred, &target).unwrap();
assert!(
(focal.loss - bce.loss).abs() < 1e-6,
"focal(gamma=0) ≈ BCE, got focal={} bce={}",
focal.loss,
bce.loss
);
}
#[test]
fn test_focal_high_confidence_downweighted() {
let pred_high = to_arrayd(vec![0.99]);
let pred_low = to_arrayd(vec![0.6]);
let target = to_arrayd(vec![1.0]);
let focal = TensorFocalLoss::new(); let out_high = focal.compute(&pred_high, &target).unwrap();
let out_low = focal.compute(&pred_low, &target).unwrap();
assert!(
out_high.loss < out_low.loss,
"high-confidence correct prediction should be downweighted"
);
}
#[test]
fn test_huber_small_error_quadratic() {
let pred = to_arrayd(vec![0.5]);
let target = to_arrayd(vec![0.0]);
let out = TensorHuberLoss::new().compute(&pred, &target).unwrap();
assert!((out.loss - 0.125).abs() < 1e-10);
}
#[test]
fn test_huber_large_error_linear() {
let pred = to_arrayd(vec![2.0]);
let target = to_arrayd(vec![0.0]);
let out = TensorHuberLoss::new().compute(&pred, &target).unwrap();
assert!((out.loss - 1.5).abs() < 1e-10);
}
#[test]
fn test_kl_div_identical_distributions_zero() {
let p = to_arrayd(vec![0.3, 0.5, 0.2]);
let out = TensorKLDivLoss::new().compute(&p, &p).unwrap();
assert!(out.loss.abs() < 1e-6);
}
#[test]
fn test_kl_div_gradient_shape() {
let pred = to_arrayd(vec![0.3, 0.5, 0.2]);
let target = to_arrayd(vec![0.4, 0.4, 0.2]);
let out = TensorKLDivLoss::new().compute(&pred, &target).unwrap();
let grad = out.grad.unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_cosine_parallel_loss_zero() {
let pred = to_arrayd(vec![1.0, 0.0, 0.0]);
let target = to_arrayd(vec![2.0, 0.0, 0.0]); let out = TensorCosineEmbeddingLoss::new()
.compute(&pred, &target)
.unwrap();
assert!(out.loss.abs() < 1e-6, "parallel vectors → loss ≈ 0");
}
#[test]
fn test_cosine_orthogonal_loss_one() {
let pred = to_arrayd(vec![1.0, 0.0]);
let target = to_arrayd(vec![0.0, 1.0]);
let out = TensorCosineEmbeddingLoss::new()
.compute(&pred, &target)
.unwrap();
assert!(
(out.loss - 1.0).abs() < 1e-6,
"orthogonal vectors → loss ≈ 1"
);
}
#[test]
fn test_reduction_sum_vs_mean() {
let pred = to_arrayd(vec![1.0, 2.0, 3.0]);
let target = to_arrayd(vec![0.0, 0.0, 0.0]);
let mean_loss = TensorMseLoss::with_config(TensorLossConfig {
reduction: LossReduction::Mean,
..Default::default()
})
.compute(&pred, &target)
.unwrap();
let sum_loss = TensorMseLoss::with_config(TensorLossConfig {
reduction: LossReduction::Sum,
..Default::default()
})
.compute(&pred, &target)
.unwrap();
assert!(
(sum_loss.loss - mean_loss.loss).abs() > 1e-6,
"sum != mean for non-unit arrays"
);
}
#[test]
fn test_reduction_none_returns_tensor() {
let pred = to_arrayd(vec![1.0, 2.0]);
let target = to_arrayd(vec![0.0, 0.0]);
let out = TensorMseLoss::with_config(TensorLossConfig {
reduction: LossReduction::None,
..Default::default()
})
.compute(&pred, &target)
.unwrap();
assert!(
out.loss_tensor.is_some(),
"None reduction should return a loss tensor"
);
let lt = out.loss_tensor.unwrap();
assert_eq!(lt.shape(), pred.shape());
}
#[test]
fn test_registry_with_all_defaults() {
let reg = TensorLossRegistry::with_all_defaults();
assert_eq!(
reg.names().len(),
7,
"registry should contain 7 built-in losses"
);
for name in &[
"mse",
"bce",
"cross_entropy",
"focal",
"huber",
"kl_div",
"cosine_embedding",
] {
assert!(reg.contains(name), "missing: {}", name);
}
}
#[test]
fn test_registry_compute_by_name() {
let reg = TensorLossRegistry::with_all_defaults();
let pred = to_arrayd(vec![0.5, 0.5]);
let target = to_arrayd(vec![1.0, 0.0]);
let out = reg.compute("bce", &pred, &target).unwrap();
assert!(
out.loss > 0.0,
"BCE of non-perfect prediction should be positive"
);
}
}