Function dfdx::losses::cross_entropy_with_logits_loss
source · [−]pub fn cross_entropy_with_logits_loss<T>(
logits: T,
target_probs: T::NoTape
) -> <T as Reduce<AllAxes>>::Reducedwhere
T: Reduce<AllAxes> + Reduce<<<T as HasArrayType>::Array as HasLastAxis>::LastAxis>,
Expand description
Cross entropy loss.
This computes: -(logits.log_softmax() * target_probs).sum(-1).mean()
This will call log_softmax(logits)
, so make sure logits is not the
output from softmax() or log_softmax() already.
Arguments
logits
: The un-normalized output from a model. log_softmax() is called in this functiontarget_probs
: Target containing probability vectors NOT class indices.
Example
let logits = Tensor1D::new([-1.0, -0.5]);
let target_probs = Tensor1D::new([0.5, 0.5]);
let loss = cross_entropy_with_logits_loss(logits.traced(), target_probs);