pub fn kl_div_with_logits_loss<T>(
    logits: T,
    target_probs: T::NoTape
) -> <T as Reduce<AllAxes>>::Reducedwhere
    T: Reduce<AllAxes> + Reduce<<<T as HasArrayType>::Array as HasLastAxis>::LastAxis>,
Expand description

KL Divergence loss. This computes (target_probs * (target_probs.log() - logits.log_softmax())).sum(-1).mean()

This will call log_softmax(logits), so make sure logits is not the output from softmax() or log_softmax() already.

Arguments

  • logits: The un-normalized output from a model. log_softmax() is called in this function
  • target_probs: Target containing probability vectors NOT class indices.

Example

let logits = Tensor1D::new([-1.0, -0.5]);
let target_probs = Tensor1D::new([0.5, 0.5]);
let loss = kl_div_with_logits_loss(logits.traced(), target_probs);