Function dfdx::losses::kl_div_with_logits_loss
source · pub fn kl_div_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
logits: Tensor<S, E, D, T>,
target_probs: Tensor<S, E, D>
) -> Tensor<Rank0, E, D, T>
Expand description
KL Divergence loss.
This computes (target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()
This will call log_softmax(logits)
, so make sure logits is not the
output from softmax() or log_softmax() already.
Arguments
logits
: The un-normalized output from a model. log_softmax() is called in this functiontarget_probs
: Target containing probability vectors NOT class indices.