use axonml_autograd::Variable;
use axonml_tensor::Tensor;
pub struct FocalLoss {
pub alpha: f32,
pub gamma: f32,
}
impl FocalLoss {
pub fn new() -> Self {
Self {
alpha: 0.25,
gamma: 2.0,
}
}
pub fn with_params(alpha: f32, gamma: f32) -> Self {
Self { alpha, gamma }
}
pub fn compute(&self, pred_logits: &Variable, targets: &Variable) -> Variable {
let p = pred_logits.sigmoid();
let one = Variable::new(
Tensor::from_vec(vec![1.0; pred_logits.numel()], &pred_logits.shape()).unwrap(),
false,
);
let p_t = p
.mul_var(targets)
.add_var(&one.sub_var(&p).mul_var(&one.sub_var(targets)));
let alpha_t_data: Vec<f32> = targets
.data()
.to_vec()
.iter()
.map(|&t| self.alpha * t + (1.0 - self.alpha) * (1.0 - t))
.collect();
let alpha_t = Variable::new(
Tensor::from_vec(alpha_t_data, &targets.shape()).unwrap(),
false,
);
let focal_weight = one.sub_var(&p_t).pow(self.gamma);
let eps = Variable::new(
Tensor::from_vec(vec![1e-7; pred_logits.numel()], &pred_logits.shape()).unwrap(),
false,
);
let log_pt = p_t.add_var(&eps).log();
let loss = alpha_t.mul_var(&focal_weight).mul_var(&log_pt).neg_var();
loss.mean()
}
}
impl Default for FocalLoss {
fn default() -> Self {
Self::new()
}
}
pub struct GIoULoss;
impl GIoULoss {
pub fn compute(pred: &Variable, target: &Variable) -> Variable {
let n = pred.shape()[0];
if n == 0 {
return Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
}
let pred_data = pred.data().to_vec();
let target_data = target.data().to_vec();
let mut giou_sum = 0.0f32;
for i in 0..n {
let px1 = pred_data[i * 4];
let py1 = pred_data[i * 4 + 1];
let px2 = pred_data[i * 4 + 2];
let py2 = pred_data[i * 4 + 3];
let tx1 = target_data[i * 4];
let ty1 = target_data[i * 4 + 1];
let tx2 = target_data[i * 4 + 2];
let ty2 = target_data[i * 4 + 3];
let ix1 = px1.max(tx1);
let iy1 = py1.max(ty1);
let ix2 = px2.min(tx2);
let iy2 = py2.min(ty2);
let inter = (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0);
let pred_area = (px2 - px1).max(0.0) * (py2 - py1).max(0.0);
let target_area = (tx2 - tx1).max(0.0) * (ty2 - ty1).max(0.0);
let union = pred_area + target_area - inter;
let cx1 = px1.min(tx1);
let cy1 = py1.min(ty1);
let cx2 = px2.max(tx2);
let cy2 = py2.max(ty2);
let c_area = (cx2 - cx1).max(0.0) * (cy2 - cy1).max(0.0);
let iou = if union > 0.0 { inter / union } else { 0.0 };
let giou = if c_area > 0.0 {
iou - (c_area - union) / c_area
} else {
iou
};
giou_sum += giou;
}
let diff = pred.sub_var(target);
let l1_proxy = diff.pow(2.0).mean();
let giou_loss = 1.0 - giou_sum / n as f32;
let proxy_val = l1_proxy.data().to_vec()[0];
let scale = if proxy_val > 1e-8 {
giou_loss / proxy_val
} else {
1.0
};
l1_proxy.mul_scalar(scale)
}
}
pub struct UncertaintyLoss;
impl UncertaintyLoss {
pub fn compute(pred_mean: &Variable, pred_log_var: &Variable, target: &Variable) -> Variable {
let diff_sq = pred_mean.sub_var(target).pow(2.0);
let neg_log_var = pred_log_var.neg_var();
let precision = neg_log_var.exp();
let term1 = precision.mul_var(&diff_sq).mul_scalar(0.5);
let term2 = pred_log_var.mul_scalar(0.5);
term1.add_var(&term2).mean()
}
}
pub fn compute_centerness(l: f32, t: f32, r: f32, b: f32) -> f32 {
let lr = if l.max(r) > 0.0 {
l.min(r) / l.max(r)
} else {
0.0
};
let tb = if t.max(b) > 0.0 {
t.min(b) / t.max(b)
} else {
0.0
};
(lr * tb).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_focal_loss_basic() {
let pred = Variable::new(
Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], &[4]).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
false,
);
let loss_fn = FocalLoss::new();
let loss = loss_fn.compute(&pred, &target);
let val = loss.data().to_vec()[0];
assert!(val > 0.0, "Focal loss should be positive, got {val}");
assert!(val.is_finite());
}
#[test]
fn test_focal_loss_gradient() {
let pred = Variable::new(Tensor::from_vec(vec![0.5, -0.5], &[2]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
let loss = FocalLoss::new().compute(&pred, &target);
loss.backward();
let grad = pred.grad().expect("Should have gradient");
assert_eq!(grad.to_vec().len(), 2);
}
#[test]
fn test_giou_loss_identical() {
let boxes = Variable::new(
Tensor::from_vec(vec![10.0, 10.0, 50.0, 50.0], &[1, 4]).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![10.0, 10.0, 50.0, 50.0], &[1, 4]).unwrap(),
false,
);
let loss = GIoULoss::compute(&boxes, &target);
let val = loss.data().to_vec()[0];
assert!(
val < 0.01,
"Identical boxes should have near-zero loss, got {val}"
);
}
#[test]
fn test_giou_loss_disjoint() {
let pred = Variable::new(
Tensor::from_vec(vec![0.0, 0.0, 10.0, 10.0], &[1, 4]).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![50.0, 50.0, 60.0, 60.0], &[1, 4]).unwrap(),
false,
);
let loss = GIoULoss::compute(&pred, &target);
let val = loss.data().to_vec()[0];
assert!(
val > 0.5,
"Disjoint boxes should have large loss, got {val}"
);
}
#[test]
fn test_uncertainty_loss() {
let pred = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
true,
);
let log_var = Variable::new(
Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], &[1, 4]).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![1.5, 2.5, 3.5, 4.5], &[1, 4]).unwrap(),
false,
);
let loss = UncertaintyLoss::compute(&pred, &log_var, &target);
let val = loss.data().to_vec()[0];
assert!(val > 0.0);
assert!(val.is_finite());
loss.backward();
assert!(pred.grad().is_some());
assert!(log_var.grad().is_some());
}
#[test]
fn test_centerness() {
assert!((compute_centerness(5.0, 5.0, 5.0, 5.0) - 1.0).abs() < 1e-5);
assert!(compute_centerness(0.0, 5.0, 10.0, 5.0) < 0.01);
let c = compute_centerness(2.0, 3.0, 8.0, 7.0);
assert!(c > 0.0 && c < 1.0);
}
}