use crate::error::{NeuralError, Result};
use crate::losses::Loss;
use scirs2_core::ndarray::Array;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy)]
pub struct TripletLoss {
margin: f64,
}
impl TripletLoss {
pub fn new(margin: f64) -> Self {
Self { margin }
}
}
impl Default for TripletLoss {
fn default() -> Self {
Self::new(1.0)
}
}
impl<F: Float + Debug + NumAssign> Loss<F> for TripletLoss {
fn forward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
_targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<F> {
if predictions.ndim() != 3 || predictions.shape()[1] != 3 {
return Err(NeuralError::InferenceError(format!(
"Expected predictions shape (batch_size, 3, embedding_dim), got {:?}",
predictions.shape()
)));
}
let batch_size = predictions.shape()[0];
let embedding_dim = predictions.shape()[2];
let margin = F::from(self.margin).ok_or_else(|| {
NeuralError::InferenceError("Could not convert margin to float".to_string())
})?;
let mut total_loss = F::zero();
let n = F::from(batch_size).ok_or_else(|| {
NeuralError::InferenceError("Could not convert batch size to float".to_string())
})?;
for i in 0..batch_size {
let anchor = predictions.slice(scirs2_core::ndarray::s![i, 0, ..]);
let positive = predictions.slice(scirs2_core::ndarray::s![i, 1, ..]);
let negative = predictions.slice(scirs2_core::ndarray::s![i, 2, ..]);
let mut pos_distance_squared = F::zero();
let mut neg_distance_squared = F::zero();
for j in 0..embedding_dim {
let pos_diff = anchor[j] - positive[j];
pos_distance_squared += pos_diff * pos_diff;
let neg_diff = anchor[j] - negative[j];
neg_distance_squared += neg_diff * neg_diff;
}
let pos_distance = pos_distance_squared.sqrt();
let neg_distance = neg_distance_squared.sqrt();
let zero = F::zero();
let triplet_loss = (pos_distance - neg_distance + margin).max(zero);
total_loss += triplet_loss;
}
let loss = total_loss / n;
Ok(loss)
}
fn backward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
_targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let batch_size = predictions.shape()[0];
let embedding_dim = predictions.shape()[2];
let margin = F::from(self.margin)
.ok_or_else(|| NeuralError::ComputationError("Failed to convert margin".to_string()))?;
let n = F::from(batch_size).ok_or_else(|| {
NeuralError::ComputationError("Failed to convert batch size".to_string())
})?;
let mut gradients = Array::zeros(predictions.raw_dim());
for i in 0..batch_size {
let anchor = predictions.slice(scirs2_core::ndarray::s![i, 0, ..]);
let positive = predictions.slice(scirs2_core::ndarray::s![i, 1, ..]);
let negative = predictions.slice(scirs2_core::ndarray::s![i, 2, ..]);
let mut pos_distance_squared = F::zero();
let mut neg_distance_squared = F::zero();
for j in 0..embedding_dim {
let pos_diff = anchor[j] - positive[j];
pos_distance_squared += pos_diff * pos_diff;
let neg_diff = anchor[j] - negative[j];
neg_distance_squared += neg_diff * neg_diff;
}
let pos_distance = pos_distance_squared.sqrt();
let neg_distance = neg_distance_squared.sqrt();
let pos_distance_safe =
pos_distance.max(F::from(1e-10).expect("Failed to convert constant to float"));
let neg_distance_safe =
neg_distance.max(F::from(1e-10).expect("Failed to convert constant to float"));
if pos_distance - neg_distance + margin > F::zero() {
for j in 0..embedding_dim {
let pos_grad = (anchor[j] - positive[j]) / pos_distance_safe;
let neg_grad = (anchor[j] - negative[j]) / neg_distance_safe;
gradients[[i, 0, j]] = (pos_grad - neg_grad) / n;
gradients[[i, 1, j]] = -pos_grad / n;
gradients[[i, 2, j]] = neg_grad / n;
}
}
}
Ok(gradients)
}
}