Module dfdx::losses

source · []
Expand description

Standard loss functions such as mse, mae, cross entropy, and more.

Functions

Cross entropy loss. This computes: -(logits.log_softmax() * target_probs).sum(-1).mean()

KL Divergence loss. This computes (target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()

Mean absolute error. This computes (&targ - pred).abs().mean()

Mean Squared Error. This computes (&targ - pred).square().mean().

Root Mean square error. This computes (&targ - pred).square().mean().sqrt()