Expand description
Standard loss functions such as mse_loss(), cross_entropy_with_logits_loss(), and more.
Functions
Binary Cross Entropy With Logits in numerically stable way.
Cross entropy loss.
This computes:
-(logits.log_softmax() * target_probs).sum(-1).mean()
Huber Loss
uses absolute error when the error is higher than
beta
, and squared error when the
error is lower than beta
.KL Divergence loss.
This computes
(target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()
Mean absolute error.
This computes
(pred - targ).abs().mean()
Mean Squared Error.
This computes
(pred - targ).square().mean()
.Root Mean square error.
This computes
(pred - targ).square().mean().sqrt()
Smooth l1 loss (closely related to Huber Loss)
uses absolute error when the error is higher than
beta
, and squared error when the
error is lower than beta
.