Expand description

Standard loss functions such as mse_loss(), cross_entropy_with_logits_loss(), and more.

Functions

Binary Cross Entropy With Logits in numerically stable way.
Cross entropy loss. This computes: -(logits.log_softmax() * target_probs).sum(-1).mean()
Huber Loss uses absolute error when the error is higher than beta, and squared error when the error is lower than beta.
KL Divergence loss. This computes (target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()
Mean absolute error. This computes (pred - targ).abs().mean()
Mean Squared Error. This computes (pred - targ).square().mean().
Root Mean square error. This computes (pred - targ).square().mean().sqrt()
Smooth l1 loss (closely related to Huber Loss) uses absolute error when the error is higher than beta, and squared error when the error is lower than beta.