#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum Reduction {
#[default]
Mean,
Sum,
None,
}
pub trait Loss: Send + Sync {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32;
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>);
}
pub struct InfoNCELoss {
temperature: f32,
}
impl InfoNCELoss {
pub fn new(temperature: f32) -> Self {
Self {
temperature: temperature.max(0.01),
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
dot / (norm_a * norm_b)
}
}
impl Loss for InfoNCELoss {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
let neg_sims: Vec<f32> = negatives
.iter()
.map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
.collect();
let max_sim = neg_sims
.iter()
.copied()
.chain(std::iter::once(pos_sim))
.fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 =
neg_sims.iter().map(|s| (s - max_sim).exp()).sum::<f32>() + (pos_sim - max_sim).exp();
let log_sum_exp = max_sim + sum_exp.ln();
log_sum_exp - pos_sim
}
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>) {
let dim = anchor.len();
let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
let neg_sims: Vec<f32> = negatives
.iter()
.map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
.collect();
let max_sim = neg_sims
.iter()
.copied()
.chain(std::iter::once(pos_sim))
.fold(f32::NEG_INFINITY, f32::max);
let pos_exp = (pos_sim - max_sim).exp();
let neg_exps: Vec<f32> = neg_sims.iter().map(|s| (s - max_sim).exp()).collect();
let total_exp: f32 = pos_exp + neg_exps.iter().sum::<f32>();
let pos_weight = pos_exp / total_exp;
let neg_weights: Vec<f32> = neg_exps.iter().map(|e| e / total_exp).collect();
let loss = -(pos_weight.ln());
let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let norm_p: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let mut gradients = vec![0.0f32; dim];
let dot_ap: f32 = anchor.iter().zip(positive.iter()).map(|(a, p)| a * p).sum();
for i in 0..dim {
let d_sim = (positive[i] / (norm_a * norm_p))
- (anchor[i] * dot_ap / (norm_a.powi(3) * norm_p));
gradients[i] += (pos_weight - 1.0) * d_sim / self.temperature;
}
for (neg, &weight) in negatives.iter().zip(neg_weights.iter()) {
let norm_n: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let dot_an: f32 = anchor.iter().zip(neg.iter()).map(|(a, n)| a * n).sum();
for i in 0..dim {
let d_sim =
(neg[i] / (norm_a * norm_n)) - (anchor[i] * dot_an / (norm_a.powi(3) * norm_n));
gradients[i] += weight * d_sim / self.temperature;
}
}
(loss, gradients)
}
}
pub struct LocalContrastiveLoss {
margin: f32,
reduction: Reduction,
}
impl LocalContrastiveLoss {
pub fn new(margin: f32) -> Self {
Self {
margin,
reduction: Reduction::Mean,
}
}
pub fn with_reduction(mut self, reduction: Reduction) -> Self {
self.reduction = reduction;
self
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
impl Loss for LocalContrastiveLoss {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
let d_pos = Self::euclidean_distance(anchor, positive);
let losses: Vec<f32> = negatives
.iter()
.map(|neg| {
let d_neg = Self::euclidean_distance(anchor, neg);
(d_pos - d_neg + self.margin).max(0.0)
})
.collect();
match self.reduction {
Reduction::Mean => losses.iter().sum::<f32>() / losses.len().max(1) as f32,
Reduction::Sum => losses.iter().sum(),
Reduction::None => losses.first().copied().unwrap_or(0.0),
}
}
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>) {
let dim = anchor.len();
let d_pos = Self::euclidean_distance(anchor, positive);
let mut total_loss = 0.0f32;
let mut gradients = vec![0.0f32; dim];
let mut active_count = 0;
for neg in negatives.iter() {
let d_neg = Self::euclidean_distance(anchor, neg);
let margin_loss = d_pos - d_neg + self.margin;
if margin_loss > 0.0 {
total_loss += margin_loss;
active_count += 1;
for i in 0..dim {
if d_pos > 1e-8 {
gradients[i] += (anchor[i] - positive[i]) / d_pos;
}
if d_neg > 1e-8 {
gradients[i] -= (anchor[i] - neg[i]) / d_neg;
}
}
}
}
let loss = match self.reduction {
Reduction::Mean if active_count > 0 => {
gradients.iter_mut().for_each(|g| *g /= active_count as f32);
total_loss / active_count as f32
}
Reduction::Sum => total_loss,
_ => total_loss / negatives.len().max(1) as f32,
};
(loss, gradients)
}
}
pub struct SpectralRegularization {
weight: f32,
}
impl SpectralRegularization {
pub fn new(weight: f32) -> Self {
Self { weight }
}
pub fn compute_batch(&self, embeddings: &[&[f32]]) -> f32 {
if embeddings.is_empty() {
return 0.0;
}
let dim = embeddings[0].len();
let n = embeddings.len();
let mut var_sum = 0.0f32;
for d in 0..dim {
let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
let var: f32 = embeddings
.iter()
.map(|e| (e[d] - mean).powi(2))
.sum::<f32>()
/ n as f32;
var_sum += var;
}
let avg_var = var_sum / dim as f32;
let var_of_var: f32 = {
let mut sum = 0.0;
for d in 0..dim {
let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
let var: f32 = embeddings
.iter()
.map(|e| (e[d] - mean).powi(2))
.sum::<f32>()
/ n as f32;
sum += (var - avg_var).powi(2);
}
sum / dim as f32
};
self.weight * var_of_var
}
}
impl Loss for SpectralRegularization {
fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
let mut all_embeddings: Vec<&[f32]> = Vec::with_capacity(2 + negatives.len());
all_embeddings.push(anchor);
all_embeddings.push(positive);
all_embeddings.extend(negatives.iter().copied());
self.compute_batch(&all_embeddings)
}
fn compute_with_gradients(
&self,
anchor: &[f32],
positive: &[f32],
negatives: &[&[f32]],
) -> (f32, Vec<f32>) {
let loss = self.compute(anchor, positive, negatives);
let gradients = vec![0.0f32; anchor.len()];
(loss, gradients)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infonce_loss() {
let loss = InfoNCELoss::new(0.07);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negatives: Vec<Vec<f32>> = vec![vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]];
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
let loss_val = loss.compute(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
}
#[test]
fn test_infonce_gradients() {
let loss = InfoNCELoss::new(0.1);
let anchor = vec![0.5; 64];
let positive = vec![0.6; 64];
let negatives: Vec<Vec<f32>> = vec![vec![0.1; 64]; 5];
let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
let (loss_val, grads) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
assert_eq!(grads.len(), 64);
}
#[test]
fn test_local_contrastive() {
let loss = LocalContrastiveLoss::new(1.0);
let anchor = vec![0.0, 0.0];
let positive = vec![0.1, 0.0]; let negatives: Vec<Vec<f32>> = vec![vec![2.0, 0.0], vec![0.0, 2.0]]; let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
let loss_val = loss.compute(&anchor, &positive, &neg_refs);
assert!(loss_val >= 0.0);
}
#[test]
fn test_spectral_regularization() {
let reg = SpectralRegularization::new(0.01);
let embeddings: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
let emb_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let loss_val = reg.compute_batch(&emb_refs);
assert!(loss_val >= 0.0);
}
}