burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use burn::nn::loss::CrossEntropyLossConfig;
use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor};

pub fn language_model_loss<B: Backend>(
    logits: Tensor<B, 3>,
    targets: Tensor<B, 2, Int>,
) -> Tensor<B, 1> {
    let [batch, time, vocab] = logits.shape().dims();
    let logits_flat = logits.reshape([batch * time, vocab]);
    let targets_flat = targets.reshape([batch * time]);
    let device = logits_flat.device();
    CrossEntropyLossConfig::new()
        .init::<B>(&device)
        .forward(logits_flat, targets_flat)
}