pub fn data2vec_batch_loss(
student_preds: &[f32],
teacher_reprs: &[f32],
masks: &[bool],
batch_size: usize,
n_tokens: usize,
dim: usize,
config: &Data2VecConfig,
) -> SslResult<f32>Expand description
Compute the mean data2vec loss over a batch of samples.
Buffers are laid out batch-first:
student_preds:[batch_size × n_tokens × dim]teacher_reprs:[batch_size × n_tokens × dim]masks:[batch_size × n_tokens]boolean
Each sample’s loss is computed independently with
data2vec_loss and the results are averaged.
§Errors
Propagates all errors from data2vec_loss together with additional shape
checks for the batch dimension.