Expand description
Standard loss functions such as mse_loss(), cross_entropy_with_logits_loss(), and more.
Functions
- Binary Cross Entropy With Logits in numerically stable way.
- Cross entropy loss. This computes:
-(logits.log_softmax() * target_probs).sum(-1).mean()
- Huber Loss uses absolute error when the error is higher than
beta
, and squared error when the error is lower thanbeta
. - KL Divergence loss. This computes
(target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()
- Mean absolute error. This computes
(pred - targ).abs().mean()
- Mean Squared Error. This computes
(pred - targ).square().mean()
. - Root Mean square error. This computes
(pred - targ).square().mean().sqrt()
- Smooth l1 loss (closely related to Huber Loss) uses absolute error when the error is higher than
beta
, and squared error when the error is lower thanbeta
.