use crate::error::{CnnError, CnnResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum TripletDistance {
Euclidean,
SquaredEuclidean,
Cosine,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TripletResult {
pub loss: f64,
pub positive_distance: f64,
pub negative_distance: f64,
pub is_hard: bool,
pub violates_margin: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TripletLoss {
margin: f64,
distance: TripletDistance,
soft_margin: bool,
l2_regularization: Option<f64>,
}
impl TripletLoss {
pub fn new(margin: f64) -> Self {
assert!(margin >= 0.0, "Margin must be non-negative");
Self {
margin,
distance: TripletDistance::SquaredEuclidean,
soft_margin: false,
l2_regularization: None,
}
}
pub fn with_distance(mut self, distance: TripletDistance) -> Self {
self.distance = distance;
self
}
pub fn with_soft_margin(mut self) -> Self {
self.soft_margin = true;
self
}
pub fn with_l2_regularization(mut self, weight: f64) -> Self {
self.l2_regularization = Some(weight);
self
}
pub fn margin(&self) -> f64 {
self.margin
}
pub fn distance_metric(&self) -> TripletDistance {
self.distance
}
pub fn forward(&self, anchor: &[f64], positive: &[f64], negative: &[f64]) -> f64 {
self.forward_detailed(anchor, positive, negative)
.map(|r| r.loss)
.unwrap_or(0.0)
}
pub fn forward_detailed(
&self,
anchor: &[f64],
positive: &[f64],
negative: &[f64],
) -> CnnResult<TripletResult> {
if anchor.is_empty() {
return Err(CnnError::InvalidInput("anchor cannot be empty".to_string()));
}
let dim = anchor.len();
if positive.len() != dim {
return Err(CnnError::DimensionMismatch(format!(
"positive has dimension {}, expected {}",
positive.len(),
dim
)));
}
if negative.len() != dim {
return Err(CnnError::DimensionMismatch(format!(
"negative has dimension {}, expected {}",
negative.len(),
dim
)));
}
for (name, vec) in [("anchor", anchor), ("positive", positive), ("negative", negative)] {
if vec.iter().any(|x| x.is_nan() || x.is_infinite()) {
return Err(CnnError::InvalidInput(format!(
"{} contains NaN or Inf",
name
)));
}
}
let pos_dist = self.compute_distance(anchor, positive);
let neg_dist = self.compute_distance(anchor, negative);
let diff = pos_dist - neg_dist + self.margin;
let loss = if self.soft_margin {
soft_relu(diff)
} else {
diff.max(0.0)
};
let loss = if let Some(weight) = self.l2_regularization {
let anchor_norm: f64 = anchor.iter().map(|x| x * x).sum();
let pos_norm: f64 = positive.iter().map(|x| x * x).sum();
let neg_norm: f64 = negative.iter().map(|x| x * x).sum();
loss + weight * (anchor_norm + pos_norm + neg_norm) / 3.0
} else {
loss
};
Ok(TripletResult {
loss,
positive_distance: pos_dist,
negative_distance: neg_dist,
is_hard: diff > 0.0,
violates_margin: pos_dist + self.margin > neg_dist,
})
}
pub fn forward_batch(
&self,
anchors: &[Vec<f64>],
positives: &[Vec<f64>],
negatives: &[Vec<f64>],
) -> CnnResult<f64> {
if anchors.len() != positives.len() || anchors.len() != negatives.len() {
return Err(CnnError::DimensionMismatch(format!(
"Batch sizes must match: anchors={}, positives={}, negatives={}",
anchors.len(),
positives.len(),
negatives.len()
)));
}
if anchors.is_empty() {
return Err(CnnError::InvalidInput("batch cannot be empty".to_string()));
}
let mut total_loss = 0.0;
for ((anchor, positive), negative) in anchors.iter().zip(positives).zip(negatives) {
total_loss += self.forward(anchor, positive, negative);
}
Ok(total_loss / anchors.len() as f64)
}
pub fn mine_hard_triplets(
&self,
embeddings: &[Vec<f64>],
labels: &[usize],
) -> Vec<(usize, usize, usize)> {
if embeddings.len() != labels.len() {
return vec![];
}
let n = embeddings.len();
let mut triplets = Vec::new();
let distances = self.compute_distance_matrix(embeddings);
for anchor_idx in 0..n {
let anchor_label = labels[anchor_idx];
let mut hardest_pos_idx = None;
let mut hardest_pos_dist = f64::NEG_INFINITY;
let mut hardest_neg_idx = None;
let mut hardest_neg_dist = f64::INFINITY;
for other_idx in 0..n {
if other_idx == anchor_idx {
continue;
}
let dist = distances[anchor_idx][other_idx];
if labels[other_idx] == anchor_label {
if dist > hardest_pos_dist {
hardest_pos_dist = dist;
hardest_pos_idx = Some(other_idx);
}
} else {
if dist < hardest_neg_dist {
hardest_neg_dist = dist;
hardest_neg_idx = Some(other_idx);
}
}
}
if let (Some(pos_idx), Some(neg_idx)) = (hardest_pos_idx, hardest_neg_idx) {
if hardest_pos_dist - hardest_neg_dist + self.margin > 0.0 {
triplets.push((anchor_idx, pos_idx, neg_idx));
}
}
}
triplets
}
#[inline]
fn compute_distance(&self, a: &[f64], b: &[f64]) -> f64 {
match self.distance {
TripletDistance::Euclidean => {
let sum_sq: f64 = a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum();
sum_sq.sqrt()
}
TripletDistance::SquaredEuclidean => {
a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum()
}
TripletDistance::Cosine => {
let mut dot = 0.0;
let mut norm_a_sq = 0.0;
let mut norm_b_sq = 0.0;
for (x, y) in a.iter().zip(b) {
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
}
let norm = (norm_a_sq * norm_b_sq).sqrt();
if norm < 1e-8 {
1.0 } else {
1.0 - dot / norm
}
}
}
}
fn compute_distance_matrix(&self, embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = embeddings.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let dist = self.compute_distance(&embeddings[i], &embeddings[j]);
matrix[i][j] = dist;
matrix[j][i] = dist;
}
}
matrix
}
}
impl Default for TripletLoss {
fn default() -> Self {
Self::new(1.0)
}
}
#[inline]
fn soft_relu(x: f64) -> f64 {
if x > 20.0 {
x } else if x < -20.0 {
0.0 } else {
(1.0 + x.exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_triplet_basic() {
let triplet = TripletLoss::new(1.0);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss = triplet.forward(&anchor, &positive, &negative);
assert!(loss >= 0.0);
}
#[test]
fn test_triplet_zero_loss() {
let triplet = TripletLoss::new(0.1);
let anchor = vec![1.0, 0.0];
let positive = vec![1.0, 0.0]; let negative = vec![-1.0, 0.0];
let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap();
assert_eq!(result.loss, 0.0);
assert!(!result.is_hard);
}
#[test]
fn test_triplet_hard() {
let triplet = TripletLoss::new(1.0);
let anchor = vec![0.0, 0.0];
let positive = vec![2.0, 0.0];
let negative = vec![1.0, 0.0];
let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap();
assert!(result.loss > 0.0);
assert!(result.is_hard);
assert!(result.violates_margin);
}
#[test]
fn test_triplet_distances() {
let triplet_euclidean = TripletLoss::new(0.0).with_distance(TripletDistance::Euclidean);
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let c = vec![0.0, 0.0];
let result = triplet_euclidean.forward_detailed(&a, &b, &c).unwrap();
assert!((result.positive_distance - 5.0).abs() < 1e-6);
assert!(result.negative_distance.abs() < 1e-6);
let triplet_cosine = TripletLoss::new(0.0).with_distance(TripletDistance::Cosine);
let x = vec![1.0, 0.0];
let y = vec![0.0, 1.0];
let z = vec![1.0, 0.0];
let result = triplet_cosine.forward_detailed(&x, &y, &z).unwrap();
assert!((result.positive_distance - 1.0).abs() < 1e-6); assert!(result.negative_distance.abs() < 1e-6); }
#[test]
fn test_soft_margin() {
let hard = TripletLoss::new(1.0);
let soft = TripletLoss::new(1.0).with_soft_margin();
let anchor = vec![0.0, 0.0];
let positive = vec![1.0, 0.0];
let negative = vec![0.5, 0.0];
let hard_loss = hard.forward(&anchor, &positive, &negative);
let soft_loss = soft.forward(&anchor, &positive, &negative);
assert!(soft_loss >= hard_loss);
assert!(hard_loss > 0.0);
assert!(soft_loss > 0.0);
}
#[test]
fn test_batch_triplet() {
let triplet = TripletLoss::new(1.0);
let anchors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let positives = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let negatives = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let loss = triplet.forward_batch(&anchors, &positives, &negatives).unwrap();
assert!(loss >= 0.0);
}
#[test]
fn test_mine_hard_triplets() {
let triplet = TripletLoss::new(0.01);
let embeddings = vec![
vec![1.0, 0.0], vec![0.95, 0.05], vec![0.9, 0.1], vec![0.85, 0.15], ];
let labels = vec![0, 0, 1, 1];
let hard_triplets = triplet.mine_hard_triplets(&embeddings, &labels);
for (a, p, n) in &hard_triplets {
assert_eq!(labels[*a], labels[*p]); assert_ne!(labels[*a], labels[*n]); }
}
#[test]
fn test_l2_regularization() {
let no_reg = TripletLoss::new(0.0);
let with_reg = TripletLoss::new(0.0).with_l2_regularization(0.01);
let anchor = vec![10.0, 0.0];
let positive = vec![10.0, 0.0];
let negative = vec![-10.0, 0.0];
let loss_no_reg = no_reg.forward(&anchor, &positive, &negative);
let loss_with_reg = with_reg.forward(&anchor, &positive, &negative);
assert!(loss_with_reg > loss_no_reg);
}
#[test]
fn test_error_handling() {
let triplet = TripletLoss::new(1.0);
let result = triplet.forward_detailed(&[], &[1.0], &[1.0]);
assert!(result.is_err());
let result = triplet.forward_detailed(&[1.0, 2.0], &[1.0], &[1.0, 2.0]);
assert!(result.is_err());
}
#[test]
fn test_soft_relu() {
assert!((soft_relu(0.0) - 2.0_f64.ln()).abs() < 1e-6);
assert!(soft_relu(-100.0) < 1e-10);
assert!((soft_relu(100.0) - 100.0).abs() < 1e-6);
let x = 1.0;
let y = soft_relu(x);
assert!(y > x.max(0.0)); }
}