use crate::{FloatElement, Tensor, TensorElement};
use torsh_core::error::{Result, TorshError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Reduction {
None,
Mean,
Sum,
}
impl Default for Reduction {
fn default() -> Self {
Reduction::Mean
}
}
impl<T: FloatElement> Tensor<T> {
pub fn mse_loss(&self, target: &Self) -> Result<Self> {
self.mse_loss_with_reduction(target, Reduction::Mean)
}
pub fn mse_loss_with_reduction(&self, target: &Self, reduction: Reduction) -> Result<Self> {
if self.shape() != target.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: target.shape().dims().to_vec(),
});
}
let self_data = self.data()?;
let target_data = target.data()?;
let squared_errors: Vec<T> = self_data
.iter()
.zip(target_data.iter())
.map(|(&pred, &targ)| {
let diff = pred - targ;
diff * diff
})
.collect();
let loss_tensor = Self::from_data(
squared_errors,
self.shape().dims().to_vec(),
self.device,
)?;
apply_reduction(&loss_tensor, reduction)
}
pub fn l1_loss(&self, target: &Self) -> Result<Self> {
self.l1_loss_with_reduction(target, Reduction::Mean)
}
pub fn l1_loss_with_reduction(&self, target: &Self, reduction: Reduction) -> Result<Self> {
if self.shape() != target.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: target.shape().dims().to_vec(),
});
}
let self_data = self.data()?;
let target_data = target.data()?;
let abs_errors: Vec<T> = self_data
.iter()
.zip(target_data.iter())
.map(|(&pred, &targ)| {
let diff = pred - targ;
if diff >= <T as TensorElement>::zero() { diff } else { -diff }
})
.collect();
let loss_tensor = Self::from_data(
abs_errors,
self.shape().dims().to_vec(),
self.device,
)?;
apply_reduction(&loss_tensor, reduction)
}
pub fn huber_loss(&self, target: &Self, delta: f64) -> Result<Self> {
self.huber_loss_with_reduction(target, delta, Reduction::Mean)
}
pub fn huber_loss_with_reduction(&self, target: &Self, delta: f64, reduction: Reduction) -> Result<Self> {
if self.shape() != target.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: target.shape().dims().to_vec(),
});
}
let delta_t = T::from_f64(delta).unwrap_or_else(|| <T as TensorElement>::one());
let half = T::from_f64(0.5).unwrap_or_else(|| <T as TensorElement>::zero());
let self_data = self.data()?;
let target_data = target.data()?;
let huber_losses: Vec<T> = self_data
.iter()
.zip(target_data.iter())
.map(|(&pred, &targ)| {
let diff = pred - targ;
let abs_diff = if diff >= <T as TensorElement>::zero() { diff } else { -diff };
if abs_diff < delta_t {
half * diff * diff
} else {
delta_t * (abs_diff - half * delta_t)
}
})
.collect();
let loss_tensor = Self::from_data(
huber_losses,
self.shape().dims().to_vec(),
self.device,
)?;
apply_reduction(&loss_tensor, reduction)
}
pub fn bce_loss(&self, target: &Self) -> Result<Self> {
self.bce_loss_with_reduction(target, Reduction::Mean)
}
pub fn bce_loss_with_reduction(&self, target: &Self, reduction: Reduction) -> Result<Self> {
if self.shape() != target.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: target.shape().dims().to_vec(),
});
}
let self_data = self.data()?;
let target_data = target.data()?;
let one = <T as TensorElement>::one();
let eps = T::from_f64(1e-8).unwrap_or_else(|| T::from_f64(1e-7).expect("f64 conversion should succeed"));
let bce_losses: Vec<T> = self_data
.iter()
.zip(target_data.iter())
.map(|(&pred, &targ)| {
let pred_clamped = if pred < eps {
eps
} else if pred > one - eps {
one - eps
} else {
pred
};
let log_pred = pred_clamped.ln();
let log_one_minus_pred = (one - pred_clamped).ln();
-(targ * log_pred + (one - targ) * log_one_minus_pred)
})
.collect();
let loss_tensor = Self::from_data(
bce_losses,
self.shape().dims().to_vec(),
self.device,
)?;
apply_reduction(&loss_tensor, reduction)
}
pub fn nll_loss(&self, target: &Tensor<i64>) -> Result<Self> {
self.nll_loss_with_reduction(target, Reduction::Mean)
}
pub fn nll_loss_with_reduction(&self, target: &Tensor<i64>, reduction: Reduction) -> Result<Self> {
let self_shape = self.shape();
let target_shape = target.shape();
if self_shape.ndim() != 2 {
return Err(TorshError::InvalidShape(
"NLL loss expects 2D input tensor (N, C)".to_string()
));
}
if target_shape.ndim() != 1 {
return Err(TorshError::InvalidShape(
"NLL loss expects 1D target tensor (N,)".to_string()
));
}
let batch_size = self_shape.dims()[0];
let num_classes = self_shape.dims()[1];
if target_shape.dims()[0] != batch_size {
return Err(TorshError::ShapeMismatch {
expected: vec![batch_size],
got: target_shape.dims().to_vec(),
});
}
let self_data = self.data()?;
let target_data = target.data()?;
let mut losses = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = target_data[i];
if target_class < 0 || target_class as usize >= num_classes {
return Err(TorshError::InvalidArgument(
format!("Target class {} out of range [0, {})", target_class, num_classes)
));
}
let log_prob_idx = i * num_classes + target_class as usize;
let log_prob = self_data[log_prob_idx];
losses.push(-log_prob);
}
let loss_tensor = Self::from_data(
losses,
vec![batch_size],
self.device,
)?;
apply_reduction(&loss_tensor, reduction)
}
pub fn cross_entropy(&self, target: &Tensor<i64>) -> Result<Self> {
self.cross_entropy_with_reduction(target, Reduction::Mean)
}
pub fn cross_entropy_with_reduction(&self, target: &Tensor<i64>, reduction: Reduction) -> Result<Self> {
let log_probs = self.log_softmax(-1)?;
log_probs.nll_loss_with_reduction(target, reduction)
}
}
fn apply_reduction<T: FloatElement>(loss_tensor: &Tensor<T>, reduction: Reduction) -> Result<Tensor<T>> {
match reduction {
Reduction::None => Ok(loss_tensor.clone()),
Reduction::Mean => {
let data = loss_tensor.data()?;
let sum: T = data.iter().fold(<T as TensorElement>::zero(), |acc, &x| acc + x);
let count = T::from_f64(data.len() as f64).unwrap_or_else(|| <T as TensorElement>::one());
let mean = sum / count;
Tensor::from_data(vec![mean], vec![1], loss_tensor.device)
},
Reduction::Sum => {
let data = loss_tensor.data()?;
let sum: T = data.iter().fold(<T as TensorElement>::zero(), |acc, &x| acc + x);
Tensor::from_data(vec![sum], vec![1], loss_tensor.device)
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_mse_loss() {
let predictions = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let targets = Tensor::from_data(vec![1.5f32, 2.5, 2.5], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let loss = predictions.mse_loss(&targets).expect("mse_loss failed");
let loss_data = loss.data().expect("data retrieval failed");
assert!((loss_data[0] - 0.25).abs() < 1e-6);
}
#[test]
fn test_l1_loss() {
let predictions = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let targets = Tensor::from_data(vec![1.5f32, 2.5, 2.5], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let loss = predictions.l1_loss(&targets).expect("l1_loss failed");
let loss_data = loss.data().expect("data retrieval failed");
assert!((loss_data[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_huber_loss() {
let predictions = Tensor::from_data(vec![1.0f32, 2.0, 5.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let targets = Tensor::from_data(vec![1.5f32, 2.5, 2.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let loss = predictions.huber_loss(&targets, 1.0).expect("huber_loss failed");
let loss_data = loss.data().expect("data retrieval failed");
assert!((loss_data[0] - 0.91666667).abs() < 1e-6);
}
#[test]
fn test_bce_loss() {
let predictions = Tensor::from_data(vec![0.8f32, 0.2, 0.9], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let targets = Tensor::from_data(vec![1.0f32, 0.0, 1.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let loss = predictions.bce_loss(&targets).expect("bce_loss failed");
let loss_data = loss.data().expect("data retrieval failed");
assert!(loss_data[0] > 0.0);
assert!(loss_data[0].is_finite());
}
#[test]
fn test_nll_loss() {
let log_probs = Tensor::from_data(
vec![-0.5f32, -1.0, -2.0, -1.5, -0.3, -3.0],
vec![2, 3],
DeviceType::Cpu
).expect("tensor creation failed");
let targets = Tensor::from_data(vec![0i64, 1], vec![2], DeviceType::Cpu).expect("tensor creation failed");
let loss = log_probs.nll_loss(&targets).expect("nll_loss failed");
let loss_data = loss.data().expect("data retrieval failed");
assert!((loss_data[0] - 0.4).abs() < 1e-6);
}
#[test]
fn test_cross_entropy() {
let logits = Tensor::from_data(
vec![1.0f32, 2.0, 0.5, 0.8, 3.0, 0.2],
vec![2, 3],
DeviceType::Cpu
).expect("tensor creation failed");
let targets = Tensor::from_data(vec![1i64, 1], vec![2], DeviceType::Cpu).expect("tensor creation failed");
let loss = logits.cross_entropy(&targets).expect("cross_entropy failed");
let loss_data = loss.data().expect("data retrieval failed");
assert!(loss_data[0] > 0.0);
assert!(loss_data[0].is_finite());
}
#[test]
fn test_reduction_modes() {
let predictions = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let targets = Tensor::from_data(vec![1.5f32, 2.5, 2.5], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let loss_none = predictions.mse_loss_with_reduction(&targets, Reduction::None).expect("mse_loss_with_reduction failed");
let loss_none_data = loss_none.data().expect("data retrieval failed");
assert_eq!(loss_none_data.len(), 3);
let loss_sum = predictions.mse_loss_with_reduction(&targets, Reduction::Sum).expect("mse_loss_with_reduction failed");
let loss_sum_data = loss_sum.data().expect("data retrieval failed");
assert_eq!(loss_sum_data.len(), 1);
let loss_mean = predictions.mse_loss(&targets).expect("mse_loss failed");
let loss_mean_data = loss_mean.data().expect("data retrieval failed");
assert_eq!(loss_mean_data.len(), 1);
assert!((loss_mean_data[0] * 3.0 - loss_sum_data[0]).abs() < 1e-6);
}
}