use crate::error::{NeuralError, Result};
use crate::losses::Loss;
use scirs2_core::ndarray::Array;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy)]
pub struct ContrastiveLoss {
margin: f64,
}
impl ContrastiveLoss {
pub fn new(margin: f64) -> Self {
Self { margin }
}
}
impl Default for ContrastiveLoss {
fn default() -> Self {
Self::new(1.0)
}
}
impl<F: Float + Debug + NumAssign> Loss<F> for ContrastiveLoss {
fn forward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<F> {
if predictions.ndim() != 3 || predictions.shape()[1] != 2 {
return Err(NeuralError::InferenceError(format!(
"Expected predictions shape (batch_size, 2, embedding_dim), got {:?}",
predictions.shape()
)));
}
if targets.ndim() != 2 || targets.shape()[1] != 1 {
return Err(NeuralError::InferenceError(format!(
"Expected targets shape (batch_size, 1), got {:?}",
targets.shape()
)));
}
if predictions.shape()[0] != targets.shape()[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: predictions {} vs targets {}",
predictions.shape()[0],
targets.shape()[0]
)));
}
let batch_size = predictions.shape()[0];
let embedding_dim = predictions.shape()[2];
let margin = F::from(self.margin).ok_or_else(|| {
NeuralError::InferenceError("Could not convert margin to float".to_string())
})?;
let mut total_loss = F::zero();
let n = F::from(batch_size).ok_or_else(|| {
NeuralError::InferenceError("Could not convert batch size to float".to_string())
})?;
for i in 0..batch_size {
let x1 = predictions.slice(scirs2_core::ndarray::s![i, 0, ..]);
let x2 = predictions.slice(scirs2_core::ndarray::s![i, 1, ..]);
let mut distance_squared = F::zero();
for j in 0..embedding_dim {
let diff = x1[j] - x2[j];
distance_squared += diff * diff;
}
let distance = distance_squared.sqrt();
let y = targets[[i, 0]];
let pair_loss = if y > F::zero() {
distance_squared
} else {
let zero = F::zero();
let margin_term = (margin - distance).max(zero);
margin_term * margin_term
};
total_loss += pair_loss;
}
let loss = total_loss / n;
Ok(loss)
}
fn backward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let batch_size = predictions.shape()[0];
let embedding_dim = predictions.shape()[2];
let n = F::from(batch_size).ok_or_else(|| {
NeuralError::ComputationError("Failed to convert batch size".to_string())
})?;
let margin = F::from(self.margin)
.ok_or_else(|| NeuralError::ComputationError("Failed to convert margin".to_string()))?;
let mut gradients = Array::zeros(predictions.raw_dim());
for i in 0..batch_size {
let x1 = predictions.slice(scirs2_core::ndarray::s![i, 0, ..]);
let x2 = predictions.slice(scirs2_core::ndarray::s![i, 1, ..]);
let y = targets[[i, 0]];
let mut distance_sq = F::zero();
for j in 0..embedding_dim {
let diff = x1[j] - x2[j];
distance_sq += diff * diff;
}
let distance = distance_sq.sqrt();
let distance_safe =
distance.max(F::from(1e-10).expect("Failed to convert constant to float"));
if y > F::zero() {
for j in 0..embedding_dim {
gradients[[i, 0, j]] = F::from(2.0)
.expect("Failed to convert constant to float")
* (x1[j] - x2[j])
/ n;
gradients[[i, 1, j]] = F::from(2.0)
.expect("Failed to convert constant to float")
* (x2[j] - x1[j])
/ n;
}
} else {
if distance < margin {
let factor = F::from(-2.0).expect("Failed to convert constant to float")
* (margin - distance)
/ distance_safe
/ n;
for j in 0..embedding_dim {
gradients[[i, 0, j]] = factor * (x1[j] - x2[j]);
gradients[[i, 1, j]] = factor * (x2[j] - x1[j]);
}
}
}
}
Ok(gradients)
}
}