use crate::dataset::{
CocoAnnotation, CocoDataset, CocoDetection, CocoDetections, DetectionInput, EvalDataset,
};
use crate::error::EvalError;
use super::assignment::{BinAssignment, DtBin};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FixKind {
Cls,
Loc,
Both,
Dupe,
Bkg,
Missed,
AllFp,
}
pub fn apply_fix(
gt: &CocoDataset,
dt: &CocoDetections,
assignment: &BinAssignment,
fix: FixKind,
) -> Result<(CocoDataset, CocoDetections), EvalError> {
let mut gts: Vec<CocoAnnotation> = gt.annotations().to_vec();
if matches!(fix, FixKind::Missed) {
let missed_set: std::collections::HashSet<(i64, usize)> =
assignment.missed_gts.iter().copied().collect();
for (gt_input_idx, ann) in gts.iter_mut().enumerate() {
if missed_set.contains(&(ann.image_id.0, gt_input_idx)) {
ann.ignore_flag = Some(true);
}
}
}
let new_gt = CocoDataset::from_parts(gt.images().to_vec(), gts, gt.categories().to_vec())?;
let mut new_dts: Vec<DetectionInput> = Vec::with_capacity(dt.detections().len());
let original_anns = gt.annotations();
for (dt_input_idx, det) in dt.detections().iter().enumerate() {
let key = (det.image_id.0, dt_input_idx);
let label = assignment.dt_labels.get(&key).copied();
let resolve_target = |target: i32| -> Result<&CocoAnnotation, EvalError> {
let local_indices = new_gt.ann_indices_for_image(det.image_id);
let target_usize =
usize::try_from(target).map_err(|_| EvalError::InvalidAnnotation {
detail: format!(
"rewrite: invalid target_gt_local_idx={target} for DT id={} on image {}",
det.id.0, det.image_id.0
),
})?;
local_indices
.get(target_usize)
.map(|&j| &original_anns[j])
.ok_or_else(|| EvalError::InvalidAnnotation {
detail: format!(
"rewrite: target_gt_local_idx={target} out of range \
for image {} (have {} GTs)",
det.image_id.0,
local_indices.len()
),
})
};
match (fix, label) {
(FixKind::AllFp, Some(lbl))
if matches!(
lbl.bin,
DtBin::Cls | DtBin::Loc | DtBin::Both | DtBin::Dupe | DtBin::Bkg
) =>
{
continue;
}
(FixKind::Cls, Some(lbl)) if lbl.bin == DtBin::Cls => {
let target = resolve_target(lbl.target_gt_local_idx)?;
new_dts.push(DetectionInput {
id: Some(det.id),
image_id: det.image_id,
category_id: target.category_id,
score: det.score,
bbox: det.bbox,
segmentation: det.segmentation.clone(),
keypoints: det.keypoints.clone(),
num_keypoints: det.num_keypoints,
});
}
(FixKind::Loc, Some(lbl)) if lbl.bin == DtBin::Loc => {
let target = resolve_target(lbl.target_gt_local_idx)?;
new_dts.push(DetectionInput {
id: Some(det.id),
image_id: det.image_id,
category_id: det.category_id,
score: det.score,
bbox: target.bbox,
segmentation: target.segmentation.clone(),
keypoints: det.keypoints.clone(),
num_keypoints: det.num_keypoints,
});
}
(FixKind::Both, Some(lbl)) if lbl.bin == DtBin::Both => continue,
(FixKind::Dupe, Some(lbl)) if lbl.bin == DtBin::Dupe => continue,
(FixKind::Bkg, Some(lbl)) if lbl.bin == DtBin::Bkg => continue,
_ => {
new_dts.push(passthrough_input(det));
}
}
}
let new_dt = CocoDetections::from_inputs(new_dts)?;
Ok((new_gt, new_dt))
}
fn passthrough_input(det: &CocoDetection) -> DetectionInput {
DetectionInput {
id: Some(det.id),
image_id: det.image_id,
category_id: det.category_id,
score: det.score,
bbox: det.bbox,
segmentation: det.segmentation.clone(),
keypoints: det.keypoints.clone(),
num_keypoints: det.num_keypoints,
}
}