Skip to main content

burn_tensor/tensor/loss/
mod.rs

1use crate::backend::Backend;
2use crate::{Tensor, activation};
3
4/// Computes the log softmax cross entropy between logits and target probabilities.
5///
6/// # Arguments
7///
8/// * `logits` - The logits.
9/// * `target_probs` - The target probabilities.
10///
11/// # Returns
12///
13/// The log softmax cross entropy.
14pub fn cross_entropy_with_logits<B: Backend, const D: usize>(
15    logits: Tensor<B, D>,
16    target_probs: Tensor<B, D>,
17) -> Tensor<B, 1> {
18    let tensor = activation::log_softmax(logits, D - 1);
19    let tensor = tensor.mul(target_probs);
20    let tensor = tensor.sum_dim(D - 1);
21
22    tensor.mean().neg()
23}