Expand description

Standard loss functions such as mse_loss(), cross_entropy_with_logits_loss(), and more.

Functions

Binary Cross Entropy With Logits in numerically stable way.

Cross entropy loss. This computes: -(logits.log_softmax() * target_probs).sum(-1).mean()

Huber Loss uses absolute error when the error is higher than beta, and squared error when the error is lower than beta.

KL Divergence loss. This computes (target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()

Mean absolute error. This computes (&targ - pred).abs().mean()

Mean Squared Error. This computes (&targ - pred).square().mean().

Root Mean square error. This computes (&targ - pred).square().mean().sqrt()

Smooth l1 loss (closely related to Huber Loss) uses absolute error when the error is higher than beta, and squared error when the error is lower than beta.