entrenar/train/loss/traits.rs
1//! Loss function trait
2
3use crate::Tensor;
4
5/// Trait for loss functions
6pub trait LossFn {
7 /// Compute loss given predictions and targets
8 ///
9 /// Returns a scalar loss value and sets up gradients for backpropagation
10 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor;
11
12 /// Name of the loss function
13 fn name(&self) -> &str;
14}