1use crate::error::{SslError, SslResult};
28use crate::handle::LcgRng;
29use crate::masked::data2vec::{Data2VecConfig, Data2VecState, data2vec_loss};
30
31#[derive(Debug, Clone)]
35pub struct Data2VecModelConfig {
36 pub d_model: usize,
38 pub n_layers: usize,
40 pub ema_decay: f32,
42 pub mask_ratio: f32,
44 pub k_top_layers: usize,
47}
48
49impl Default for Data2VecModelConfig {
50 fn default() -> Self {
51 Self {
52 d_model: 64,
53 n_layers: 2,
54 ema_decay: 0.999,
55 mask_ratio: 0.65,
56 k_top_layers: 1,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
73pub struct Data2VecModel {
74 student_w: Vec<Vec<f32>>,
76 student_b: Vec<Vec<f32>>,
78 teacher_state: Data2VecState,
80 config: Data2VecModelConfig,
82}
83
84impl Data2VecModel {
85 pub fn new(config: Data2VecModelConfig, rng: &mut LcgRng) -> SslResult<Self> {
92 if config.d_model == 0 {
93 return Err(SslError::InvalidParameter {
94 name: "d_model".into(),
95 reason: "must be > 0".into(),
96 });
97 }
98 if config.n_layers == 0 {
99 return Err(SslError::InvalidParameter {
100 name: "n_layers".into(),
101 reason: "must be >= 1".into(),
102 });
103 }
104
105 let d = config.d_model;
106 let mut student_w = Vec::with_capacity(config.n_layers);
107 let mut student_b = Vec::with_capacity(config.n_layers);
108
109 for _ in 0..config.n_layers {
110 let w = kaiming_init(d, d, rng);
111 let b = vec![0.0_f32; d];
112 student_w.push(w);
113 student_b.push(b);
114 }
115
116 let flat_params = flatten_params(&student_w, &student_b);
118 let teacher_state = Data2VecState::new(&flat_params);
119
120 Ok(Self {
121 student_w,
122 student_b,
123 teacher_state,
124 config,
125 })
126 }
127
128 pub fn encode_student(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>> {
139 let d = self.config.d_model;
140 let expected = n_patches * d;
141 if x.len() != expected {
142 return Err(SslError::DimensionMismatch {
143 expected,
144 got: x.len(),
145 });
146 }
147 apply_encoder_layers(
148 x,
149 n_patches,
150 d,
151 &self.student_w,
152 &self.student_b,
153 self.config.n_layers,
154 )
155 }
156
157 pub fn encode_teacher(&self, x: &[f32], n_patches: usize) -> SslResult<Vec<f32>> {
168 let d = self.config.d_model;
169 let expected = n_patches * d;
170 if x.len() != expected {
171 return Err(SslError::DimensionMismatch {
172 expected,
173 got: x.len(),
174 });
175 }
176 let (teacher_w, teacher_b) =
177 unflatten_params(self.teacher_state.teacher(), d, self.config.n_layers)?;
178 apply_encoder_layers(
179 x,
180 n_patches,
181 d,
182 &teacher_w,
183 &teacher_b,
184 self.config.n_layers,
185 )
186 }
187
188 pub fn loss(&self, x: &[f32], mask: &[bool], n_patches: usize) -> SslResult<f32> {
202 let d = self.config.d_model;
203 let student_repr = self.encode_student(x, n_patches)?;
204 let teacher_repr = self.encode_teacher(x, n_patches)?;
205 let d2v_config = Data2VecConfig {
206 mask_ratio: self.config.mask_ratio,
207 momentum: self.config.ema_decay,
208 top_k_average: self.config.k_top_layers,
209 ..Data2VecConfig::default()
210 };
211 let result = data2vec_loss(
212 &student_repr,
213 &teacher_repr,
214 mask,
215 n_patches,
216 d,
217 &d2v_config,
218 )?;
219 Ok(result.loss)
220 }
221
222 pub fn ema_update(&mut self) -> SslResult<()> {
229 let flat_student = flatten_params(&self.student_w, &self.student_b);
230 self.teacher_state
231 .update_teacher(&flat_student, self.config.ema_decay)
232 }
233
234 #[inline]
236 #[must_use]
237 pub fn d_model(&self) -> usize {
238 self.config.d_model
239 }
240}
241
242fn kaiming_init(out_dim: usize, in_dim: usize, rng: &mut LcgRng) -> Vec<f32> {
246 let scale = (2.0_f32 / in_dim as f32).sqrt();
247 let mut w = vec![0.0_f32; out_dim * in_dim];
248 rng.fill_normal(&mut w);
249 for v in w.iter_mut() {
250 *v *= scale;
251 }
252 w
253}
254
255fn linear_relu(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
259 let mut out = vec![0.0_f32; out_dim];
260 for i in 0..out_dim {
261 let mut acc = b[i];
262 let row = i * in_dim;
263 for j in 0..in_dim {
264 acc += w[row + j] * x[j];
265 }
266 out[i] = acc.max(0.0);
267 }
268 out
269}
270
271fn flatten_params(ws: &[Vec<f32>], bs: &[Vec<f32>]) -> Vec<f32> {
275 let total: usize =
276 ws.iter().map(|w| w.len()).sum::<usize>() + bs.iter().map(|b| b.len()).sum::<usize>();
277 let mut flat = Vec::with_capacity(total);
278 for (w, b) in ws.iter().zip(bs.iter()) {
279 flat.extend_from_slice(w);
280 flat.extend_from_slice(b);
281 }
282 flat
283}
284
285type LayerParams = (Vec<Vec<f32>>, Vec<Vec<f32>>);
287
288fn unflatten_params(flat: &[f32], d_model: usize, n_layers: usize) -> SslResult<LayerParams> {
293 let w_size = d_model * d_model;
294 let b_size = d_model;
295 let layer_size = w_size + b_size;
296 let expected = n_layers * layer_size;
297 if flat.len() < expected {
298 return Err(SslError::DimensionMismatch {
299 expected,
300 got: flat.len(),
301 });
302 }
303 let mut ws = Vec::with_capacity(n_layers);
304 let mut bs = Vec::with_capacity(n_layers);
305 let mut offset = 0;
306 for _ in 0..n_layers {
307 ws.push(flat[offset..offset + w_size].to_vec());
308 offset += w_size;
309 bs.push(flat[offset..offset + b_size].to_vec());
310 offset += b_size;
311 }
312 Ok((ws, bs))
313}
314
315fn apply_encoder_layers(
317 x: &[f32],
318 n_patches: usize,
319 d_model: usize,
320 ws: &[Vec<f32>],
321 bs: &[Vec<f32>],
322 n_layers: usize,
323) -> SslResult<Vec<f32>> {
324 let mut current = x.to_vec();
326 for l in 0..n_layers {
327 let w = &ws[l];
328 let b = &bs[l];
329 let mut next = Vec::with_capacity(n_patches * d_model);
330 for t in 0..n_patches {
331 let start = t * d_model;
332 let token = ¤t[start..start + d_model];
333 next.extend_from_slice(&linear_relu(w, b, token, d_model, d_model));
334 }
335 current = next;
336 }
337 Ok(current)
338}
339
340#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::handle::LcgRng;
346 use crate::masked::data2vec::data2vec_mask;
347
348 fn make_model(seed: u64) -> Data2VecModel {
349 let mut rng = LcgRng::new(seed);
350 Data2VecModel::new(Data2VecModelConfig::default(), &mut rng)
351 .expect("value should be present")
352 }
353
354 fn random_vec(n: usize, seed: u64) -> Vec<f32> {
355 let mut rng = LcgRng::new(seed);
356 let mut v = vec![0.0_f32; n];
357 rng.fill_normal(&mut v);
358 v
359 }
360
361 fn make_mask(n_patches: usize, mask_ratio: f32, seed: u64) -> Vec<bool> {
362 let mut rng = LcgRng::new(seed);
363 data2vec_mask(n_patches, mask_ratio, &mut rng).expect("data2vec_mask should succeed")
364 }
365
366 #[test]
367 fn encode_student_shape() {
368 let m = make_model(1);
369 let n_patches = 8;
370 let d = m.d_model();
371 let x = random_vec(n_patches * d, 2);
372 let out = m
373 .encode_student(&x, n_patches)
374 .expect("encode_student should succeed");
375 assert_eq!(
376 out.len(),
377 n_patches * d,
378 "student output must have len == n_patches * d_model"
379 );
380 }
381
382 #[test]
383 fn encode_teacher_shape() {
384 let m = make_model(3);
385 let n_patches = 8;
386 let d = m.d_model();
387 let x = random_vec(n_patches * d, 4);
388 let out = m
389 .encode_teacher(&x, n_patches)
390 .expect("encode_teacher should succeed");
391 assert_eq!(
392 out.len(),
393 n_patches * d,
394 "teacher output must have len == n_patches * d_model"
395 );
396 }
397
398 #[test]
399 fn loss_finite() {
400 let m = make_model(5);
401 let n_patches = 8;
402 let d = m.d_model();
403 let x = random_vec(n_patches * d, 6);
404 let mask = make_mask(n_patches, 0.5, 7);
405 let l = m.loss(&x, &mask, n_patches).expect("loss should succeed");
406 assert!(l.is_finite(), "loss must be finite, got {l}");
407 }
408
409 #[test]
410 fn loss_nonneg() {
411 let m = make_model(8);
412 let n_patches = 8;
413 let d = m.d_model();
414 let x = random_vec(n_patches * d, 9);
415 let mask = make_mask(n_patches, 0.5, 10);
416 let l = m.loss(&x, &mask, n_patches).expect("loss should succeed");
417 assert!(l >= 0.0, "Huber loss must be >= 0, got {l}");
418 }
419
420 #[test]
421 fn ema_update_changes_teacher() {
422 let mut m = make_model(11);
423 let teacher_before = m.teacher_state.teacher_params.clone();
424 for v in m.student_w[0].iter_mut() {
427 *v += 1.0;
428 }
429 m.ema_update().expect("ema_update should succeed");
430 let teacher_after = &m.teacher_state.teacher_params;
431 let diff: f32 = teacher_before
432 .iter()
433 .zip(teacher_after.iter())
434 .map(|(a, b)| (a - b).abs())
435 .sum();
436 assert!(
437 diff > 1e-8,
438 "teacher must change after ema_update when student differs, diff={diff}"
439 );
440 }
441
442 #[test]
443 fn ema_update_preserves_student() {
444 let mut m = make_model(12);
445 let student_w_before: Vec<Vec<f32>> = m.student_w.clone();
446 let student_b_before: Vec<Vec<f32>> = m.student_b.clone();
447 m.ema_update().expect("ema_update should succeed");
448 assert_eq!(
449 m.student_w, student_w_before,
450 "student weights must not change during ema_update"
451 );
452 assert_eq!(
453 m.student_b, student_b_before,
454 "student biases must not change during ema_update"
455 );
456 }
457
458 #[test]
459 fn d_model_0_error() {
460 let mut rng = LcgRng::new(13);
461 let result = Data2VecModel::new(
462 Data2VecModelConfig {
463 d_model: 0,
464 ..Data2VecModelConfig::default()
465 },
466 &mut rng,
467 );
468 assert!(result.is_err(), "d_model=0 must return Err");
469 }
470
471 #[test]
472 fn n_layers_1_works() {
473 let mut rng = LcgRng::new(14);
474 let m = Data2VecModel::new(
475 Data2VecModelConfig {
476 n_layers: 1,
477 ..Data2VecModelConfig::default()
478 },
479 &mut rng,
480 )
481 .expect("value should be present");
482 let n_patches = 4;
483 let x = random_vec(n_patches * m.d_model(), 15);
484 let out = m
485 .encode_student(&x, n_patches)
486 .expect("encode_student should succeed");
487 assert_eq!(out.len(), n_patches * m.d_model());
488 }
489
490 #[test]
491 fn different_x_different_encode() {
492 let m = make_model(16);
493 let n_patches = 4;
494 let d = m.d_model();
495 let x1 = random_vec(n_patches * d, 17);
496 let x2 = random_vec(n_patches * d, 18);
497 let e1 = m
498 .encode_student(&x1, n_patches)
499 .expect("encode_student should succeed");
500 let e2 = m
501 .encode_student(&x2, n_patches)
502 .expect("encode_student should succeed");
503 let diff: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| (a - b).abs()).sum();
504 assert!(
505 diff > 1e-6,
506 "different inputs must produce different encodings, diff={diff}"
507 );
508 }
509
510 #[test]
511 fn n_layers_0_error() {
512 let mut rng = LcgRng::new(19);
513 let result = Data2VecModel::new(
514 Data2VecModelConfig {
515 n_layers: 0,
516 ..Data2VecModelConfig::default()
517 },
518 &mut rng,
519 );
520 assert!(result.is_err(), "n_layers=0 must return Err");
521 }
522
523 #[test]
524 fn d_model_accessor() {
525 let m = make_model(20);
526 assert_eq!(m.d_model(), 64);
527 }
528}