use super::{NnResult, ReductionMode};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{
Array, Array1, Array2, ArrayView, ArrayView1, ArrayView2, Axis, ScalarOperand, Zip,
};
use scirs2_core::numeric::Float;
use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn mse_loss<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, y_pred has {}",
y_true.len(),
y_pred.len()
)));
}
let diff = y_pred - y_true;
let squared = &diff * &diff;
match reduction {
ReductionMode::None => {
Ok(squared[0])
}
ReductionMode::Mean => {
let sum = squared.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(squared.sum()),
}
}
pub fn mse_loss_2d<T>(
y_true: &ArrayView2<T>,
y_pred: &ArrayView2<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.shape() != y_pred.shape() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true shape {:?}, y_pred shape {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let diff = y_pred - y_true;
let squared = &diff * &diff;
match reduction {
ReductionMode::None => Ok(squared[[0, 0]]),
ReductionMode::Mean => {
let sum = squared.sum();
let n = T::from(squared.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(squared.sum()),
}
}
pub fn mae_loss<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, y_pred has {}",
y_true.len(),
y_pred.len()
)));
}
let diff = y_pred - y_true;
let abs_diff = diff.mapv(|x| x.abs());
match reduction {
ReductionMode::None => Ok(abs_diff[0]),
ReductionMode::Mean => {
let sum = abs_diff.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(abs_diff.sum()),
}
}
pub fn huber_loss<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
delta: T,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, y_pred has {}",
y_true.len(),
y_pred.len()
)));
}
if delta <= T::zero() {
return Err(NumRs2Error::InvalidOperation(
"Huber loss delta must be positive".to_string(),
));
}
let half = T::from(0.5)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert 0.5".to_string()))?;
let diff = y_pred - y_true;
let abs_diff = diff.mapv(|x| x.abs());
let loss = Zip::from(&abs_diff).and(&diff).map_collect(|&a, &d| {
if a <= delta {
half * d * d
} else {
delta * (a - half * delta)
}
});
match reduction {
ReductionMode::None => Ok(loss[0]),
ReductionMode::Mean => {
let sum = loss.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(loss.sum()),
}
}
pub fn smooth_l1_loss<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
huber_loss(y_true, y_pred, T::one(), reduction)
}
pub fn binary_cross_entropy<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, y_pred has {}",
y_true.len(),
y_pred.len()
)));
}
let eps = T::from(1e-7)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
let one = T::one();
let y_pred_clipped = y_pred.mapv(|p| p.max(eps).min(one - eps));
let loss = Zip::from(y_true)
.and(&y_pred_clipped)
.map_collect(|&y, &p| -(y * p.ln() + (one - y) * (one - p).ln()));
match reduction {
ReductionMode::None => Ok(loss[0]),
ReductionMode::Mean => {
let sum = loss.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(loss.sum()),
}
}
pub fn bce_with_logits<T>(
y_true: &ArrayView1<T>,
logits: &ArrayView1<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != logits.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, logits has {}",
y_true.len(),
logits.len()
)));
}
let one = T::one();
let loss = Zip::from(y_true).and(logits).map_collect(|&y, &x| {
let max_val = x.max(T::zero());
max_val - x * y + (one + (-x.abs()).exp()).ln()
});
match reduction {
ReductionMode::None => Ok(loss[0]),
ReductionMode::Mean => {
let sum = loss.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(loss.sum()),
}
}
pub fn categorical_cross_entropy<T>(
y_true: &ArrayView2<T>,
y_pred: &ArrayView2<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.shape() != y_pred.shape() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true shape {:?}, y_pred shape {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let eps = T::from(1e-7)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
let y_pred_clipped = y_pred.mapv(|p| p.max(eps));
let loss_per_element = y_true * &y_pred_clipped.mapv(|p| p.ln());
let loss = -loss_per_element.sum_axis(Axis(1));
match reduction {
ReductionMode::None => Ok(loss[0]),
ReductionMode::Mean => {
let sum = loss.sum();
let n = T::from(y_true.nrows()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert batch size".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(loss.sum()),
}
}
pub fn sparse_categorical_cross_entropy<T>(
y_true: &[usize],
y_pred: &ArrayView2<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.nrows() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Batch size mismatch: y_true has {} samples, y_pred has {}",
y_true.len(),
y_pred.nrows()
)));
}
let eps = T::from(1e-7)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
let num_classes = y_pred.ncols();
let mut losses = Vec::with_capacity(y_true.len());
for (i, &class_idx) in y_true.iter().enumerate() {
if class_idx >= num_classes {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Class index {} out of bounds for {} classes",
class_idx, num_classes
)));
}
let pred_prob = y_pred[[i, class_idx]];
let clipped_prob = pred_prob.max(eps);
losses.push(-clipped_prob.ln());
}
match reduction {
ReductionMode::None => Ok(losses[0]),
ReductionMode::Mean => {
let sum = losses.iter().fold(T::zero(), |acc, &x| acc + x);
let n = T::from(losses.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(losses.iter().fold(T::zero(), |acc, &x| acc + x)),
}
}
pub fn nll_loss<T>(
y_true: &[usize],
log_probs: &ArrayView2<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != log_probs.nrows() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Batch size mismatch: y_true has {} samples, log_probs has {}",
y_true.len(),
log_probs.nrows()
)));
}
let num_classes = log_probs.ncols();
let mut losses = Vec::with_capacity(y_true.len());
for (i, &class_idx) in y_true.iter().enumerate() {
if class_idx >= num_classes {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Class index {} out of bounds for {} classes",
class_idx, num_classes
)));
}
losses.push(-log_probs[[i, class_idx]]);
}
match reduction {
ReductionMode::None => Ok(losses[0]),
ReductionMode::Mean => {
let sum = losses.iter().fold(T::zero(), |acc, &x| acc + x);
let n = T::from(losses.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(losses.iter().fold(T::zero(), |acc, &x| acc + x)),
}
}
pub fn kl_div_loss<T>(p: &ArrayView1<T>, q: &ArrayView1<T>, reduction: ReductionMode) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: p has {} elements, q has {}",
p.len(),
q.len()
)));
}
let eps = T::from(1e-10)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
let kl = Zip::from(p).and(q).map_collect(|&p_val, &q_val| {
if p_val > eps {
p_val * ((p_val + eps) / (q_val + eps)).ln()
} else {
T::zero()
}
});
match reduction {
ReductionMode::None => Ok(kl[0]),
ReductionMode::Mean => {
let sum = kl.sum();
let n = T::from(p.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(kl.sum()),
}
}
pub fn hinge_loss<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, y_pred has {}",
y_true.len(),
y_pred.len()
)));
}
let one = T::one();
let zero = T::zero();
let loss = Zip::from(y_true)
.and(y_pred)
.map_collect(|&y, &pred| (one - y * pred).max(zero));
match reduction {
ReductionMode::None => Ok(loss[0]),
ReductionMode::Mean => {
let sum = loss.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(loss.sum()),
}
}
pub fn cosine_embedding_loss<T>(
x1: &ArrayView1<T>,
x2: &ArrayView1<T>,
y: T,
margin: T,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if x1.len() != x2.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Embedding dimension mismatch: x1 has {}, x2 has {}",
x1.len(),
x2.len()
)));
}
let eps = T::from(1e-8)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
let dot_product = Zip::from(x1)
.and(x2)
.fold(T::zero(), |acc, &a, &b| acc + a * b);
let norm1 = x1.mapv(|v| v * v).sum().sqrt();
let norm2 = x2.mapv(|v| v * v).sum().sqrt();
let cos_sim = dot_product / ((norm1 * norm2) + eps);
let one = T::one();
let zero = T::zero();
if y == one {
Ok(one - cos_sim)
} else {
Ok((cos_sim - margin).max(zero))
}
}
pub fn focal_loss<T>(
y_true: &ArrayView1<T>,
y_pred: &ArrayView1<T>,
alpha: T,
gamma: T,
reduction: ReductionMode,
) -> NnResult<T>
where
T: Float + SimdUnifiedOps,
{
if y_true.len() != y_pred.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Shape mismatch: y_true has {} elements, y_pred has {}",
y_true.len(),
y_pred.len()
)));
}
let eps = T::from(1e-7)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
let one = T::one();
let y_pred_clipped = y_pred.mapv(|p| p.max(eps).min(one - eps));
let loss = Zip::from(y_true)
.and(&y_pred_clipped)
.map_collect(|&y, &p| {
let p_t = if y == one { p } else { one - p };
let alpha_t = if y == one { alpha } else { one - alpha };
-alpha_t * (one - p_t).powf(gamma) * p_t.ln()
});
match reduction {
ReductionMode::None => Ok(loss[0]),
ReductionMode::Mean => {
let sum = loss.sum();
let n = T::from(y_true.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert length".to_string())
})?;
Ok(sum / n)
}
ReductionMode::Sum => Ok(loss.sum()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mse_loss() {
let y_true = array![1.0, 2.0, 3.0];
let y_pred = array![1.1, 2.1, 3.1];
let loss = mse_loss(&y_true.view(), &y_pred.view(), ReductionMode::Mean).unwrap();
assert_abs_diff_eq!(loss, 0.01, epsilon = 1e-6);
}
#[test]
fn test_mae_loss() {
let y_true = array![1.0, 2.0, 3.0];
let y_pred = array![1.1, 2.1, 3.1];
let loss = mae_loss(&y_true.view(), &y_pred.view(), ReductionMode::Mean).unwrap();
assert_abs_diff_eq!(loss, 0.1, epsilon = 1e-6);
}
#[test]
fn test_huber_loss() {
let y_true = array![0.0, 0.0, 0.0];
let y_pred = array![0.5, 1.0, 2.0];
let loss = huber_loss(&y_true.view(), &y_pred.view(), 1.0, ReductionMode::Mean).unwrap();
assert_abs_diff_eq!(loss, 0.708333, epsilon = 1e-5);
}
#[test]
fn test_binary_cross_entropy() {
let y_true = array![1.0, 0.0, 1.0];
let y_pred = array![0.9, 0.1, 0.8];
let loss =
binary_cross_entropy(&y_true.view(), &y_pred.view(), ReductionMode::Mean).unwrap();
assert!(loss < 0.2);
assert!(loss > 0.0);
}
#[test]
fn test_mse_shape_mismatch() {
let y_true = array![1.0, 2.0];
let y_pred = array![1.0, 2.0, 3.0];
let result = mse_loss(&y_true.view(), &y_pred.view(), ReductionMode::Mean);
assert!(result.is_err());
}
}