use std::collections::HashSet;
use ndarray::ArrayView2;
use crate::dataset::{Annotation, CategoryId, CocoDataset, CocoDetections, EvalDataset, ImageMeta};
use crate::error::EvalError;
use crate::evaluate::{evaluate_with, EvalGrid, EvalKernel, EvaluateParams};
use crate::parity::ParityMode;
use crate::tables::RetainedIous;
use super::params::LrpParams;
use super::tau_search::{search_tau, TauSearchResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct PerClassDecomposition {
pub category_index: usize,
pub olrp: Option<f64>,
pub olrp_loc: Option<f64>,
pub olrp_fp: Option<f64>,
pub olrp_fn: Option<f64>,
pub tau: Option<f64>,
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn decompose_class(
gt: &CocoDataset,
k: usize,
category_id: Option<CategoryId>,
n_images: usize,
image_order: &[&ImageMeta],
retained: &RetainedIous,
grid: &crate::evaluate::EvalGrid,
parity_mode: ParityMode,
params: &LrpParams<'_>,
image_mask: Option<&[bool]>,
) -> Result<PerClassDecomposition, EvalError> {
let cap = params.max_dets_per_image.saturating_mul(n_images);
let mut dt_score: Vec<f64> = Vec::with_capacity(cap);
let mut dt_matched: Vec<bool> = Vec::with_capacity(cap);
let mut dt_iou: Vec<f64> = Vec::with_capacity(cap);
let mut n_pos_gt: u64 = 0;
let mut gt_crowd: Vec<bool> = Vec::new();
let mut gt_ignore_mask: Vec<bool> = Vec::new();
let mut gt_taken: Vec<bool> = Vec::new();
let gt_anns = gt.annotations();
for (i, image) in image_order.iter().enumerate() {
if let Some(mask) = image_mask {
if !mask[i] {
continue;
}
}
let cell = match grid.cell(k, 0, i) {
Some(c) => c,
None => continue,
};
if grid.cell_meta(k, 0, i).is_none() {
continue;
}
let image_id = image.id;
let gt_indices = match category_id {
Some(c) => gt.ann_indices_for(image_id, c),
None => gt.ann_indices_for_image(image_id),
};
for &j in gt_indices {
let ann = >_anns[j];
if !ann.is_crowd() && !ann.effective_ignore(parity_mode) {
n_pos_gt += 1;
}
}
let iou_view = match retained.get(k, i) {
Some(v) => v,
None => continue, };
run_cell_matching(
&iou_view,
cell,
gt_indices,
gt_anns,
parity_mode,
params.tp_threshold,
&mut dt_score,
&mut dt_matched,
&mut dt_iou,
&mut gt_crowd,
&mut gt_ignore_mask,
&mut gt_taken,
);
}
if n_pos_gt == 0 {
return Ok(PerClassDecomposition {
category_index: k,
olrp: None,
olrp_loc: None,
olrp_fp: None,
olrp_fn: None,
tau: None,
});
}
let search = search_tau(
&dt_score,
&dt_matched,
&dt_iou,
n_pos_gt,
params.tp_threshold,
params.tau_grid,
);
let result = match search {
Some(r) => r,
None => {
return Ok(PerClassDecomposition {
category_index: k,
olrp: None,
olrp_loc: None,
olrp_fp: None,
olrp_fn: None,
tau: None,
});
}
};
Ok(decompose_at(result, params, k))
}
fn decompose_at(
result: TauSearchResult,
params: &LrpParams<'_>,
category_index: usize,
) -> PerClassDecomposition {
let n_tp = result.stats.n_tp;
let n_fp = result.stats.n_fp;
let n_fn = result.stats.n_fn;
let sum_loc = result.stats.sum_loc;
let one_minus_tau_tp = if params.tp_threshold >= 1.0 {
1.0
} else {
1.0 - params.tp_threshold
};
let tau = params.tau_grid.get(result.star).copied();
if n_tp > 0 {
let n_tp_f = n_tp as f64;
let loc = (sum_loc / n_tp_f) / one_minus_tau_tp;
let fp = if n_tp + n_fp > 0 {
(n_fp as f64) / ((n_tp + n_fp) as f64)
} else {
0.0
};
let fn_rate = if n_tp + n_fn > 0 {
(n_fn as f64) / ((n_tp + n_fn) as f64)
} else {
0.0
};
PerClassDecomposition {
category_index,
olrp: Some(result.lrp),
olrp_loc: Some(loc),
olrp_fp: Some(fp),
olrp_fn: Some(fn_rate),
tau,
}
} else {
PerClassDecomposition {
category_index,
olrp: Some(result.lrp),
olrp_loc: None,
olrp_fp: if n_fp > 0 { None } else { Some(0.0) },
olrp_fn: Some(1.0),
tau: None,
}
}
}
#[allow(clippy::too_many_arguments)]
fn run_cell_matching(
iou_mat: &ArrayView2<'_, f64>,
cell: &crate::accumulate::PerImageEval,
gt_indices: &[usize],
gt_anns: &[crate::dataset::CocoAnnotation],
parity_mode: ParityMode,
tp_threshold: f64,
dt_score: &mut Vec<f64>,
dt_matched: &mut Vec<bool>,
dt_iou: &mut Vec<f64>,
gt_crowd: &mut Vec<bool>,
gt_ignore_mask: &mut Vec<bool>,
gt_taken: &mut Vec<bool>,
) {
let g = iou_mat.nrows();
let d = iou_mat.ncols();
gt_crowd.clear();
gt_ignore_mask.clear();
for &j in gt_indices {
let a = >_anns[j];
gt_crowd.push(a.is_crowd());
gt_ignore_mask.push(a.effective_ignore(parity_mode));
}
debug_assert_eq!(gt_crowd.len(), g);
debug_assert_eq!(gt_ignore_mask.len(), g);
gt_taken.clear();
gt_taken.resize(g, false);
let scores = &cell.dt_scores;
debug_assert_eq!(scores.len(), d);
for k in 0..d {
let mut best_iou = tp_threshold;
let mut best_g: Option<usize> = None;
for j in 0..g {
if gt_taken[j] {
continue;
}
if gt_crowd[j] || gt_ignore_mask[j] {
continue;
}
let v = iou_mat[(j, k)];
if v >= best_iou {
best_iou = v;
best_g = Some(j);
}
}
if let Some(j) = best_g {
gt_taken[j] = true;
dt_score.push(scores[k]);
dt_matched.push(true);
dt_iou.push(best_iou);
continue;
}
let mut best_iou_alt = tp_threshold;
let mut hit_ignore = false;
for j in 0..g {
if !(gt_crowd[j] || gt_ignore_mask[j]) {
continue;
}
let v = iou_mat[(j, k)];
if v >= best_iou_alt {
best_iou_alt = v;
hit_ignore = true;
}
}
if hit_ignore {
continue;
}
dt_score.push(scores[k]);
dt_matched.push(false);
dt_iou.push(0.0);
}
}
pub(crate) struct LrpPassContext<'gt> {
pub(crate) gt: &'gt CocoDataset,
pub(crate) retained: RetainedIous,
pub(crate) grid: EvalGrid,
pub(crate) image_order: Vec<&'gt ImageMeta>,
pub(crate) category_buckets: Vec<Option<CategoryId>>,
}
pub(crate) fn prepare_lrp_pass<'gt, K: EvalKernel>(
gt: &'gt CocoDataset,
dt: &CocoDetections,
kernel: &K,
params: &LrpParams<'_>,
parity_mode: ParityMode,
) -> Result<LrpPassContext<'gt>, EvalError> {
let eval_params = EvaluateParams {
iou_thresholds: params.iou_thresholds,
area_ranges: params.area_ranges,
max_dets_per_image: params.max_dets_per_image,
use_cats: params.use_cats,
retain_iou: true,
};
let mut grid = evaluate_with(gt, dt, eval_params, parity_mode, kernel)?;
let retained = grid
.retained_ious
.take()
.ok_or_else(|| EvalError::InvalidConfig {
detail: "lrp: evaluate_with returned no retained_ious despite retain_iou=true".into(),
})?;
let mut images: Vec<&ImageMeta> = gt.images().iter().collect();
images.sort_unstable_by_key(|im| im.id.0);
let category_buckets: Vec<Option<CategoryId>> = if params.use_cats {
let mut cats: Vec<_> = gt.categories().iter().map(|c| c.id).collect();
cats.sort_unstable_by_key(|c| c.0);
cats.into_iter().map(Some).collect()
} else {
vec![None]
};
debug_assert_eq!(category_buckets.len(), grid.n_categories);
Ok(LrpPassContext {
gt,
retained,
grid,
image_order: images,
category_buckets,
})
}
pub(crate) fn decompose_all_classes(
ctx: &LrpPassContext<'_>,
parity_mode: ParityMode,
params: &LrpParams<'_>,
image_filter: Option<&HashSet<usize>>,
) -> Result<Vec<PerClassDecomposition>, EvalError> {
let n_images = ctx.image_order.len();
let mask_storage: Option<Vec<bool>> = image_filter.map(|set| {
let mut m = vec![false; n_images];
for &i in set {
if i < n_images {
m[i] = true;
}
}
m
});
let image_mask = mask_storage.as_deref();
let mut out: Vec<PerClassDecomposition> = Vec::with_capacity(ctx.category_buckets.len());
for (k, cat) in ctx.category_buckets.iter().enumerate() {
let d = decompose_class(
ctx.gt,
k,
*cat,
n_images,
&ctx.image_order,
&ctx.retained,
&ctx.grid,
parity_mode,
params,
image_mask,
)?;
out.push(d);
}
Ok(out)
}