pub fn data2vec_loss(
student_pred: &[f32],
teacher_repr: &[f32],
mask: &[bool],
n_tokens: usize,
dim: usize,
config: &Data2VecConfig,
) -> SslResult<Data2VecResult>Expand description
Compute the data2vec loss for a single sample.
Implements the full algorithm:
- Validate shapes.
- Optionally L2-normalise teacher representations across masked tokens per feature dimension.
- Compute mean Huber loss between student predictions and normalised teacher targets at masked positions only.
§Arguments
student_pred—[n_tokens × dim]student predictions (row-major).teacher_repr—[n_tokens × dim]teacher representations (row-major).mask—[n_tokens]boolean vector;true= masked position.n_tokens,dim— spatial and channel dimensions.config— data2vec hyper-parameters.
§Errors
SslError::EmptyInputwhenn_tokens == 0ordim == 0.SslError::DimensionMismatchwhen any buffer has the wrong length.SslError::EmptyInputwhen no tokens are masked (graceful; returns 0.0 loss insideData2VecResultrather than erroring, since the caller may legitimately supply an all-visible batch during warm-up).