sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! SigLIP sigmoid contrastive loss.
//!
//! SigLIP ("Sigmoid Loss for Language-Image Pre-Training", Zhai et al. 2023)
//! replaces the softmax contrastive loss used in CLIP with an independent
//! sigmoid binary cross-entropy over every entry in the similarity matrix.
//! This decouples the loss from the batch size and avoids the normalisation
//! instability of softmax.
//!
//! # Mathematical derivation
//!
//! Given a `(B, B)` similarity matrix `S` produced by
//! `SensorLMModel::similarity_matrix`:
//!
//! ```text
//! S[i,j] = temperature * dot(z_sensor[i], z_text[j]) + bias
//! ```
//!
//! Define the label matrix:
//!
//! ```text
//! y[i,j] = +1   if  i == j   (positive pair)
//! y[i,j] = -1   if  i != j   (negative pair)
//! ```
//!
//! The loss is the average binary cross-entropy over all `B²` pairs:
//!
//! ```text
//! L = -1/B² · Σ_ij log( sigmoid( y[i,j] · S[i,j] ) )
//! ```
//!
//! Because `log(sigmoid(-x)) = log(1 - sigmoid(x))`, this is equivalent to:
//!
//! ```text
//! For positive pairs (i==j): -log(sigmoid(S[i,j]))
//! For negative pairs (i!=j): -log(1 - sigmoid(S[i,j])) = -log(sigmoid(-S[i,j]))
//! ```
//!
//! # Initialisation bias
//!
//! The reference implementation initialises `bias = -log(B² - B)` ≈ −10
//! for `B = 1024`, which makes the initial probability of a positive pair
//! approximately `1 / B`.  This prevents a large initial loss and stabilises
//! training.
//!
//! # Reference
//!
//! Zhai et al. (2023). _Sigmoid Loss for Language Image Pre-Training_.
//! <https://arxiv.org/abs/2303.15343>

use burn::tensor::{backend::Backend, Tensor};

// ---------------------------------------------------------------------------
// Core loss function
// ---------------------------------------------------------------------------

/// Compute the SigLIP sigmoid contrastive loss from a pre-computed similarity
/// matrix.
///
/// # Arguments
///
/// * `logits` – `(B, B)` similarity matrix
///   `S[i,j] = temperature * dot(z_i, z_j) + bias`.
///
/// # Returns
///
/// Scalar loss `Tensor<B, 1>`.
pub fn siglip_loss<B: Backend>(logits: Tensor<B, 2>) -> Tensor<B, 1> {
    let [batch, _] = logits.dims();

    // Build the `(B, B)` label matrix.
    // y[i,j] = +1 if i==j, -1 otherwise.
    let labels = eye_pm1::<B>(batch, logits.device());

    // Element-wise:  log(sigmoid(y * S))  = -softplus(-y * S)
    // This numerically stable form avoids overflow.
    let neg_y_s = labels * logits; // y[i,j] * S[i,j]

    // log(sigmoid(x)) = -softplus(-x)
    // Loss per element = -log(sigmoid(y*S)) = softplus(-y*S)
    let per_element_loss = softplus(neg_y_s.neg()); // softplus(-y*S)

    // Average over the full (B, B) matrix.
    per_element_loss.mean()
}

/// Numerically stable `log(1 + exp(x))`.
///
/// Uses the identity `softplus(x) = log(1 + exp(x)) = x + log(1 + exp(-x))`
/// for `x > 0` and `log(1 + exp(x))` for `x ≤ 0`.
fn softplus<B: Backend>(x: Tensor<B, 2>) -> Tensor<B, 2> {
    // Safe approximation: log1p(exp(x)) clamped to avoid overflow.
    //   softplus(x) = log(1 + exp(x))
    // For numerical stability clip x to [-100, 100].
    let x_clamped = x.clamp(-100.0f32, 100.0f32);
    x_clamped.clone().exp().add_scalar(1.0f32).log()
}

/// Create a `(n, n)` float matrix with `+1` on the diagonal and `-1` elsewhere.
fn eye_pm1<B: Backend>(n: usize, device: B::Device) -> Tensor<B, 2> {
    // Start with all -1, then add 2 on the diagonal (→ +1 on diagonal, -1 off).
    let neg_ones = Tensor::<B, 2>::full([n, n], -1.0f32, &device);
    let eye = eye_float::<B>(n, &device);
    neg_ones + eye.mul_scalar(2.0f32)
}

/// Create an `(n, n)` identity matrix as a float tensor.
fn eye_float<B: Backend>(n: usize, device: &B::Device) -> Tensor<B, 2> {
    let data: Vec<f32> = (0..n)
        .flat_map(|i| (0..n).map(move |j| if i == j { 1.0f32 } else { 0.0f32 }))
        .collect();
    Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([n, n])
}

// ---------------------------------------------------------------------------
// Loss variant: symmetric (sensor↔text and text↔sensor averaged)
// ---------------------------------------------------------------------------

/// Symmetric version of the SigLIP loss.
///
/// Averages the sensor-to-text and text-to-sensor contrastive directions:
///
/// ```text
/// L_sym = (L(S) + L(Sᵀ)) / 2
/// ```
///
/// In practice `S` is symmetric when `z_sensor` and `z_text` live in the same
/// L2-normalised space, so `L_sym ≈ L`.  Provided here for completeness.
pub fn siglip_loss_symmetric<B: Backend>(logits: Tensor<B, 2>) -> Tensor<B, 1> {
    let device = logits.device();
    let l_fwd = siglip_loss(logits.clone());
    let l_bwd = siglip_loss(logits.transpose());
    (l_fwd + l_bwd) / Tensor::<B, 1>::from_floats([2.0f32], &device)
}

// ---------------------------------------------------------------------------
// Evaluation metrics (no grad)
// ---------------------------------------------------------------------------

/// Compute top-k recall (Recall@k) for cross-modal retrieval.
///
/// Given the `(B, B)` similarity matrix, for each row (query) check whether
/// the ground-truth positive column (same index) falls in the top-k entries.
///
/// # Returns
///
/// Recall@k as a float in `[0, 1]`.
pub fn recall_at_k<B: Backend>(logits: Tensor<B, 2>, k: usize) -> f32 {
    let [batch, _] = logits.dims();
    let data: Vec<f32> = logits
        .clone()
        .into_data()
        .to_vec::<f32>()
        .unwrap_or_default();

    let mut correct = 0usize;
    for i in 0..batch {
        let row = &data[i * batch..(i + 1) * batch];
        // Find the rank of element `i` (ground truth).
        let gt_score = row[i];
        let rank = row.iter().filter(|&&s| s > gt_score).count(); // 0-indexed rank
        if rank < k {
            correct += 1;
        }
    }

    correct as f32 / batch as f32
}

#[cfg(test)]
mod tests {
    use super::*;
    use burn::backend::NdArray;
    use burn::backend::ndarray::NdArrayDevice;

    type B = NdArray;

    #[test]
    fn test_siglip_loss_perfect() {
        let device = NdArrayDevice::default();
        // With a very large diagonal and very small off-diagonal, loss → 0.
        let data: Vec<f32> = (0..4usize)
            .flat_map(|i| (0..4usize).map(move |j| if i == j { 100.0f32 } else { -100.0f32 }))
            .collect();
        let logits = Tensor::<B, 1>::from_floats(data.as_slice(), &device).reshape([4, 4]);
        let loss: f32 = siglip_loss(logits).into_scalar();
        assert!(loss < 0.01, "Near-perfect logits should give small loss, got {loss}");
    }

    #[test]
    fn test_siglip_loss_random() {
        let device = NdArrayDevice::default();
        let data: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
        let logits = Tensor::<B, 1>::from_floats(data.as_slice(), &device).reshape([4, 4]);
        let loss: f32 = siglip_loss(logits).into_scalar();
        assert!(loss > 0.0, "Loss must be positive for random logits");
        assert!(!loss.is_nan(), "Loss must not be NaN");
    }

    #[test]
    fn test_recall_at_k() {
        let device = NdArrayDevice::default();
        // Perfect similarity matrix.
        let data: Vec<f32> = (0..4usize)
            .flat_map(|i| (0..4usize).map(move |j| if i == j { 1.0f32 } else { 0.0f32 }))
            .collect();
        let logits = Tensor::<B, 1>::from_floats(data.as_slice(), &device).reshape([4, 4]);
        let r1 = recall_at_k(logits, 1);
        assert!((r1 - 1.0).abs() < 1e-5, "Perfect logits → Recall@1 = 1.0, got {r1}");
    }

    #[test]
    fn test_eye_pm1() {
        let device = NdArrayDevice::default();
        let labels = eye_pm1::<B>(3, device);
        let data: Vec<f32> = labels.into_data().to_vec::<f32>().unwrap();
        // Diagonal must be +1, off-diagonal -1.
        assert_eq!(data[0], 1.0);   // (0,0)
        assert_eq!(data[1], -1.0);  // (0,1)
        assert_eq!(data[4], 1.0);   // (1,1)
    }
}