Skip to main content

data2vec_loss

Function data2vec_loss 

Source
pub fn data2vec_loss(
    student_pred: &[f32],
    teacher_repr: &[f32],
    mask: &[bool],
    n_tokens: usize,
    dim: usize,
    config: &Data2VecConfig,
) -> SslResult<Data2VecResult>
Expand description

Compute the data2vec loss for a single sample.

Implements the full algorithm:

  1. Validate shapes.
  2. Optionally L2-normalise teacher representations across masked tokens per feature dimension.
  3. Compute mean Huber loss between student predictions and normalised teacher targets at masked positions only.

§Arguments

  • student_pred[n_tokens × dim] student predictions (row-major).
  • teacher_repr[n_tokens × dim] teacher representations (row-major).
  • mask[n_tokens] boolean vector; true = masked position.
  • n_tokens, dim — spatial and channel dimensions.
  • config — data2vec hyper-parameters.

§Errors