use crate::Tensor;
use ndarray::Array1;
use super::LossFn;
pub struct WeightedLoss {
inner: Box<dyn LossFn>,
weight: f32,
}
impl WeightedLoss {
pub fn new(inner: Box<dyn LossFn>, weight: f32) -> Self {
Self { inner, weight }
}
pub fn unweighted(inner: Box<dyn LossFn>) -> Self {
Self::new(inner, 1.0)
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn set_weight(&mut self, weight: f32) {
self.weight = weight;
}
}
impl LossFn for WeightedLoss {
fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
let inner_loss = self.inner.forward(predictions, targets);
if (self.weight - 1.0).abs() < 1e-7 {
return inner_loss;
}
let weighted_val = inner_loss.data()[0] * self.weight;
let mut weighted_loss = Tensor::from_vec(vec![weighted_val], true);
use crate::autograd::BackwardOp;
use std::rc::Rc;
struct WeightedBackward {
inner_backward: Option<Rc<dyn BackwardOp>>,
#[allow(dead_code)]
weight: f32, }
impl BackwardOp for WeightedBackward {
fn backward(&self) {
if let Some(ref inner) = self.inner_backward {
inner.backward();
}
}
}
if predictions.requires_grad() {
weighted_loss.set_backward_op(Rc::new(WeightedBackward {
inner_backward: inner_loss.backward_op(),
weight: self.weight,
}));
}
weighted_loss
}
fn name(&self) -> &'static str {
"Weighted"
}
}
pub struct SampleWeightedLoss {
#[allow(dead_code)]
inner: Box<dyn LossFn>, }
impl SampleWeightedLoss {
pub fn new(inner: Box<dyn LossFn>) -> Self {
Self { inner }
}
pub fn forward_weighted(
&self,
predictions: &Tensor,
targets: &Tensor,
weights: &[f32],
) -> Tensor {
assert_eq!(predictions.len(), weights.len(), "Weights must match predictions length");
let diff = predictions.data() - targets.data();
let n = predictions.len() as f32;
let weighted_loss: f32 =
diff.iter().zip(weights.iter()).map(|(&d, &w)| w * d * d).sum::<f32>() / n;
let mut loss = Tensor::from_vec(vec![weighted_loss], true);
let grad: Array1<f32> =
diff.iter().zip(weights.iter()).map(|(&d, &w)| 2.0 * w * d / n).collect();
use crate::autograd::BackwardOp;
use std::rc::Rc;
struct SampleWeightedBackward {
pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
grad: Array1<f32>,
}
impl BackwardOp for SampleWeightedBackward {
fn backward(&self) {
let mut pred_grad = self.pred_grad_cell.borrow_mut();
if let Some(existing) = pred_grad.as_mut() {
*existing = &*existing + &self.grad;
} else {
*pred_grad = Some(self.grad.clone());
}
}
}
if predictions.requires_grad() {
loss.set_backward_op(Rc::new(SampleWeightedBackward {
pred_grad_cell: predictions.grad_cell(),
grad,
}));
}
loss
}
}
impl LossFn for SampleWeightedLoss {
fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
let weights = vec![1.0; predictions.len()];
self.forward_weighted(predictions, targets, &weights)
}
fn name(&self) -> &'static str {
"SampleWeighted"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::train::MSELoss;
use approx::assert_relative_eq;
#[test]
fn test_weighted_loss_scales_value() {
let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
let unweighted = MSELoss;
let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
let weighted = loss_fn.forward(&pred, &target);
let base = unweighted.forward(&pred.clone(), &target);
assert_relative_eq!(weighted.data()[0], base.data()[0] * 1.5, epsilon = 1e-5);
}
#[test]
fn test_weighted_loss_unit_weight() {
let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.0);
let unweighted = MSELoss;
let pred = Tensor::from_vec(vec![1.0, 2.0], true);
let target = Tensor::from_vec(vec![1.5, 2.5], false);
let weighted = loss_fn.forward(&pred, &target);
let base = unweighted.forward(&pred.clone(), &target);
assert_relative_eq!(weighted.data()[0], base.data()[0], epsilon = 1e-5);
}
#[test]
fn test_weighted_loss_zero_weight() {
let loss_fn = WeightedLoss::new(Box::new(MSELoss), 0.0);
let pred = Tensor::from_vec(vec![1.0, 2.0], true);
let target = Tensor::from_vec(vec![10.0, 20.0], false);
let loss = loss_fn.forward(&pred, &target);
assert_relative_eq!(loss.data()[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_weighted_loss_methods() {
let mut loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
assert_eq!(loss_fn.weight(), 1.5);
assert_eq!(loss_fn.name(), "Weighted");
loss_fn.set_weight(2.0);
assert_eq!(loss_fn.weight(), 2.0);
}
#[test]
fn test_weighted_loss_unweighted() {
let loss_fn = WeightedLoss::unweighted(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![1.0, 2.0], true);
let target = Tensor::from_vec(vec![1.5, 2.5], false);
let loss = loss_fn.forward(&pred, &target);
assert_eq!(loss_fn.weight(), 1.0);
assert!(loss.data()[0] > 0.0);
}
#[test]
fn test_weighted_no_grad() {
let loss_fn = WeightedLoss::new(Box::new(MSELoss), 1.5);
let pred = Tensor::from_vec(vec![1.0, 2.0], false);
let target = Tensor::from_vec(vec![1.5, 2.5], false);
let loss = loss_fn.forward(&pred, &target);
assert!(loss.data()[0] > 0.0);
}
#[test]
fn test_weighted_backward_with_grad() {
let loss_fn = WeightedLoss::new(Box::new(MSELoss), 2.0);
let pred = Tensor::from_vec(vec![1.0, 2.0], true);
let target = Tensor::from_vec(vec![0.0, 0.0], false);
let loss = loss_fn.forward(&pred, &target);
if let Some(backward_op) = loss.backward_op() {
backward_op.backward();
}
let grad = pred.grad();
assert!(grad.is_some());
}
#[test]
fn test_sample_weighted_loss_uniform() {
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
let loss = loss_fn.forward(&pred, &target);
let mse_loss = MSELoss.forward(&pred.clone(), &target);
assert_relative_eq!(loss.data()[0], mse_loss.data()[0], epsilon = 1e-5);
}
#[test]
fn test_sample_weighted_loss_custom_weights() {
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![0.0, 0.0], true);
let target = Tensor::from_vec(vec![1.0, 1.0], false);
let weights = vec![2.0, 0.0];
let loss = loss_fn.forward_weighted(&pred, &target, &weights);
assert_relative_eq!(loss.data()[0], 1.0, epsilon = 1e-5);
}
#[test]
fn test_sample_weighted_loss_gradient() {
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![0.0, 0.0], true);
let target = Tensor::from_vec(vec![1.0, 1.0], false);
let weights = vec![2.0, 1.0];
let loss = loss_fn.forward_weighted(&pred, &target, &weights);
if let Some(backward_op) = loss.backward_op() {
backward_op.backward();
}
let grad = pred.grad().expect("gradient should be available");
assert_relative_eq!(grad[0], -2.0, epsilon = 1e-5);
assert_relative_eq!(grad[1], -1.0, epsilon = 1e-5);
}
#[test]
fn test_sample_weighted_citl_reweight() {
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![0.0, 0.0, 0.0], true);
let target = Tensor::from_vec(vec![1.0, 1.0, 1.0], false);
let weights = vec![1.5, 1.5, 1.0];
let weighted_loss = loss_fn.forward_weighted(&pred, &target, &weights);
let uniform = loss_fn.forward(&pred.clone(), &target);
assert!(weighted_loss.data()[0] > uniform.data()[0]);
}
#[test]
fn test_sample_weighted_no_grad() {
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![1.0, 2.0], false);
let target = Tensor::from_vec(vec![1.5, 2.5], false);
let weights = vec![1.0, 2.0];
let loss = loss_fn.forward_weighted(&pred, &target, &weights);
assert!(loss.data()[0] > 0.0);
}
#[test]
#[should_panic(expected = "Weights must match")]
fn test_sample_weighted_mismatched_weights() {
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
let weights = vec![1.0, 1.0]; loss_fn.forward_weighted(&pred, &target, &weights);
}
#[test]
fn test_gradient_accumulation_sample_weighted() {
let pred = Tensor::from_vec(vec![1.0, 2.0], true);
let target = Tensor::from_vec(vec![0.0, 0.0], false);
let weights = vec![1.0, 1.5];
let loss_fn = SampleWeightedLoss::new(Box::new(MSELoss));
let loss1 = loss_fn.forward_weighted(&pred, &target, &weights);
if let Some(op) = loss1.backward_op() {
op.backward();
}
let loss2 = loss_fn.forward_weighted(&pred, &target, &weights);
if let Some(op) = loss2.backward_op() {
op.backward();
}
let grad = pred.grad().expect("gradient should be available");
assert!(grad[0].is_finite());
assert!(grad[1].is_finite());
}
}