Skip to main content

oxicuda_ssl/ssl/
data2vec_v2.rs

1//! Data2Vec struct — student/teacher encoder pair with EMA update.
2//!
3//! This module provides a struct-based wrapper that manages the full Data2Vec
4//! training loop: a student encoder, an EMA teacher state, and the loss
5//! computation via the functional API in [`crate::masked::data2vec`].
6//!
7//! ## Architecture
8//! Both student and teacher are `n_layers`-deep encoders where each layer is a
9//! single linear transform followed by ReLU, operating token-by-token over a
10//! sequence of shape `[n_patches × d_model]`.
11//!
12//! Teacher weights are stored flat in [`Data2VecState::teacher_params`] in the
13//! same layout as the student parameters:
14//! ```text
15//! [ layer_0_w (d_model × d_model) | layer_0_b (d_model) |
16//!   layer_1_w (d_model × d_model) | layer_1_b (d_model) | ... ]
17//! ```
18//!
19//! ## EMA update
20//! ```text
21//! teacher ← ema_decay · teacher + (1 − ema_decay) · student
22//! ```
23//!
24//! Reference: "data2vec: A General Framework for Self-supervised Learning in
25//! Speech, Vision and Language", Baevski et al., ICML 2022.
26
27use crate::error::{SslError, SslResult};
28use crate::handle::LcgRng;
29use crate::masked::data2vec::{Data2VecConfig, Data2VecState, data2vec_loss};
30
31// ─── Configuration ────────────────────────────────────────────────────────────
32
33/// Hyper-parameters for the struct-based [`Data2VecModel`].
34#[derive(Debug, Clone)]
35pub struct Data2VecModelConfig {
36    /// Token / patch embedding dimension.
37    pub d_model: usize,
38    /// Number of encoder layers (must be >= 1).
39    pub n_layers: usize,
40    /// EMA decay coefficient for the teacher update (e.g. 0.999).
41    pub ema_decay: f32,
42    /// Fraction of tokens to mask during the student forward pass.
43    pub mask_ratio: f32,
44    /// Number of top teacher layer outputs to average for the target
45    /// (passed through to [`Data2VecConfig::top_k_average`]).
46    pub k_top_layers: usize,
47}
48
49impl Default for Data2VecModelConfig {
50    fn default() -> Self {
51        Self {
52            d_model: 64,
53            n_layers: 2,
54            ema_decay: 0.999,
55            mask_ratio: 0.65,
56            k_top_layers: 1,
57        }
58    }
59}
60
61// ─── Data2VecModel ────────────────────────────────────────────────────────────
62
63/// Struct-based Data2Vec model that owns student encoder layers and a teacher
64/// EMA state.
65///
66/// Student weights per layer are stored as:
67/// - `student_w[l]` — `[d_model × d_model]` weight matrix (row-major).
68/// - `student_b[l]` — `[d_model]` bias vector.
69///
70/// Teacher weights are stored flat in [`Data2VecState::teacher_params`] in the
71/// same sequential layout (w0, b0, w1, b1, …).
72#[derive(Debug, Clone)]
73pub struct Data2VecModel {
74    /// Per-layer student weight matrices `n_layers × [d_model × d_model]`.
75    student_w: Vec<Vec<f32>>,
76    /// Per-layer student bias vectors `n_layers × [d_model]`.
77    student_b: Vec<Vec<f32>>,
78    /// EMA teacher state (flat parameter vector + step counter).
79    teacher_state: Data2VecState,
80    /// Configuration used to create this model.
81    config: Data2VecModelConfig,
82}
83
84impl Data2VecModel {
85    /// Create a new [`Data2VecModel`] with Kaiming-initialised student layers and
86    /// a teacher state cloned from the initial student parameters.
87    ///
88    /// # Errors
89    /// - [`SslError::InvalidParameter`] when `d_model == 0`.
90    /// - [`SslError::InvalidParameter`] when `n_layers == 0`.
91    pub fn new(config: Data2VecModelConfig, rng: &mut LcgRng) -> SslResult<Self> {
92        if config.d_model == 0 {
93            return Err(SslError::InvalidParameter {
94                name: "d_model".into(),
95                reason: "must be > 0".into(),
96            });
97        }
98        if config.n_layers == 0 {
99            return Err(SslError::InvalidParameter {
100                name: "n_layers".into(),
101                reason: "must be >= 1".into(),
102            });
103        }
104
105        let d = config.d_model;
106        let mut student_w = Vec::with_capacity(config.n_layers);
107        let mut student_b = Vec::with_capacity(config.n_layers);
108
109        for _ in 0..config.n_layers {
110            let w = kaiming_init(d, d, rng);
111            let b = vec![0.0_f32; d];
112            student_w.push(w);
113            student_b.push(b);
114        }
115
116        // Flatten student params into teacher initial state.
117        let flat_params = flatten_params(&student_w, &student_b);
118        let teacher_state = Data2VecState::new(&flat_params);
119
120        Ok(Self {
121            student_w,
122            student_b,
123            teacher_state,
124            config,
125        })
126    }
127
128    /// Encode a sequence of patch embeddings with the **student** encoder.
129    ///
130    /// Each layer applies: `token ← ReLU(W · token + b)` independently per token.
131    ///
132    /// # Arguments
133    /// * `x`        — flat `[n_patches × d_model]` input (row-major).
134    /// * `n_patches` — number of tokens in the sequence.
135    ///
136    /// # Errors
137    /// [`SslError::DimensionMismatch`] when `x.len() != n_patches * d_model`.
138    pub fn encode_student(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>> {
139        let d = self.config.d_model;
140        let expected = n_patches * d;
141        if x.len() != expected {
142            return Err(SslError::DimensionMismatch {
143                expected,
144                got: x.len(),
145            });
146        }
147        apply_encoder_layers(
148            x,
149            n_patches,
150            d,
151            &self.student_w,
152            &self.student_b,
153            self.config.n_layers,
154        )
155    }
156
157    /// Encode a sequence of patch embeddings with the **teacher** encoder.
158    ///
159    /// Uses the teacher weight matrices stored in [`Data2VecState::teacher_params`].
160    ///
161    /// # Arguments
162    /// * `x`        — flat `[n_patches × d_model]` input (row-major).
163    /// * `n_patches` — number of tokens in the sequence.
164    ///
165    /// # Errors
166    /// [`SslError::DimensionMismatch`] when `x.len() != n_patches * d_model`.
167    pub fn encode_teacher(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>> {
168        let d = self.config.d_model;
169        let expected = n_patches * d;
170        if x.len() != expected {
171            return Err(SslError::DimensionMismatch {
172                expected,
173                got: x.len(),
174            });
175        }
176        let (teacher_w, teacher_b) =
177            unflatten_params(self.teacher_state.teacher(), d, self.config.n_layers)?;
178        apply_encoder_layers(
179            x,
180            n_patches,
181            d,
182            &teacher_w,
183            &teacher_b,
184            self.config.n_layers,
185        )
186    }
187
188    /// Compute the Data2Vec loss for a masked input.
189    ///
190    /// 1. Encodes with the student encoder → `student_repr [n_patches × d_model]`.
191    /// 2. Encodes with the teacher encoder → `teacher_repr [n_patches × d_model]`.
192    /// 3. Computes Huber loss at masked positions via [`data2vec_loss`].
193    ///
194    /// # Arguments
195    /// * `x`        — flat `[n_patches × d_model]` input.
196    /// * `mask`     — `[n_patches]` boolean; `true` = masked position.
197    /// * `n_patches` — number of tokens.
198    ///
199    /// # Errors
200    /// Propagates dimension and config errors.
201    pub fn loss(&self, x: &[f32], mask: &[bool], n_patches: usize) -> SslResult<f32> {
202        let d = self.config.d_model;
203        let student_repr = self.encode_student(x, n_patches)?;
204        let teacher_repr = self.encode_teacher(x, n_patches)?;
205        let d2v_config = Data2VecConfig {
206            mask_ratio: self.config.mask_ratio,
207            momentum: self.config.ema_decay,
208            top_k_average: self.config.k_top_layers,
209            ..Data2VecConfig::default()
210        };
211        let result = data2vec_loss(
212            &student_repr,
213            &teacher_repr,
214            mask,
215            n_patches,
216            d,
217            &d2v_config,
218        )?;
219        Ok(result.loss)
220    }
221
222    /// Apply the EMA update: `teacher ← ema_decay · teacher + (1 − ema_decay) · student`.
223    ///
224    /// # Errors
225    /// - [`SslError::InvalidMomentum`] when `ema_decay` is not in `[0, 1]`.
226    /// - [`SslError::DimensionMismatch`] when param shapes mismatch (should not
227    ///   occur in normal usage).
228    pub fn ema_update(&mut self) -> SslResult<()> {
229        let flat_student = flatten_params(&self.student_w, &self.student_b);
230        self.teacher_state
231            .update_teacher(&flat_student, self.config.ema_decay)
232    }
233
234    /// Return the token/patch embedding dimension.
235    #[inline]
236    #[must_use]
237    pub fn d_model(&self) -> usize {
238        self.config.d_model
239    }
240}
241
242// ─── Internal helpers ────────────────────────────────────────────────────────
243
244/// Kaiming (He) normal weight init: `scale = sqrt(2 / fan_in)`.
245fn kaiming_init(out_dim: usize, in_dim: usize, rng: &mut LcgRng) -> Vec<f32> {
246    let scale = (2.0_f32 / in_dim as f32).sqrt();
247    let mut w = vec![0.0_f32; out_dim * in_dim];
248    rng.fill_normal(&mut w);
249    for v in w.iter_mut() {
250        *v *= scale;
251    }
252    w
253}
254
255/// Row-major matrix-vector multiply with optional ReLU.
256///
257/// `out[i] = max(0, b[i] + Σ_j w[i·in_dim + j] * x[j])` if `relu`, else without max.
258fn linear_relu(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
259    let mut out = vec![0.0_f32; out_dim];
260    for i in 0..out_dim {
261        let mut acc = b[i];
262        let row = i * in_dim;
263        for j in 0..in_dim {
264            acc += w[row + j] * x[j];
265        }
266        out[i] = acc.max(0.0);
267    }
268    out
269}
270
271/// Flatten student weights and biases into a single contiguous `Vec<f32>`.
272///
273/// Layout: `[w0_flat, b0, w1_flat, b1, …]`
274fn flatten_params(ws: &[Vec<f32>], bs: &[Vec<f32>]) -> Vec<f32> {
275    let total: usize =
276        ws.iter().map(|w| w.len()).sum::<usize>() + bs.iter().map(|b| b.len()).sum::<usize>();
277    let mut flat = Vec::with_capacity(total);
278    for (w, b) in ws.iter().zip(bs.iter()) {
279        flat.extend_from_slice(w);
280        flat.extend_from_slice(b);
281    }
282    flat
283}
284
285/// Per-layer weight and bias vectors `(weights, biases)`.
286type LayerParams = (Vec<Vec<f32>>, Vec<Vec<f32>>);
287
288/// Inverse of [`flatten_params`]: reconstruct per-layer weight/bias vectors.
289///
290/// # Errors
291/// [`SslError::DimensionMismatch`] when the flat slice is shorter than expected.
292fn unflatten_params(flat: &[f32], d_model: usize, n_layers: usize) -> SslResult<LayerParams> {
293    let w_size = d_model * d_model;
294    let b_size = d_model;
295    let layer_size = w_size + b_size;
296    let expected = n_layers * layer_size;
297    if flat.len() < expected {
298        return Err(SslError::DimensionMismatch {
299            expected,
300            got: flat.len(),
301        });
302    }
303    let mut ws = Vec::with_capacity(n_layers);
304    let mut bs = Vec::with_capacity(n_layers);
305    let mut offset = 0;
306    for _ in 0..n_layers {
307        ws.push(flat[offset..offset + w_size].to_vec());
308        offset += w_size;
309        bs.push(flat[offset..offset + b_size].to_vec());
310        offset += b_size;
311    }
312    Ok((ws, bs))
313}
314
315/// Apply `n_layers` linear+ReLU layers to every token in `x [n_patches × d_model]`.
316fn apply_encoder_layers(
317    x: &[f32],
318    n_patches: usize,
319    d_model: usize,
320    ws: &[Vec<f32>],
321    bs: &[Vec<f32>],
322    n_layers: usize,
323) -> SslResult<Vec<f32>> {
324    // Work token-by-token; keep results in a flat buffer.
325    let mut current = x.to_vec();
326    for l in 0..n_layers {
327        let w = &ws[l];
328        let b = &bs[l];
329        let mut next = Vec::with_capacity(n_patches * d_model);
330        for t in 0..n_patches {
331            let start = t * d_model;
332            let token = &current[start..start + d_model];
333            next.extend_from_slice(&linear_relu(w, b, token, d_model, d_model));
334        }
335        current = next;
336    }
337    Ok(current)
338}
339
340// ─── Tests ────────────────────────────────────────────────────────────────────
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::handle::LcgRng;
346    use crate::masked::data2vec::data2vec_mask;
347
348    fn make_model(seed: u64) -> Data2VecModel {
349        let mut rng = LcgRng::new(seed);
350        Data2VecModel::new(Data2VecModelConfig::default(), &mut rng)
351            .expect("value should be present")
352    }
353
354    fn random_vec(n: usize, seed: u64) -> Vec<f32> {
355        let mut rng = LcgRng::new(seed);
356        let mut v = vec![0.0_f32; n];
357        rng.fill_normal(&mut v);
358        v
359    }
360
361    fn make_mask(n_patches: usize, mask_ratio: f32, seed: u64) -> Vec<bool> {
362        let mut rng = LcgRng::new(seed);
363        data2vec_mask(n_patches, mask_ratio, &mut rng).expect("data2vec_mask should succeed")
364    }
365
366    #[test]
367    fn encode_student_shape() {
368        let m = make_model(1);
369        let n_patches = 8;
370        let d = m.d_model();
371        let x = random_vec(n_patches * d, 2);
372        let out = m
373            .encode_student(&x, n_patches)
374            .expect("encode_student should succeed");
375        assert_eq!(
376            out.len(),
377            n_patches * d,
378            "student output must have len == n_patches * d_model"
379        );
380    }
381
382    #[test]
383    fn encode_teacher_shape() {
384        let m = make_model(3);
385        let n_patches = 8;
386        let d = m.d_model();
387        let x = random_vec(n_patches * d, 4);
388        let out = m
389            .encode_teacher(&x, n_patches)
390            .expect("encode_teacher should succeed");
391        assert_eq!(
392            out.len(),
393            n_patches * d,
394            "teacher output must have len == n_patches * d_model"
395        );
396    }
397
398    #[test]
399    fn loss_finite() {
400        let m = make_model(5);
401        let n_patches = 8;
402        let d = m.d_model();
403        let x = random_vec(n_patches * d, 6);
404        let mask = make_mask(n_patches, 0.5, 7);
405        let l = m.loss(&x, &mask, n_patches).expect("loss should succeed");
406        assert!(l.is_finite(), "loss must be finite, got {l}");
407    }
408
409    #[test]
410    fn loss_nonneg() {
411        let m = make_model(8);
412        let n_patches = 8;
413        let d = m.d_model();
414        let x = random_vec(n_patches * d, 9);
415        let mask = make_mask(n_patches, 0.5, 10);
416        let l = m.loss(&x, &mask, n_patches).expect("loss should succeed");
417        assert!(l >= 0.0, "Huber loss must be >= 0, got {l}");
418    }
419
420    #[test]
421    fn ema_update_changes_teacher() {
422        let mut m = make_model(11);
423        let teacher_before = m.teacher_state.teacher_params.clone();
424        // The student is already different from the teacher snapshot if we
425        // modify student weights slightly.
426        for v in m.student_w[0].iter_mut() {
427            *v += 1.0;
428        }
429        m.ema_update().expect("ema_update should succeed");
430        let teacher_after = &m.teacher_state.teacher_params;
431        let diff: f32 = teacher_before
432            .iter()
433            .zip(teacher_after.iter())
434            .map(|(a, b)| (a - b).abs())
435            .sum();
436        assert!(
437            diff > 1e-8,
438            "teacher must change after ema_update when student differs, diff={diff}"
439        );
440    }
441
442    #[test]
443    fn ema_update_preserves_student() {
444        let mut m = make_model(12);
445        let student_w_before: Vec<Vec<f32>> = m.student_w.clone();
446        let student_b_before: Vec<Vec<f32>> = m.student_b.clone();
447        m.ema_update().expect("ema_update should succeed");
448        assert_eq!(
449            m.student_w, student_w_before,
450            "student weights must not change during ema_update"
451        );
452        assert_eq!(
453            m.student_b, student_b_before,
454            "student biases must not change during ema_update"
455        );
456    }
457
458    #[test]
459    fn d_model_0_error() {
460        let mut rng = LcgRng::new(13);
461        let result = Data2VecModel::new(
462            Data2VecModelConfig {
463                d_model: 0,
464                ..Data2VecModelConfig::default()
465            },
466            &mut rng,
467        );
468        assert!(result.is_err(), "d_model=0 must return Err");
469    }
470
471    #[test]
472    fn n_layers_1_works() {
473        let mut rng = LcgRng::new(14);
474        let m = Data2VecModel::new(
475            Data2VecModelConfig {
476                n_layers: 1,
477                ..Data2VecModelConfig::default()
478            },
479            &mut rng,
480        )
481        .expect("value should be present");
482        let n_patches = 4;
483        let x = random_vec(n_patches * m.d_model(), 15);
484        let out = m
485            .encode_student(&x, n_patches)
486            .expect("encode_student should succeed");
487        assert_eq!(out.len(), n_patches * m.d_model());
488    }
489
490    #[test]
491    fn different_x_different_encode() {
492        let m = make_model(16);
493        let n_patches = 4;
494        let d = m.d_model();
495        let x1 = random_vec(n_patches * d, 17);
496        let x2 = random_vec(n_patches * d, 18);
497        let e1 = m
498            .encode_student(&x1, n_patches)
499            .expect("encode_student should succeed");
500        let e2 = m
501            .encode_student(&x2, n_patches)
502            .expect("encode_student should succeed");
503        let diff: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| (a - b).abs()).sum();
504        assert!(
505            diff > 1e-6,
506            "different inputs must produce different encodings, diff={diff}"
507        );
508    }
509
510    #[test]
511    fn n_layers_0_error() {
512        let mut rng = LcgRng::new(19);
513        let result = Data2VecModel::new(
514            Data2VecModelConfig {
515                n_layers: 0,
516                ..Data2VecModelConfig::default()
517            },
518            &mut rng,
519        );
520        assert!(result.is_err(), "n_layers=0 must return Err");
521    }
522
523    #[test]
524    fn d_model_accessor() {
525        let m = make_model(20);
526        assert_eq!(m.d_model(), 64);
527    }
528}