1use crate::error::{SslError, SslResult};
22use crate::handle::LcgRng;
23
24#[derive(Debug, Clone)]
28pub struct Data2VecConfig {
29 pub mask_ratio: f32,
31 pub momentum: f32,
33 pub beta: f32,
35 pub normalize_targets: bool,
38 pub top_k_average: usize,
41}
42
43impl Default for Data2VecConfig {
44 fn default() -> Self {
45 Self {
46 mask_ratio: 0.65,
47 momentum: 0.999,
48 beta: 2.0,
49 normalize_targets: true,
50 top_k_average: 1,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
59pub struct Data2VecResult {
60 pub loss: f32,
62 pub n_masked: usize,
64 pub accuracy_at_1: f32,
66}
67
68#[derive(Debug, Clone)]
78pub struct Data2VecState {
79 pub teacher_params: Vec<f32>,
81 pub step: usize,
83}
84
85impl Data2VecState {
86 #[must_use]
91 pub fn new(online_params: &[f32]) -> Self {
92 Self {
93 teacher_params: online_params.to_vec(),
94 step: 0,
95 }
96 }
97
98 pub fn update_teacher(&mut self, online_params: &[f32], momentum: f32) -> SslResult<()> {
107 if !(momentum.is_finite() && (0.0..=1.0).contains(&momentum)) {
108 return Err(SslError::InvalidMomentum { momentum });
109 }
110 if self.teacher_params.len() != online_params.len() {
111 return Err(SslError::DimensionMismatch {
112 expected: self.teacher_params.len(),
113 got: online_params.len(),
114 });
115 }
116 let one_minus_m = 1.0 - momentum;
117 for (t, &o) in self.teacher_params.iter_mut().zip(online_params.iter()) {
118 *t = momentum * *t + one_minus_m * o;
119 }
120 self.step += 1;
121 Ok(())
122 }
123
124 #[must_use]
126 #[inline]
127 pub fn teacher(&self) -> &[f32] {
128 &self.teacher_params
129 }
130}
131
132#[must_use]
145pub fn huber_loss(predictions: &[f32], targets: &[f32], beta: f32) -> f32 {
146 if predictions.is_empty() || predictions.len() != targets.len() {
147 return 0.0;
148 }
149 let n = predictions.len() as f64;
150 let half_beta = (beta as f64) / 2.0;
151 let inv_beta = 1.0 / (beta as f64);
152 let total: f64 = predictions
153 .iter()
154 .zip(targets.iter())
155 .map(|(&p, &t)| {
156 let x = (p - t) as f64;
157 let ax = x.abs();
158 if ax < beta as f64 {
159 0.5 * x * x * inv_beta
160 } else {
161 ax - half_beta
162 }
163 })
164 .sum();
165 (total / n) as f32
166}
167
168pub fn normalize_teacher_targets(targets: &mut [f32], n_tokens: usize, dim: usize) {
179 if n_tokens == 0 || dim == 0 || targets.len() != n_tokens * dim {
180 return;
181 }
182 const EPS: f32 = 1e-8;
183 let n = n_tokens as f32;
184 for d in 0..dim {
186 let mut sum_sq = 0.0_f32;
187 for i in 0..n_tokens {
188 let v = targets[i * dim + d];
189 sum_sq += v * v;
190 }
191 let norm = (sum_sq / n).sqrt();
192 let scale = 1.0 / (norm + EPS);
193 for i in 0..n_tokens {
194 targets[i * dim + d] *= scale;
195 }
196 }
197}
198
199pub fn data2vec_mask(n_tokens: usize, mask_ratio: f32, rng: &mut LcgRng) -> SslResult<Vec<bool>> {
211 if n_tokens == 0 {
212 return Err(SslError::EmptyInput);
213 }
214 if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
215 return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
216 }
217 let n_mask = (n_tokens as f32 * mask_ratio) as usize;
218 let mut indices: Vec<usize> = (0..n_tokens).collect();
220 rng.shuffle(&mut indices);
221 let mut mask = vec![false; n_tokens];
222 for &idx in indices.iter().take(n_mask) {
223 mask[idx] = true;
224 }
225 Ok(mask)
226}
227
228pub fn data2vec_loss(
253 student_pred: &[f32],
254 teacher_repr: &[f32],
255 mask: &[bool],
256 n_tokens: usize,
257 dim: usize,
258 config: &Data2VecConfig,
259) -> SslResult<Data2VecResult> {
260 if n_tokens == 0 || dim == 0 {
262 return Err(SslError::EmptyInput);
263 }
264 let expected = n_tokens * dim;
265 if student_pred.len() != expected {
266 return Err(SslError::DimensionMismatch {
267 expected,
268 got: student_pred.len(),
269 });
270 }
271 if teacher_repr.len() != expected {
272 return Err(SslError::DimensionMismatch {
273 expected,
274 got: teacher_repr.len(),
275 });
276 }
277 if mask.len() != n_tokens {
278 return Err(SslError::DimensionMismatch {
279 expected: n_tokens,
280 got: mask.len(),
281 });
282 }
283
284 let masked_indices: Vec<usize> = (0..n_tokens).filter(|&i| mask[i]).collect();
286 let n_masked = masked_indices.len();
287
288 if n_masked == 0 {
289 return Ok(Data2VecResult {
291 loss: 0.0,
292 n_masked: 0,
293 accuracy_at_1: 0.0,
294 });
295 }
296
297 let mut teacher_masked = Vec::with_capacity(n_masked * dim);
301 let mut student_masked = Vec::with_capacity(n_masked * dim);
302 for &i in &masked_indices {
303 let start = i * dim;
304 let end = start + dim;
305 teacher_masked.extend_from_slice(&teacher_repr[start..end]);
306 student_masked.extend_from_slice(&student_pred[start..end]);
307 }
308
309 if config.normalize_targets {
311 normalize_teacher_targets(&mut teacher_masked, n_masked, dim);
312 }
313
314 let loss = huber_loss(&student_masked, &teacher_masked, config.beta);
316
317 Ok(Data2VecResult {
318 loss,
319 n_masked,
320 accuracy_at_1: 0.0,
321 })
322}
323
324pub fn data2vec_batch_loss(
338 student_preds: &[f32],
339 teacher_reprs: &[f32],
340 masks: &[bool],
341 batch_size: usize,
342 n_tokens: usize,
343 dim: usize,
344 config: &Data2VecConfig,
345) -> SslResult<f32> {
346 if batch_size == 0 {
347 return Err(SslError::EmptyInput);
348 }
349 let sample_len = n_tokens * dim;
350 let expected_feat = batch_size * sample_len;
351 let expected_mask = batch_size * n_tokens;
352
353 if student_preds.len() != expected_feat {
354 return Err(SslError::DimensionMismatch {
355 expected: expected_feat,
356 got: student_preds.len(),
357 });
358 }
359 if teacher_reprs.len() != expected_feat {
360 return Err(SslError::DimensionMismatch {
361 expected: expected_feat,
362 got: teacher_reprs.len(),
363 });
364 }
365 if masks.len() != expected_mask {
366 return Err(SslError::DimensionMismatch {
367 expected: expected_mask,
368 got: masks.len(),
369 });
370 }
371
372 let mut total_loss = 0.0_f64;
373 for b in 0..batch_size {
374 let feat_start = b * sample_len;
375 let feat_end = feat_start + sample_len;
376 let mask_start = b * n_tokens;
377 let mask_end = mask_start + n_tokens;
378
379 let result = data2vec_loss(
380 &student_preds[feat_start..feat_end],
381 &teacher_reprs[feat_start..feat_end],
382 &masks[mask_start..mask_end],
383 n_tokens,
384 dim,
385 config,
386 )?;
387 total_loss += result.loss as f64;
388 }
389 Ok((total_loss / batch_size as f64) as f32)
390}
391
392#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::handle::LcgRng;
398
399 #[test]
402 fn config_defaults() {
403 let cfg = Data2VecConfig::default();
404 assert!((cfg.mask_ratio - 0.65).abs() < 1e-7);
405 assert!((cfg.momentum - 0.999).abs() < 1e-7);
406 assert!((cfg.beta - 2.0).abs() < 1e-7);
407 assert!(cfg.normalize_targets);
408 assert_eq!(cfg.top_k_average, 1);
409 }
410
411 #[test]
415 fn huber_loss_small_error() {
416 let pred = vec![0.5_f32];
417 let tgt = vec![0.0_f32];
418 let loss = huber_loss(&pred, &tgt, 2.0);
419 let expected = 0.5_f32 * 0.25_f32 / 2.0_f32; assert!(
421 (loss - expected).abs() < 1e-6,
422 "loss={loss} expected={expected}"
423 );
424 }
425
426 #[test]
430 fn huber_loss_large_error() {
431 let pred = vec![3.0_f32];
432 let tgt = vec![0.0_f32];
433 let loss = huber_loss(&pred, &tgt, 2.0);
434 assert!((loss - 2.0).abs() < 1e-6, "loss={loss}");
435 }
436
437 #[test]
440 fn huber_loss_zero() {
441 let v = vec![1.5_f32, -0.7, 3.2, 0.0];
442 let loss = huber_loss(&v, &v, 2.0);
443 assert!(loss.abs() < 1e-7, "loss={loss}");
444 }
445
446 #[test]
449 fn mask_exact_ratio() {
450 let mut rng = LcgRng::new(42);
451 let mask = data2vec_mask(100, 0.65, &mut rng).expect("data2vec_mask should succeed");
452 let n_masked = mask.iter().filter(|&&v| v).count();
453 assert_eq!(n_masked, 65, "expected 65 masked, got {n_masked}");
454 }
455
456 #[test]
459 fn mask_length() {
460 let mut rng = LcgRng::new(7);
461 let mask = data2vec_mask(196, 0.75, &mut rng).expect("data2vec_mask should succeed");
462 assert_eq!(mask.len(), 196);
463 }
464
465 #[test]
468 fn data2vec_loss_only_masked() {
469 let n_tokens = 10;
470 let dim = 4;
471 let mut mask = vec![false; n_tokens];
473 mask[3] = true;
474 mask[7] = true;
475
476 let repr: Vec<f32> = (0..n_tokens * dim).map(|i| (i as f32) * 0.1).collect();
478 let cfg = Data2VecConfig {
479 normalize_targets: false,
480 ..Data2VecConfig::default()
481 };
482 let result = data2vec_loss(&repr, &repr, &mask, n_tokens, dim, &cfg)
483 .expect("data2vec_loss should succeed");
484 assert!(result.loss.abs() < 1e-6, "loss={}", result.loss);
485 assert_eq!(result.n_masked, 2);
486 assert!((result.accuracy_at_1 - 0.0).abs() < 1e-7);
487 }
488
489 #[test]
492 fn data2vec_loss_no_masked_tokens() {
493 let n_tokens = 8;
494 let dim = 3;
495 let mask = vec![false; n_tokens];
496 let v = vec![0.0_f32; n_tokens * dim];
497 let cfg = Data2VecConfig::default();
498 let result = data2vec_loss(&v, &v, &mask, n_tokens, dim, &cfg)
499 .expect("data2vec_loss should succeed");
500 assert_eq!(result.n_masked, 0);
501 assert!(result.loss.abs() < 1e-7);
502 }
503
504 #[test]
507 fn normalize_targets_reduces_large_values() {
508 let n_tokens = 4;
509 let dim = 2;
510 let mut targets = vec![100.0_f32; n_tokens * dim];
512 normalize_teacher_targets(&mut targets, n_tokens, dim);
513 for &v in &targets {
516 assert!(v.abs() < 2.0, "value after norm={v}");
517 }
518 }
519
520 #[test]
523 fn state_init_matches_online() {
524 let online = vec![0.1_f32, 0.5, -0.3, 1.2];
525 let state = Data2VecState::new(&online);
526 assert_eq!(state.teacher(), online.as_slice());
527 assert_eq!(state.step, 0);
528 }
529
530 #[test]
533 fn state_update_closer_to_online_m0() {
534 let teacher_init = vec![1.0_f32, 2.0, 3.0];
535 let online = vec![10.0_f32, 20.0, 30.0];
536 let mut state = Data2VecState::new(&teacher_init);
537 state
538 .update_teacher(&online, 0.0)
539 .expect("update_teacher should succeed");
540 for (&t, &o) in state.teacher().iter().zip(online.iter()) {
542 assert!((t - o).abs() < 1e-6, "teacher={t} online={o}");
543 }
544 assert_eq!(state.step, 1);
545 }
546
547 #[test]
550 fn state_update_m1_unchanged() {
551 let teacher_init = vec![5.0_f32, -3.0, 0.7];
552 let online = vec![0.0_f32, 0.0, 0.0];
553 let mut state = Data2VecState::new(&teacher_init);
554 let expected = state.teacher().to_vec();
555 state
556 .update_teacher(&online, 1.0)
557 .expect("update_teacher should succeed");
558 for (&t, &e) in state.teacher().iter().zip(expected.iter()) {
560 assert!((t - e).abs() < 1e-6, "teacher={t} expected={e}");
561 }
562 }
563
564 #[test]
567 fn batch_loss_matches_single() {
568 let n_tokens = 6;
569 let dim = 4;
570 let mut rng = LcgRng::new(99);
571
572 let mut student = vec![0.0_f32; n_tokens * dim];
573 let mut teacher = vec![0.0_f32; n_tokens * dim];
574 rng.fill_normal(&mut student);
575 rng.fill_normal(&mut teacher);
576
577 let mask = data2vec_mask(n_tokens, 0.5, &mut rng).expect("data2vec_mask should succeed");
578
579 let cfg = Data2VecConfig::default();
580
581 let single = data2vec_loss(&student, &teacher, &mask, n_tokens, dim, &cfg)
582 .expect("value should be present")
583 .loss;
584 let batch = data2vec_batch_loss(&student, &teacher, &mask, 1, n_tokens, dim, &cfg)
585 .expect("data2vec_batch_loss should succeed");
586
587 assert!(
588 (single - batch).abs() < 1e-5,
589 "single={single} batch={batch}"
590 );
591 }
592
593 #[test]
596 fn batch_loss_finite() {
597 let batch_size = 4;
598 let n_tokens = 16;
599 let dim = 8;
600 let mut rng = LcgRng::new(1337);
601
602 let total_feat = batch_size * n_tokens * dim;
603 let mut student = vec![0.0_f32; total_feat];
604 let mut teacher = vec![0.0_f32; total_feat];
605 rng.fill_normal(&mut student);
606 rng.fill_normal(&mut teacher);
607
608 let mut masks = Vec::with_capacity(batch_size * n_tokens);
609 for _ in 0..batch_size {
610 masks.extend(
611 data2vec_mask(n_tokens, 0.65, &mut rng).expect("data2vec_mask should succeed"),
612 );
613 }
614
615 let cfg = Data2VecConfig::default();
616 let loss = data2vec_batch_loss(&student, &teacher, &masks, batch_size, n_tokens, dim, &cfg)
617 .expect("value should be present");
618
619 assert!(loss.is_finite(), "loss={loss}");
620 assert!(loss >= 0.0, "loss={loss}");
621 }
622
623 #[test]
626 fn mask_invalid_ratio_errors() {
627 let mut rng = LcgRng::new(1);
628 assert!(data2vec_mask(10, 1.0, &mut rng).is_err()); assert!(data2vec_mask(10, -0.1, &mut rng).is_err());
630 assert!(data2vec_mask(10, f32::NAN, &mut rng).is_err());
631 }
632
633 #[test]
636 fn state_update_rejects_invalid_momentum() {
637 let mut state = Data2VecState::new(&[1.0_f32, 2.0]);
638 let online = vec![3.0_f32, 4.0];
639 assert!(state.update_teacher(&online, 1.5).is_err());
640 assert!(state.update_teacher(&online, -0.1).is_err());
641 assert!(state.update_teacher(&online, f32::NAN).is_err());
642 }
643
644 #[test]
647 fn normalize_teacher_targets_empty_noop() {
648 let mut v: Vec<f32> = vec![];
649 normalize_teacher_targets(&mut v, 0, 4); let mut v2 = vec![1.0_f32; 8];
651 normalize_teacher_targets(&mut v2, 4, 0); }
653
654 #[test]
657 fn data2vec_loss_shape_errors() {
658 let n = 4;
659 let d = 3;
660 let cfg = Data2VecConfig::default();
661 let good = vec![0.0_f32; n * d];
662 let short = vec![0.0_f32; n * d - 1];
663 let mask = vec![true; n];
664 assert!(data2vec_loss(&short, &good, &mask, n, d, &cfg).is_err());
666 assert!(data2vec_loss(&good, &short, &mask, n, d, &cfg).is_err());
668 let bad_mask = vec![true; n - 1];
670 assert!(data2vec_loss(&good, &good, &bad_mask, n, d, &cfg).is_err());
671 }
672}