#[must_use]
pub fn info_nce_loss(
anchor: &Vector<f32>,
positive: &Vector<f32>,
negatives: &[Vector<f32>],
temperature: f32,
) -> f32 {
assert_eq!(
anchor.len(),
positive.len(),
"Anchor and positive must have same dimension"
);
for neg in negatives {
assert_eq!(
anchor.len(),
neg.len(),
"All embeddings must have same dimension"
);
}
assert!(temperature > 0.0, "Temperature must be positive");
let sim_pos = cosine_similarity(anchor, positive) / temperature;
let mut max_sim = sim_pos;
for neg in negatives {
let sim_neg = cosine_similarity(anchor, neg) / temperature;
max_sim = max_sim.max(sim_neg);
}
let mut sum_exp = (sim_pos - max_sim).exp();
for neg in negatives {
let sim_neg = cosine_similarity(anchor, neg) / temperature;
sum_exp += (sim_neg - max_sim).exp();
}
-sim_pos + max_sim + sum_exp.ln()
}
#[must_use]
pub fn focal_loss(predictions: &Vector<f32>, targets: &Vector<f32>, alpha: f32, gamma: f32) -> f32 {
assert_eq!(
predictions.len(),
targets.len(),
"Predictions and targets must have same length"
);
let n = predictions.len() as f32;
let mut sum = 0.0;
for i in 0..predictions.len() {
let p = predictions[i].clamp(1e-7, 1.0 - 1e-7);
let t = targets[i];
let loss = if t > 0.5 {
-alpha * (1.0 - p).powf(gamma) * p.ln()
} else {
-(1.0 - alpha) * p.powf(gamma) * (1.0 - p).ln()
};
sum += loss;
}
sum / n
}
#[must_use]
pub fn kl_divergence(p: &Vector<f32>, q: &Vector<f32>) -> f32 {
assert_eq!(p.len(), q.len(), "Distributions must have same length");
let mut sum = 0.0;
for i in 0..p.len() {
if p[i] > 1e-10 {
let q_safe = q[i].max(1e-10);
sum += p[i] * (p[i] / q_safe).ln();
}
}
sum
}
#[derive(Debug, Clone, Copy)]
pub struct TripletLoss {
margin: f32,
}
impl TripletLoss {
#[must_use]
pub fn new(margin: f32) -> Self {
Self { margin }
}
#[must_use]
pub fn margin(&self) -> f32 {
self.margin
}
#[must_use]
pub fn compute_triplet(
&self,
anchor: &Vector<f32>,
positive: &Vector<f32>,
negative: &Vector<f32>,
) -> f32 {
triplet_loss(anchor, positive, negative, self.margin)
}
}
#[derive(Debug, Clone, Copy)]
pub struct FocalLoss {
alpha: f32,
gamma: f32,
}
impl FocalLoss {
#[must_use]
pub fn new(alpha: f32, gamma: f32) -> Self {
Self { alpha, gamma }
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
#[must_use]
pub fn gamma(&self) -> f32 {
self.gamma
}
}
impl Loss for FocalLoss {
fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
focal_loss(y_pred, y_true, self.alpha, self.gamma)
}
fn name(&self) -> &'static str {
"Focal"
}
}
#[derive(Debug, Clone, Copy)]
pub struct InfoNCELoss {
temperature: f32,
}
impl InfoNCELoss {
#[must_use]
pub fn new(temperature: f32) -> Self {
Self { temperature }
}
#[must_use]
pub fn temperature(&self) -> f32 {
self.temperature
}
#[must_use]
pub fn compute_contrastive(
&self,
anchor: &Vector<f32>,
positive: &Vector<f32>,
negatives: &[Vector<f32>],
) -> f32 {
info_nce_loss(anchor, positive, negatives, self.temperature)
}
}
#[must_use]
pub fn dice_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, smooth: f32) -> f32 {
assert_eq!(y_pred.len(), y_true.len());
let mut intersection = 0.0;
let mut pred_sum = 0.0;
let mut true_sum = 0.0;
for i in 0..y_pred.len() {
intersection += y_pred[i] * y_true[i];
pred_sum += y_pred[i];
true_sum += y_true[i];
}
let dice = (2.0 * intersection + smooth) / (pred_sum + true_sum + smooth);
1.0 - dice
}
#[must_use]
pub fn hinge_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, margin: f32) -> f32 {
assert_eq!(y_pred.len(), y_true.len());
let mut sum = 0.0;
for i in 0..y_pred.len() {
let loss = (margin - y_true[i] * y_pred[i]).max(0.0);
sum += loss;
}
sum / y_pred.len() as f32
}
#[must_use]
pub fn squared_hinge_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, margin: f32) -> f32 {
assert_eq!(y_pred.len(), y_true.len());
let mut sum = 0.0;
for i in 0..y_pred.len() {
let loss = (margin - y_true[i] * y_pred[i]).max(0.0);
sum += loss * loss;
}
sum / y_pred.len() as f32
}
#[derive(Debug, Clone, Copy)]
pub struct DiceLoss {
smooth: f32,
}
impl DiceLoss {
#[must_use]
pub fn new(smooth: f32) -> Self {
Self { smooth }
}
#[must_use]
pub fn smooth(&self) -> f32 {
self.smooth
}
}
impl Loss for DiceLoss {
fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
dice_loss(y_pred, y_true, self.smooth)
}
fn name(&self) -> &'static str {
"Dice"
}
}
#[derive(Debug, Clone, Copy)]
pub struct HingeLoss {
margin: f32,
}
impl HingeLoss {
#[must_use]
pub fn new(margin: f32) -> Self {
Self { margin }
}
#[must_use]
pub fn margin(&self) -> f32 {
self.margin
}
}
impl Loss for HingeLoss {
fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
hinge_loss(y_pred, y_true, self.margin)
}
fn name(&self) -> &'static str {
"Hinge"
}
}
#[derive(Debug, Clone)]
pub struct CTCLoss {
blank_idx: usize,
}