use crate::tensor::Tensor;
pub fn mse_loss(y_pred: &Tensor, y_true: &Tensor) -> Tensor {
let error = y_pred - y_true;
let squared_error = &error * &error;
squared_error.sum()
}
pub fn mse_loss_mean(y_pred: &Tensor, y_true: &Tensor) -> Tensor {
let error = y_pred - y_true;
let squared_error = &error * &error;
squared_error.mean()
}
pub fn l1_loss(y_pred: &Tensor, y_true: &Tensor) -> Tensor {
let error = y_pred - y_true;
error.abs().sum()
}
pub fn l1_loss_mean(y_pred: &Tensor, y_true: &Tensor) -> Tensor {
let error = y_pred - y_true;
error.abs().mean()
}
pub fn smooth_l1_loss(y_pred: &Tensor, y_true: &Tensor, beta: f32) -> Tensor {
let error = y_pred - y_true;
let abs_error = error.abs();
let beta_tensor = Tensor::scalar(&y_pred.context, beta);
let half = Tensor::scalar(&y_pred.context, 0.5);
let clamped = abs_error.clamp(0.0, beta);
let quadratic = &(&clamped * &clamped) * &(&half / &beta_tensor);
let excess = &abs_error - &beta_tensor;
let linear_excess = excess.clamp(0.0, f32::MAX);
let total = &quadratic + &linear_excess;
total.sum()
}
pub fn huber_loss(y_pred: &Tensor, y_true: &Tensor, delta: f32) -> Tensor {
smooth_l1_loss(y_pred, y_true, delta)
}
pub fn cross_entropy_loss(y_pred: &Tensor, y_true: &Tensor, eps: f32) -> Tensor {
let eps_tensor = Tensor::scalar(&y_pred.context, eps);
let stabilized = y_pred + &eps_tensor;
let log_pred = stabilized.log();
let ce = y_true * &log_pred;
ce.sum().neg()
}
pub fn cross_entropy_with_label_smoothing(
y_pred: &Tensor,
y_true: &Tensor,
smoothing: f32,
num_classes: usize,
eps: f32,
) -> Tensor {
let one_minus_smooth = Tensor::scalar(&y_pred.context, 1.0 - smoothing);
let smooth_value = Tensor::scalar(&y_pred.context, smoothing / num_classes as f32);
let smoothed = &(y_true * &one_minus_smooth) + &smooth_value;
let eps_tensor = Tensor::scalar(&y_pred.context, eps);
let stabilized = y_pred + &eps_tensor;
let log_pred = stabilized.log();
let ce = &smoothed * &log_pred;
ce.sum().neg()
}
pub fn binary_cross_entropy(y_pred: &Tensor, y_true: &Tensor, eps: f32) -> Tensor {
let eps_tensor = Tensor::scalar(&y_pred.context, eps);
let one = Tensor::scalar(&y_pred.context, 1.0);
let log_pred = (y_pred + &eps_tensor).log();
let one_minus_pred = &one - y_pred;
let log_one_minus_pred = (&one_minus_pred + &eps_tensor).log();
let one_minus_true = &one - y_true;
let term1 = y_true * &log_pred;
let term2 = &one_minus_true * &log_one_minus_pred;
let bce = &term1 + &term2;
bce.sum().neg()
}
pub fn bce_with_logits(logits: &Tensor, y_true: &Tensor) -> Tensor {
let one = Tensor::scalar(&logits.context, 1.0);
let max_x_0 = logits.clamp(0.0, f32::MAX);
let x_times_y = logits * y_true;
let abs_x = logits.abs();
let neg_abs_x = abs_x.neg();
let exp_neg_abs = neg_abs_x.exp();
let one_plus_exp = &one + &exp_neg_abs;
let log_term = one_plus_exp.log();
let loss = &(&max_x_0 - &x_times_y) + &log_term;
loss.sum()
}
pub fn kl_divergence(p: &Tensor, q: &Tensor, eps: f32) -> Tensor {
let eps_tensor = Tensor::scalar(&p.context, eps);
let log_p = (p + &eps_tensor).log();
let log_q = (q + &eps_tensor).log();
let log_ratio = &log_p - &log_q;
let kl = p * &log_ratio;
kl.sum()
}
pub fn nll_loss(log_probs: &Tensor, y_true: &Tensor) -> Tensor {
let nll = y_true * log_probs;
nll.sum().neg()
}
pub fn hinge_loss(y_pred: &Tensor, y_true: &Tensor, margin: f32) -> Tensor {
let margin_tensor = Tensor::scalar(&y_pred.context, margin);
let prod = y_true * y_pred;
let diff = &margin_tensor - ∏
let hinge = diff.clamp(0.0, f32::MAX);
hinge.sum()
}
pub fn squared_hinge_loss(y_pred: &Tensor, y_true: &Tensor, margin: f32) -> Tensor {
let margin_tensor = Tensor::scalar(&y_pred.context, margin);
let prod = y_true * y_pred;
let diff = &margin_tensor - ∏
let hinge = diff.clamp(0.0, f32::MAX);
let squared = &hinge * &hinge;
squared.sum()
}
pub fn focal_loss(y_pred: &Tensor, y_true: &Tensor, alpha: f32, gamma: f32, eps: f32) -> Tensor {
let alpha_tensor = Tensor::scalar(&y_pred.context, alpha);
let one = Tensor::scalar(&y_pred.context, 1.0);
let eps_tensor = Tensor::scalar(&y_pred.context, eps);
let one_minus_true = &one - y_true;
let one_minus_pred = &one - y_pred;
let p_t = &(y_true * y_pred) + &(&one_minus_true * &one_minus_pred);
let one_minus_pt = &one - &p_t;
let one_minus_pt_stable = &one_minus_pt + &eps_tensor;
let log_one_minus_pt = one_minus_pt_stable.log();
let gamma_tensor = Tensor::scalar(&y_pred.context, gamma);
let gamma_log = &gamma_tensor * &log_one_minus_pt;
let focal_weight = gamma_log.exp();
let p_t_stable = &p_t + &eps_tensor;
let log_pt = p_t_stable.log();
let focal = &(&alpha_tensor * &focal_weight) * &log_pt;
focal.sum().neg()
}
pub fn cosine_embedding_loss(
x1: &Tensor,
x2: &Tensor,
y: &Tensor, margin: f32,
eps: f32,
) -> Tensor {
let one = Tensor::scalar(&x1.context, 1.0);
let margin_tensor = Tensor::scalar(&x1.context, margin);
let eps_tensor = Tensor::scalar(&x1.context, eps);
let dot = &(x1 * x2);
let dot_sum = dot.sum();
let x1_sq = &(x1 * x1);
let x2_sq = &(x2 * x2);
let norm1_sq = x1_sq.sum();
let norm2_sq = x2_sq.sum();
let norm1 = norm1_sq.sqrt();
let norm2 = norm2_sq.sqrt();
let norm_prod = &(&norm1 * &norm2) + &eps_tensor;
let cos_sim = &dot_sum / &norm_prod;
let half = Tensor::scalar(&x1.context, 0.5);
let one_plus_y = &one + y;
let one_minus_y = &one - y;
let pos_loss = &one - &cos_sim;
let neg_diff = &cos_sim - &margin_tensor;
let neg_loss = neg_diff.clamp(0.0, f32::MAX);
let pos_weight = &one_plus_y * ½
let neg_weight = &one_minus_y * ½
let weighted_pos = &pos_weight * &pos_loss;
let weighted_neg = &neg_weight * &neg_loss;
&weighted_pos + &weighted_neg
}
pub fn triplet_margin_loss(
anchor: &Tensor,
positive: &Tensor,
negative: &Tensor,
margin: f32,
) -> Tensor {
let margin_tensor = Tensor::scalar(&anchor.context, margin);
let diff_pos = anchor - positive;
let dist_pos = (&diff_pos * &diff_pos).sum();
let diff_neg = anchor - negative;
let dist_neg = (&diff_neg * &diff_neg).sum();
let diff = &(&dist_pos - &dist_neg) + &margin_tensor;
diff.clamp(0.0, f32::MAX)
}
pub fn margin_ranking_loss(x1: &Tensor, x2: &Tensor, y: &Tensor, margin: f32) -> Tensor {
let margin_tensor = Tensor::scalar(&x1.context, margin);
let diff = x1 - x2;
let neg_y = y.neg();
let scaled = &neg_y * &diff;
let with_margin = &scaled + &margin_tensor;
with_margin.clamp(0.0, f32::MAX).sum()
}