Skip to main content

oxicuda_ssl/non_contrastive/
dense_cl.rs

1//! DenseCL / PixPro — pixel-level dense contrastive losses.
2//!
3//! Wang et al. 2021: "Dense Contrastive Learning for Self-Supervised Visual
4//! Pre-Training." Unlike global SSL methods (SimCLR/MoCo), DenseCL contrasts
5//! local feature-map regions, producing representations far more useful for
6//! dense prediction tasks (detection, segmentation).
7//!
8//! # Algorithm Overview
9//!
10//! **Global branch** (standard MoCo InfoNCE):
11//! ```text
12//! L_global = -log[ exp(q_g · k_g / τ) / (exp(q_g · k_g / τ) + Σ_n exp(q_g · n / τ)) ]
13//! ```
14//!
15//! **Dense branch** (DenseCL):
16//! 1. For each query position `i` in `[H*W]`, find best-matching key position:
17//!    `j*(i) = argmax_j cosine_sim(f_q[i], f_k[j])`
18//! 2. Dense InfoNCE using query positions as negatives:
19//!    `L_dense = -1/(H*W) · Σ_i log[ exp(sim(q_i, k_{j*(i)})/τ) / Σ_l exp(sim(q_i, n_l)/τ) ]`
20//!
21//! **Combined**: `L = (1 - λ) · L_global + λ · L_dense`
22//!
23//! **PixPro variant**: Instead of InfoNCE, applies similarity-weighted feature
24//! propagation then minimises cosine distance between predicted and propagated keys.
25
26use crate::error::{SslError, SslResult};
27
28// ─── Configuration ────────────────────────────────────────────────────────────
29
30/// Configuration for the DenseCL combined loss.
31#[derive(Debug, Clone)]
32pub struct DenseCLConfig {
33    /// Temperature τ for InfoNCE numerics (default: 0.2).
34    pub temperature: f32,
35    /// Weight λ for the dense branch in [0, 1] (default: 0.5).
36    pub lambda_dense: f32,
37    /// Number of negative samples per query position (default: 256).
38    /// If larger than the available negatives, all are used.
39    pub n_negatives_per_pos: usize,
40    /// Top-k matches to average for the positive key (default: 1 → argmax).
41    pub correspondence_topk: usize,
42    /// Numerical epsilon for L2-normalisation (default: 1e-8).
43    pub eps: f32,
44}
45
46impl Default for DenseCLConfig {
47    fn default() -> Self {
48        Self {
49            temperature: 0.2,
50            lambda_dense: 0.5,
51            n_negatives_per_pos: 256,
52            correspondence_topk: 1,
53            eps: 1e-8,
54        }
55    }
56}
57
58/// Configuration for the PixPro loss.
59#[derive(Debug, Clone)]
60pub struct PixProConfig {
61    /// Temperature τ for the similarity-weighted propagation (default: 0.2).
62    pub temperature: f32,
63    /// Number of propagation iterations (default: 1).
64    pub propagation_iters: usize,
65    /// Numerical epsilon for L2-normalisation (default: 1e-8).
66    pub eps: f32,
67}
68
69impl Default for PixProConfig {
70    fn default() -> Self {
71        Self {
72            temperature: 0.2,
73            propagation_iters: 1,
74            eps: 1e-8,
75        }
76    }
77}
78
79// ─── Output types ─────────────────────────────────────────────────────────────
80
81/// Detailed output from [`dense_cl_loss`].
82#[derive(Debug, Clone)]
83pub struct DenseCLResult {
84    /// Combined loss: `(1 - λ) · L_global + λ · L_dense`.
85    pub total_loss: f32,
86    /// Global InfoNCE component.
87    pub global_loss: f32,
88    /// Dense InfoNCE component.
89    pub dense_loss: f32,
90    /// Diagnostic: mean cosine similarity of the found correspondences.
91    pub mean_correspondence_sim: f32,
92    /// Number of spatial positions (`H*W`) processed.
93    pub n_positions: usize,
94}
95
96// ─── Internal helpers ─────────────────────────────────────────────────────────
97
98/// L2-normalise every row of a `[n, d]` row-major matrix **in place**.
99/// Rows with near-zero norm are left unchanged (no NaN produced).
100#[inline]
101fn l2_normalise_rows_inplace(data: &mut [f32], n: usize, d: usize, eps: f32) {
102    for row in data.chunks_mut(d) {
103        let norm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
104        if norm > eps {
105            let inv = 1.0 / norm;
106            for v in row.iter_mut() {
107                *v *= inv;
108            }
109        }
110    }
111    let _ = n;
112}
113
114/// L2-normalise every row into a **new** allocation, leaving `src` untouched.
115#[inline]
116fn l2_normalise_clone(src: &[f32], n: usize, d: usize, eps: f32) -> Vec<f32> {
117    let mut out = src.to_vec();
118    l2_normalise_rows_inplace(&mut out, n, d, eps);
119    out
120}
121
122/// Dot product of two equal-length slices.
123#[inline]
124fn dot(a: &[f32], b: &[f32]) -> f32 {
125    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
126}
127
128/// Numerically stable log-sum-exp of a slice.
129#[inline]
130fn log_sum_exp(vals: &[f32]) -> f64 {
131    if vals.is_empty() {
132        return f64::NEG_INFINITY;
133    }
134    let max_v = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
135    let sum: f64 = vals.iter().map(|&v| ((v - max_v) as f64).exp()).sum();
136    (max_v as f64) + sum.ln()
137}
138
139/// Validate `temperature > 0` and finite.
140#[inline]
141fn check_temperature(t: f32) -> SslResult<()> {
142    if !(t.is_finite() && t > 0.0) {
143        return Err(SslError::InvalidTemperature { temp: t });
144    }
145    Ok(())
146}
147
148/// Validate `spatial_size >= 1` and `dense_dim >= 1`.
149#[inline]
150fn check_spatial_dense(spatial_size: usize, dense_dim: usize) -> SslResult<()> {
151    if dense_dim == 0 {
152        return Err(SslError::InvalidFeatureDim);
153    }
154    if spatial_size == 0 {
155        return Err(SslError::EmptyInput);
156    }
157    Ok(())
158}
159
160// ─── Correspondence finding ───────────────────────────────────────────────────
161
162/// Find the best-matching key position for each query position.
163///
164/// Computes the full pairwise cosine similarity matrix `[HW, HW]` and returns
165/// `argmax_j sim(f_q[i], f_k[j])` for every query position `i`.
166///
167/// - `query_dense`: `[HW, C]` L2-normalised dense query features (row-major).
168/// - `key_dense`:   `[HW, C]` L2-normalised dense key features (row-major).
169/// - `spatial_size`: `H*W`.
170/// - `dense_dim`:    `C`.
171///
172/// Returns a `Vec<usize>` of length `spatial_size` mapping query index → key index.
173pub fn dense_correspondence(
174    query_dense: &[f32],
175    key_dense: &[f32],
176    spatial_size: usize,
177    dense_dim: usize,
178) -> Vec<usize> {
179    // Inputs are assumed L2-normalised; no re-normalisation here so the function
180    // remains O(HW * C) without allocating extra buffers.
181    let mut corr = Vec::with_capacity(spatial_size);
182    for i in 0..spatial_size {
183        let q_row = &query_dense[i * dense_dim..(i + 1) * dense_dim];
184        let mut best_j = 0usize;
185        let mut best_s = f32::NEG_INFINITY;
186        for j in 0..spatial_size {
187            let k_row = &key_dense[j * dense_dim..(j + 1) * dense_dim];
188            let s = dot(q_row, k_row);
189            if s > best_s {
190                best_s = s;
191                best_j = j;
192            }
193        }
194        corr.push(best_j);
195    }
196    corr
197}
198
199// ─── Top-k correspondence (internal) ─────────────────────────────────────────
200
201/// Build a top-k correspondence-averaged positive-key matrix `[HW, C]`.
202///
203/// For `k == 1` this is identical to `dense_correspondence` + gather.
204/// For `k > 1` we average the top-k matched key vectors and re-normalise.
205fn dense_correspondence_topk(
206    query_dense_norm: &[f32],
207    key_dense_norm: &[f32],
208    spatial_size: usize,
209    dense_dim: usize,
210    topk: usize,
211    eps: f32,
212) -> Vec<f32> {
213    let k = topk.max(1);
214    let mut pos_keys = vec![0.0_f32; spatial_size * dense_dim];
215
216    // Temporary buffer for (similarity, index) pairs – reused per query row.
217    let mut sims: Vec<(f32, usize)> = Vec::with_capacity(spatial_size);
218
219    for i in 0..spatial_size {
220        let q_row = &query_dense_norm[i * dense_dim..(i + 1) * dense_dim];
221        sims.clear();
222        for j in 0..spatial_size {
223            let k_row = &key_dense_norm[j * dense_dim..(j + 1) * dense_dim];
224            sims.push((dot(q_row, k_row), j));
225        }
226        // Partial sort: move top-k largest to front using selection.
227        let take = k.min(spatial_size);
228        for t in 0..take {
229            let mut best_idx = t;
230            for u in (t + 1)..sims.len() {
231                if sims[u].0 > sims[best_idx].0 {
232                    best_idx = u;
233                }
234            }
235            sims.swap(t, best_idx);
236        }
237        // Accumulate the averaged positive key.
238        let out_row = &mut pos_keys[i * dense_dim..(i + 1) * dense_dim];
239        for &(_, kj) in sims.iter().take(take) {
240            let k_row = &key_dense_norm[kj * dense_dim..(kj + 1) * dense_dim];
241            for (o, &v) in out_row.iter_mut().zip(k_row.iter()) {
242                *o += v;
243            }
244        }
245        // Re-normalise the averaged vector.
246        let norm: f32 = out_row.iter().map(|v| v * v).sum::<f32>().sqrt();
247        if norm > eps {
248            let inv = 1.0 / norm;
249            for v in out_row.iter_mut() {
250                *v *= inv;
251            }
252        }
253    }
254    pos_keys
255}
256
257// ─── Dense InfoNCE ────────────────────────────────────────────────────────────
258
259/// Dense InfoNCE loss for a single image's spatial positions.
260///
261/// For each query position `i` in `[0, HW)`:
262/// - Positive: the pre-computed `pos_keys[i]` (the best-matched key feature).
263/// - Negatives: all rows of `all_query` (the concatenated `[B*HW, C]` queries
264///   from the entire batch, minus self if present).
265///
266/// The numerically-stable form (log-sum-exp) is used throughout.
267///
268/// # Errors
269/// - [`SslError::InvalidTemperature`] if `temperature <= 0` or non-finite.
270/// - [`SslError::EmptyInput`] if inputs are empty.
271pub fn dense_infonce(
272    query: &[f32],
273    pos_keys: &[f32],
274    all_query: &[f32],
275    spatial_size: usize,
276    batch_size: usize,
277    dense_dim: usize,
278    temperature: f32,
279) -> SslResult<f32> {
280    if spatial_size == 0 || dense_dim == 0 || batch_size == 0 {
281        return Err(SslError::EmptyInput);
282    }
283    check_temperature(temperature)?;
284
285    let hw_total = spatial_size * batch_size;
286    if all_query.len() != hw_total * dense_dim {
287        return Err(SslError::DimensionMismatch {
288            expected: hw_total * dense_dim,
289            got: all_query.len(),
290        });
291    }
292    if query.len() != spatial_size * dense_dim {
293        return Err(SslError::DimensionMismatch {
294            expected: spatial_size * dense_dim,
295            got: query.len(),
296        });
297    }
298    if pos_keys.len() != spatial_size * dense_dim {
299        return Err(SslError::DimensionMismatch {
300            expected: spatial_size * dense_dim,
301            got: pos_keys.len(),
302        });
303    }
304
305    let inv_t = 1.0_f32 / temperature;
306    let mut total_loss = 0.0_f64;
307
308    // Pre-compute logits for all negatives once for each query position.
309    // Each negative is a row of `all_query`; we include all of them
310    // (self-contrast is allowed as a slightly conservative approximation —
311    // exact self-exclusion would require tracking which row corresponds to this
312    // image, but that is caller-managed).
313    for i in 0..spatial_size {
314        let q_row = &query[i * dense_dim..(i + 1) * dense_dim];
315        let p_row = &pos_keys[i * dense_dim..(i + 1) * dense_dim];
316
317        // Positive logit.
318        let pos_logit = dot(q_row, p_row) * inv_t;
319
320        // Negative logits (all rows of all_query including current image).
321        let mut neg_logits: Vec<f32> = Vec::with_capacity(hw_total);
322        for l in 0..hw_total {
323            let n_row = &all_query[l * dense_dim..(l + 1) * dense_dim];
324            neg_logits.push(dot(q_row, n_row) * inv_t);
325        }
326
327        // log-sum-exp over negatives.
328        let log_z_neg = log_sum_exp(&neg_logits);
329
330        // log-sum-exp over {positive} ∪ {negatives}.
331        let mut all_logits = neg_logits;
332        all_logits.push(pos_logit);
333        let log_z_all = log_sum_exp(&all_logits);
334
335        let _ = log_z_neg;
336        // InfoNCE: -log[exp(pos) / Σ_{all}]  =  log_z_all - pos_logit
337        total_loss += log_z_all - (pos_logit as f64);
338    }
339
340    Ok((total_loss / spatial_size as f64) as f32)
341}
342
343// ─── Global InfoNCE (MoCo-style, single query vs single positive + queue) ─────
344
345/// Single-query MoCo-style InfoNCE used as the global branch of DenseCL.
346///
347/// `query_global` and `key_global` are `[D]` L2-normalised vectors.
348/// `queue` is `[Q, D]` (negatives, row-major).
349fn global_infonce_single(
350    query_global: &[f32],
351    key_global: &[f32],
352    queue: &[f32],
353    global_dim: usize,
354    temperature: f32,
355    eps: f32,
356) -> f32 {
357    let inv_t = 1.0_f32 / temperature;
358
359    // Normalise defensively.
360    let q = l2_normalise_clone(query_global, 1, global_dim, eps);
361    let k = l2_normalise_clone(key_global, 1, global_dim, eps);
362
363    let pos_logit = dot(&q, &k) * inv_t;
364
365    if queue.is_empty() {
366        // No negatives: loss is zero (cannot compute denominator meaningfully).
367        return 0.0;
368    }
369    let n_neg = queue.len() / global_dim;
370    let mut logits: Vec<f32> = Vec::with_capacity(n_neg + 1);
371    logits.push(pos_logit);
372    for kn in 0..n_neg {
373        let k_row = &queue[kn * global_dim..(kn + 1) * global_dim];
374        logits.push(dot(&q, k_row) * inv_t);
375    }
376    let log_z = log_sum_exp(&logits);
377    (log_z - pos_logit as f64) as f32
378}
379
380// ─── DenseCL combined loss ────────────────────────────────────────────────────
381
382/// DenseCL combined loss (global InfoNCE + dense InfoNCE).
383///
384/// Implements Wang 2021 §3.2. The global branch mirrors MoCo; the dense branch
385/// contrasts spatial positions using correspondence-found positives.
386///
387/// # Parameters
388/// - `query_global`: `[D]` L2-normalised global query embedding.
389/// - `key_global`:   `[D]` L2-normalised global key embedding.
390/// - `query_dense`:  `[HW, C]` L2-normalised dense query feature map (row-major).
391/// - `key_dense`:    `[HW, C]` L2-normalised dense key feature map (row-major).
392/// - `neg_queue`:    `[Q, D]` global negative queue (may be empty → global loss = 0).
393/// - `spatial_size`: `H*W`.
394/// - `global_dim`:   `D`.
395/// - `dense_dim`:    `C`.
396/// - `config`:       [`DenseCLConfig`].
397///
398/// # Errors
399/// - [`SslError::InvalidTemperature`] if `config.temperature <= 0`.
400/// - [`SslError::InvalidFeatureDim`] if `global_dim == 0` or `dense_dim == 0`.
401/// - [`SslError::EmptyInput`] if `spatial_size == 0`.
402/// - [`SslError::DimensionMismatch`] on shape inconsistencies.
403/// - [`SslError::InvalidParameter`] if `lambda_dense ∉ [0, 1]`.
404pub fn dense_cl_loss(
405    query_global: &[f32],
406    key_global: &[f32],
407    query_dense: &[f32],
408    key_dense: &[f32],
409    neg_queue: &[f32],
410    spatial_size: usize,
411    global_dim: usize,
412    dense_dim: usize,
413    config: &DenseCLConfig,
414) -> SslResult<DenseCLResult> {
415    // ── Validate ──────────────────────────────────────────────────────────────
416    check_temperature(config.temperature)?;
417    check_spatial_dense(spatial_size, dense_dim)?;
418
419    if global_dim == 0 {
420        return Err(SslError::InvalidFeatureDim);
421    }
422    if !(config.lambda_dense.is_finite()
423        && config.lambda_dense >= 0.0
424        && config.lambda_dense <= 1.0)
425    {
426        return Err(SslError::InvalidParameter {
427            name: "lambda_dense".to_string(),
428            reason: "must be in [0, 1]".to_string(),
429        });
430    }
431    if query_global.len() != global_dim {
432        return Err(SslError::DimensionMismatch {
433            expected: global_dim,
434            got: query_global.len(),
435        });
436    }
437    if key_global.len() != global_dim {
438        return Err(SslError::DimensionMismatch {
439            expected: global_dim,
440            got: key_global.len(),
441        });
442    }
443    if query_dense.len() != spatial_size * dense_dim {
444        return Err(SslError::DimensionMismatch {
445            expected: spatial_size * dense_dim,
446            got: query_dense.len(),
447        });
448    }
449    if key_dense.len() != spatial_size * dense_dim {
450        return Err(SslError::DimensionMismatch {
451            expected: spatial_size * dense_dim,
452            got: key_dense.len(),
453        });
454    }
455
456    // ── Normalise inputs ──────────────────────────────────────────────────────
457    let q_norm = l2_normalise_clone(query_dense, spatial_size, dense_dim, config.eps);
458    let k_norm = l2_normalise_clone(key_dense, spatial_size, dense_dim, config.eps);
459
460    // ── Global InfoNCE ────────────────────────────────────────────────────────
461    let global_loss = if config.lambda_dense < 1.0 {
462        global_infonce_single(
463            query_global,
464            key_global,
465            neg_queue,
466            global_dim,
467            config.temperature,
468            config.eps,
469        )
470    } else {
471        0.0
472    };
473
474    // ── Correspondence finding ────────────────────────────────────────────────
475    // Build the positive-key matrix `[HW, C]` using top-k correspondence.
476    let pos_keys = dense_correspondence_topk(
477        &q_norm,
478        &k_norm,
479        spatial_size,
480        dense_dim,
481        config.correspondence_topk,
482        config.eps,
483    );
484
485    // Diagnostic: average cosine similarity of matched pairs.
486    let mut sum_sim = 0.0_f64;
487    let corr_map = dense_correspondence(&q_norm, &k_norm, spatial_size, dense_dim);
488    for i in 0..spatial_size {
489        let q_row = &q_norm[i * dense_dim..(i + 1) * dense_dim];
490        let j = corr_map[i];
491        let k_row = &k_norm[j * dense_dim..(j + 1) * dense_dim];
492        sum_sim += dot(q_row, k_row) as f64;
493    }
494    let mean_correspondence_sim = (sum_sim / spatial_size as f64) as f32;
495
496    // ── Dense InfoNCE ─────────────────────────────────────────────────────────
497    let dense_loss = if config.lambda_dense > 0.0 {
498        // Use all query positions within this single sample as negatives
499        // (batch_size = 1 for the public API; callers may concatenate across
500        // images to supply a richer negative set via `dense_infonce` directly).
501        dense_infonce(
502            &q_norm,
503            &pos_keys,
504            &q_norm,
505            spatial_size,
506            1, // batch_size = 1 (single image)
507            dense_dim,
508            config.temperature,
509        )?
510    } else {
511        0.0
512    };
513
514    // ── Combine ───────────────────────────────────────────────────────────────
515    let lambda = config.lambda_dense;
516    let total_loss = (1.0 - lambda) * global_loss + lambda * dense_loss;
517
518    Ok(DenseCLResult {
519        total_loss,
520        global_loss,
521        dense_loss,
522        mean_correspondence_sim,
523        n_positions: spatial_size,
524    })
525}
526
527// ─── PixPro ───────────────────────────────────────────────────────────────────
528
529/// PixPro dense loss — Xie et al. 2021.
530///
531/// Propagates key features using similarity-weighted averaging of neighbouring
532/// positions (without spatial graph — full cross-attention style propagation),
533/// then computes the average negative cosine distance between the query and the
534/// propagated key features. Each propagation step is:
535///
536/// ```text
537/// w(i, j) = softmax_j( sim(f_k[i], f_k[j]) / τ )
538/// f_k_prop[i] = Σ_j w(i, j) · f_k[j]
539/// ```
540///
541/// After all propagation iterations the result is L2-normalised, then:
542/// ```text
543/// L = 1/(HW) · Σ_i [ 1 − f_q[i] · f_k_prop[i] ]
544/// ```
545///
546/// The loss is in `[0, 2]` (cosine similarity ∈ [-1, 1]).
547///
548/// # Errors
549/// - [`SslError::InvalidTemperature`] if `config.temperature <= 0`.
550/// - [`SslError::EmptyInput`] if `spatial_size == 0` or `dense_dim == 0`.
551/// - [`SslError::DimensionMismatch`] on shape mismatch.
552pub fn pixpro_loss(
553    query_dense: &[f32],
554    key_dense: &[f32],
555    spatial_size: usize,
556    dense_dim: usize,
557    config: &PixProConfig,
558) -> SslResult<f32> {
559    check_temperature(config.temperature)?;
560    check_spatial_dense(spatial_size, dense_dim)?;
561
562    if query_dense.len() != spatial_size * dense_dim {
563        return Err(SslError::DimensionMismatch {
564            expected: spatial_size * dense_dim,
565            got: query_dense.len(),
566        });
567    }
568    if key_dense.len() != spatial_size * dense_dim {
569        return Err(SslError::DimensionMismatch {
570            expected: spatial_size * dense_dim,
571            got: key_dense.len(),
572        });
573    }
574
575    let q_norm = l2_normalise_clone(query_dense, spatial_size, dense_dim, config.eps);
576    let mut k_prop = l2_normalise_clone(key_dense, spatial_size, dense_dim, config.eps);
577
578    let iters = config.propagation_iters.max(1);
579    for _ in 0..iters {
580        k_prop = pixpro_propagate_once(
581            &k_prop,
582            spatial_size,
583            dense_dim,
584            config.temperature,
585            config.eps,
586        );
587    }
588
589    // Cosine loss: 1/(HW) · Σ_i [1 - sim(q_i, k_prop_i)]
590    let mut total = 0.0_f64;
591    for i in 0..spatial_size {
592        let q_row = &q_norm[i * dense_dim..(i + 1) * dense_dim];
593        let k_row = &k_prop[i * dense_dim..(i + 1) * dense_dim];
594        let sim = dot(q_row, k_row) as f64;
595        total += 1.0 - sim;
596    }
597    let loss = (total / spatial_size as f64) as f32;
598
599    if !loss.is_finite() {
600        return Err(SslError::NanEncountered {
601            location: "pixpro_loss",
602        });
603    }
604
605    Ok(loss)
606}
607
608/// One step of PixPro feature propagation.
609///
610/// For every position `i`:
611/// ```text
612/// w(i,j) = softmax_j( k_prop[i] · k_prop[j] / τ )
613/// out[i]  = Σ_j w(i,j) · k_prop[j]
614/// ```
615/// Output is L2-normalised.
616fn pixpro_propagate_once(
617    k: &[f32],
618    spatial_size: usize,
619    dense_dim: usize,
620    temperature: f32,
621    eps: f32,
622) -> Vec<f32> {
623    let inv_t = 1.0_f32 / temperature;
624    let mut out = vec![0.0_f32; spatial_size * dense_dim];
625
626    for i in 0..spatial_size {
627        let k_i = &k[i * dense_dim..(i + 1) * dense_dim];
628        // Compute raw scores and softmax weights.
629        let mut scores: Vec<f32> = (0..spatial_size)
630            .map(|j| {
631                let k_j = &k[j * dense_dim..(j + 1) * dense_dim];
632                dot(k_i, k_j) * inv_t
633            })
634            .collect();
635        // Numerically stable softmax.
636        let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
637        let mut sum_exp = 0.0_f32;
638        for s in scores.iter_mut() {
639            *s = (*s - max_s).exp();
640            sum_exp += *s;
641        }
642        if sum_exp > eps {
643            let inv_sum = 1.0 / sum_exp;
644            for s in scores.iter_mut() {
645                *s *= inv_sum;
646            }
647        }
648        // Weighted sum into output.
649        let out_i = &mut out[i * dense_dim..(i + 1) * dense_dim];
650        for (j, &w) in scores.iter().enumerate() {
651            let k_j = &k[j * dense_dim..(j + 1) * dense_dim];
652            for (o, &kv) in out_i.iter_mut().zip(k_j.iter()) {
653                *o += w * kv;
654            }
655        }
656    }
657    // Re-normalise.
658    l2_normalise_rows_inplace(&mut out, spatial_size, dense_dim, eps);
659    out
660}
661
662// ─── Tests ────────────────────────────────────────────────────────────────────
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    /// Simple LCG RNG matching the project convention (no `rand` crate).
669    struct Lcg {
670        state: u64,
671    }
672    impl Lcg {
673        fn new(seed: u64) -> Self {
674            Self { state: seed }
675        }
676        fn next_f32(&mut self) -> f32 {
677            self.state = self
678                .state
679                .wrapping_mul(6_364_136_223_846_793_005)
680                .wrapping_add(1_442_695_040_888_963_407);
681            (self.state >> 33) as f32 / (u32::MAX as f32 + 1.0)
682        }
683        fn fill(&mut self, buf: &mut [f32]) {
684            for v in buf.iter_mut() {
685                *v = self.next_f32() - 0.5;
686            }
687        }
688    }
689
690    fn rand_unit(n: usize, d: usize, seed: u64, eps: f32) -> Vec<f32> {
691        let mut rng = Lcg::new(seed);
692        let mut buf = vec![0.0_f32; n * d];
693        rng.fill(&mut buf);
694        l2_normalise_rows_inplace(&mut buf, n, d, eps);
695        buf
696    }
697
698    // ── Test 1: total_loss is finite and non-negative ─────────────────────────
699    #[test]
700    fn total_loss_finite_nonnegative() {
701        let hw = 4;
702        let d = 8;
703        let c = 8;
704        let cfg = DenseCLConfig::default();
705        let qg = rand_unit(1, d, 1, cfg.eps);
706        let kg = rand_unit(1, d, 2, cfg.eps);
707        let qd = rand_unit(hw, c, 3, cfg.eps);
708        let kd = rand_unit(hw, c, 4, cfg.eps);
709        let queue = rand_unit(16, d, 5, cfg.eps);
710
711        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
712            .expect("dense_cl_loss should succeed");
713        assert!(res.total_loss.is_finite(), "total_loss not finite");
714        assert!(
715            res.total_loss >= 0.0,
716            "total_loss negative: {}",
717            res.total_loss
718        );
719    }
720
721    // ── Test 2: lambda_dense=0 → total_loss == global_loss ───────────────────
722    #[test]
723    fn lambda_zero_gives_global_only() {
724        let hw = 4;
725        let d = 8;
726        let c = 8;
727        let cfg = DenseCLConfig {
728            lambda_dense: 0.0,
729            ..Default::default()
730        };
731
732        let qg = rand_unit(1, d, 10, cfg.eps);
733        let kg = rand_unit(1, d, 11, cfg.eps);
734        let qd = rand_unit(hw, c, 12, cfg.eps);
735        let kd = rand_unit(hw, c, 13, cfg.eps);
736        let queue = rand_unit(8, d, 14, cfg.eps);
737
738        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
739            .expect("dense_cl_loss should succeed");
740        assert!(
741            (res.total_loss - res.global_loss).abs() < 1e-5,
742            "total={} global={}",
743            res.total_loss,
744            res.global_loss
745        );
746    }
747
748    // ── Test 3: lambda_dense=1 → total_loss == dense_loss ────────────────────
749    #[test]
750    fn lambda_one_gives_dense_only() {
751        let hw = 4;
752        let d = 8;
753        let c = 8;
754        let cfg = DenseCLConfig {
755            lambda_dense: 1.0,
756            ..Default::default()
757        };
758
759        let qg = rand_unit(1, d, 20, cfg.eps);
760        let kg = rand_unit(1, d, 21, cfg.eps);
761        let qd = rand_unit(hw, c, 22, cfg.eps);
762        let kd = rand_unit(hw, c, 23, cfg.eps);
763        let queue = rand_unit(8, d, 24, cfg.eps);
764
765        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
766            .expect("dense_cl_loss should succeed");
767        assert!(
768            (res.total_loss - res.dense_loss).abs() < 1e-5,
769            "total={} dense={}",
770            res.total_loss,
771            res.dense_loss
772        );
773    }
774
775    // ── Test 4: correspondence map length == spatial_size ────────────────────
776    #[test]
777    fn correspondence_map_length_equals_spatial_size() {
778        let hw = 9;
779        let c = 6;
780        let qd = rand_unit(hw, c, 30, 1e-8);
781        let kd = rand_unit(hw, c, 31, 1e-8);
782        let corr = dense_correspondence(&qd, &kd, hw, c);
783        assert_eq!(corr.len(), hw);
784    }
785
786    // ── Test 5: all correspondence indices ∈ [0, spatial_size) ───────────────
787    #[test]
788    fn correspondence_indices_in_range() {
789        let hw = 16;
790        let c = 8;
791        let qd = rand_unit(hw, c, 40, 1e-8);
792        let kd = rand_unit(hw, c, 41, 1e-8);
793        let corr = dense_correspondence(&qd, &kd, hw, c);
794        for &idx in &corr {
795            assert!(idx < hw, "index {idx} out of [0, {hw})");
796        }
797    }
798
799    // ── Test 6: mean_correspondence_sim ∈ [-1, 1] ────────────────────────────
800    #[test]
801    fn mean_correspondence_sim_in_range() {
802        let hw = 6;
803        let d = 4;
804        let c = 4;
805        let cfg = DenseCLConfig::default();
806        let qg = rand_unit(1, d, 50, cfg.eps);
807        let kg = rand_unit(1, d, 51, cfg.eps);
808        let qd = rand_unit(hw, c, 52, cfg.eps);
809        let kd = rand_unit(hw, c, 53, cfg.eps);
810        let queue = rand_unit(4, d, 54, cfg.eps);
811
812        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
813            .expect("dense_cl_loss should succeed");
814        assert!(
815            res.mean_correspondence_sim >= -1.0 - 1e-5 && res.mean_correspondence_sim <= 1.0 + 1e-5,
816            "mean_corr_sim = {}",
817            res.mean_correspondence_sim
818        );
819    }
820
821    // ── Test 7: identical query and key → mean_correspondence_sim ≈ 1 ─────────
822    #[test]
823    fn identical_query_key_max_correspondence() {
824        let hw = 5;
825        let d = 4;
826        let c = 4;
827        let cfg = DenseCLConfig {
828            lambda_dense: 1.0,
829            ..Default::default()
830        };
831
832        let qg = rand_unit(1, d, 60, cfg.eps);
833        let kg = qg.clone();
834        let qd = rand_unit(hw, c, 62, cfg.eps);
835        let kd = qd.clone();
836        let queue: Vec<f32> = vec![];
837
838        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
839            .expect("dense_cl_loss should succeed");
840        assert!(
841            res.mean_correspondence_sim > 0.99,
842            "expected ~1.0, got {}",
843            res.mean_correspondence_sim
844        );
845    }
846
847    // ── Test 8: dense_infonce finite for random inputs ────────────────────────
848    #[test]
849    fn dense_infonce_finite_random() {
850        let hw = 8;
851        let c = 6;
852        let batch = 2;
853        let q = rand_unit(hw, c, 70, 1e-8);
854        let pk = rand_unit(hw, c, 71, 1e-8);
855        let all_q = rand_unit(hw * batch, c, 72, 1e-8);
856        let loss = dense_infonce(&q, &pk, &all_q, hw, batch, c, 0.2)
857            .expect("dense_infonce should succeed");
858        assert!(loss.is_finite(), "loss = {loss}");
859    }
860
861    // ── Test 9: pixpro_loss finite and in [0, 4] ─────────────────────────────
862    #[test]
863    fn pixpro_loss_finite_and_bounded() {
864        let hw = 6;
865        let c = 8;
866        let cfg = PixProConfig::default();
867        let qd = rand_unit(hw, c, 80, cfg.eps);
868        let kd = rand_unit(hw, c, 81, cfg.eps);
869        let loss = pixpro_loss(&qd, &kd, hw, c, &cfg).expect("pixpro_loss should succeed");
870        assert!(loss.is_finite(), "loss not finite");
871        // cosine loss in [0, 2], so mean ∈ [0, 2] ≤ 4.
872        assert!(loss >= 0.0, "loss = {loss} < 0");
873        assert!(loss <= 4.0, "loss = {loss} > 4");
874    }
875
876    // ── Test 10: invalid temperature → error ──────────────────────────────────
877    #[test]
878    fn invalid_temperature_returns_error() {
879        let hw = 4;
880        let d = 4;
881        let c = 4;
882        let cfg = DenseCLConfig {
883            temperature: 0.0,
884            ..Default::default()
885        };
886
887        let qg = rand_unit(1, d, 90, 1e-8);
888        let kg = rand_unit(1, d, 91, 1e-8);
889        let qd = rand_unit(hw, c, 92, 1e-8);
890        let kd = rand_unit(hw, c, 93, 1e-8);
891        let queue = rand_unit(4, d, 94, 1e-8);
892
893        assert!(dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).is_err());
894
895        let px_cfg = PixProConfig {
896            temperature: 0.0,
897            ..Default::default()
898        };
899        assert!(pixpro_loss(&qd, &kd, hw, c, &px_cfg).is_err());
900    }
901
902    // ── Test 11: spatial_size=1 → both losses work ───────────────────────────
903    #[test]
904    fn single_spatial_position_works() {
905        let hw = 1;
906        let d = 8;
907        let c = 8;
908        let cfg = DenseCLConfig::default();
909
910        let qg = rand_unit(1, d, 100, cfg.eps);
911        let kg = rand_unit(1, d, 101, cfg.eps);
912        let qd = rand_unit(hw, c, 102, cfg.eps);
913        let kd = rand_unit(hw, c, 103, cfg.eps);
914        let queue = rand_unit(4, d, 104, cfg.eps);
915
916        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
917            .expect("dense_cl_loss should succeed");
918        assert!(res.total_loss.is_finite());
919        assert_eq!(res.n_positions, 1);
920
921        let px_cfg = PixProConfig::default();
922        let pl = pixpro_loss(&qd, &kd, hw, c, &px_cfg).expect("pixpro_loss should succeed");
923        assert!(pl.is_finite());
924    }
925
926    // ── Test 12: larger batch provides more negatives (monotone test) ─────────
927    #[test]
928    fn larger_batch_size_more_negatives() {
929        // More negatives → denominator grows → loss should be >= single-batch.
930        // We verify that the function runs correctly for batch_size > 1 without
931        // error, and that the loss is finite.
932        let hw = 4;
933        let c = 6;
934        let q = rand_unit(hw, c, 110, 1e-8);
935        let pk = rand_unit(hw, c, 111, 1e-8);
936
937        let batch_small = 1usize;
938        let all_q_small = rand_unit(hw * batch_small, c, 112, 1e-8);
939        let l_small = dense_infonce(&q, &pk, &all_q_small, hw, batch_small, c, 0.2)
940            .expect("dense_infonce should succeed");
941
942        let batch_large = 4usize;
943        let all_q_large = rand_unit(hw * batch_large, c, 113, 1e-8);
944        let l_large = dense_infonce(&q, &pk, &all_q_large, hw, batch_large, c, 0.2)
945            .expect("dense_infonce should succeed");
946
947        assert!(l_small.is_finite());
948        assert!(l_large.is_finite());
949        // Both losses are non-negative.
950        assert!(l_small >= 0.0);
951        assert!(l_large >= 0.0);
952    }
953
954    // ── Test 13: global+dense linear combination correctness ──────────────────
955    #[test]
956    fn linear_combination_matches_components() {
957        let hw = 4;
958        let d = 8;
959        let c = 8;
960        let cfg = DenseCLConfig {
961            lambda_dense: 0.3,
962            ..Default::default()
963        };
964
965        let qg = rand_unit(1, d, 120, cfg.eps);
966        let kg = rand_unit(1, d, 121, cfg.eps);
967        let qd = rand_unit(hw, c, 122, cfg.eps);
968        let kd = rand_unit(hw, c, 123, cfg.eps);
969        let queue = rand_unit(8, d, 124, cfg.eps);
970
971        let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg)
972            .expect("dense_cl_loss should succeed");
973
974        let expected = 0.7 * res.global_loss + 0.3 * res.dense_loss;
975        assert!(
976            (res.total_loss - expected).abs() < 1e-5,
977            "total={} expected={}",
978            res.total_loss,
979            expected
980        );
981    }
982
983    // ── Test 14: pixpro with multiple propagation iterations ─────────────────
984    #[test]
985    fn pixpro_multi_iter_finite() {
986        let hw = 8;
987        let c = 6;
988        let cfg = PixProConfig {
989            temperature: 0.1,
990            propagation_iters: 3,
991            eps: 1e-8,
992        };
993        let qd = rand_unit(hw, c, 130, cfg.eps);
994        let kd = rand_unit(hw, c, 131, cfg.eps);
995        let loss = pixpro_loss(&qd, &kd, hw, c, &cfg).expect("pixpro_loss should succeed");
996        assert!(loss.is_finite());
997        assert!((0.0..=4.0).contains(&loss));
998    }
999
1000    // ── Test 15: DimensionMismatch detected ───────────────────────────────────
1001    #[test]
1002    fn dimension_mismatch_detected() {
1003        let hw = 4;
1004        let d = 8;
1005        let c = 8;
1006        let cfg = DenseCLConfig::default();
1007
1008        // query_dense too short
1009        let qg = rand_unit(1, d, 140, cfg.eps);
1010        let kg = rand_unit(1, d, 141, cfg.eps);
1011        let qd_bad = rand_unit(hw - 1, c, 142, cfg.eps); // wrong shape
1012        let kd = rand_unit(hw, c, 143, cfg.eps);
1013        let queue = rand_unit(4, d, 144, cfg.eps);
1014
1015        let res = dense_cl_loss(&qg, &kg, &qd_bad, &kd, &queue, hw, d, c, &cfg);
1016        assert!(res.is_err());
1017    }
1018}