Skip to main content

entrenar/train/loss/
traits.rs

1//! Loss function trait
2
3use crate::Tensor;
4
5/// Trait for loss functions
6pub trait LossFn {
7    /// Compute loss given predictions and targets
8    ///
9    /// Returns a scalar loss value and sets up gradients for backpropagation
10    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor;
11
12    /// Name of the loss function
13    fn name(&self) -> &str;
14}