1use crate::error::{SslError, SslResult};
27
28#[derive(Debug, Clone)]
32pub struct DenseCLConfig {
33 pub temperature: f32,
35 pub lambda_dense: f32,
37 pub n_negatives_per_pos: usize,
40 pub correspondence_topk: usize,
42 pub eps: f32,
44}
45
46impl Default for DenseCLConfig {
47 fn default() -> Self {
48 Self {
49 temperature: 0.2,
50 lambda_dense: 0.5,
51 n_negatives_per_pos: 256,
52 correspondence_topk: 1,
53 eps: 1e-8,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct PixProConfig {
61 pub temperature: f32,
63 pub propagation_iters: usize,
65 pub eps: f32,
67}
68
69impl Default for PixProConfig {
70 fn default() -> Self {
71 Self {
72 temperature: 0.2,
73 propagation_iters: 1,
74 eps: 1e-8,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
83pub struct DenseCLResult {
84 pub total_loss: f32,
86 pub global_loss: f32,
88 pub dense_loss: f32,
90 pub mean_correspondence_sim: f32,
92 pub n_positions: usize,
94}
95
96#[inline]
101fn l2_normalise_rows_inplace(data: &mut [f32], n: usize, d: usize, eps: f32) {
102 for row in data.chunks_mut(d) {
103 let norm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
104 if norm > eps {
105 let inv = 1.0 / norm;
106 for v in row.iter_mut() {
107 *v *= inv;
108 }
109 }
110 }
111 let _ = n;
112}
113
114#[inline]
116fn l2_normalise_clone(src: &[f32], n: usize, d: usize, eps: f32) -> Vec<f32> {
117 let mut out = src.to_vec();
118 l2_normalise_rows_inplace(&mut out, n, d, eps);
119 out
120}
121
122#[inline]
124fn dot(a: &[f32], b: &[f32]) -> f32 {
125 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
126}
127
128#[inline]
130fn log_sum_exp(vals: &[f32]) -> f64 {
131 if vals.is_empty() {
132 return f64::NEG_INFINITY;
133 }
134 let max_v = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
135 let sum: f64 = vals.iter().map(|&v| ((v - max_v) as f64).exp()).sum();
136 (max_v as f64) + sum.ln()
137}
138
139#[inline]
141fn check_temperature(t: f32) -> SslResult<()> {
142 if !(t.is_finite() && t > 0.0) {
143 return Err(SslError::InvalidTemperature { temp: t });
144 }
145 Ok(())
146}
147
148#[inline]
150fn check_spatial_dense(spatial_size: usize, dense_dim: usize) -> SslResult<()> {
151 if dense_dim == 0 {
152 return Err(SslError::InvalidFeatureDim);
153 }
154 if spatial_size == 0 {
155 return Err(SslError::EmptyInput);
156 }
157 Ok(())
158}
159
160pub fn dense_correspondence(
174 query_dense: &[f32],
175 key_dense: &[f32],
176 spatial_size: usize,
177 dense_dim: usize,
178) -> Vec<usize> {
179 let mut corr = Vec::with_capacity(spatial_size);
182 for i in 0..spatial_size {
183 let q_row = &query_dense[i * dense_dim..(i + 1) * dense_dim];
184 let mut best_j = 0usize;
185 let mut best_s = f32::NEG_INFINITY;
186 for j in 0..spatial_size {
187 let k_row = &key_dense[j * dense_dim..(j + 1) * dense_dim];
188 let s = dot(q_row, k_row);
189 if s > best_s {
190 best_s = s;
191 best_j = j;
192 }
193 }
194 corr.push(best_j);
195 }
196 corr
197}
198
199fn dense_correspondence_topk(
206 query_dense_norm: &[f32],
207 key_dense_norm: &[f32],
208 spatial_size: usize,
209 dense_dim: usize,
210 topk: usize,
211 eps: f32,
212) -> Vec<f32> {
213 let k = topk.max(1);
214 let mut pos_keys = vec![0.0_f32; spatial_size * dense_dim];
215
216 let mut sims: Vec<(f32, usize)> = Vec::with_capacity(spatial_size);
218
219 for i in 0..spatial_size {
220 let q_row = &query_dense_norm[i * dense_dim..(i + 1) * dense_dim];
221 sims.clear();
222 for j in 0..spatial_size {
223 let k_row = &key_dense_norm[j * dense_dim..(j + 1) * dense_dim];
224 sims.push((dot(q_row, k_row), j));
225 }
226 let take = k.min(spatial_size);
228 for t in 0..take {
229 let mut best_idx = t;
230 for u in (t + 1)..sims.len() {
231 if sims[u].0 > sims[best_idx].0 {
232 best_idx = u;
233 }
234 }
235 sims.swap(t, best_idx);
236 }
237 let out_row = &mut pos_keys[i * dense_dim..(i + 1) * dense_dim];
239 for &(_, kj) in sims.iter().take(take) {
240 let k_row = &key_dense_norm[kj * dense_dim..(kj + 1) * dense_dim];
241 for (o, &v) in out_row.iter_mut().zip(k_row.iter()) {
242 *o += v;
243 }
244 }
245 let norm: f32 = out_row.iter().map(|v| v * v).sum::<f32>().sqrt();
247 if norm > eps {
248 let inv = 1.0 / norm;
249 for v in out_row.iter_mut() {
250 *v *= inv;
251 }
252 }
253 }
254 pos_keys
255}
256
257pub fn dense_infonce(
272 query: &[f32],
273 pos_keys: &[f32],
274 all_query: &[f32],
275 spatial_size: usize,
276 batch_size: usize,
277 dense_dim: usize,
278 temperature: f32,
279) -> SslResult<f32> {
280 if spatial_size == 0 || dense_dim == 0 || batch_size == 0 {
281 return Err(SslError::EmptyInput);
282 }
283 check_temperature(temperature)?;
284
285 let hw_total = spatial_size * batch_size;
286 if all_query.len() != hw_total * dense_dim {
287 return Err(SslError::DimensionMismatch {
288 expected: hw_total * dense_dim,
289 got: all_query.len(),
290 });
291 }
292 if query.len() != spatial_size * dense_dim {
293 return Err(SslError::DimensionMismatch {
294 expected: spatial_size * dense_dim,
295 got: query.len(),
296 });
297 }
298 if pos_keys.len() != spatial_size * dense_dim {
299 return Err(SslError::DimensionMismatch {
300 expected: spatial_size * dense_dim,
301 got: pos_keys.len(),
302 });
303 }
304
305 let inv_t = 1.0_f32 / temperature;
306 let mut total_loss = 0.0_f64;
307
308 for i in 0..spatial_size {
314 let q_row = &query[i * dense_dim..(i + 1) * dense_dim];
315 let p_row = &pos_keys[i * dense_dim..(i + 1) * dense_dim];
316
317 let pos_logit = dot(q_row, p_row) * inv_t;
319
320 let mut neg_logits: Vec<f32> = Vec::with_capacity(hw_total);
322 for l in 0..hw_total {
323 let n_row = &all_query[l * dense_dim..(l + 1) * dense_dim];
324 neg_logits.push(dot(q_row, n_row) * inv_t);
325 }
326
327 let log_z_neg = log_sum_exp(&neg_logits);
329
330 let mut all_logits = neg_logits;
332 all_logits.push(pos_logit);
333 let log_z_all = log_sum_exp(&all_logits);
334
335 let _ = log_z_neg;
336 total_loss += log_z_all - (pos_logit as f64);
338 }
339
340 Ok((total_loss / spatial_size as f64) as f32)
341}
342
343fn global_infonce_single(
350 query_global: &[f32],
351 key_global: &[f32],
352 queue: &[f32],
353 global_dim: usize,
354 temperature: f32,
355 eps: f32,
356) -> f32 {
357 let inv_t = 1.0_f32 / temperature;
358
359 let q = l2_normalise_clone(query_global, 1, global_dim, eps);
361 let k = l2_normalise_clone(key_global, 1, global_dim, eps);
362
363 let pos_logit = dot(&q, &k) * inv_t;
364
365 if queue.is_empty() {
366 return 0.0;
368 }
369 let n_neg = queue.len() / global_dim;
370 let mut logits: Vec<f32> = Vec::with_capacity(n_neg + 1);
371 logits.push(pos_logit);
372 for kn in 0..n_neg {
373 let k_row = &queue[kn * global_dim..(kn + 1) * global_dim];
374 logits.push(dot(&q, k_row) * inv_t);
375 }
376 let log_z = log_sum_exp(&logits);
377 (log_z - pos_logit as f64) as f32
378}
379
380pub fn dense_cl_loss(
405 query_global: &[f32],
406 key_global: &[f32],
407 query_dense: &[f32],
408 key_dense: &[f32],
409 neg_queue: &[f32],
410 spatial_size: usize,
411 global_dim: usize,
412 dense_dim: usize,
413 config: &DenseCLConfig,
414) -> SslResult<DenseCLResult> {
415 check_temperature(config.temperature)?;
417 check_spatial_dense(spatial_size, dense_dim)?;
418
419 if global_dim == 0 {
420 return Err(SslError::InvalidFeatureDim);
421 }
422 if !(config.lambda_dense.is_finite()
423 && config.lambda_dense >= 0.0
424 && config.lambda_dense <= 1.0)
425 {
426 return Err(SslError::InvalidParameter {
427 name: "lambda_dense".to_string(),
428 reason: "must be in [0, 1]".to_string(),
429 });
430 }
431 if query_global.len() != global_dim {
432 return Err(SslError::DimensionMismatch {
433 expected: global_dim,
434 got: query_global.len(),
435 });
436 }
437 if key_global.len() != global_dim {
438 return Err(SslError::DimensionMismatch {
439 expected: global_dim,
440 got: key_global.len(),
441 });
442 }
443 if query_dense.len() != spatial_size * dense_dim {
444 return Err(SslError::DimensionMismatch {
445 expected: spatial_size * dense_dim,
446 got: query_dense.len(),
447 });
448 }
449 if key_dense.len() != spatial_size * dense_dim {
450 return Err(SslError::DimensionMismatch {
451 expected: spatial_size * dense_dim,
452 got: key_dense.len(),
453 });
454 }
455
456 let q_norm = l2_normalise_clone(query_dense, spatial_size, dense_dim, config.eps);
458 let k_norm = l2_normalise_clone(key_dense, spatial_size, dense_dim, config.eps);
459
460 let global_loss = if config.lambda_dense < 1.0 {
462 global_infonce_single(
463 query_global,
464 key_global,
465 neg_queue,
466 global_dim,
467 config.temperature,
468 config.eps,
469 )
470 } else {
471 0.0
472 };
473
474 let pos_keys = dense_correspondence_topk(
477 &q_norm,
478 &k_norm,
479 spatial_size,
480 dense_dim,
481 config.correspondence_topk,
482 config.eps,
483 );
484
485 let mut sum_sim = 0.0_f64;
487 let corr_map = dense_correspondence(&q_norm, &k_norm, spatial_size, dense_dim);
488 for i in 0..spatial_size {
489 let q_row = &q_norm[i * dense_dim..(i + 1) * dense_dim];
490 let j = corr_map[i];
491 let k_row = &k_norm[j * dense_dim..(j + 1) * dense_dim];
492 sum_sim += dot(q_row, k_row) as f64;
493 }
494 let mean_correspondence_sim = (sum_sim / spatial_size as f64) as f32;
495
496 let dense_loss = if config.lambda_dense > 0.0 {
498 dense_infonce(
502 &q_norm,
503 &pos_keys,
504 &q_norm,
505 spatial_size,
506 1, dense_dim,
508 config.temperature,
509 )?
510 } else {
511 0.0
512 };
513
514 let lambda = config.lambda_dense;
516 let total_loss = (1.0 - lambda) * global_loss + lambda * dense_loss;
517
518 Ok(DenseCLResult {
519 total_loss,
520 global_loss,
521 dense_loss,
522 mean_correspondence_sim,
523 n_positions: spatial_size,
524 })
525}
526
527pub fn pixpro_loss(
553 query_dense: &[f32],
554 key_dense: &[f32],
555 spatial_size: usize,
556 dense_dim: usize,
557 config: &PixProConfig,
558) -> SslResult<f32> {
559 check_temperature(config.temperature)?;
560 check_spatial_dense(spatial_size, dense_dim)?;
561
562 if query_dense.len() != spatial_size * dense_dim {
563 return Err(SslError::DimensionMismatch {
564 expected: spatial_size * dense_dim,
565 got: query_dense.len(),
566 });
567 }
568 if key_dense.len() != spatial_size * dense_dim {
569 return Err(SslError::DimensionMismatch {
570 expected: spatial_size * dense_dim,
571 got: key_dense.len(),
572 });
573 }
574
575 let q_norm = l2_normalise_clone(query_dense, spatial_size, dense_dim, config.eps);
576 let mut k_prop = l2_normalise_clone(key_dense, spatial_size, dense_dim, config.eps);
577
578 let iters = config.propagation_iters.max(1);
579 for _ in 0..iters {
580 k_prop = pixpro_propagate_once(
581 &k_prop,
582 spatial_size,
583 dense_dim,
584 config.temperature,
585 config.eps,
586 );
587 }
588
589 let mut total = 0.0_f64;
591 for i in 0..spatial_size {
592 let q_row = &q_norm[i * dense_dim..(i + 1) * dense_dim];
593 let k_row = &k_prop[i * dense_dim..(i + 1) * dense_dim];
594 let sim = dot(q_row, k_row) as f64;
595 total += 1.0 - sim;
596 }
597 let loss = (total / spatial_size as f64) as f32;
598
599 if !loss.is_finite() {
600 return Err(SslError::NanEncountered {
601 location: "pixpro_loss",
602 });
603 }
604
605 Ok(loss)
606}
607
608fn pixpro_propagate_once(
617 k: &[f32],
618 spatial_size: usize,
619 dense_dim: usize,
620 temperature: f32,
621 eps: f32,
622) -> Vec<f32> {
623 let inv_t = 1.0_f32 / temperature;
624 let mut out = vec![0.0_f32; spatial_size * dense_dim];
625
626 for i in 0..spatial_size {
627 let k_i = &k[i * dense_dim..(i + 1) * dense_dim];
628 let mut scores: Vec<f32> = (0..spatial_size)
630 .map(|j| {
631 let k_j = &k[j * dense_dim..(j + 1) * dense_dim];
632 dot(k_i, k_j) * inv_t
633 })
634 .collect();
635 let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
637 let mut sum_exp = 0.0_f32;
638 for s in scores.iter_mut() {
639 *s = (*s - max_s).exp();
640 sum_exp += *s;
641 }
642 if sum_exp > eps {
643 let inv_sum = 1.0 / sum_exp;
644 for s in scores.iter_mut() {
645 *s *= inv_sum;
646 }
647 }
648 let out_i = &mut out[i * dense_dim..(i + 1) * dense_dim];
650 for (j, &w) in scores.iter().enumerate() {
651 let k_j = &k[j * dense_dim..(j + 1) * dense_dim];
652 for (o, &kv) in out_i.iter_mut().zip(k_j.iter()) {
653 *o += w * kv;
654 }
655 }
656 }
657 l2_normalise_rows_inplace(&mut out, spatial_size, dense_dim, eps);
659 out
660}
661
662#[cfg(test)]
665mod tests {
666 use super::*;
667
668 struct Lcg {
670 state: u64,
671 }
672 impl Lcg {
673 fn new(seed: u64) -> Self {
674 Self { state: seed }
675 }
676 fn next_f32(&mut self) -> f32 {
677 self.state = self
678 .state
679 .wrapping_mul(6_364_136_223_846_793_005)
680 .wrapping_add(1_442_695_040_888_963_407);
681 (self.state >> 33) as f32 / (u32::MAX as f32 + 1.0)
682 }
683 fn fill(&mut self, buf: &mut [f32]) {
684 for v in buf.iter_mut() {
685 *v = self.next_f32() - 0.5;
686 }
687 }
688 }
689
690 fn rand_unit(n: usize, d: usize, seed: u64, eps: f32) -> Vec<f32> {
691 let mut rng = Lcg::new(seed);
692 let mut buf = vec![0.0_f32; n * d];
693 rng.fill(&mut buf);
694 l2_normalise_rows_inplace(&mut buf, n, d, eps);
695 buf
696 }
697
698 #[test]
700 fn total_loss_finite_nonnegative() {
701 let hw = 4;
702 let d = 8;
703 let c = 8;
704 let cfg = DenseCLConfig::default();
705 let qg = rand_unit(1, d, 1, cfg.eps);
706 let kg = rand_unit(1, d, 2, cfg.eps);
707 let qd = rand_unit(hw, c, 3, cfg.eps);
708 let kd = rand_unit(hw, c, 4, cfg.eps);
709 let queue = rand_unit(16, d, 5, cfg.eps);
710
711 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
712 .expect("dense_cl_loss should succeed");
713 assert!(res.total_loss.is_finite(), "total_loss not finite");
714 assert!(
715 res.total_loss >= 0.0,
716 "total_loss negative: {}",
717 res.total_loss
718 );
719 }
720
721 #[test]
723 fn lambda_zero_gives_global_only() {
724 let hw = 4;
725 let d = 8;
726 let c = 8;
727 let cfg = DenseCLConfig {
728 lambda_dense: 0.0,
729 ..Default::default()
730 };
731
732 let qg = rand_unit(1, d, 10, cfg.eps);
733 let kg = rand_unit(1, d, 11, cfg.eps);
734 let qd = rand_unit(hw, c, 12, cfg.eps);
735 let kd = rand_unit(hw, c, 13, cfg.eps);
736 let queue = rand_unit(8, d, 14, cfg.eps);
737
738 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
739 .expect("dense_cl_loss should succeed");
740 assert!(
741 (res.total_loss - res.global_loss).abs() < 1e-5,
742 "total={} global={}",
743 res.total_loss,
744 res.global_loss
745 );
746 }
747
748 #[test]
750 fn lambda_one_gives_dense_only() {
751 let hw = 4;
752 let d = 8;
753 let c = 8;
754 let cfg = DenseCLConfig {
755 lambda_dense: 1.0,
756 ..Default::default()
757 };
758
759 let qg = rand_unit(1, d, 20, cfg.eps);
760 let kg = rand_unit(1, d, 21, cfg.eps);
761 let qd = rand_unit(hw, c, 22, cfg.eps);
762 let kd = rand_unit(hw, c, 23, cfg.eps);
763 let queue = rand_unit(8, d, 24, cfg.eps);
764
765 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
766 .expect("dense_cl_loss should succeed");
767 assert!(
768 (res.total_loss - res.dense_loss).abs() < 1e-5,
769 "total={} dense={}",
770 res.total_loss,
771 res.dense_loss
772 );
773 }
774
775 #[test]
777 fn correspondence_map_length_equals_spatial_size() {
778 let hw = 9;
779 let c = 6;
780 let qd = rand_unit(hw, c, 30, 1e-8);
781 let kd = rand_unit(hw, c, 31, 1e-8);
782 let corr = dense_correspondence(&qd, &kd, hw, c);
783 assert_eq!(corr.len(), hw);
784 }
785
786 #[test]
788 fn correspondence_indices_in_range() {
789 let hw = 16;
790 let c = 8;
791 let qd = rand_unit(hw, c, 40, 1e-8);
792 let kd = rand_unit(hw, c, 41, 1e-8);
793 let corr = dense_correspondence(&qd, &kd, hw, c);
794 for &idx in &corr {
795 assert!(idx < hw, "index {idx} out of [0, {hw})");
796 }
797 }
798
799 #[test]
801 fn mean_correspondence_sim_in_range() {
802 let hw = 6;
803 let d = 4;
804 let c = 4;
805 let cfg = DenseCLConfig::default();
806 let qg = rand_unit(1, d, 50, cfg.eps);
807 let kg = rand_unit(1, d, 51, cfg.eps);
808 let qd = rand_unit(hw, c, 52, cfg.eps);
809 let kd = rand_unit(hw, c, 53, cfg.eps);
810 let queue = rand_unit(4, d, 54, cfg.eps);
811
812 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
813 .expect("dense_cl_loss should succeed");
814 assert!(
815 res.mean_correspondence_sim >= -1.0 - 1e-5 && res.mean_correspondence_sim <= 1.0 + 1e-5,
816 "mean_corr_sim = {}",
817 res.mean_correspondence_sim
818 );
819 }
820
821 #[test]
823 fn identical_query_key_max_correspondence() {
824 let hw = 5;
825 let d = 4;
826 let c = 4;
827 let cfg = DenseCLConfig {
828 lambda_dense: 1.0,
829 ..Default::default()
830 };
831
832 let qg = rand_unit(1, d, 60, cfg.eps);
833 let kg = qg.clone();
834 let qd = rand_unit(hw, c, 62, cfg.eps);
835 let kd = qd.clone();
836 let queue: Vec<f32> = vec![];
837
838 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
839 .expect("dense_cl_loss should succeed");
840 assert!(
841 res.mean_correspondence_sim > 0.99,
842 "expected ~1.0, got {}",
843 res.mean_correspondence_sim
844 );
845 }
846
847 #[test]
849 fn dense_infonce_finite_random() {
850 let hw = 8;
851 let c = 6;
852 let batch = 2;
853 let q = rand_unit(hw, c, 70, 1e-8);
854 let pk = rand_unit(hw, c, 71, 1e-8);
855 let all_q = rand_unit(hw * batch, c, 72, 1e-8);
856 let loss = dense_infonce(&q, &pk, &all_q, hw, batch, c, 0.2)
857 .expect("dense_infonce should succeed");
858 assert!(loss.is_finite(), "loss = {loss}");
859 }
860
861 #[test]
863 fn pixpro_loss_finite_and_bounded() {
864 let hw = 6;
865 let c = 8;
866 let cfg = PixProConfig::default();
867 let qd = rand_unit(hw, c, 80, cfg.eps);
868 let kd = rand_unit(hw, c, 81, cfg.eps);
869 let loss = pixpro_loss(&qd, &kd, hw, c, &cfg).expect("pixpro_loss should succeed");
870 assert!(loss.is_finite(), "loss not finite");
871 assert!(loss >= 0.0, "loss = {loss} < 0");
873 assert!(loss <= 4.0, "loss = {loss} > 4");
874 }
875
876 #[test]
878 fn invalid_temperature_returns_error() {
879 let hw = 4;
880 let d = 4;
881 let c = 4;
882 let cfg = DenseCLConfig {
883 temperature: 0.0,
884 ..Default::default()
885 };
886
887 let qg = rand_unit(1, d, 90, 1e-8);
888 let kg = rand_unit(1, d, 91, 1e-8);
889 let qd = rand_unit(hw, c, 92, 1e-8);
890 let kd = rand_unit(hw, c, 93, 1e-8);
891 let queue = rand_unit(4, d, 94, 1e-8);
892
893 assert!(dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).is_err());
894
895 let px_cfg = PixProConfig {
896 temperature: 0.0,
897 ..Default::default()
898 };
899 assert!(pixpro_loss(&qd, &kd, hw, c, &px_cfg).is_err());
900 }
901
902 #[test]
904 fn single_spatial_position_works() {
905 let hw = 1;
906 let d = 8;
907 let c = 8;
908 let cfg = DenseCLConfig::default();
909
910 let qg = rand_unit(1, d, 100, cfg.eps);
911 let kg = rand_unit(1, d, 101, cfg.eps);
912 let qd = rand_unit(hw, c, 102, cfg.eps);
913 let kd = rand_unit(hw, c, 103, cfg.eps);
914 let queue = rand_unit(4, d, 104, cfg.eps);
915
916 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
917 .expect("dense_cl_loss should succeed");
918 assert!(res.total_loss.is_finite());
919 assert_eq!(res.n_positions, 1);
920
921 let px_cfg = PixProConfig::default();
922 let pl = pixpro_loss(&qd, &kd, hw, c, &px_cfg).expect("pixpro_loss should succeed");
923 assert!(pl.is_finite());
924 }
925
926 #[test]
928 fn larger_batch_size_more_negatives() {
929 let hw = 4;
933 let c = 6;
934 let q = rand_unit(hw, c, 110, 1e-8);
935 let pk = rand_unit(hw, c, 111, 1e-8);
936
937 let batch_small = 1usize;
938 let all_q_small = rand_unit(hw * batch_small, c, 112, 1e-8);
939 let l_small = dense_infonce(&q, &pk, &all_q_small, hw, batch_small, c, 0.2)
940 .expect("dense_infonce should succeed");
941
942 let batch_large = 4usize;
943 let all_q_large = rand_unit(hw * batch_large, c, 113, 1e-8);
944 let l_large = dense_infonce(&q, &pk, &all_q_large, hw, batch_large, c, 0.2)
945 .expect("dense_infonce should succeed");
946
947 assert!(l_small.is_finite());
948 assert!(l_large.is_finite());
949 assert!(l_small >= 0.0);
951 assert!(l_large >= 0.0);
952 }
953
954 #[test]
956 fn linear_combination_matches_components() {
957 let hw = 4;
958 let d = 8;
959 let c = 8;
960 let cfg = DenseCLConfig {
961 lambda_dense: 0.3,
962 ..Default::default()
963 };
964
965 let qg = rand_unit(1, d, 120, cfg.eps);
966 let kg = rand_unit(1, d, 121, cfg.eps);
967 let qd = rand_unit(hw, c, 122, cfg.eps);
968 let kd = rand_unit(hw, c, 123, cfg.eps);
969 let queue = rand_unit(8, d, 124, cfg.eps);
970
971 let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
972 .expect("dense_cl_loss should succeed");
973
974 let expected = 0.7 * res.global_loss + 0.3 * res.dense_loss;
975 assert!(
976 (res.total_loss - expected).abs() < 1e-5,
977 "total={} expected={}",
978 res.total_loss,
979 expected
980 );
981 }
982
983 #[test]
985 fn pixpro_multi_iter_finite() {
986 let hw = 8;
987 let c = 6;
988 let cfg = PixProConfig {
989 temperature: 0.1,
990 propagation_iters: 3,
991 eps: 1e-8,
992 };
993 let qd = rand_unit(hw, c, 130, cfg.eps);
994 let kd = rand_unit(hw, c, 131, cfg.eps);
995 let loss = pixpro_loss(&qd, &kd, hw, c, &cfg).expect("pixpro_loss should succeed");
996 assert!(loss.is_finite());
997 assert!((0.0..=4.0).contains(&loss));
998 }
999
1000 #[test]
1002 fn dimension_mismatch_detected() {
1003 let hw = 4;
1004 let d = 8;
1005 let c = 8;
1006 let cfg = DenseCLConfig::default();
1007
1008 let qg = rand_unit(1, d, 140, cfg.eps);
1010 let kg = rand_unit(1, d, 141, cfg.eps);
1011 let qd_bad = rand_unit(hw - 1, c, 142, cfg.eps); let kd = rand_unit(hw, c, 143, cfg.eps);
1013 let queue = rand_unit(4, d, 144, cfg.eps);
1014
1015 let res = dense_cl_loss(&qg, &kg, &qd_bad, &kd, &queue, hw, d, c, &cfg);
1016 assert!(res.is_err());
1017 }
1018}