use burn::tensor::{backend::Backend, Tensor};
pub fn siglip_loss<B: Backend>(logits: Tensor<B, 2>) -> Tensor<B, 1> {
let [batch, _] = logits.dims();
let labels = eye_pm1::<B>(batch, logits.device());
let neg_y_s = labels * logits;
let per_element_loss = softplus(neg_y_s.neg());
per_element_loss.mean()
}
fn softplus<B: Backend>(x: Tensor<B, 2>) -> Tensor<B, 2> {
let x_clamped = x.clamp(-100.0f32, 100.0f32);
x_clamped.clone().exp().add_scalar(1.0f32).log()
}
fn eye_pm1<B: Backend>(n: usize, device: B::Device) -> Tensor<B, 2> {
let neg_ones = Tensor::<B, 2>::full([n, n], -1.0f32, &device);
let eye = eye_float::<B>(n, &device);
neg_ones + eye.mul_scalar(2.0f32)
}
fn eye_float<B: Backend>(n: usize, device: &B::Device) -> Tensor<B, 2> {
let data: Vec<f32> = (0..n)
.flat_map(|i| (0..n).map(move |j| if i == j { 1.0f32 } else { 0.0f32 }))
.collect();
Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([n, n])
}
pub fn siglip_loss_symmetric<B: Backend>(logits: Tensor<B, 2>) -> Tensor<B, 1> {
let device = logits.device();
let l_fwd = siglip_loss(logits.clone());
let l_bwd = siglip_loss(logits.transpose());
(l_fwd + l_bwd) / Tensor::<B, 1>::from_floats([2.0f32], &device)
}
pub fn recall_at_k<B: Backend>(logits: Tensor<B, 2>, k: usize) -> f32 {
let [batch, _] = logits.dims();
let data: Vec<f32> = logits
.clone()
.into_data()
.to_vec::<f32>()
.unwrap_or_default();
let mut correct = 0usize;
for i in 0..batch {
let row = &data[i * batch..(i + 1) * batch];
let gt_score = row[i];
let rank = row.iter().filter(|&&s| s > gt_score).count(); if rank < k {
correct += 1;
}
}
correct as f32 / batch as f32
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::backend::ndarray::NdArrayDevice;
type B = NdArray;
#[test]
fn test_siglip_loss_perfect() {
let device = NdArrayDevice::default();
let data: Vec<f32> = (0..4usize)
.flat_map(|i| (0..4usize).map(move |j| if i == j { 100.0f32 } else { -100.0f32 }))
.collect();
let logits = Tensor::<B, 1>::from_floats(data.as_slice(), &device).reshape([4, 4]);
let loss: f32 = siglip_loss(logits).into_scalar();
assert!(loss < 0.01, "Near-perfect logits should give small loss, got {loss}");
}
#[test]
fn test_siglip_loss_random() {
let device = NdArrayDevice::default();
let data: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let logits = Tensor::<B, 1>::from_floats(data.as_slice(), &device).reshape([4, 4]);
let loss: f32 = siglip_loss(logits).into_scalar();
assert!(loss > 0.0, "Loss must be positive for random logits");
assert!(!loss.is_nan(), "Loss must not be NaN");
}
#[test]
fn test_recall_at_k() {
let device = NdArrayDevice::default();
let data: Vec<f32> = (0..4usize)
.flat_map(|i| (0..4usize).map(move |j| if i == j { 1.0f32 } else { 0.0f32 }))
.collect();
let logits = Tensor::<B, 1>::from_floats(data.as_slice(), &device).reshape([4, 4]);
let r1 = recall_at_k(logits, 1);
assert!((r1 - 1.0).abs() < 1e-5, "Perfect logits → Recall@1 = 1.0, got {r1}");
}
#[test]
fn test_eye_pm1() {
let device = NdArrayDevice::default();
let labels = eye_pm1::<B>(3, device);
let data: Vec<f32> = labels.into_data().to_vec::<f32>().unwrap();
assert_eq!(data[0], 1.0); assert_eq!(data[1], -1.0); assert_eq!(data[4], 1.0); }
}