use crate::error::{SslError, SslResult};
#[derive(Debug, Clone)]
pub struct DenseCLConfig {
pub temperature: f32,
pub lambda_dense: f32,
pub n_negatives_per_pos: usize,
pub correspondence_topk: usize,
pub eps: f32,
}
impl Default for DenseCLConfig {
fn default() -> Self {
Self {
temperature: 0.2,
lambda_dense: 0.5,
n_negatives_per_pos: 256,
correspondence_topk: 1,
eps: 1e-8,
}
}
}
#[derive(Debug, Clone)]
pub struct PixProConfig {
pub temperature: f32,
pub propagation_iters: usize,
pub eps: f32,
}
impl Default for PixProConfig {
fn default() -> Self {
Self {
temperature: 0.2,
propagation_iters: 1,
eps: 1e-8,
}
}
}
#[derive(Debug, Clone)]
pub struct DenseCLResult {
pub total_loss: f32,
pub global_loss: f32,
pub dense_loss: f32,
pub mean_correspondence_sim: f32,
pub n_positions: usize,
}
#[inline]
fn l2_normalise_rows_inplace(data: &mut [f32], n: usize, d: usize, eps: f32) {
for row in data.chunks_mut(d) {
let norm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > eps {
let inv = 1.0 / norm;
for v in row.iter_mut() {
*v *= inv;
}
}
}
let _ = n;
}
#[inline]
fn l2_normalise_clone(src: &[f32], n: usize, d: usize, eps: f32) -> Vec<f32> {
let mut out = src.to_vec();
l2_normalise_rows_inplace(&mut out, n, d, eps);
out
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn log_sum_exp(vals: &[f32]) -> f64 {
if vals.is_empty() {
return f64::NEG_INFINITY;
}
let max_v = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum: f64 = vals.iter().map(|&v| ((v - max_v) as f64).exp()).sum();
(max_v as f64) + sum.ln()
}
#[inline]
fn check_temperature(t: f32) -> SslResult<()> {
if !(t.is_finite() && t > 0.0) {
return Err(SslError::InvalidTemperature { temp: t });
}
Ok(())
}
#[inline]
fn check_spatial_dense(spatial_size: usize, dense_dim: usize) -> SslResult<()> {
if dense_dim == 0 {
return Err(SslError::InvalidFeatureDim);
}
if spatial_size == 0 {
return Err(SslError::EmptyInput);
}
Ok(())
}
pub fn dense_correspondence(
query_dense: &[f32],
key_dense: &[f32],
spatial_size: usize,
dense_dim: usize,
) -> Vec<usize> {
let mut corr = Vec::with_capacity(spatial_size);
for i in 0..spatial_size {
let q_row = &query_dense[i * dense_dim..(i + 1) * dense_dim];
let mut best_j = 0usize;
let mut best_s = f32::NEG_INFINITY;
for j in 0..spatial_size {
let k_row = &key_dense[j * dense_dim..(j + 1) * dense_dim];
let s = dot(q_row, k_row);
if s > best_s {
best_s = s;
best_j = j;
}
}
corr.push(best_j);
}
corr
}
fn dense_correspondence_topk(
query_dense_norm: &[f32],
key_dense_norm: &[f32],
spatial_size: usize,
dense_dim: usize,
topk: usize,
eps: f32,
) -> Vec<f32> {
let k = topk.max(1);
let mut pos_keys = vec![0.0_f32; spatial_size * dense_dim];
let mut sims: Vec<(f32, usize)> = Vec::with_capacity(spatial_size);
for i in 0..spatial_size {
let q_row = &query_dense_norm[i * dense_dim..(i + 1) * dense_dim];
sims.clear();
for j in 0..spatial_size {
let k_row = &key_dense_norm[j * dense_dim..(j + 1) * dense_dim];
sims.push((dot(q_row, k_row), j));
}
let take = k.min(spatial_size);
for t in 0..take {
let mut best_idx = t;
for u in (t + 1)..sims.len() {
if sims[u].0 > sims[best_idx].0 {
best_idx = u;
}
}
sims.swap(t, best_idx);
}
let out_row = &mut pos_keys[i * dense_dim..(i + 1) * dense_dim];
for &(_, kj) in sims.iter().take(take) {
let k_row = &key_dense_norm[kj * dense_dim..(kj + 1) * dense_dim];
for (o, &v) in out_row.iter_mut().zip(k_row.iter()) {
*o += v;
}
}
let norm: f32 = out_row.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > eps {
let inv = 1.0 / norm;
for v in out_row.iter_mut() {
*v *= inv;
}
}
}
pos_keys
}
pub fn dense_infonce(
query: &[f32],
pos_keys: &[f32],
all_query: &[f32],
spatial_size: usize,
batch_size: usize,
dense_dim: usize,
temperature: f32,
) -> SslResult<f32> {
if spatial_size == 0 || dense_dim == 0 || batch_size == 0 {
return Err(SslError::EmptyInput);
}
check_temperature(temperature)?;
let hw_total = spatial_size * batch_size;
if all_query.len() != hw_total * dense_dim {
return Err(SslError::DimensionMismatch {
expected: hw_total * dense_dim,
got: all_query.len(),
});
}
if query.len() != spatial_size * dense_dim {
return Err(SslError::DimensionMismatch {
expected: spatial_size * dense_dim,
got: query.len(),
});
}
if pos_keys.len() != spatial_size * dense_dim {
return Err(SslError::DimensionMismatch {
expected: spatial_size * dense_dim,
got: pos_keys.len(),
});
}
let inv_t = 1.0_f32 / temperature;
let mut total_loss = 0.0_f64;
for i in 0..spatial_size {
let q_row = &query[i * dense_dim..(i + 1) * dense_dim];
let p_row = &pos_keys[i * dense_dim..(i + 1) * dense_dim];
let pos_logit = dot(q_row, p_row) * inv_t;
let mut neg_logits: Vec<f32> = Vec::with_capacity(hw_total);
for l in 0..hw_total {
let n_row = &all_query[l * dense_dim..(l + 1) * dense_dim];
neg_logits.push(dot(q_row, n_row) * inv_t);
}
let log_z_neg = log_sum_exp(&neg_logits);
let mut all_logits = neg_logits;
all_logits.push(pos_logit);
let log_z_all = log_sum_exp(&all_logits);
let _ = log_z_neg;
total_loss += log_z_all - (pos_logit as f64);
}
Ok((total_loss / spatial_size as f64) as f32)
}
fn global_infonce_single(
query_global: &[f32],
key_global: &[f32],
queue: &[f32],
global_dim: usize,
temperature: f32,
eps: f32,
) -> f32 {
let inv_t = 1.0_f32 / temperature;
let q = l2_normalise_clone(query_global, 1, global_dim, eps);
let k = l2_normalise_clone(key_global, 1, global_dim, eps);
let pos_logit = dot(&q, &k) * inv_t;
if queue.is_empty() {
return 0.0;
}
let n_neg = queue.len() / global_dim;
let mut logits: Vec<f32> = Vec::with_capacity(n_neg + 1);
logits.push(pos_logit);
for kn in 0..n_neg {
let k_row = &queue[kn * global_dim..(kn + 1) * global_dim];
logits.push(dot(&q, k_row) * inv_t);
}
let log_z = log_sum_exp(&logits);
(log_z - pos_logit as f64) as f32
}
pub fn dense_cl_loss(
query_global: &[f32],
key_global: &[f32],
query_dense: &[f32],
key_dense: &[f32],
neg_queue: &[f32],
spatial_size: usize,
global_dim: usize,
dense_dim: usize,
config: &DenseCLConfig,
) -> SslResult<DenseCLResult> {
check_temperature(config.temperature)?;
check_spatial_dense(spatial_size, dense_dim)?;
if global_dim == 0 {
return Err(SslError::InvalidFeatureDim);
}
if !(config.lambda_dense.is_finite()
&& config.lambda_dense >= 0.0
&& config.lambda_dense <= 1.0)
{
return Err(SslError::InvalidParameter {
name: "lambda_dense".to_string(),
reason: "must be in [0, 1]".to_string(),
});
}
if query_global.len() != global_dim {
return Err(SslError::DimensionMismatch {
expected: global_dim,
got: query_global.len(),
});
}
if key_global.len() != global_dim {
return Err(SslError::DimensionMismatch {
expected: global_dim,
got: key_global.len(),
});
}
if query_dense.len() != spatial_size * dense_dim {
return Err(SslError::DimensionMismatch {
expected: spatial_size * dense_dim,
got: query_dense.len(),
});
}
if key_dense.len() != spatial_size * dense_dim {
return Err(SslError::DimensionMismatch {
expected: spatial_size * dense_dim,
got: key_dense.len(),
});
}
let q_norm = l2_normalise_clone(query_dense, spatial_size, dense_dim, config.eps);
let k_norm = l2_normalise_clone(key_dense, spatial_size, dense_dim, config.eps);
let global_loss = if config.lambda_dense < 1.0 {
global_infonce_single(
query_global,
key_global,
neg_queue,
global_dim,
config.temperature,
config.eps,
)
} else {
0.0
};
let pos_keys = dense_correspondence_topk(
&q_norm,
&k_norm,
spatial_size,
dense_dim,
config.correspondence_topk,
config.eps,
);
let mut sum_sim = 0.0_f64;
let corr_map = dense_correspondence(&q_norm, &k_norm, spatial_size, dense_dim);
for i in 0..spatial_size {
let q_row = &q_norm[i * dense_dim..(i + 1) * dense_dim];
let j = corr_map[i];
let k_row = &k_norm[j * dense_dim..(j + 1) * dense_dim];
sum_sim += dot(q_row, k_row) as f64;
}
let mean_correspondence_sim = (sum_sim / spatial_size as f64) as f32;
let dense_loss = if config.lambda_dense > 0.0 {
dense_infonce(
&q_norm,
&pos_keys,
&q_norm,
spatial_size,
1, dense_dim,
config.temperature,
)?
} else {
0.0
};
let lambda = config.lambda_dense;
let total_loss = (1.0 - lambda) * global_loss + lambda * dense_loss;
Ok(DenseCLResult {
total_loss,
global_loss,
dense_loss,
mean_correspondence_sim,
n_positions: spatial_size,
})
}
pub fn pixpro_loss(
query_dense: &[f32],
key_dense: &[f32],
spatial_size: usize,
dense_dim: usize,
config: &PixProConfig,
) -> SslResult<f32> {
check_temperature(config.temperature)?;
check_spatial_dense(spatial_size, dense_dim)?;
if query_dense.len() != spatial_size * dense_dim {
return Err(SslError::DimensionMismatch {
expected: spatial_size * dense_dim,
got: query_dense.len(),
});
}
if key_dense.len() != spatial_size * dense_dim {
return Err(SslError::DimensionMismatch {
expected: spatial_size * dense_dim,
got: key_dense.len(),
});
}
let q_norm = l2_normalise_clone(query_dense, spatial_size, dense_dim, config.eps);
let mut k_prop = l2_normalise_clone(key_dense, spatial_size, dense_dim, config.eps);
let iters = config.propagation_iters.max(1);
for _ in 0..iters {
k_prop = pixpro_propagate_once(
&k_prop,
spatial_size,
dense_dim,
config.temperature,
config.eps,
);
}
let mut total = 0.0_f64;
for i in 0..spatial_size {
let q_row = &q_norm[i * dense_dim..(i + 1) * dense_dim];
let k_row = &k_prop[i * dense_dim..(i + 1) * dense_dim];
let sim = dot(q_row, k_row) as f64;
total += 1.0 - sim;
}
let loss = (total / spatial_size as f64) as f32;
if !loss.is_finite() {
return Err(SslError::NanEncountered {
location: "pixpro_loss",
});
}
Ok(loss)
}
fn pixpro_propagate_once(
k: &[f32],
spatial_size: usize,
dense_dim: usize,
temperature: f32,
eps: f32,
) -> Vec<f32> {
let inv_t = 1.0_f32 / temperature;
let mut out = vec![0.0_f32; spatial_size * dense_dim];
for i in 0..spatial_size {
let k_i = &k[i * dense_dim..(i + 1) * dense_dim];
let mut scores: Vec<f32> = (0..spatial_size)
.map(|j| {
let k_j = &k[j * dense_dim..(j + 1) * dense_dim];
dot(k_i, k_j) * inv_t
})
.collect();
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum_exp = 0.0_f32;
for s in scores.iter_mut() {
*s = (*s - max_s).exp();
sum_exp += *s;
}
if sum_exp > eps {
let inv_sum = 1.0 / sum_exp;
for s in scores.iter_mut() {
*s *= inv_sum;
}
}
let out_i = &mut out[i * dense_dim..(i + 1) * dense_dim];
for (j, &w) in scores.iter().enumerate() {
let k_j = &k[j * dense_dim..(j + 1) * dense_dim];
for (o, &kv) in out_i.iter_mut().zip(k_j.iter()) {
*o += w * kv;
}
}
}
l2_normalise_rows_inplace(&mut out, spatial_size, dense_dim, eps);
out
}
#[cfg(test)]
mod tests {
use super::*;
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_f32(&mut self) -> f32 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(self.state >> 33) as f32 / (u32::MAX as f32 + 1.0)
}
fn fill(&mut self, buf: &mut [f32]) {
for v in buf.iter_mut() {
*v = self.next_f32() - 0.5;
}
}
}
fn rand_unit(n: usize, d: usize, seed: u64, eps: f32) -> Vec<f32> {
let mut rng = Lcg::new(seed);
let mut buf = vec![0.0_f32; n * d];
rng.fill(&mut buf);
l2_normalise_rows_inplace(&mut buf, n, d, eps);
buf
}
#[test]
fn total_loss_finite_nonnegative() {
let hw = 4;
let d = 8;
let c = 8;
let cfg = DenseCLConfig::default();
let qg = rand_unit(1, d, 1, cfg.eps);
let kg = rand_unit(1, d, 2, cfg.eps);
let qd = rand_unit(hw, c, 3, cfg.eps);
let kd = rand_unit(hw, c, 4, cfg.eps);
let queue = rand_unit(16, d, 5, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
assert!(res.total_loss.is_finite(), "total_loss not finite");
assert!(
res.total_loss >= 0.0,
"total_loss negative: {}",
res.total_loss
);
}
#[test]
fn lambda_zero_gives_global_only() {
let hw = 4;
let d = 8;
let c = 8;
let cfg = DenseCLConfig {
lambda_dense: 0.0,
..Default::default()
};
let qg = rand_unit(1, d, 10, cfg.eps);
let kg = rand_unit(1, d, 11, cfg.eps);
let qd = rand_unit(hw, c, 12, cfg.eps);
let kd = rand_unit(hw, c, 13, cfg.eps);
let queue = rand_unit(8, d, 14, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
assert!(
(res.total_loss - res.global_loss).abs() < 1e-5,
"total={} global={}",
res.total_loss,
res.global_loss
);
}
#[test]
fn lambda_one_gives_dense_only() {
let hw = 4;
let d = 8;
let c = 8;
let cfg = DenseCLConfig {
lambda_dense: 1.0,
..Default::default()
};
let qg = rand_unit(1, d, 20, cfg.eps);
let kg = rand_unit(1, d, 21, cfg.eps);
let qd = rand_unit(hw, c, 22, cfg.eps);
let kd = rand_unit(hw, c, 23, cfg.eps);
let queue = rand_unit(8, d, 24, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
assert!(
(res.total_loss - res.dense_loss).abs() < 1e-5,
"total={} dense={}",
res.total_loss,
res.dense_loss
);
}
#[test]
fn correspondence_map_length_equals_spatial_size() {
let hw = 9;
let c = 6;
let qd = rand_unit(hw, c, 30, 1e-8);
let kd = rand_unit(hw, c, 31, 1e-8);
let corr = dense_correspondence(&qd, &kd, hw, c);
assert_eq!(corr.len(), hw);
}
#[test]
fn correspondence_indices_in_range() {
let hw = 16;
let c = 8;
let qd = rand_unit(hw, c, 40, 1e-8);
let kd = rand_unit(hw, c, 41, 1e-8);
let corr = dense_correspondence(&qd, &kd, hw, c);
for &idx in &corr {
assert!(idx < hw, "index {idx} out of [0, {hw})");
}
}
#[test]
fn mean_correspondence_sim_in_range() {
let hw = 6;
let d = 4;
let c = 4;
let cfg = DenseCLConfig::default();
let qg = rand_unit(1, d, 50, cfg.eps);
let kg = rand_unit(1, d, 51, cfg.eps);
let qd = rand_unit(hw, c, 52, cfg.eps);
let kd = rand_unit(hw, c, 53, cfg.eps);
let queue = rand_unit(4, d, 54, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
assert!(
res.mean_correspondence_sim >= -1.0 - 1e-5 && res.mean_correspondence_sim <= 1.0 + 1e-5,
"mean_corr_sim = {}",
res.mean_correspondence_sim
);
}
#[test]
fn identical_query_key_max_correspondence() {
let hw = 5;
let d = 4;
let c = 4;
let cfg = DenseCLConfig {
lambda_dense: 1.0,
..Default::default()
};
let qg = rand_unit(1, d, 60, cfg.eps);
let kg = qg.clone();
let qd = rand_unit(hw, c, 62, cfg.eps);
let kd = qd.clone();
let queue: Vec<f32> = vec![];
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
assert!(
res.mean_correspondence_sim > 0.99,
"expected ~1.0, got {}",
res.mean_correspondence_sim
);
}
#[test]
fn dense_infonce_finite_random() {
let hw = 8;
let c = 6;
let batch = 2;
let q = rand_unit(hw, c, 70, 1e-8);
let pk = rand_unit(hw, c, 71, 1e-8);
let all_q = rand_unit(hw * batch, c, 72, 1e-8);
let loss = dense_infonce(&q, &pk, &all_q, hw, batch, c, 0.2).unwrap();
assert!(loss.is_finite(), "loss = {loss}");
}
#[test]
fn pixpro_loss_finite_and_bounded() {
let hw = 6;
let c = 8;
let cfg = PixProConfig::default();
let qd = rand_unit(hw, c, 80, cfg.eps);
let kd = rand_unit(hw, c, 81, cfg.eps);
let loss = pixpro_loss(&qd, &kd, hw, c, &cfg).unwrap();
assert!(loss.is_finite(), "loss not finite");
assert!(loss >= 0.0, "loss = {loss} < 0");
assert!(loss <= 4.0, "loss = {loss} > 4");
}
#[test]
fn invalid_temperature_returns_error() {
let hw = 4;
let d = 4;
let c = 4;
let cfg = DenseCLConfig {
temperature: 0.0,
..Default::default()
};
let qg = rand_unit(1, d, 90, 1e-8);
let kg = rand_unit(1, d, 91, 1e-8);
let qd = rand_unit(hw, c, 92, 1e-8);
let kd = rand_unit(hw, c, 93, 1e-8);
let queue = rand_unit(4, d, 94, 1e-8);
assert!(dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).is_err());
let px_cfg = PixProConfig {
temperature: 0.0,
..Default::default()
};
assert!(pixpro_loss(&qd, &kd, hw, c, &px_cfg).is_err());
}
#[test]
fn single_spatial_position_works() {
let hw = 1;
let d = 8;
let c = 8;
let cfg = DenseCLConfig::default();
let qg = rand_unit(1, d, 100, cfg.eps);
let kg = rand_unit(1, d, 101, cfg.eps);
let qd = rand_unit(hw, c, 102, cfg.eps);
let kd = rand_unit(hw, c, 103, cfg.eps);
let queue = rand_unit(4, d, 104, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
assert!(res.total_loss.is_finite());
assert_eq!(res.n_positions, 1);
let px_cfg = PixProConfig::default();
let pl = pixpro_loss(&qd, &kd, hw, c, &px_cfg).unwrap();
assert!(pl.is_finite());
}
#[test]
fn larger_batch_size_more_negatives() {
let hw = 4;
let c = 6;
let q = rand_unit(hw, c, 110, 1e-8);
let pk = rand_unit(hw, c, 111, 1e-8);
let batch_small = 1usize;
let all_q_small = rand_unit(hw * batch_small, c, 112, 1e-8);
let l_small = dense_infonce(&q, &pk, &all_q_small, hw, batch_small, c, 0.2).unwrap();
let batch_large = 4usize;
let all_q_large = rand_unit(hw * batch_large, c, 113, 1e-8);
let l_large = dense_infonce(&q, &pk, &all_q_large, hw, batch_large, c, 0.2).unwrap();
assert!(l_small.is_finite());
assert!(l_large.is_finite());
assert!(l_small >= 0.0);
assert!(l_large >= 0.0);
}
#[test]
fn linear_combination_matches_components() {
let hw = 4;
let d = 8;
let c = 8;
let cfg = DenseCLConfig {
lambda_dense: 0.3,
..Default::default()
};
let qg = rand_unit(1, d, 120, cfg.eps);
let kg = rand_unit(1, d, 121, cfg.eps);
let qd = rand_unit(hw, c, 122, cfg.eps);
let kd = rand_unit(hw, c, 123, cfg.eps);
let queue = rand_unit(8, d, 124, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd, &kd, &queue, hw, d, c, &cfg).unwrap();
let expected = 0.7 * res.global_loss + 0.3 * res.dense_loss;
assert!(
(res.total_loss - expected).abs() < 1e-5,
"total={} expected={}",
res.total_loss,
expected
);
}
#[test]
fn pixpro_multi_iter_finite() {
let hw = 8;
let c = 6;
let cfg = PixProConfig {
temperature: 0.1,
propagation_iters: 3,
eps: 1e-8,
};
let qd = rand_unit(hw, c, 130, cfg.eps);
let kd = rand_unit(hw, c, 131, cfg.eps);
let loss = pixpro_loss(&qd, &kd, hw, c, &cfg).unwrap();
assert!(loss.is_finite());
assert!((0.0..=4.0).contains(&loss));
}
#[test]
fn dimension_mismatch_detected() {
let hw = 4;
let d = 8;
let c = 8;
let cfg = DenseCLConfig::default();
let qg = rand_unit(1, d, 140, cfg.eps);
let kg = rand_unit(1, d, 141, cfg.eps);
let qd_bad = rand_unit(hw - 1, c, 142, cfg.eps); let kd = rand_unit(hw, c, 143, cfg.eps);
let queue = rand_unit(4, d, 144, cfg.eps);
let res = dense_cl_loss(&qg, &kg, &qd_bad, &kd, &queue, hw, d, c, &cfg);
assert!(res.is_err());
}
}