pub struct Data2VecModel { /* private fields */ }Expand description
Struct-based Data2Vec model that owns student encoder layers and a teacher EMA state.
Student weights per layer are stored as:
student_w[l]—[d_model × d_model]weight matrix (row-major).student_b[l]—[d_model]bias vector.
Teacher weights are stored flat in Data2VecState::teacher_params in the
same sequential layout (w0, b0, w1, b1, …).
Implementations§
Source§impl Data2VecModel
impl Data2VecModel
Sourcepub fn new(config: Data2VecModelConfig, rng: &mut LcgRng) -> SslResult<Self>
pub fn new(config: Data2VecModelConfig, rng: &mut LcgRng) -> SslResult<Self>
Create a new Data2VecModel with Kaiming-initialised student layers and
a teacher state cloned from the initial student parameters.
§Errors
SslError::InvalidParameterwhend_model == 0.SslError::InvalidParameterwhenn_layers == 0.
Sourcepub fn encode_student(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>>
pub fn encode_student(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>>
Encode a sequence of patch embeddings with the student encoder.
Each layer applies: token ← ReLU(W · token + b) independently per token.
§Arguments
x— flat[n_patches × d_model]input (row-major).n_patches— number of tokens in the sequence.
§Errors
SslError::DimensionMismatch when x.len() != n_patches * d_model.
Sourcepub fn encode_teacher(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>>
pub fn encode_teacher(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>>
Encode a sequence of patch embeddings with the teacher encoder.
Uses the teacher weight matrices stored in Data2VecState::teacher_params.
§Arguments
x— flat[n_patches × d_model]input (row-major).n_patches— number of tokens in the sequence.
§Errors
SslError::DimensionMismatch when x.len() != n_patches * d_model.
Sourcepub fn loss(&self, x: &[f32], mask: &[bool], n_patches: usize) -> SslResult<f32>
pub fn loss(&self, x: &[f32], mask: &[bool], n_patches: usize) -> SslResult<f32>
Compute the Data2Vec loss for a masked input.
- Encodes with the student encoder →
student_repr [n_patches × d_model]. - Encodes with the teacher encoder →
teacher_repr [n_patches × d_model]. - Computes Huber loss at masked positions via
data2vec_loss.
§Arguments
x— flat[n_patches × d_model]input.mask—[n_patches]boolean;true= masked position.n_patches— number of tokens.
§Errors
Propagates dimension and config errors.
Sourcepub fn ema_update(&mut self) -> SslResult<()>
pub fn ema_update(&mut self) -> SslResult<()>
Apply the EMA update: teacher ← ema_decay · teacher + (1 − ema_decay) · student.
§Errors
SslError::InvalidMomentumwhenema_decayis not in[0, 1].SslError::DimensionMismatchwhen param shapes mismatch (should not occur in normal usage).
Trait Implementations§
Source§impl Clone for Data2VecModel
impl Clone for Data2VecModel
Source§fn clone(&self) -> Data2VecModel
fn clone(&self) -> Data2VecModel
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more