use axonml_autograd::Variable;
use axonml_tensor::Tensor;
fn triplet_loss_raw(anchor: &[f32], positive: &[f32], negative: &[f32], margin: f32) -> f32 {
let dim = anchor.len();
assert_eq!(dim, positive.len());
assert_eq!(dim, negative.len());
let mut dot_pos = 0.0f32;
let mut dot_neg = 0.0f32;
for i in 0..dim {
dot_pos += anchor[i] * positive[i];
dot_neg += anchor[i] * negative[i];
}
let dist_pos = 1.0 - dot_pos;
let dist_neg = 1.0 - dot_neg;
(dist_pos - dist_neg + margin).max(0.0)
}
fn triplet_loss_var(
anchor: &Variable,
positive: &Variable,
negative: &Variable,
margin: f32,
) -> Variable {
let dot_pos = anchor.mul_var(positive).sum(); let dist_pos = dot_pos.mul_scalar(-1.0).add_scalar(1.0);
let dot_neg = anchor.mul_var(negative).sum();
let dist_neg = dot_neg.mul_scalar(-1.0).add_scalar(1.0);
let raw_loss = dist_pos.sub_var(&dist_neg).add_scalar(margin);
raw_loss.relu()
}
pub struct CrystallizationLoss {
pub margin: f32,
pub convergence_weight: f32,
pub target_velocity: f32,
}
impl Default for CrystallizationLoss {
fn default() -> Self {
Self {
margin: 0.3,
convergence_weight: 0.1,
target_velocity: 0.1,
}
}
}
impl CrystallizationLoss {
pub fn compute(
&self,
anchor: &[f32],
positive: &[f32],
negative: &[f32],
velocities: &[f32],
) -> f32 {
let triplet = triplet_loss_raw(anchor, positive, negative, self.margin);
let conv_reg: f32 = velocities
.iter()
.map(|v| (v - self.target_velocity).max(0.0).powi(2))
.sum::<f32>()
/ velocities.len().max(1) as f32;
triplet + self.convergence_weight * conv_reg
}
pub fn compute_var(
&self,
anchor: &Variable,
positive: &Variable,
negative: &Variable,
velocities: &Variable,
) -> Variable {
let triplet = triplet_loss_var(anchor, positive, negative, self.margin);
let v_data = velocities.data().to_vec();
let conv_reg: f32 = v_data
.iter()
.map(|v| (v - self.target_velocity).max(0.0).powi(2))
.sum::<f32>()
/ v_data.len().max(1) as f32;
triplet.add_scalar(self.convergence_weight * conv_reg)
}
}
pub struct ContrastiveLoss {
pub margin: f32,
pub orientation_weight: f32,
}
impl Default for ContrastiveLoss {
fn default() -> Self {
Self {
margin: 1.0,
orientation_weight: 0.05,
}
}
}
impl ContrastiveLoss {
pub fn compute(&self, embedding_a: &[f32], embedding_b: &[f32], is_same: bool) -> f32 {
let dim = embedding_a.len();
let mut dist_sq = 0.0f32;
for i in 0..dim {
let d = embedding_a[i] - embedding_b[i];
dist_sq += d * d;
}
if is_same {
dist_sq
} else {
(self.margin - dist_sq.sqrt()).max(0.0).powi(2)
}
}
pub fn compute_var(
&self,
embedding_a: &Variable,
embedding_b: &Variable,
is_same: bool,
) -> Variable {
let diff = embedding_a.sub_var(embedding_b);
let dist_sq = diff.mul_var(&diff).sum();
if is_same {
dist_sq
} else {
let dist = dist_sq.add_scalar(1e-8).sqrt();
let margin_diff = dist.mul_scalar(-1.0).add_scalar(self.margin);
let clamped = margin_diff.relu();
clamped.mul_var(&clamped)
}
}
}
pub struct EchoLoss {
pub prediction_weight: f32,
pub speaker_weight: f32,
pub margin: f32,
}
impl Default for EchoLoss {
fn default() -> Self {
Self {
prediction_weight: 1.0,
speaker_weight: 0.5,
margin: 0.3,
}
}
}
impl EchoLoss {
pub fn prediction_loss(predicted: &Variable, actual: &Variable) -> f32 {
let p = predicted.data().to_vec();
let a = actual.data().to_vec();
let n = p.len() as f32;
p.iter()
.zip(a.iter())
.map(|(pi, ai)| (pi - ai) * (pi - ai))
.sum::<f32>()
/ n
}
pub fn prediction_loss_var(predicted: &Variable, actual: &Variable) -> Variable {
let diff = predicted.sub_var(actual);
let sq = diff.mul_var(&diff);
sq.mean()
}
pub fn compute(
&self,
predicted_mel: &Variable,
actual_mel: &Variable,
speaker_anchor: &[f32],
speaker_pos: &[f32],
speaker_neg: &[f32],
) -> f32 {
let pred_loss = Self::prediction_loss(predicted_mel, actual_mel);
let speaker_loss = triplet_loss_raw(speaker_anchor, speaker_pos, speaker_neg, self.margin);
self.prediction_weight * pred_loss + self.speaker_weight * speaker_loss
}
pub fn compute_var(
&self,
predicted_mel: &Variable,
actual_mel: &Variable,
speaker_anchor: &Variable,
speaker_pos: &Variable,
speaker_neg: &Variable,
) -> Variable {
let pred_loss = Self::prediction_loss_var(predicted_mel, actual_mel);
let speaker_loss = triplet_loss_var(speaker_anchor, speaker_pos, speaker_neg, self.margin);
pred_loss
.mul_scalar(self.prediction_weight)
.add_var(&speaker_loss.mul_scalar(self.speaker_weight))
}
}
pub struct ArgusLoss {
pub margin: f32,
pub phase_weight: f32,
}
impl Default for ArgusLoss {
fn default() -> Self {
Self {
margin: 0.3,
phase_weight: 0.1,
}
}
}
impl ArgusLoss {
pub fn compute(
&self,
anchor: &[f32],
positive: &[f32],
negative: &[f32],
code_original: &[f32],
code_rotated: &[f32],
) -> f32 {
let triplet = triplet_loss_raw(anchor, positive, negative, self.margin);
let mut dot = 0.0f32;
for i in 0..code_original.len() {
dot += code_original[i] * code_rotated[i];
}
let phase_loss = (1.0 - dot).max(0.0);
triplet + self.phase_weight * phase_loss
}
pub fn compute_var(
&self,
anchor: &Variable,
positive: &Variable,
negative: &Variable,
code_original: &Variable,
code_rotated: &Variable,
) -> Variable {
let triplet = triplet_loss_var(anchor, positive, negative, self.margin);
let dot = code_original.mul_var(code_rotated).sum();
let phase_loss = dot.mul_scalar(-1.0).add_scalar(1.0).relu();
triplet.add_var(&phase_loss.mul_scalar(self.phase_weight))
}
}
pub struct ThemisLoss {
pub bce_weight: f32,
pub triplet_weight: f32,
pub calibration_weight: f32,
pub margin: f32,
}
impl Default for ThemisLoss {
fn default() -> Self {
Self {
bce_weight: 1.0,
triplet_weight: 0.5,
calibration_weight: 0.1,
margin: 0.3,
}
}
}
impl ThemisLoss {
fn bce(predicted: f32, target: f32) -> f32 {
let p = predicted.clamp(1e-7, 1.0 - 1e-7);
-(target * p.ln() + (1.0 - target) * (1.0 - p).ln())
}
fn calibration_loss(confidence: f32, was_correct: bool) -> f32 {
let acc = if was_correct { 1.0 } else { 0.0 };
(confidence - acc).powi(2)
}
pub fn compute(
&self,
match_prob: f32,
is_match: bool,
fused_anchor: &[f32],
fused_pos: &[f32],
fused_neg: &[f32],
confidence: f32,
) -> f32 {
let target = if is_match { 1.0 } else { 0.0 };
let bce = Self::bce(match_prob, target);
let triplet = triplet_loss_raw(fused_anchor, fused_pos, fused_neg, self.margin);
let prediction_correct = (match_prob > 0.5) == is_match;
let cal = Self::calibration_loss(confidence, prediction_correct);
self.bce_weight * bce + self.triplet_weight * triplet + self.calibration_weight * cal
}
pub fn compute_var(
&self,
match_prob: f32,
is_match: bool,
fused_anchor: &Variable,
fused_pos: &Variable,
fused_neg: &Variable,
confidence: f32,
) -> Variable {
let target = if is_match { 1.0 } else { 0.0 };
let bce = Self::bce(match_prob, target);
let triplet = triplet_loss_var(fused_anchor, fused_pos, fused_neg, self.margin);
let prediction_correct = (match_prob > 0.5) == is_match;
let cal = Self::calibration_loss(confidence, prediction_correct);
triplet
.mul_scalar(self.triplet_weight)
.add_scalar(self.bce_weight * bce + self.calibration_weight * cal)
}
}
pub struct CenterLoss {
pub weight: f32,
pub uncertainty_alpha: f32,
}
impl Default for CenterLoss {
fn default() -> Self {
Self {
weight: 0.01,
uncertainty_alpha: 1.0,
}
}
}
impl CenterLoss {
pub fn compute(
&self,
embeddings: &[f32],
centers: &[f32],
log_variances: &[f32],
dim: usize,
) -> f32 {
assert_eq!(embeddings.len(), centers.len());
let n = log_variances.len();
if n == 0 || dim == 0 {
return 0.0;
}
assert_eq!(embeddings.len(), n * dim);
let mut total = 0.0f32;
for i in 0..n {
let offset = i * dim;
let mut dist_sq = 0.0f32;
for d in 0..dim {
let diff = embeddings[offset + d] - centers[offset + d];
dist_sq += diff * diff;
}
let variance = log_variances[i].exp();
let w = (-self.uncertainty_alpha * variance).exp();
total += w * dist_sq;
}
self.weight * total / n as f32
}
pub fn compute_var(
&self,
embeddings: &Variable,
centers: &Variable,
log_variances: &Variable,
) -> Variable {
let diff = embeddings.sub_var(centers);
let dist_sq = diff.mul_var(&diff);
let lv_data = log_variances.data().to_vec();
let n = lv_data.len();
if n == 0 {
return dist_sq.mul_scalar(0.0);
}
let mean_weight: f32 = lv_data
.iter()
.map(|lv| (-self.uncertainty_alpha * lv.exp()).exp())
.sum::<f32>()
/ n as f32;
dist_sq.mean().mul_scalar(self.weight * mean_weight)
}
}
pub struct AngularMarginLoss {
pub margin: f32,
pub scale: f32,
pub uncertainty_beta: f32,
}
impl Default for AngularMarginLoss {
fn default() -> Self {
Self {
margin: 0.3,
scale: 30.0,
uncertainty_beta: 1.0,
}
}
}
impl AngularMarginLoss {
pub fn compute(&self, cos_similarities: &[f32], target_class: usize, log_variance: f32) -> f32 {
let n_classes = cos_similarities.len();
if n_classes == 0 || target_class >= n_classes {
return 0.0;
}
let variance = log_variance.exp();
let adaptive_margin = self.margin * (-self.uncertainty_beta * variance).exp();
let cos_target = cos_similarities[target_class].clamp(-1.0, 1.0);
let theta = cos_target.acos();
let cos_with_margin = (theta + adaptive_margin).cos();
let mut log_sum_exp = f32::NEG_INFINITY;
for (j, &cos_j) in cos_similarities.iter().enumerate() {
let logit = if j == target_class {
self.scale * cos_with_margin
} else {
self.scale * cos_j
};
if logit > log_sum_exp {
log_sum_exp = logit + (1.0 + (log_sum_exp - logit).exp()).ln();
} else {
log_sum_exp = log_sum_exp + (1.0 + (logit - log_sum_exp).exp()).ln();
}
}
let target_logit = self.scale * cos_with_margin;
(log_sum_exp - target_logit).max(0.0)
}
pub fn compute_var(
&self,
cos_similarities: &Variable,
target_class: usize,
log_variance: &Variable,
) -> Variable {
let cos_data = cos_similarities.data().to_vec();
let n_classes = cos_data.len();
if n_classes == 0 || target_class >= n_classes {
return Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
}
let lv_data = log_variance.data().to_vec();
let lv = if lv_data.is_empty() { 0.0 } else { lv_data[0] };
let variance = lv.exp();
let adaptive_margin = self.margin * (-self.uncertainty_beta * variance).exp();
let cos_target = cos_data[target_class].clamp(-1.0, 1.0);
let theta = cos_target.acos();
let cos_with_margin = (theta + adaptive_margin).cos();
let mut modified_logits = cos_data.clone();
modified_logits[target_class] = cos_with_margin;
let scaled: Vec<f32> = modified_logits.iter().map(|c| self.scale * c).collect();
let max_val = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = scaled.iter().map(|s| (s - max_val).exp()).sum();
let log_sum_exp = max_val + sum_exp.ln();
let nll = log_sum_exp - scaled[target_class];
let scaled_var = cos_similarities.mul_scalar(self.scale);
scaled_var.mul_scalar(0.0).add_scalar(nll)
}
}
pub struct DiversityRegularization {
pub weight: f32,
pub target_similarity: f32,
}
impl Default for DiversityRegularization {
fn default() -> Self {
Self {
weight: 0.01,
target_similarity: 0.1,
}
}
}
impl DiversityRegularization {
pub fn compute(&self, embeddings: &[f32], n: usize, dim: usize) -> f32 {
if n < 2 || dim == 0 {
return 0.0;
}
assert!(embeddings.len() >= n * dim);
let mut total_sim = 0.0f32;
let mut pair_count = 0;
for i in 0..n {
for j in (i + 1)..n {
let offset_i = i * dim;
let offset_j = j * dim;
let mut dot = 0.0f32;
for d in 0..dim {
dot += embeddings[offset_i + d] * embeddings[offset_j + d];
}
total_sim += dot;
pair_count += 1;
}
}
if pair_count == 0 {
return 0.0;
}
let avg_sim = total_sim / pair_count as f32;
let excess = (avg_sim - self.target_similarity).max(0.0);
self.weight * excess * excess
}
pub fn compute_var(&self, embeddings: &[Variable]) -> Variable {
let n = embeddings.len();
if n < 2 {
return Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
}
let mut pair_sum = embeddings[0].mul_var(&embeddings[1]).sum();
let mut pair_count = 1usize;
for i in 0..n {
for j in (i + 1)..n {
if i == 0 && j == 1 {
continue; }
let dot_ij = embeddings[i].mul_var(&embeddings[j]).sum();
pair_sum = pair_sum.add_var(&dot_ij);
pair_count += 1;
}
}
let avg_sim = pair_sum.mul_scalar(1.0 / pair_count as f32);
let excess = avg_sim.add_scalar(-self.target_similarity).relu();
let penalty = excess.mul_var(&excess);
penalty.mul_scalar(self.weight)
}
}
pub struct LivenessLoss {
pub smoothness_weight: f32,
pub variance_weight: f32,
pub classification_weight: f32,
pub target_smoothness: f32,
pub target_variance: f32,
}
impl Default for LivenessLoss {
fn default() -> Self {
Self {
smoothness_weight: 0.3,
variance_weight: 0.3,
classification_weight: 1.0,
target_smoothness: 0.5,
target_variance: 0.05,
}
}
}
impl LivenessLoss {
fn autocorrelation(signal: &[f32]) -> f32 {
let n = signal.len();
if n < 3 {
return 0.0;
}
let mean: f32 = signal.iter().sum::<f32>() / n as f32;
let mut var = 0.0f32;
let mut cov = 0.0f32;
for i in 0..n {
let centered = signal[i] - mean;
var += centered * centered;
if i < n - 1 {
cov += centered * (signal[i + 1] - mean);
}
}
if var < 1e-10 {
return 1.0; }
(cov / var).clamp(-1.0, 1.0)
}
fn bce(predicted: f32, target: f32) -> f32 {
let p = predicted.clamp(1e-7, 1.0 - 1e-7);
-(target * p.ln() + (1.0 - target) * (1.0 - p).ln())
}
pub fn compute(&self, liveness_score: f32, is_live: bool, trajectory_deltas: &[f32]) -> f32 {
let target = if is_live { 1.0 } else { 0.0 };
let bce = Self::bce(liveness_score, target);
let autocorr = Self::autocorrelation(trajectory_deltas);
let smoothness_excess = (autocorr - self.target_smoothness).max(0.0);
let smoothness_loss = smoothness_excess * smoothness_excess;
let n = trajectory_deltas.len();
let temporal_var = if n > 1 {
let mean: f32 = trajectory_deltas.iter().sum::<f32>() / n as f32;
trajectory_deltas
.iter()
.map(|d| (d - mean) * (d - mean))
.sum::<f32>()
/ (n - 1) as f32
} else {
0.0
};
let variance_deficit = (self.target_variance - temporal_var).max(0.0);
let variance_loss = variance_deficit * variance_deficit;
self.classification_weight * bce
+ self.smoothness_weight * smoothness_loss
+ self.variance_weight * variance_loss
}
pub fn compute_var(
&self,
liveness_score: &Variable,
is_live: bool,
trajectory_deltas: &[f32],
) -> Variable {
let target_val = if is_live { 1.0 } else { 0.0 };
let score_clamped = liveness_score.add_scalar(1e-7).relu(); let _ln_score = score_clamped
.add_scalar(1e-7)
.sqrt()
.mul_var(&score_clamped.add_scalar(1e-7).sqrt()); let liveness_data = liveness_score.data().to_vec();
let p = if liveness_data.is_empty() {
0.5
} else {
liveness_data[0]
};
let bce_val = Self::bce(p, target_val);
let autocorr = Self::autocorrelation(trajectory_deltas);
let smoothness_excess = (autocorr - self.target_smoothness).max(0.0);
let smoothness_loss = smoothness_excess * smoothness_excess;
let n = trajectory_deltas.len();
let temporal_var = if n > 1 {
let mean: f32 = trajectory_deltas.iter().sum::<f32>() / n as f32;
trajectory_deltas
.iter()
.map(|d| (d - mean) * (d - mean))
.sum::<f32>()
/ (n - 1) as f32
} else {
0.0
};
let variance_deficit = (self.target_variance - temporal_var).max(0.0);
let variance_loss = variance_deficit * variance_deficit;
let scalar_component = self.classification_weight * bce_val
+ self.smoothness_weight * smoothness_loss
+ self.variance_weight * variance_loss;
liveness_score.mul_scalar(0.0).add_scalar(scalar_component)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_triplet_loss_same() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![1.0, 0.0, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss = triplet_loss_raw(&anchor, &positive, &negative, 0.3);
assert!(
loss < 0.01,
"Loss should be ~0 when positive is identical: {}",
loss
);
}
#[test]
fn test_triplet_loss_violation() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.0, 1.0, 0.0];
let negative = vec![0.9, 0.1, 0.0];
let loss = triplet_loss_raw(&anchor, &positive, &negative, 0.3);
assert!(loss > 0.0, "Loss should be positive when margin violated");
}
#[test]
fn test_triplet_loss_var_graph_tracked() {
let anchor = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let positive = Variable::new(
Tensor::from_vec(vec![0.9, 0.1, 0.0], &[1, 3]).unwrap(),
true,
);
let negative = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 0.0], &[1, 3]).unwrap(),
true,
);
let loss = triplet_loss_var(&anchor, &positive, &negative, 0.3);
let loss_val = loss.data().to_vec()[0];
assert!(loss_val < 0.5, "Triplet loss should be low: {}", loss_val);
}
#[test]
fn test_crystallization_loss() {
let loss_fn = CrystallizationLoss::default();
let anchor = vec![1.0, 0.0];
let positive = vec![0.9, 0.1];
let negative = vec![-1.0, 0.0];
let velocities = vec![0.05, 0.08];
let loss = loss_fn.compute(&anchor, &positive, &negative, &velocities);
assert!(loss >= 0.0);
}
#[test]
fn test_crystallization_loss_var() {
let loss_fn = CrystallizationLoss::default();
let anchor = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap(), true);
let positive = Variable::new(Tensor::from_vec(vec![0.9, 0.1], &[1, 2]).unwrap(), true);
let negative = Variable::new(Tensor::from_vec(vec![-1.0, 0.0], &[1, 2]).unwrap(), true);
let velocities = Variable::new(Tensor::from_vec(vec![0.05, 0.08], &[2]).unwrap(), false);
let loss = loss_fn.compute_var(&anchor, &positive, &negative, &velocities);
let val = loss.data().to_vec()[0];
assert!(val >= 0.0, "Loss should be non-negative: {}", val);
}
#[test]
fn test_crystallization_high_velocity_penalty() {
let loss_fn = CrystallizationLoss::default();
let anchor = vec![1.0, 0.0];
let positive = vec![1.0, 0.0];
let negative = vec![0.0, 1.0];
let low_vel = loss_fn.compute(&anchor, &positive, &negative, &[0.05]);
let high_vel = loss_fn.compute(&anchor, &positive, &negative, &[0.5]);
assert!(high_vel > low_vel, "High velocity should be penalized more");
}
#[test]
fn test_contrastive_loss_same() {
let loss_fn = ContrastiveLoss::default();
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let loss = loss_fn.compute(&a, &b, true);
assert!(loss < 0.01, "Same identity should have ~0 loss: {}", loss);
}
#[test]
fn test_contrastive_loss_different() {
let loss_fn = ContrastiveLoss::default();
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let loss = loss_fn.compute(&a, &b, false);
assert!(loss < 0.01, "Well-separated different identity: {}", loss);
}
#[test]
fn test_contrastive_loss_var() {
let loss_fn = ContrastiveLoss::default();
let a = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let b = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let loss = loss_fn.compute_var(&a, &b, true);
let val = loss.data().to_vec()[0];
assert!(val < 0.01, "Same identity var loss should be ~0: {}", val);
}
#[test]
fn test_echo_prediction_loss() {
let predicted = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 1, 3]).unwrap(),
false,
);
let actual = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 1, 3]).unwrap(),
false,
);
let loss = EchoLoss::prediction_loss(&predicted, &actual);
assert!(
loss < 0.01,
"Perfect prediction should have ~0 loss: {}",
loss
);
}
#[test]
fn test_echo_loss_var() {
let predicted = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
true,
);
let actual = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
false,
);
let anchor = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap(), true);
let pos = Variable::new(Tensor::from_vec(vec![0.9, 0.1], &[1, 2]).unwrap(), true);
let neg = Variable::new(Tensor::from_vec(vec![0.0, 1.0], &[1, 2]).unwrap(), true);
let loss_fn = EchoLoss::default();
let loss = loss_fn.compute_var(&predicted, &actual, &anchor, &pos, &neg);
let val = loss.data().to_vec()[0];
assert!(val >= 0.0, "Echo loss should be non-negative: {}", val);
}
#[test]
fn test_argus_loss_var() {
let anchor = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let pos = Variable::new(
Tensor::from_vec(vec![0.9, 0.1, 0.0], &[1, 3]).unwrap(),
true,
);
let neg = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 0.0], &[1, 3]).unwrap(),
true,
);
let orig = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.5], &[1, 3]).unwrap(),
true,
);
let rot = Variable::new(
Tensor::from_vec(vec![0.9, 0.1, 0.5], &[1, 3]).unwrap(),
true,
);
let loss_fn = ArgusLoss::default();
let loss = loss_fn.compute_var(&anchor, &pos, &neg, &orig, &rot);
let val = loss.data().to_vec()[0];
assert!(val >= 0.0, "Argus loss should be non-negative: {}", val);
}
#[test]
fn test_themis_bce() {
let loss = ThemisLoss::bce(0.99, 1.0);
assert!(loss < 0.1, "High-confidence correct: {}", loss);
let loss = ThemisLoss::bce(0.01, 1.0);
assert!(loss > 2.0, "Low-confidence for match: {}", loss);
}
#[test]
fn test_themis_combined() {
let loss_fn = ThemisLoss::default();
let fused_a = vec![1.0, 0.0, 0.0];
let fused_p = vec![0.9, 0.1, 0.0];
let fused_n = vec![-1.0, 0.0, 0.0];
let loss = loss_fn.compute(0.9, true, &fused_a, &fused_p, &fused_n, 0.9);
assert!(loss >= 0.0);
assert!(loss < 5.0, "Combined loss should be reasonable: {}", loss);
}
#[test]
fn test_themis_loss_var() {
let loss_fn = ThemisLoss::default();
let anchor = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let pos = Variable::new(
Tensor::from_vec(vec![0.9, 0.1, 0.0], &[1, 3]).unwrap(),
true,
);
let neg = Variable::new(
Tensor::from_vec(vec![-1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let loss = loss_fn.compute_var(0.9, true, &anchor, &pos, &neg, 0.9);
let val = loss.data().to_vec()[0];
assert!(
val >= 0.0,
"Themis var loss should be non-negative: {}",
val
);
}
#[test]
fn test_center_loss_zero_distance() {
let loss_fn = CenterLoss::default();
let embeddings = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let centers = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let log_variances = vec![0.0, 0.0];
let loss = loss_fn.compute(&embeddings, ¢ers, &log_variances, 3);
assert!(loss < 1e-6, "Zero distance should give zero loss: {}", loss);
}
#[test]
fn test_center_loss_nonzero_distance() {
let loss_fn = CenterLoss::default();
let embeddings = vec![1.0, 0.0, 0.0];
let centers = vec![0.0, 1.0, 0.0];
let log_variances = vec![0.0]; let loss = loss_fn.compute(&embeddings, ¢ers, &log_variances, 3);
assert!(
loss > 0.0,
"Non-zero distance should give positive loss: {}",
loss
);
}
#[test]
fn test_center_loss_uncertainty_attenuation() {
let loss_fn = CenterLoss::default();
let embeddings = vec![1.0, 0.0, 0.0];
let centers = vec![0.0, 1.0, 0.0];
let loss_confident = loss_fn.compute(&embeddings, ¢ers, &[-2.0], 3);
let loss_uncertain = loss_fn.compute(&embeddings, ¢ers, &[2.0], 3);
assert!(
loss_confident > loss_uncertain,
"Uncertain samples should have lower center loss: confident={}, uncertain={}",
loss_confident,
loss_uncertain
);
}
#[test]
fn test_center_loss_var() {
let loss_fn = CenterLoss::default();
let emb = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let center = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 0.0], &[1, 3]).unwrap(),
false,
);
let lv = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
let loss = loss_fn.compute_var(&emb, ¢er, &lv);
let val = loss.data().to_vec()[0];
assert!(
val >= 0.0,
"Center var loss should be non-negative: {}",
val
);
assert!(
val > 0.0,
"Non-zero distance should produce positive loss: {}",
val
);
}
#[test]
fn test_center_loss_empty() {
let loss_fn = CenterLoss::default();
let loss = loss_fn.compute(&[], &[], &[], 3);
assert_eq!(loss, 0.0, "Empty inputs should give zero loss");
}
#[test]
fn test_angular_margin_loss_correct_class() {
let loss_fn = AngularMarginLoss::default();
let cos_sims = vec![0.95, 0.1, -0.2, 0.05];
let loss = loss_fn.compute(&cos_sims, 0, 0.0); assert!(loss >= 0.0, "Loss should be non-negative: {}", loss);
assert!(
loss < 5.0,
"Loss should be reasonable for correct prediction: {}",
loss
);
}
#[test]
fn test_angular_margin_loss_wrong_class() {
let loss_fn = AngularMarginLoss::default();
let cos_sims = vec![0.1, 0.95, -0.2, 0.05];
let loss = loss_fn.compute(&cos_sims, 0, 0.0);
assert!(
loss > 5.0,
"Loss should be high for wrong prediction: {}",
loss
);
}
#[test]
fn test_angular_margin_uncertainty_scaling() {
let loss_fn = AngularMarginLoss::default();
let cos_sims = vec![0.7, 0.3, 0.1];
let loss_confident = loss_fn.compute(&cos_sims, 0, -2.0);
let loss_uncertain = loss_fn.compute(&cos_sims, 0, 2.0);
assert!(
loss_confident > loss_uncertain,
"Confident samples should face larger margin: confident={}, uncertain={}",
loss_confident,
loss_uncertain
);
}
#[test]
fn test_angular_margin_loss_var() {
let loss_fn = AngularMarginLoss::default();
let cos_var = Variable::new(Tensor::from_vec(vec![0.9, 0.1, -0.1], &[3]).unwrap(), true);
let lv_var = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
let loss = loss_fn.compute_var(&cos_var, 0, &lv_var);
let val = loss.data().to_vec()[0];
assert!(
val >= 0.0,
"Angular margin var loss should be non-negative: {}",
val
);
}
#[test]
fn test_angular_margin_empty() {
let loss_fn = AngularMarginLoss::default();
let loss = loss_fn.compute(&[], 0, 0.0);
assert_eq!(loss, 0.0, "Empty input should give zero loss");
}
#[test]
fn test_diversity_collapsed_embeddings() {
let loss_fn = DiversityRegularization::default();
let embeddings = vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let loss = loss_fn.compute(&embeddings, 3, 3);
assert!(
loss > 0.0,
"Collapsed embeddings should be penalized: {}",
loss
);
}
#[test]
fn test_diversity_orthogonal_embeddings() {
let loss_fn = DiversityRegularization::default();
let embeddings = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let loss = loss_fn.compute(&embeddings, 3, 3);
assert!(
loss < 1e-6,
"Orthogonal embeddings should have ~0 penalty: {}",
loss
);
}
#[test]
fn test_diversity_single_embedding() {
let loss_fn = DiversityRegularization::default();
let loss = loss_fn.compute(&[1.0, 0.0], 1, 2);
assert_eq!(loss, 0.0, "Single embedding cannot collapse");
}
#[test]
fn test_diversity_var_collapsed() {
let loss_fn = DiversityRegularization::default();
let emb1 = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let emb2 = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let emb3 = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let loss = loss_fn.compute_var(&[emb1, emb2, emb3]);
let val = loss.data().to_vec()[0];
assert!(
val > 0.0,
"Collapsed embeddings should have positive penalty: {}",
val
);
}
#[test]
fn test_diversity_var_diverse() {
let loss_fn = DiversityRegularization::default();
let emb1 = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 0.0], &[1, 3]).unwrap(),
true,
);
let emb2 = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 0.0], &[1, 3]).unwrap(),
true,
);
let emb3 = Variable::new(
Tensor::from_vec(vec![0.0, 0.0, 1.0], &[1, 3]).unwrap(),
true,
);
let loss = loss_fn.compute_var(&[emb1, emb2, emb3]);
let val = loss.data().to_vec()[0];
assert!(
val < 1e-6,
"Diverse embeddings should have ~0 penalty: {}",
val
);
}
#[test]
fn test_liveness_loss_live_sample() {
let loss_fn = LivenessLoss::default();
let deltas = vec![0.15, 0.03, 0.22, 0.08, 0.18, 0.05, 0.25, 0.02];
let loss = loss_fn.compute(0.9, true, &deltas);
assert!(loss >= 0.0, "Loss should be non-negative: {}", loss);
}
#[test]
fn test_liveness_loss_spoof_sample() {
let loss_fn = LivenessLoss::default();
let deltas = vec![0.01, 0.011, 0.012, 0.011, 0.01, 0.011, 0.012, 0.011];
let loss_spoof = loss_fn.compute(0.1, false, &deltas);
assert!(
loss_spoof >= 0.0,
"Loss should be non-negative: {}",
loss_spoof
);
}
#[test]
fn test_liveness_smooth_trajectory_penalty() {
let loss_fn = LivenessLoss::default();
let smooth = vec![0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17];
let irregular = vec![0.2, 0.05, 0.25, 0.02, 0.18, 0.08, 0.22, 0.03];
let loss_smooth = loss_fn.compute(0.5, true, &smooth);
let loss_irregular = loss_fn.compute(0.5, true, &irregular);
assert!(
loss_smooth > loss_irregular,
"Smooth trajectory should be penalized more: smooth={}, irregular={}",
loss_smooth,
loss_irregular
);
}
#[test]
fn test_liveness_low_variance_penalty() {
let loss_fn = LivenessLoss::default();
let low_var = vec![0.1, 0.1, 0.1, 0.1, 0.1, 0.1];
let high_var = vec![0.01, 0.5, 0.02, 0.4, 0.03, 0.6];
let loss_low = loss_fn.compute(0.5, true, &low_var);
let loss_high = loss_fn.compute(0.5, true, &high_var);
assert!(
loss_low > loss_high,
"Low variance should be penalized more: low={}, high={}",
loss_low,
loss_high
);
}
#[test]
fn test_liveness_autocorrelation() {
let constant = vec![0.5, 0.5, 0.5, 0.5, 0.5];
let ac = LivenessLoss::autocorrelation(&constant);
assert!(
(ac - 1.0).abs() < 0.01,
"Constant signal autocorrelation should be ~1: {}",
ac
);
let short = vec![1.0, 2.0];
let ac_short = LivenessLoss::autocorrelation(&short);
assert_eq!(ac_short, 0.0, "Too-short signal should return 0");
}
#[test]
fn test_liveness_loss_var() {
let loss_fn = LivenessLoss::default();
let score = Variable::new(Tensor::from_vec(vec![0.8], &[1]).unwrap(), true);
let deltas = vec![0.15, 0.03, 0.22, 0.08, 0.18];
let loss = loss_fn.compute_var(&score, true, &deltas);
let val = loss.data().to_vec()[0];
assert!(
val >= 0.0,
"Liveness var loss should be non-negative: {}",
val
);
}
#[test]
fn test_liveness_bce_correct() {
let loss = LivenessLoss::bce(0.99, 1.0);
assert!(
loss < 0.1,
"Confident correct should have low BCE: {}",
loss
);
}
#[test]
fn test_liveness_bce_wrong() {
let loss = LivenessLoss::bce(0.01, 1.0);
assert!(loss > 2.0, "Confident wrong should have high BCE: {}", loss);
}
#[test]
fn test_all_losses_implement_default() {
let _ = CrystallizationLoss::default();
let _ = ContrastiveLoss::default();
let _ = EchoLoss::default();
let _ = ArgusLoss::default();
let _ = ThemisLoss::default();
let _ = CenterLoss::default();
let _ = AngularMarginLoss::default();
let _ = DiversityRegularization::default();
let _ = LivenessLoss::default();
}
}