1use crate::error::{SslError, SslResult};
25use crate::handle::LcgRng;
26
27#[derive(Debug, Clone)]
31pub struct BeitConfig {
32 pub n_codes: usize,
34 pub code_dim: usize,
36 pub mask_ratio: f32,
38 pub ema_momentum: f32,
40 pub commitment_weight: f32,
42 pub temperature: f32,
44 pub eps: f32,
46}
47
48impl Default for BeitConfig {
49 fn default() -> Self {
50 Self {
51 n_codes: 8192,
52 code_dim: 256,
53 mask_ratio: 0.4,
54 ema_momentum: 0.999,
55 commitment_weight: 0.25,
56 temperature: 1.0,
57 eps: 1e-6,
58 }
59 }
60}
61
62impl BeitConfig {
63 pub fn new(
70 n_codes: usize,
71 code_dim: usize,
72 mask_ratio: f32,
73 ema_momentum: f32,
74 commitment_weight: f32,
75 temperature: f32,
76 eps: f32,
77 ) -> SslResult<Self> {
78 if n_codes == 0 {
79 return Err(SslError::InvalidParameter {
80 name: "n_codes".into(),
81 reason: "must be > 0".into(),
82 });
83 }
84 if code_dim == 0 {
85 return Err(SslError::InvalidParameter {
86 name: "code_dim".into(),
87 reason: "must be > 0".into(),
88 });
89 }
90 if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
91 return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
92 }
93 if !(ema_momentum.is_finite() && (0.0..=1.0).contains(&ema_momentum)) {
94 return Err(SslError::InvalidMomentum {
95 momentum: ema_momentum,
96 });
97 }
98 if !(commitment_weight.is_finite() && commitment_weight >= 0.0) {
99 return Err(SslError::InvalidParameter {
100 name: "commitment_weight".into(),
101 reason: "must be finite and >= 0".into(),
102 });
103 }
104 if !(temperature.is_finite() && temperature > 0.0) {
105 return Err(SslError::InvalidTemperature { temp: temperature });
106 }
107 if !(eps.is_finite() && eps > 0.0) {
108 return Err(SslError::InvalidParameter {
109 name: "eps".into(),
110 reason: "must be finite and > 0".into(),
111 });
112 }
113 Ok(Self {
114 n_codes,
115 code_dim,
116 mask_ratio,
117 ema_momentum,
118 commitment_weight,
119 temperature,
120 eps,
121 })
122 }
123}
124
125#[derive(Debug, Clone)]
139pub struct VqCodebook {
140 pub embeddings: Vec<f32>,
142 pub n_codes: usize,
144 pub code_dim: usize,
146 pub ema_momentum: f32,
148 pub commitment_weight: f32,
150 pub ema_counts: Vec<f32>,
152 pub ema_sum: Vec<f32>,
154}
155
156pub fn vq_codebook_init(
166 n_codes: usize,
167 code_dim: usize,
168 rng: &mut LcgRng,
169) -> SslResult<VqCodebook> {
170 if n_codes == 0 {
171 return Err(SslError::InvalidParameter {
172 name: "n_codes".into(),
173 reason: "must be > 0".into(),
174 });
175 }
176 if code_dim == 0 {
177 return Err(SslError::InvalidParameter {
178 name: "code_dim".into(),
179 reason: "must be > 0".into(),
180 });
181 }
182 let total = n_codes * code_dim;
183 let mut embeddings = vec![0.0_f32; total];
184 rng.fill_normal(&mut embeddings);
185 let scale = 1.0 / (code_dim as f32).sqrt();
186 for v in &mut embeddings {
187 *v *= scale;
188 }
189 let ema_counts = vec![1.0_f32; n_codes];
191 let ema_sum = embeddings.clone();
192 Ok(VqCodebook {
193 embeddings,
194 n_codes,
195 code_dim,
196 ema_momentum: 0.999,
197 commitment_weight: 0.25,
198 ema_counts,
199 ema_sum,
200 })
201}
202
203pub fn vq_encode(
226 codebook: &VqCodebook,
227 embeddings: &[f32],
228 n_patches: usize,
229 code_dim: usize,
230) -> SslResult<(Vec<usize>, Vec<f32>, f32)> {
231 if n_patches == 0 || code_dim == 0 {
232 return Err(SslError::EmptyInput);
233 }
234 let expected = n_patches * code_dim;
235 if embeddings.len() != expected {
236 return Err(SslError::DimensionMismatch {
237 expected,
238 got: embeddings.len(),
239 });
240 }
241 if codebook.n_codes == 0 {
242 return Err(SslError::EmptyInput);
243 }
244
245 let k = codebook.n_codes;
246 let c = code_dim;
247 let beta = codebook.commitment_weight;
248
249 let mut indices = Vec::with_capacity(n_patches);
250 let mut quantized_z = Vec::with_capacity(n_patches * c);
251 let mut vq_loss_acc = 0.0_f64;
252
253 for i in 0..n_patches {
254 let z = &embeddings[i * c..(i + 1) * c];
255
256 let mut best_k = 0usize;
258 let mut best_dist = f64::MAX;
259
260 for ki in 0..k {
261 let e_k = &codebook.embeddings[ki * c..(ki + 1) * c];
262 let dist: f64 = z
263 .iter()
264 .zip(e_k.iter())
265 .map(|(&zi, &eki)| {
266 let d = (zi - eki) as f64;
267 d * d
268 })
269 .sum();
270 if dist < best_dist {
271 best_dist = dist;
272 best_k = ki;
273 }
274 }
275
276 indices.push(best_k);
277
278 let e_star = &codebook.embeddings[best_k * c..(best_k + 1) * c];
280 quantized_z.extend_from_slice(e_star);
281
282 vq_loss_acc += best_dist * (1.0 + beta as f64);
286 }
287
288 let vq_loss = (vq_loss_acc / n_patches as f64) as f32;
289 Ok((indices, quantized_z, vq_loss))
290}
291
292pub fn vq_update_codebook(
309 codebook: &mut VqCodebook,
310 embeddings: &[f32],
311 indices: &[usize],
312 n_patches: usize,
313) -> SslResult<()> {
314 if n_patches == 0 {
315 return Err(SslError::EmptyInput);
316 }
317 let c = codebook.code_dim;
318 let k = codebook.n_codes;
319 let expected_emb = n_patches * c;
320 if embeddings.len() != expected_emb {
321 return Err(SslError::DimensionMismatch {
322 expected: expected_emb,
323 got: embeddings.len(),
324 });
325 }
326 if indices.len() != n_patches {
327 return Err(SslError::DimensionMismatch {
328 expected: n_patches,
329 got: indices.len(),
330 });
331 }
332 for &idx in indices {
334 if idx >= k {
335 return Err(SslError::InvalidParameter {
336 name: "index".into(),
337 reason: format!("codebook index {idx} out of range [0, {k})"),
338 });
339 }
340 }
341
342 let m = codebook.ema_momentum;
343 let one_minus_m = 1.0 - m;
344
345 let mut batch_counts = vec![0.0_f32; k];
347 let mut batch_sums = vec![0.0_f32; k * c];
348
349 for (i, &ki) in indices.iter().enumerate() {
350 batch_counts[ki] += 1.0;
351 let z = &embeddings[i * c..(i + 1) * c];
352 let sum_slice = &mut batch_sums[ki * c..(ki + 1) * c];
353 for (s, &zi) in sum_slice.iter_mut().zip(z.iter()) {
354 *s += zi;
355 }
356 }
357
358 for ki in 0..k {
360 codebook.ema_counts[ki] = m * codebook.ema_counts[ki] + one_minus_m * batch_counts[ki];
361 let count = codebook.ema_counts[ki].max(1e-6); let sum_slice = &mut codebook.ema_sum[ki * c..(ki + 1) * c];
363 let batch_sum_slice = &batch_sums[ki * c..(ki + 1) * c];
364 for (s, &bs) in sum_slice.iter_mut().zip(batch_sum_slice.iter()) {
365 *s = m * (*s) + one_minus_m * bs;
366 }
367 let inv_count = 1.0 / count;
369 let emb_slice = &mut codebook.embeddings[ki * c..(ki + 1) * c];
370 let ema_sum_slice = &codebook.ema_sum[ki * c..(ki + 1) * c];
371 for (e, &es) in emb_slice.iter_mut().zip(ema_sum_slice.iter()) {
372 *e = es * inv_count;
373 }
374 }
375
376 Ok(())
377}
378
379#[derive(Debug, Clone)]
383pub struct BeitResult {
384 pub beit_loss: f32,
386 pub vq_loss: f32,
388 pub total_loss: f32,
390 pub n_masked: usize,
392 pub codebook_usage: f32,
394 pub perplexity: f32,
396}
397
398pub fn beit_loss(
417 student_logits: &[f32],
418 token_indices: &[usize],
419 mask: &[bool],
420 n_patches: usize,
421 n_codes: usize,
422 config: &BeitConfig,
423) -> SslResult<BeitResult> {
424 if n_codes == 0 {
425 return Err(SslError::InvalidParameter {
426 name: "n_codes".into(),
427 reason: "must be > 0".into(),
428 });
429 }
430 if n_patches == 0 {
431 return Err(SslError::EmptyInput);
432 }
433 if !(config.temperature.is_finite() && config.temperature > 0.0) {
434 return Err(SslError::InvalidTemperature {
435 temp: config.temperature,
436 });
437 }
438
439 let expected_logits = n_patches * n_codes;
440 if student_logits.len() != expected_logits {
441 return Err(SslError::DimensionMismatch {
442 expected: expected_logits,
443 got: student_logits.len(),
444 });
445 }
446 if token_indices.len() != n_patches {
447 return Err(SslError::DimensionMismatch {
448 expected: n_patches,
449 got: token_indices.len(),
450 });
451 }
452 if mask.len() != n_patches {
453 return Err(SslError::DimensionMismatch {
454 expected: n_patches,
455 got: mask.len(),
456 });
457 }
458
459 for &qi in token_indices {
461 if qi >= n_codes {
462 return Err(SslError::InvalidParameter {
463 name: "token_index".into(),
464 reason: format!("token index {qi} out of range [0, {n_codes})"),
465 });
466 }
467 }
468
469 let tau = config.temperature;
470 let n_masked = mask.iter().filter(|&&m| m).count();
471
472 let mut beit_loss_acc = 0.0_f64;
474
475 let mut code_freq = vec![0.0_f64; n_codes];
477
478 for i in 0..n_patches {
479 let qi = token_indices[i];
480 let logits = &student_logits[i * n_codes..(i + 1) * n_codes];
481
482 code_freq[qi] += 1.0;
485
486 if !mask[i] {
487 continue; }
489
490 let mut max_v = f32::NEG_INFINITY;
492 for &lv in logits {
493 let scaled = lv / tau;
494 if scaled > max_v {
495 max_v = scaled;
496 }
497 }
498 let mut sum_exp = 0.0_f64;
499 let mut exp_qi = 0.0_f64;
500 for (k, &lv) in logits.iter().enumerate() {
501 let e = ((lv / tau - max_v) as f64).exp();
502 sum_exp += e;
503 if k == qi {
504 exp_qi = e;
505 }
506 }
507 let log_prob = (exp_qi / sum_exp.max(1e-30)).max(1e-30_f64).ln();
508 beit_loss_acc -= log_prob;
509 }
510
511 let beit_loss_val = if n_masked == 0 {
512 0.0_f32
513 } else {
514 (beit_loss_acc / n_masked as f64) as f32
515 };
516
517 let total_assignments = n_patches as f64;
519 let n_used = code_freq.iter().filter(|&&f| f > 0.0).count();
520 let codebook_usage = n_used as f32 / n_codes as f32;
521
522 let mut entropy = 0.0_f64;
524 for &freq in &code_freq {
525 if freq > 0.0 {
526 let p = freq / total_assignments;
527 entropy -= p * p.ln();
528 }
529 }
530 let perplexity = entropy.exp().clamp(1.0, n_codes as f64) as f32;
531
532 let vq_loss_val = 0.0_f32;
539 let total_loss = beit_loss_val + vq_loss_val;
540
541 Ok(BeitResult {
542 beit_loss: beit_loss_val,
543 vq_loss: vq_loss_val,
544 total_loss,
545 n_masked,
546 codebook_usage,
547 perplexity,
548 })
549}
550
551pub fn beit_block_mask(
580 n_patches: usize,
581 patch_grid_h: usize,
582 patch_grid_w: usize,
583 mask_ratio: f32,
584 rng: &mut LcgRng,
585) -> SslResult<Vec<bool>> {
586 if patch_grid_h == 0 || patch_grid_w == 0 {
587 return Err(SslError::EmptyInput);
588 }
589 if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
590 return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
591 }
592 let grid_total = patch_grid_h * patch_grid_w;
593 if n_patches != grid_total {
594 return Err(SslError::InvalidParameter {
595 name: "n_patches".into(),
596 reason: format!(
597 "n_patches ({n_patches}) must equal patch_grid_h * patch_grid_w ({grid_total})"
598 ),
599 });
600 }
601
602 let target_masked = (n_patches as f32 * mask_ratio).floor() as usize;
603 let mut mask = vec![false; n_patches];
604 let mut n_masked = 0usize;
605
606 if target_masked == 0 {
607 return Ok(mask);
608 }
609
610 const ASPECT_RATIOS: [f32; 7] = [0.3, 0.5, 0.75, 1.0, 1.33, 2.0, 3.0];
612
613 let min_area = (n_patches as f32 * 0.05).ceil() as usize;
615 let min_area = min_area.max(1);
616 let max_area = (n_patches as f32 * 0.30).ceil() as usize;
617 let max_area = max_area.max(min_area);
618
619 let max_iters = (target_masked * 16 + 1).max(200);
621 let mut iters = 0usize;
622
623 while n_masked < target_masked && iters < max_iters {
624 iters += 1;
625
626 let area_range = max_area - min_area + 1;
628 let area = min_area + rng.next_usize(area_range);
629
630 let ratio_idx = rng.next_usize(ASPECT_RATIOS.len());
632 let ar = ASPECT_RATIOS[ratio_idx];
633
634 let bh_f = (area as f32 / ar).sqrt();
636 let bw_f = (area as f32 * ar).sqrt();
637 let bh = (bh_f.round() as usize).clamp(1, patch_grid_h);
638 let bw = (bw_f.round() as usize).clamp(1, patch_grid_w);
639
640 let r0 = if patch_grid_h > bh {
642 rng.next_usize(patch_grid_h - bh + 1)
643 } else {
644 0
645 };
646 let c0 = if patch_grid_w > bw {
647 rng.next_usize(patch_grid_w - bw + 1)
648 } else {
649 0
650 };
651
652 for r in r0..r0 + bh {
654 for c in c0..c0 + bw {
655 let idx = r * patch_grid_w + c;
656 if !mask[idx] {
657 mask[idx] = true;
658 n_masked += 1;
659 if n_masked >= target_masked {
661 break;
662 }
663 }
664 }
665 if n_masked >= target_masked {
666 break;
667 }
668 }
669 }
670
671 Ok(mask)
672}
673
674#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[test]
684 fn vq_codebook_init_correct_shape() {
685 let mut rng = LcgRng::new(1);
686 let cb = vq_codebook_init(64, 32, &mut rng).expect("vq_codebook_init should succeed");
687 assert_eq!(cb.embeddings.len(), 64 * 32);
688 assert_eq!(cb.n_codes, 64);
689 assert_eq!(cb.code_dim, 32);
690 assert_eq!(cb.ema_counts.len(), 64);
691 assert_eq!(cb.ema_sum.len(), 64 * 32);
692 }
693
694 #[test]
696 fn vq_codebook_init_entries_finite() {
697 let mut rng = LcgRng::new(2);
698 let cb = vq_codebook_init(16, 8, &mut rng).expect("vq_codebook_init should succeed");
699 assert!(cb.embeddings.iter().all(|v| v.is_finite()));
700 assert!(cb.ema_sum.iter().all(|v| v.is_finite()));
701 }
702
703 #[test]
705 fn vq_codebook_init_rejects_zero_codes() {
706 let mut rng = LcgRng::new(3);
707 assert!(vq_codebook_init(0, 32, &mut rng).is_err());
708 }
709
710 #[test]
712 fn vq_codebook_init_rejects_zero_dim() {
713 let mut rng = LcgRng::new(4);
714 assert!(vq_codebook_init(16, 0, &mut rng).is_err());
715 }
716
717 #[test]
721 fn vq_encode_indices_in_range() {
722 let mut rng = LcgRng::new(5);
723 let k = 32;
724 let c = 8;
725 let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
726 let n = 20;
727 let mut emb = vec![0.0_f32; n * c];
728 rng.fill_normal(&mut emb);
729 let (indices, _, _) = vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
730 assert_eq!(indices.len(), n);
731 for &idx in &indices {
732 assert!(idx < k, "index {idx} out of range");
733 }
734 }
735
736 #[test]
738 fn vq_encode_vq_loss_non_negative() {
739 let mut rng = LcgRng::new(6);
740 let k = 16;
741 let c = 4;
742 let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
743 let n = 10;
744 let mut emb = vec![0.0_f32; n * c];
745 rng.fill_normal(&mut emb);
746 let (_, _, vq_loss) = vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
747 assert!(vq_loss >= 0.0, "vq_loss = {vq_loss} should be >= 0");
748 }
749
750 #[test]
752 fn vq_encode_quantized_shape() {
753 let mut rng = LcgRng::new(7);
754 let k = 8;
755 let c = 6;
756 let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
757 let n = 5;
758 let mut emb = vec![0.0_f32; n * c];
759 rng.fill_normal(&mut emb);
760 let (indices, quantized, _) = vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
761 assert_eq!(quantized.len(), n * c);
762 assert_eq!(indices.len(), n);
763 }
764
765 #[test]
767 fn vq_encode_exact_match_selected() {
768 let mut rng = LcgRng::new(8);
769 let k = 8;
770 let c = 4;
771 let mut cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
772 for v in &mut cb.embeddings[3 * c..4 * c] {
774 *v = 0.0;
775 }
776 let emb = vec![0.0_f32; c];
779 let (indices, _, vq_loss) = vq_encode(&cb, &emb, 1, c).expect("vq_encode should succeed");
780 assert!(indices[0] < k);
782 assert!(vq_loss >= 0.0);
784 }
785
786 #[test]
791 fn vq_update_codebook_ema_moves_toward_assigned() {
792 let mut rng = LcgRng::new(9);
793 let k = 4;
794 let c = 3;
795 let mut cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
796 cb.ema_momentum = 0.5;
798
799 let orig_code0: Vec<f32> = cb.embeddings[0..c].to_vec();
801
802 let n = 5;
804 let emb = vec![1.0_f32; n * c];
805 let indices = vec![0usize; n];
806 vq_update_codebook(&mut cb, &emb, &indices, n).expect("vq_update_codebook should succeed");
807
808 let updated_code0: Vec<f32> = cb.embeddings[0..c].to_vec();
809 for (orig, updated) in orig_code0.iter().zip(updated_code0.iter()) {
811 let dist_before = (orig - 1.0).abs();
812 let dist_after = (updated - 1.0).abs();
813 assert!(
814 dist_after < dist_before || dist_before < 1e-6,
815 "EMA update did not move code 0 toward [1,1,1]: orig={orig} updated={updated}"
816 );
817 }
818 }
819
820 #[test]
824 fn beit_loss_finite_and_non_negative() {
825 let mut rng = LcgRng::new(10);
826 let n = 16;
827 let k = 8;
828 let cfg = BeitConfig {
829 n_codes: k,
830 code_dim: 4,
831 ..BeitConfig::default()
832 };
833 let mut logits = vec![0.0_f32; n * k];
834 rng.fill_normal(&mut logits);
835 let indices: Vec<usize> = (0..n).map(|i| i % k).collect();
836 let mask: Vec<bool> = (0..n).map(|i| i % 2 == 0).collect();
837 let result =
838 beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
839 assert!(result.total_loss.is_finite(), "total_loss should be finite");
840 assert!(result.beit_loss >= 0.0, "beit_loss should be >= 0");
841 }
842
843 #[test]
845 fn beit_loss_n_masked_matches_mask() {
846 let n = 20;
847 let k = 4;
848 let cfg = BeitConfig {
849 n_codes: k,
850 ..BeitConfig::default()
851 };
852 let logits = vec![1.0_f32; n * k];
853 let indices = vec![0usize; n];
854 let mask: Vec<bool> = (0..n).map(|i| i < 7).collect(); let result =
856 beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
857 assert_eq!(result.n_masked, 7);
858 }
859
860 #[test]
862 fn beit_loss_all_unmasked_returns_zero() {
863 let n = 8;
864 let k = 4;
865 let cfg = BeitConfig {
866 n_codes: k,
867 ..BeitConfig::default()
868 };
869 let logits = vec![0.5_f32; n * k];
870 let indices = vec![0usize; n];
871 let mask = vec![false; n];
872 let result =
873 beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
874 assert_eq!(result.n_masked, 0);
875 assert!(
876 result.beit_loss.abs() < 1e-7,
877 "expected 0 loss, got {}",
878 result.beit_loss
879 );
880 }
881
882 #[test]
884 fn beit_loss_codebook_usage_in_range() {
885 let mut rng = LcgRng::new(11);
886 let n = 12;
887 let k = 16;
888 let cfg = BeitConfig {
889 n_codes: k,
890 ..BeitConfig::default()
891 };
892 let mut logits = vec![0.0_f32; n * k];
893 rng.fill_normal(&mut logits);
894 let indices: Vec<usize> = (0..n).map(|_| rng.next_usize(k)).collect();
895 let mask = vec![true; n];
896 let result =
897 beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
898 assert!(
899 (0.0..=1.0).contains(&result.codebook_usage),
900 "codebook_usage = {}",
901 result.codebook_usage
902 );
903 }
904
905 #[test]
907 fn beit_loss_perplexity_in_range() {
908 let mut rng = LcgRng::new(12);
909 let n = 32;
910 let k = 16;
911 let cfg = BeitConfig {
912 n_codes: k,
913 ..BeitConfig::default()
914 };
915 let mut logits = vec![0.0_f32; n * k];
916 rng.fill_normal(&mut logits);
917 let indices: Vec<usize> = (0..n).map(|i| i % k).collect();
919 let mask = vec![true; n];
920 let result =
921 beit_loss(&logits, &indices, &mask, n, k, &cfg).expect("beit_loss should succeed");
922 assert!(
923 result.perplexity >= 1.0 && result.perplexity <= k as f32 + 1e-4,
924 "perplexity = {} out of [1, {}]",
925 result.perplexity,
926 k
927 );
928 }
929
930 #[test]
932 fn beit_loss_rejects_zero_n_codes() {
933 let logits = vec![1.0_f32; 4];
934 let indices = vec![0usize; 4];
935 let mask = vec![true; 4];
936 let cfg = BeitConfig::default();
937 assert!(beit_loss(&logits, &indices, &mask, 4, 0, &cfg).is_err());
938 }
939
940 #[test]
944 fn beit_block_mask_correct_length() {
945 let mut rng = LcgRng::new(13);
946 let h = 14;
947 let w = 14;
948 let n = h * w;
949 let mask = beit_block_mask(n, h, w, 0.4, &mut rng).expect("beit_block_mask should succeed");
950 assert_eq!(mask.len(), n);
951 }
952
953 #[test]
955 fn beit_block_mask_zero_ratio_all_unmasked() {
956 let mut rng = LcgRng::new(14);
957 let h = 8;
958 let w = 8;
959 let n = h * w;
960 let mask = beit_block_mask(n, h, w, 0.0, &mut rng).expect("beit_block_mask should succeed");
961 assert!(mask.iter().all(|&v| !v));
962 }
963
964 #[test]
966 fn beit_block_mask_rejects_invalid_ratio() {
967 let mut rng = LcgRng::new(15);
968 assert!(beit_block_mask(16, 4, 4, 1.1, &mut rng).is_err());
969 assert!(beit_block_mask(16, 4, 4, -0.1, &mut rng).is_err());
970 assert!(beit_block_mask(16, 4, 4, f32::NAN, &mut rng).is_err());
971 }
972
973 #[test]
976 fn beit_block_mask_approx_ratio() {
977 let mut rng = LcgRng::new(16);
978 let h = 14;
979 let w = 14;
980 let n = h * w; let ratio = 0.4_f32;
982 let mask =
983 beit_block_mask(n, h, w, ratio, &mut rng).expect("beit_block_mask should succeed");
984 let n_masked = mask.iter().filter(|&&v| v).count();
985 let target = (n as f32 * ratio).floor() as usize;
987 assert!(
988 n_masked <= target,
989 "n_masked ({n_masked}) > target ({target}): block stopped early but should not over-shoot"
990 );
991 assert!(
993 n_masked >= target / 2,
994 "too few patches masked: {n_masked} vs target {target}"
995 );
996 }
997
998 #[test]
1000 fn vq_encode_batch_all_valid_assignments() {
1001 let mut rng = LcgRng::new(17);
1002 let k = 32;
1003 let c = 16;
1004 let cb = vq_codebook_init(k, c, &mut rng).expect("vq_codebook_init should succeed");
1005 let n = 50;
1006 let mut emb = vec![0.0_f32; n * c];
1007 rng.fill_normal(&mut emb);
1008 let (indices, quantized, vq_loss) =
1009 vq_encode(&cb, &emb, n, c).expect("vq_encode should succeed");
1010 assert_eq!(indices.len(), n);
1011 assert_eq!(quantized.len(), n * c);
1012 assert!(vq_loss.is_finite() && vq_loss >= 0.0);
1013 for &idx in &indices {
1014 assert!(idx < k, "assignment {idx} out of [0, {k})");
1015 }
1016 }
1017}