#![allow(missing_docs)]
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
use super::HeliosTrainOutput;
pub struct CIoULoss;
impl CIoULoss {
pub fn compute(pred: &Variable, target: &Variable) -> Variable {
let n = pred.shape()[0];
if n == 0 {
return Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
}
let pred_data = pred.data().to_vec();
let target_data = target.data().to_vec();
let mut loss_sum = 0.0f32;
for i in 0..n {
let px1 = pred_data[i * 4];
let py1 = pred_data[i * 4 + 1];
let px2 = pred_data[i * 4 + 2];
let py2 = pred_data[i * 4 + 3];
let tx1 = target_data[i * 4];
let ty1 = target_data[i * 4 + 1];
let tx2 = target_data[i * 4 + 2];
let ty2 = target_data[i * 4 + 3];
let pw = (px2 - px1).max(1e-6);
let ph = (py2 - py1).max(1e-6);
let tw = (tx2 - tx1).max(1e-6);
let th = (ty2 - ty1).max(1e-6);
let ix1 = px1.max(tx1);
let iy1 = py1.max(ty1);
let ix2 = px2.min(tx2);
let iy2 = py2.min(ty2);
let inter = (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0);
let pred_area = pw * ph;
let target_area = tw * th;
let union = pred_area + target_area - inter + 1e-7;
let iou = inter / union;
let pcx = (px1 + px2) * 0.5;
let pcy = (py1 + py2) * 0.5;
let tcx = (tx1 + tx2) * 0.5;
let tcy = (ty1 + ty2) * 0.5;
let d2 = (pcx - tcx).powi(2) + (pcy - tcy).powi(2);
let cx1 = px1.min(tx1);
let cy1 = py1.min(ty1);
let cx2 = px2.max(tx2);
let cy2 = py2.max(ty2);
let c2 = (cx2 - cx1).powi(2) + (cy2 - cy1).powi(2) + 1e-7;
let v = {
let atan_pred = (pw / ph).atan();
let atan_target = (tw / th).atan();
let diff = atan_pred - atan_target;
(4.0 / (std::f32::consts::PI * std::f32::consts::PI)) * diff * diff
};
let alpha = v / (1.0 - iou + v + 1e-7);
let ciou = iou - d2 / c2 - alpha * v;
loss_sum += 1.0 - ciou;
}
let diff = pred.sub_var(target);
let l2_proxy = diff.pow(2.0).mean();
let proxy_val = l2_proxy.data().to_vec()[0];
let ciou_loss = loss_sum / n as f32;
let scale = if proxy_val > 1e-8 {
ciou_loss / proxy_val
} else {
1.0
};
l2_proxy.mul_scalar(scale)
}
pub fn ciou_values(pred: &[f32], target: &[f32], n: usize) -> Vec<f32> {
let mut values = Vec::with_capacity(n);
for i in 0..n {
let px1 = pred[i * 4];
let py1 = pred[i * 4 + 1];
let px2 = pred[i * 4 + 2];
let py2 = pred[i * 4 + 3];
let tx1 = target[i * 4];
let ty1 = target[i * 4 + 1];
let tx2 = target[i * 4 + 2];
let ty2 = target[i * 4 + 3];
let pw = (px2 - px1).max(1e-6);
let ph = (py2 - py1).max(1e-6);
let tw = (tx2 - tx1).max(1e-6);
let th = (ty2 - ty1).max(1e-6);
let ix1 = px1.max(tx1);
let iy1 = py1.max(ty1);
let ix2 = px2.min(tx2);
let iy2 = py2.min(ty2);
let inter = (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0);
let union = pw * ph + tw * th - inter + 1e-7;
let iou = inter / union;
let pcx = (px1 + px2) * 0.5;
let pcy = (py1 + py2) * 0.5;
let tcx = (tx1 + tx2) * 0.5;
let tcy = (ty1 + ty2) * 0.5;
let d2 = (pcx - tcx).powi(2) + (pcy - tcy).powi(2);
let cx1 = px1.min(tx1);
let cy1 = py1.min(ty1);
let cx2 = px2.max(tx2);
let cy2 = py2.max(ty2);
let c2 = (cx2 - cx1).powi(2) + (cy2 - cy1).powi(2) + 1e-7;
let v = {
let diff = (pw / ph).atan() - (tw / th).atan();
(4.0 / (std::f32::consts::PI * std::f32::consts::PI)) * diff * diff
};
let alpha = v / (1.0 - iou + v + 1e-7);
values.push(iou - d2 / c2 - alpha * v);
}
values
}
}
pub struct DFLLoss {
reg_max: usize,
}
impl DFLLoss {
pub fn new(reg_max: usize) -> Self {
Self { reg_max }
}
pub fn compute(&self, pred_dfl: &Variable, target_ltrb: &[f32], mask: &[bool]) -> Variable {
let shape = pred_dfl.shape();
let n = shape[0];
let h = shape[2];
let w = shape[3];
let nhw = n * h * w;
let pred_data = pred_dfl.data().to_vec();
let reg_max = self.reg_max;
let mut loss_sum = 0.0f32;
let mut count = 0usize;
for pos in 0..nhw {
if !mask[pos] {
continue;
}
let b = pos / (h * w);
let spatial = pos % (h * w);
for coord in 0..4 {
let target_val = target_ltrb[b * 4 * h * w + coord * h * w + spatial]
.clamp(0.0, (reg_max - 1) as f32);
let target_left = target_val.floor() as usize;
let target_right = (target_left + 1).min(reg_max - 1);
let weight_right = target_val - target_left as f32;
let weight_left = 1.0 - weight_right;
let base = b * (4 * reg_max) * h * w + coord * reg_max * h * w;
let mut logits = vec![0.0f32; reg_max];
for bin in 0..reg_max {
logits[bin] = pred_data[base + bin * h * w + spatial];
}
let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|&v| (v - max_val).exp()).sum();
let log_sum = max_val + exp_sum.ln();
let log_prob_left = logits[target_left] - log_sum;
let log_prob_right = logits[target_right] - log_sum;
loss_sum -= weight_left * log_prob_left + weight_right * log_prob_right;
count += 1;
}
}
if count == 0 {
return Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
}
let proxy = pred_dfl.pow(2.0).mean();
let proxy_val = proxy.data().to_vec()[0];
let dfl_loss = loss_sum / count as f32;
let scale = if proxy_val > 1e-8 {
dfl_loss / proxy_val
} else {
1.0
};
proxy.mul_scalar(scale)
}
}
#[derive(Debug, Clone)]
pub struct Assignment {
pub gt_indices: Vec<i32>,
pub target_classes: Vec<usize>,
pub target_ltrb: Vec<[f32; 4]>,
pub positive_mask: Vec<bool>,
}
pub struct TaskAlignedAssigner {
top_k: usize,
alpha: f32,
beta: f32,
}
impl TaskAlignedAssigner {
pub fn new(top_k: usize, alpha: f32, beta: f32) -> Self {
Self { top_k, alpha, beta }
}
pub fn default_v8() -> Self {
Self::new(13, 1.0, 6.0)
}
pub fn assign(
&self,
cls_scores: &[f32],
pred_boxes: &[f32],
gt_boxes: &[f32],
gt_classes: &[usize],
anchor_points: &[f32],
strides: &[f32],
num_anchors: usize,
num_classes: usize,
) -> Assignment {
let num_gt = gt_classes.len();
if num_gt == 0 {
return Assignment {
gt_indices: vec![-1; num_anchors],
target_classes: vec![0; num_anchors],
target_ltrb: vec![[0.0; 4]; num_anchors],
positive_mask: vec![false; num_anchors],
};
}
let mut anchor_in_gt = vec![vec![false; num_anchors]; num_gt];
for g in 0..num_gt {
let gx1 = gt_boxes[g * 4];
let gy1 = gt_boxes[g * 4 + 1];
let gx2 = gt_boxes[g * 4 + 2];
let gy2 = gt_boxes[g * 4 + 3];
for a in 0..num_anchors {
let cx = anchor_points[a * 2];
let cy = anchor_points[a * 2 + 1];
anchor_in_gt[g][a] = cx >= gx1 && cx <= gx2 && cy >= gy1 && cy <= gy2;
}
}
let mut alignment = vec![vec![0.0f32; num_anchors]; num_gt];
for g in 0..num_gt {
let gt_cls = gt_classes[g];
for a in 0..num_anchors {
if !anchor_in_gt[g][a] {
continue;
}
let s = cls_scores[a * num_classes + gt_cls].max(1e-7);
let u = iou_single(&pred_boxes[a * 4..a * 4 + 4], >_boxes[g * 4..g * 4 + 4]);
alignment[g][a] = s.powf(self.alpha) * u.powf(self.beta);
}
}
let mut candidate_mask = vec![vec![false; num_anchors]; num_gt];
for g in 0..num_gt {
let mut scored: Vec<(usize, f32)> = (0..num_anchors)
.filter(|&a| anchor_in_gt[g][a])
.map(|a| (a, alignment[g][a]))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
for (a, _) in scored.iter().take(self.top_k) {
candidate_mask[g][*a] = true;
}
}
let mut gt_indices = vec![-1i32; num_anchors];
let mut target_classes = vec![0usize; num_anchors];
let mut target_ltrb = vec![[0.0f32; 4]; num_anchors];
let mut positive_mask = vec![false; num_anchors];
for a in 0..num_anchors {
let mut best_gt = -1i32;
let mut best_align = 0.0f32;
for g in 0..num_gt {
if candidate_mask[g][a] && alignment[g][a] > best_align {
best_align = alignment[g][a];
best_gt = g as i32;
}
}
if best_gt >= 0 {
let g = best_gt as usize;
gt_indices[a] = best_gt;
target_classes[a] = gt_classes[g];
positive_mask[a] = true;
let cx = anchor_points[a * 2];
let cy = anchor_points[a * 2 + 1];
let stride = strides[a];
target_ltrb[a] = [
(cx - gt_boxes[g * 4]) / stride,
(cy - gt_boxes[g * 4 + 1]) / stride,
(gt_boxes[g * 4 + 2] - cx) / stride,
(gt_boxes[g * 4 + 3] - cy) / stride,
];
}
}
Assignment {
gt_indices,
target_classes,
target_ltrb,
positive_mask,
}
}
}
fn iou_single(a: &[f32], b: &[f32]) -> f32 {
let ix1 = a[0].max(b[0]);
let iy1 = a[1].max(b[1]);
let ix2 = a[2].min(b[2]);
let iy2 = a[3].min(b[3]);
let inter = (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0);
let area_a = (a[2] - a[0]).max(0.0) * (a[3] - a[1]).max(0.0);
let area_b = (b[2] - b[0]).max(0.0) * (b[3] - b[1]).max(0.0);
let union = area_a + area_b - inter + 1e-7;
inter / union
}
pub struct HeliosLoss {
pub cls_weight: f32,
pub box_weight: f32,
pub dfl_weight: f32,
pub reg_max: usize,
_dfl_loss: DFLLoss,
assigner: TaskAlignedAssigner,
}
impl HeliosLoss {
pub fn new(num_classes: usize, reg_max: usize) -> Self {
let _ = num_classes; Self {
cls_weight: 1.0,
box_weight: 7.5,
dfl_weight: 1.5,
reg_max,
_dfl_loss: DFLLoss::new(reg_max),
assigner: TaskAlignedAssigner::default_v8(),
}
}
pub fn with_weights(
num_classes: usize,
reg_max: usize,
cls_weight: f32,
box_weight: f32,
dfl_weight: f32,
) -> Self {
let _ = num_classes;
Self {
cls_weight,
box_weight,
dfl_weight,
reg_max,
_dfl_loss: DFLLoss::new(reg_max),
assigner: TaskAlignedAssigner::default_v8(),
}
}
pub fn compute(
&self,
train_out: &HeliosTrainOutput,
gt_boxes: &[Vec<[f32; 4]>],
gt_classes: &[Vec<usize>],
num_classes: usize,
) -> (Variable, f32, f32, f32) {
let batch_size = gt_boxes.len();
let strides_cfg: Vec<usize> = train_out.scales.iter().map(|s| s.stride).collect();
let mut all_cls_logits = Vec::new();
let mut all_bbox_dfl = Vec::new();
let mut all_anchor_points = Vec::new();
let mut all_strides = Vec::new();
let mut scale_hw: Vec<(usize, usize)> = Vec::new();
for (si, scale) in train_out.scales.iter().enumerate() {
let cls_shape = scale.cls_logits.shape();
let h = cls_shape[2];
let w = cls_shape[3];
let stride = strides_cfg[si] as f32;
scale_hw.push((h, w));
for yi in 0..h {
for xi in 0..w {
all_anchor_points.push((xi as f32 + 0.5) * stride);
all_anchor_points.push((yi as f32 + 0.5) * stride);
all_strides.push(stride);
}
}
all_cls_logits.push(&scale.cls_logits);
all_bbox_dfl.push(&scale.bbox_dfl);
}
let total_anchors: usize = scale_hw.iter().map(|(h, w)| h * w).sum();
let mut flat_cls_scores = vec![0.0f32; batch_size * total_anchors * num_classes];
let mut flat_pred_boxes = vec![0.0f32; batch_size * total_anchors * 4];
let mut anchor_offset = 0;
for (si, scale) in train_out.scales.iter().enumerate() {
let cls_data = scale.cls_logits.sigmoid().data().to_vec();
let (h, w) = scale_hw[si];
let stride = strides_cfg[si] as f32;
let bbox_decoded = {
let dfl_shape = scale.bbox_dfl.shape();
let n = dfl_shape[0];
let dfl_data = scale.bbox_dfl.data().to_vec();
let reg_max = self.reg_max;
decode_dfl_boxes(
&dfl_data,
n,
reg_max,
h,
w,
stride,
&all_anchor_points,
anchor_offset,
)
};
for b in 0..batch_size {
for yi in 0..h {
for xi in 0..w {
let local_idx = yi * w + xi;
let global_idx = anchor_offset + local_idx;
for c in 0..num_classes {
flat_cls_scores
[b * total_anchors * num_classes + global_idx * num_classes + c] =
cls_data[b * num_classes * h * w + c * h * w + yi * w + xi];
}
let this_scale_anchors = h * w;
for coord in 0..4 {
flat_pred_boxes[b * total_anchors * 4 + global_idx * 4 + coord] =
bbox_decoded[b * this_scale_anchors * 4 + local_idx * 4 + coord];
}
}
}
}
anchor_offset += h * w;
}
let mut total_positives = 0usize;
let mut all_cls_targets = vec![0.0f32; batch_size * total_anchors * num_classes];
let mut pos_anchor_indices: Vec<usize> = Vec::new();
let mut pos_target_boxes: Vec<f32> = Vec::new();
for b in 0..batch_size {
let cls_slice = &flat_cls_scores
[b * total_anchors * num_classes..(b + 1) * total_anchors * num_classes];
let box_slice = &flat_pred_boxes[b * total_anchors * 4..(b + 1) * total_anchors * 4];
let gt_b: Vec<f32> = gt_boxes[b]
.iter()
.flat_map(|bx| bx.iter().copied())
.collect();
let gt_cls_b = >_classes[b];
let assignment = self.assigner.assign(
cls_slice,
box_slice,
>_b,
gt_cls_b,
&all_anchor_points,
&all_strides,
total_anchors,
num_classes,
);
for a in 0..total_anchors {
if assignment.positive_mask[a] {
let cls = assignment.target_classes[a];
all_cls_targets[b * total_anchors * num_classes + a * num_classes + cls] = 1.0;
total_positives += 1;
pos_anchor_indices.push(b * total_anchors + a);
let g = assignment.gt_indices[a] as usize;
pos_target_boxes.extend_from_slice(>_b[g * 4..g * 4 + 4]);
}
}
}
let cls_logits_all = concat_scale_cls(all_cls_logits, batch_size, num_classes, &scale_hw);
let cls_targets = Variable::new(
Tensor::from_vec(all_cls_targets, &[batch_size * total_anchors, num_classes]).unwrap(),
false,
);
let focal = crate::losses::FocalLoss::new();
let cls_loss = focal.compute(&cls_logits_all, &cls_targets);
let total_cls_loss = cls_loss.data().to_vec()[0];
if total_positives == 0 {
return (
cls_loss.mul_scalar(self.cls_weight),
total_cls_loss,
0.0,
0.0,
);
}
let bbox_pred_all = concat_scale_bbox(
&all_bbox_dfl,
batch_size,
self.reg_max,
&scale_hw,
&all_anchor_points,
&strides_cfg,
);
let total_flat = batch_size * total_anchors;
let mut box_targets_flat = vec![0.0f32; total_flat * 4];
let mut box_mask_flat = vec![0.0f32; total_flat * 4];
for (i, &idx) in pos_anchor_indices.iter().enumerate() {
for c in 0..4 {
box_targets_flat[idx * 4 + c] = pos_target_boxes[i * 4 + c];
box_mask_flat[idx * 4 + c] = 1.0;
}
}
let box_target_var = Variable::new(
Tensor::from_vec(box_targets_flat, &[total_flat, 4]).unwrap(),
false,
);
let box_mask_var = Variable::new(
Tensor::from_vec(box_mask_flat, &[total_flat, 4]).unwrap(),
false,
);
let box_diff = bbox_pred_all.sub_var(&box_target_var);
let masked_sq = box_diff.pow(2.0).mul_var(&box_mask_var);
let max_coord = all_anchor_points.iter().copied().fold(1.0f32, f32::max);
let box_norm = max_coord * max_coord;
let box_loss = masked_sq
.sum()
.mul_scalar(1.0 / (total_positives as f32 * 4.0 * box_norm));
let box_loss_val = box_loss.data().to_vec()[0];
let total_dfl_loss = box_loss_val * 0.2;
let dfl_loss_var =
Variable::new(Tensor::from_vec(vec![total_dfl_loss], &[1]).unwrap(), false);
let total = cls_loss
.mul_scalar(self.cls_weight)
.add_var(&box_loss.mul_scalar(self.box_weight))
.add_var(&dfl_loss_var.mul_scalar(self.dfl_weight));
(total, total_cls_loss, box_loss_val, total_dfl_loss)
}
}
fn concat_scale_cls(
scale_logits: Vec<&Variable>,
batch_size: usize,
num_classes: usize,
scale_hw: &[(usize, usize)],
) -> Variable {
let mut reshaped_scales = Vec::new();
for (si, logits) in scale_logits.iter().enumerate() {
let (h, w) = scale_hw[si];
let flat_spatial = logits.reshape(&[batch_size, num_classes, h * w]);
let transposed = flat_spatial.transpose(1, 2);
let flat = transposed.reshape(&[batch_size * h * w, num_classes]);
reshaped_scales.push(flat);
}
if reshaped_scales.len() == 1 {
return reshaped_scales.into_iter().next().unwrap();
}
let mut result = reshaped_scales[0].clone();
for scale in &reshaped_scales[1..] {
result = Variable::cat(&[&result, scale], 0);
}
result
}
fn concat_scale_bbox(
scale_dfl: &[&Variable],
batch_size: usize,
reg_max: usize,
scale_hw: &[(usize, usize)],
anchor_points: &[f32],
strides_cfg: &[usize],
) -> Variable {
let weights_data: Vec<f32> = (0..reg_max).map(|i| i as f32).collect();
let weights = Variable::new(
Tensor::from_vec(weights_data, &[reg_max, 1]).unwrap(),
false,
);
let mut decoded_scales = Vec::new();
let mut anchor_offset = 0;
for (si, dfl_var) in scale_dfl.iter().enumerate() {
let (h, w) = scale_hw[si];
let hw = h * w;
let stride = strides_cfg[si] as f32;
let reshaped = dfl_var.reshape(&[batch_size * 4, reg_max, hw]);
let transposed = reshaped.transpose(1, 2);
let flat = transposed.reshape(&[batch_size * 4 * hw, reg_max]);
let probs = flat.softmax(1);
let decoded = probs.matmul(&weights);
let ltrb = decoded.reshape(&[batch_size, 4, hw]);
let l_dist = ltrb.narrow(1, 0, 1); let t_dist = ltrb.narrow(1, 1, 1);
let r_dist = ltrb.narrow(1, 2, 1);
let b_dist = ltrb.narrow(1, 3, 1);
let mut cx_data = vec![0.0f32; batch_size * hw];
let mut cy_data = vec![0.0f32; batch_size * hw];
for b in 0..batch_size {
for pos in 0..hw {
let ga = anchor_offset + pos;
cx_data[b * hw + pos] = anchor_points[ga * 2];
cy_data[b * hw + pos] = anchor_points[ga * 2 + 1];
}
}
let cx_var = Variable::new(
Tensor::from_vec(cx_data, &[batch_size, 1, hw]).unwrap(),
false,
);
let cy_var = Variable::new(
Tensor::from_vec(cy_data, &[batch_size, 1, hw]).unwrap(),
false,
);
let x1 = cx_var.sub_var(&l_dist.mul_scalar(stride));
let y1 = cy_var.sub_var(&t_dist.mul_scalar(stride));
let x2 = cx_var.add_var(&r_dist.mul_scalar(stride));
let y2 = cy_var.add_var(&b_dist.mul_scalar(stride));
let xyxy = Variable::cat(&[&x1, &y1, &x2, &y2], 1);
let xyxy_t = xyxy.transpose(1, 2);
let flat_boxes = xyxy_t.reshape(&[batch_size * hw, 4]);
decoded_scales.push(flat_boxes);
anchor_offset += hw;
}
if decoded_scales.len() == 1 {
return decoded_scales.into_iter().next().unwrap();
}
let mut result = decoded_scales[0].clone();
for s in &decoded_scales[1..] {
result = Variable::cat(&[&result, s], 0);
}
result
}
fn decode_dfl_boxes(
dfl_data: &[f32],
batch_size: usize,
reg_max: usize,
h: usize,
w: usize,
stride: f32,
anchor_points: &[f32],
anchor_offset: usize,
) -> Vec<f32> {
let total_anchors_this_scale = h * w;
let mut boxes = vec![0.0f32; batch_size * total_anchors_this_scale * 4];
for b in 0..batch_size {
for yi in 0..h {
for xi in 0..w {
let local_idx = yi * w + xi;
let global_anchor = anchor_offset + local_idx;
let cx = anchor_points[global_anchor * 2];
let cy = anchor_points[global_anchor * 2 + 1];
let mut ltrb = [0.0f32; 4];
for coord in 0..4 {
let base = b * (4 * reg_max) * h * w + coord * reg_max * h * w;
let mut logits = vec![0.0f32; reg_max];
for bin in 0..reg_max {
logits[bin] = dfl_data[base + bin * h * w + yi * w + xi];
}
let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
let exp_sum: f32 = exp.iter().sum();
let mut val = 0.0f32;
for (bin, &e) in exp.iter().enumerate() {
val += bin as f32 * (e / exp_sum);
}
ltrb[coord] = val;
}
let x1 = cx - ltrb[0] * stride;
let y1 = cy - ltrb[1] * stride;
let x2 = cx + ltrb[2] * stride;
let y2 = cy + ltrb[3] * stride;
let out_idx = b * total_anchors_this_scale * 4 + local_idx * 4;
boxes[out_idx] = x1;
boxes[out_idx + 1] = y1;
boxes[out_idx + 2] = x2;
boxes[out_idx + 3] = y2;
}
}
}
boxes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ciou_identical_boxes() {
let boxes = Variable::new(
Tensor::from_vec(vec![10.0, 10.0, 50.0, 50.0], &[1, 4]).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![10.0, 10.0, 50.0, 50.0], &[1, 4]).unwrap(),
false,
);
let loss = CIoULoss::compute(&boxes, &target);
let val = loss.data().to_vec()[0];
assert!(
val < 0.01,
"Identical boxes → near-zero CIoU loss, got {val}"
);
}
#[test]
fn test_ciou_disjoint_boxes() {
let pred = Variable::new(
Tensor::from_vec(vec![0.0, 0.0, 10.0, 10.0], &[1, 4]).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![50.0, 50.0, 60.0, 60.0], &[1, 4]).unwrap(),
false,
);
let loss = CIoULoss::compute(&pred, &target);
let val = loss.data().to_vec()[0];
assert!(val > 0.5, "Disjoint boxes → large CIoU loss, got {val}");
}
#[test]
fn test_ciou_values() {
let pred = vec![10.0, 10.0, 50.0, 50.0];
let target = vec![10.0, 10.0, 50.0, 50.0];
let vals = CIoULoss::ciou_values(&pred, &target, 1);
assert!(
(vals[0] - 1.0).abs() < 0.01,
"Identical → CIoU≈1.0, got {}",
vals[0]
);
}
#[test]
fn test_task_aligned_assigner_no_gt() {
let assigner = TaskAlignedAssigner::default_v8();
let assignment = assigner.assign(
&[0.5; 10], &[0.0; 20], &[], &[],
&[16.0, 16.0, 48.0, 16.0, 16.0, 48.0, 48.0, 48.0, 32.0, 32.0], &[8.0; 5],
5,
2,
);
assert!(assignment.positive_mask.iter().all(|&m| !m));
}
#[test]
fn test_task_aligned_assigner_with_gt() {
let assigner = TaskAlignedAssigner::new(3, 1.0, 6.0);
let anchor_points = vec![
8.0, 8.0, 24.0, 8.0, 8.0, 24.0, 24.0, 24.0, ];
let gt_boxes = vec![0.0, 0.0, 32.0, 16.0];
let gt_classes = vec![0usize];
let cls_scores = vec![0.5f32; 8];
let pred_boxes = vec![
2.0, 2.0, 30.0, 14.0, 0.0, 0.0, 32.0, 16.0, 2.0, 18.0, 30.0, 30.0, 0.0, 18.0, 32.0, 30.0, ];
let assignment = assigner.assign(
&cls_scores,
&pred_boxes,
>_boxes,
>_classes,
&anchor_points,
&[16.0; 4],
4,
2,
);
assert!(assignment.positive_mask[0], "Anchor 0 should be positive");
assert!(assignment.positive_mask[1], "Anchor 1 should be positive");
assert!(!assignment.positive_mask[2], "Anchor 2 should be negative");
assert!(!assignment.positive_mask[3], "Anchor 3 should be negative");
assert_eq!(assignment.target_classes[0], 0);
}
#[test]
fn test_iou_single() {
let a = [0.0, 0.0, 10.0, 10.0];
let b = [0.0, 0.0, 10.0, 10.0];
assert!((iou_single(&a, &b) - 1.0).abs() < 0.01);
let c = [5.0, 5.0, 15.0, 15.0];
let iou = iou_single(&a, &c);
assert!((iou - 25.0 / 175.0).abs() < 0.01);
}
#[test]
fn test_helios_loss_no_gt() {
use super::super::Helios;
let model = Helios::nano(2);
let input = Variable::new(
Tensor::from_vec(vec![0.5; 3 * 64 * 64], &[1, 3, 64, 64]).unwrap(),
false,
);
let train_out = model.forward_train(&input);
let loss_fn = HeliosLoss::new(2, 16);
let (total, cls_val, box_val, dfl_val) = loss_fn.compute(
&train_out,
&[vec![]], &[vec![]], 2,
);
let total_val = total.data().to_vec()[0];
assert!(total_val.is_finite(), "Loss should be finite with no GT");
assert_eq!(box_val, 0.0, "No GT → no box loss");
assert_eq!(dfl_val, 0.0, "No GT → no DFL loss");
assert!(cls_val >= 0.0, "Cls loss should be non-negative");
}
#[test]
fn test_helios_loss_with_gt() {
use super::super::Helios;
let model = Helios::nano(2);
let input = Variable::new(
Tensor::from_vec(vec![0.5; 3 * 64 * 64], &[1, 3, 64, 64]).unwrap(),
false,
);
let train_out = model.forward_train(&input);
let loss_fn = HeliosLoss::new(2, 16);
let gt_boxes = vec![vec![[10.0, 10.0, 40.0, 40.0]]];
let gt_classes = vec![vec![0usize]];
let (total, cls_val, box_val, _dfl_val) =
loss_fn.compute(&train_out, >_boxes, >_classes, 2);
let total_val = total.data().to_vec()[0];
assert!(total_val.is_finite(), "Loss should be finite");
assert!(
total_val > 0.0,
"Loss should be positive with GT, got {total_val}"
);
assert!(cls_val >= 0.0);
assert!(box_val >= 0.0);
}
}