Skip to main content

data2vec_batch_loss

Function data2vec_batch_loss 

Source
pub fn data2vec_batch_loss(
    student_preds: &[f32],
    teacher_reprs: &[f32],
    masks: &[bool],
    batch_size: usize,
    n_tokens: usize,
    dim: usize,
    config: &Data2VecConfig,
) -> SslResult<f32>
Expand description

Compute the mean data2vec loss over a batch of samples.

Buffers are laid out batch-first:

  • student_preds : [batch_size × n_tokens × dim]
  • teacher_reprs : [batch_size × n_tokens × dim]
  • masks : [batch_size × n_tokens] boolean

Each sample’s loss is computed independently with data2vec_loss and the results are averaged.

§Errors

Propagates all errors from data2vec_loss together with additional shape checks for the batch dimension.