use std::collections::{HashMap, HashSet};
use std::mem::size_of;
use crate::accumulate::{accumulate, AccumulateParams, PerImageEval};
use crate::dataset::{
AnnId, CategoryId, CocoDataset, CocoDetection, CocoDetections, DetectionInput, EvalDataset,
ImageId,
};
use crate::error::EvalError;
use crate::evaluate::{evaluate_with, EvalKernel, OwnedEvaluateParams};
use crate::parity::{recall_thresholds, ParityMode};
use crate::summarize::{summarize_detection, summarize_with, StatRequest, Summary};
const DEFAULT_BUDGET_BYTES: usize = 8 * 1024 * 1024 * 1024;
const DEFAULT_SOFT_WARN_FRACTION: f64 = 0.80;
#[derive(Debug, Clone, Copy)]
pub struct MemoryBudget {
pub bytes: usize,
pub soft_warn_fraction: f64,
}
impl MemoryBudget {
pub fn auto_default() -> Self {
let half_total = Self::system_total_bytes()
.map(|t| t / 2)
.unwrap_or(DEFAULT_BUDGET_BYTES);
Self {
bytes: DEFAULT_BUDGET_BYTES.min(half_total),
soft_warn_fraction: DEFAULT_SOFT_WARN_FRACTION,
}
}
fn system_total_bytes() -> Option<usize> {
if cfg!(target_os = "linux") {
let contents = std::fs::read_to_string("/proc/meminfo").ok()?;
for line in contents.lines() {
if let Some(rest) = line.strip_prefix("MemTotal:") {
let rest = rest.trim();
let kb_part = rest.strip_suffix(" kB")?;
let kb: usize = kb_part.trim().parse().ok()?;
return Some(kb.saturating_mul(1024));
}
}
None
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct EvalGridMeta {
pub n_categories: usize,
pub n_area_ranges: usize,
pub n_images: usize,
pub category_id_to_idx: HashMap<CategoryId, usize>,
pub image_id_to_idx: HashMap<ImageId, usize>,
}
#[derive(Debug, Clone, Default)]
pub struct PerImageEvalStore {
cells: HashMap<(usize, usize, usize), PerImageEval>,
}
impl PerImageEvalStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.cells.len()
}
pub fn is_empty(&self) -> bool {
self.cells.is_empty()
}
pub fn insert(&mut self, k: usize, a: usize, i: usize, cell: PerImageEval) {
self.cells.insert((k, a, i), cell);
}
pub fn flatten(&self, meta: &EvalGridMeta) -> Vec<Option<PerImageEval>> {
let total = meta.n_categories * meta.n_area_ranges * meta.n_images;
let mut out: Vec<Option<PerImageEval>> = Vec::with_capacity(total);
for k in 0..meta.n_categories {
for a in 0..meta.n_area_ranges {
for i in 0..meta.n_images {
out.push(self.cells.get(&(k, a, i)).cloned());
}
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct UpdateReport {
pub n_detections_accepted: usize,
pub n_images_in_batch: usize,
pub n_cells_inserted: usize,
pub soft_warn_triggered: bool,
}
#[derive(Debug, Clone)]
pub struct ParsedDetections<K: EvalKernel> {
pub detections: CocoDetections,
_kernel: std::marker::PhantomData<K>,
}
impl<K: EvalKernel> ParsedDetections<K> {
pub fn from_detections(detections: CocoDetections) -> Self {
Self {
detections,
_kernel: std::marker::PhantomData,
}
}
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self, EvalError> {
Ok(Self::from_detections(CocoDetections::from_json_bytes(
bytes,
)?))
}
}
#[derive(Debug)]
pub struct StreamingEvaluator<K: EvalKernel> {
dataset: CocoDataset,
kernel: K,
params: OwnedEvaluateParams,
parity_mode: ParityMode,
grid_meta: EvalGridMeta,
cells: PerImageEvalStore,
seen_images: HashSet<i64>,
seen_image_indices: HashSet<usize>,
gt_only_cells: Option<Vec<Option<PerImageEval>>>,
n_detections: usize,
next_dt_id: i64,
bytes_cells_struct: usize,
bytes_dt_scores: usize,
bytes_match_flags: usize,
budget: MemoryBudget,
soft_warn_fired: bool,
}
impl<K: EvalKernel> StreamingEvaluator<K> {
pub fn new(
dataset: CocoDataset,
kernel: K,
params: OwnedEvaluateParams,
parity_mode: ParityMode,
budget: MemoryBudget,
) -> Result<Self, EvalError> {
if params.area_ranges.is_empty() {
return Err(EvalError::InvalidConfig {
detail: "OwnedEvaluateParams.area_ranges must be non-empty".into(),
});
}
let grid_meta = build_grid_meta(&dataset, ¶ms);
Ok(Self {
dataset,
kernel,
params,
parity_mode,
grid_meta,
cells: PerImageEvalStore::new(),
seen_images: HashSet::new(),
seen_image_indices: HashSet::new(),
gt_only_cells: None,
n_detections: 0,
next_dt_id: 1,
bytes_cells_struct: 0,
bytes_dt_scores: 0,
bytes_match_flags: 0,
budget,
soft_warn_fired: false,
})
}
pub fn images_seen(&self) -> usize {
self.seen_images.len()
}
pub fn detections_seen(&self) -> usize {
self.n_detections
}
pub fn images_pending(&self) -> usize {
self.grid_meta.n_images.saturating_sub(self.images_seen())
}
pub fn memory_used_bytes(&self) -> usize {
self.bytes_cells_struct + self.bytes_dt_scores + self.bytes_match_flags
}
pub fn budget(&self) -> MemoryBudget {
self.budget
}
pub fn grid_meta(&self) -> &EvalGridMeta {
&self.grid_meta
}
pub fn update(&mut self, json_bytes: &[u8]) -> Result<UpdateReport, EvalError> {
let parsed = ParsedDetections::<K>::from_json_bytes(json_bytes)?;
self.update_parsed(parsed)
}
pub fn update_parsed(
&mut self,
parsed: ParsedDetections<K>,
) -> Result<UpdateReport, EvalError> {
let detections = parsed.detections;
let mut batch_image_ids: HashSet<i64> = HashSet::new();
for dt in detections.detections() {
let id = dt.image_id.0;
if self.seen_images.contains(&id) {
return Err(EvalError::InvalidAnnotation {
detail: format!(
"image_id={id} was already submitted in a prior update(); \
StreamingEvaluator does not silently merge — submit all \
detections for an image in a single batch"
),
});
}
batch_image_ids.insert(id);
}
let grid = evaluate_with(
&self.dataset,
&detections,
self.params.borrow(),
self.parity_mode,
&self.kernel,
)?;
let mut batch_image_indices: HashSet<usize> = HashSet::with_capacity(batch_image_ids.len());
for id in &batch_image_ids {
if let Some(&idx) = self.grid_meta.image_id_to_idx.get(&ImageId(*id)) {
batch_image_indices.insert(idx);
}
}
let n_t = self.params.iou_thresholds.len();
let n_k = grid.n_categories;
let n_a = grid.n_area_ranges;
let n_i = grid.n_images;
let mut staged: Vec<(usize, usize, usize, PerImageEval, CellCost)> = Vec::new();
let mut cost_total = CellCost::default();
for &i in &batch_image_indices {
for k in 0..n_k {
for a in 0..n_a {
let flat = k * n_a * n_i + a * n_i + i;
if let Some(cell) = grid.eval_imgs.get(flat).and_then(|opt| opt.as_ref()) {
let cost = cell_cost(cell, n_t);
cost_total = cost_total.add(cost);
staged.push((k, a, i, cell.clone(), cost));
}
}
}
}
let projected = self.memory_used_bytes() + cost_total.total();
if projected > self.budget.bytes {
let mut breakdown: HashMap<&'static str, usize> = HashMap::new();
breakdown.insert(
"cells_store",
self.bytes_cells_struct + cost_total.cells_struct,
);
breakdown.insert("scores", self.bytes_dt_scores + cost_total.dt_scores);
breakdown.insert(
"match_flags",
self.bytes_match_flags + cost_total.match_flags,
);
return Err(EvalError::OutOfBudget {
used_bytes: projected,
budget_bytes: self.budget.bytes,
breakdown,
});
}
let n_cells_inserted = staged.len();
for (k, a, i, cell, cost) in staged {
self.cells.insert(k, a, i, cell);
self.bytes_cells_struct += cost.cells_struct;
self.bytes_dt_scores += cost.dt_scores;
self.bytes_match_flags += cost.match_flags;
}
let n_detections_accepted = detections.detections().len();
self.n_detections += n_detections_accepted;
self.next_dt_id = self.next_dt_id.saturating_add(n_detections_accepted as i64);
for id in &batch_image_ids {
self.seen_images.insert(*id);
}
for idx in &batch_image_indices {
self.seen_image_indices.insert(*idx);
}
let total_used = self.memory_used_bytes();
let threshold = (self.budget.bytes as f64 * self.budget.soft_warn_fraction) as usize;
let soft_warn_triggered = total_used >= threshold && !self.soft_warn_fired;
if soft_warn_triggered {
self.soft_warn_fired = true;
}
Ok(UpdateReport {
n_detections_accepted,
n_images_in_batch: batch_image_ids.len(),
n_cells_inserted,
soft_warn_triggered,
})
}
pub fn snapshot(&mut self) -> Result<Summary, EvalError> {
self.compute_summary()
}
pub fn snapshot_running(&mut self) -> Result<Summary, EvalError> {
self.snapshot()
}
pub fn finalize(mut self) -> Result<Summary, EvalError> {
self.compute_summary()
}
pub fn checkpoint(&self) -> Result<Vec<u8>, EvalError> {
Err(EvalError::NotImplemented {
feature: "StreamingEvaluator::checkpoint",
})
}
pub fn restore(_bytes: &[u8]) -> Result<Self, EvalError> {
Err(EvalError::NotImplemented {
feature: "StreamingEvaluator::restore",
})
}
fn ensure_gt_only_cells(&mut self) -> Result<(), EvalError> {
if self.gt_only_cells.is_some() {
return Ok(());
}
let empty_dt = CocoDetections::from_inputs(Vec::new())?;
let grid = evaluate_with(
&self.dataset,
&empty_dt,
self.params.borrow(),
self.parity_mode,
&self.kernel,
)?;
self.gt_only_cells = Some(grid.eval_imgs);
Ok(())
}
fn compute_summary(&mut self) -> Result<Summary, EvalError> {
let mut eval_imgs = self.cells.flatten(&self.grid_meta);
if self.images_seen() < self.grid_meta.n_images {
self.ensure_gt_only_cells()?;
let n_k = self.grid_meta.n_categories;
let n_a = self.grid_meta.n_area_ranges;
let n_i = self.grid_meta.n_images;
let gt_only = self
.gt_only_cells
.as_ref()
.ok_or_else(|| EvalError::InvalidConfig {
detail: "gt_only_cells cache missing after init".into(),
})?;
for i in 0..n_i {
if self.seen_image_indices.contains(&i) {
continue;
}
for k in 0..n_k {
for a in 0..n_a {
let flat = k * n_a * n_i + a * n_i + i;
if let Some(cell) = gt_only.get(flat).and_then(|opt| opt.as_ref()) {
eval_imgs[flat] = Some(cell.clone());
}
}
}
}
}
let max_dets: [usize; 3] = [1, 10, 100];
let accum_params = AccumulateParams {
iou_thresholds: &self.params.iou_thresholds,
recall_thresholds: recall_thresholds(),
max_dets: &max_dets,
n_categories: self.grid_meta.n_categories,
n_area_ranges: self.grid_meta.n_area_ranges,
n_images: self.grid_meta.n_images,
};
let accumulated = accumulate(&eval_imgs, accum_params, self.parity_mode)?;
if self.kernel.is_keypoints() {
let kp_max_dets: [usize; 1] = [20];
let accum_params_kp = AccumulateParams {
iou_thresholds: &self.params.iou_thresholds,
recall_thresholds: recall_thresholds(),
max_dets: &kp_max_dets,
n_categories: self.grid_meta.n_categories,
n_area_ranges: self.grid_meta.n_area_ranges,
n_images: self.grid_meta.n_images,
};
let accumulated_kp = accumulate(&eval_imgs, accum_params_kp, self.parity_mode)?;
let plan = StatRequest::coco_keypoints_default();
summarize_with(
&accumulated_kp,
&plan,
&self.params.iou_thresholds,
&kp_max_dets,
)
} else {
summarize_detection(&accumulated, &self.params.iou_thresholds, &max_dets)
}
}
}
#[derive(Debug, Default, Clone, Copy)]
struct CellCost {
cells_struct: usize,
dt_scores: usize,
match_flags: usize,
}
impl CellCost {
fn total(self) -> usize {
self.cells_struct + self.dt_scores + self.match_flags
}
fn add(self, other: Self) -> Self {
Self {
cells_struct: self.cells_struct + other.cells_struct,
dt_scores: self.dt_scores + other.dt_scores,
match_flags: self.match_flags + other.match_flags,
}
}
}
fn cell_cost(cell: &PerImageEval, n_iou_thresholds: usize) -> CellCost {
let n_d = cell.dt_scores.len();
CellCost {
cells_struct: size_of::<PerImageEval>(),
dt_scores: cell.dt_scores.capacity() * size_of::<f64>(),
match_flags: n_iou_thresholds
.saturating_mul(n_d)
.saturating_mul(size_of::<bool>())
.saturating_mul(2),
}
}
fn build_grid_meta(dataset: &CocoDataset, params: &OwnedEvaluateParams) -> EvalGridMeta {
let n_area_ranges = params.area_ranges.len();
let n_images = dataset.images().len();
let mut image_ids: Vec<ImageId> = dataset.images().iter().map(|im| im.id).collect();
image_ids.sort_unstable_by_key(|id| id.0);
let mut image_id_to_idx: HashMap<ImageId, usize> = HashMap::with_capacity(n_images);
for (i, id) in image_ids.into_iter().enumerate() {
image_id_to_idx.insert(id, i);
}
let (n_categories, category_id_to_idx) = if params.use_cats {
let mut cat_ids: Vec<CategoryId> = dataset.categories().iter().map(|c| c.id).collect();
cat_ids.sort_unstable_by_key(|c| c.0);
let mut map: HashMap<CategoryId, usize> = HashMap::with_capacity(cat_ids.len());
for (k, id) in cat_ids.iter().enumerate() {
map.insert(*id, k);
}
(cat_ids.len(), map)
} else {
(1, HashMap::new())
};
EvalGridMeta {
n_categories,
n_area_ranges,
n_images,
category_id_to_idx,
image_id_to_idx,
}
}
#[allow(dead_code)]
fn _docs_typecheck(_a: AnnId, _b: CocoDetection, _c: DetectionInput) {}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{Bbox, CategoryMeta, CocoAnnotation, ImageMeta};
use crate::evaluate::AreaRange;
use crate::parity::iou_thresholds;
use crate::similarity::BboxIou;
fn img(id: i64, w: u32, h: u32) -> ImageMeta {
ImageMeta {
id: ImageId(id),
width: w,
height: h,
file_name: None,
}
}
fn cat(id: i64, name: &str) -> CategoryMeta {
CategoryMeta {
id: CategoryId(id),
name: name.into(),
supercategory: None,
}
}
fn ann(id: i64, image: i64, cat: i64, bbox: (f64, f64, f64, f64)) -> CocoAnnotation {
CocoAnnotation {
id: AnnId(id),
image_id: ImageId(image),
category_id: CategoryId(cat),
area: bbox.2 * bbox.3,
is_crowd: false,
ignore_flag: None,
bbox: Bbox {
x: bbox.0,
y: bbox.1,
w: bbox.2,
h: bbox.3,
},
segmentation: None,
keypoints: None,
num_keypoints: None,
}
}
fn tiny_dataset() -> CocoDataset {
let images = vec![img(1, 100, 100), img(2, 100, 100)];
let cats = vec![cat(1, "thing")];
let anns = vec![
ann(1, 1, 1, (0.0, 0.0, 10.0, 10.0)),
ann(2, 2, 1, (50.0, 50.0, 10.0, 10.0)),
];
CocoDataset::from_parts(images, anns, cats).unwrap()
}
fn default_params() -> OwnedEvaluateParams {
OwnedEvaluateParams {
iou_thresholds: iou_thresholds().to_vec(),
area_ranges: AreaRange::coco_default().to_vec(),
max_dets_per_image: 100,
use_cats: true,
}
}
#[test]
fn auto_default_budget_is_nonzero() {
let b = MemoryBudget::auto_default();
assert!(b.bytes > 0);
assert!((b.soft_warn_fraction - DEFAULT_SOFT_WARN_FRACTION).abs() < 1e-12);
}
#[test]
fn fresh_evaluator_reports_zero_counters() {
let ds = tiny_dataset();
let ev = StreamingEvaluator::new(
ds,
BboxIou,
default_params(),
ParityMode::Strict,
MemoryBudget::auto_default(),
)
.unwrap();
assert_eq!(ev.images_seen(), 0);
assert_eq!(ev.detections_seen(), 0);
assert_eq!(ev.memory_used_bytes(), 0);
assert_eq!(ev.images_pending(), 2);
assert_eq!(ev.grid_meta().n_categories, 1);
assert_eq!(ev.grid_meta().n_area_ranges, 4);
assert_eq!(ev.grid_meta().n_images, 2);
}
#[test]
fn empty_update_returns_zero_counters() {
let ds = tiny_dataset();
let mut ev = StreamingEvaluator::new(
ds,
BboxIou,
default_params(),
ParityMode::Strict,
MemoryBudget::auto_default(),
)
.unwrap();
let report = ev.update(b"[]").unwrap();
assert_eq!(report.n_detections_accepted, 0);
assert_eq!(report.n_images_in_batch, 0);
assert_eq!(report.n_cells_inserted, 0);
assert!(!report.soft_warn_triggered);
assert_eq!(ev.detections_seen(), 0);
assert_eq!(ev.images_seen(), 0);
assert_eq!(ev.memory_used_bytes(), 0);
}
#[test]
fn finalize_returns_summary_with_canonical_shape() {
let ds = tiny_dataset();
let ev = StreamingEvaluator::new(
ds,
BboxIou,
default_params(),
ParityMode::Strict,
MemoryBudget::auto_default(),
)
.unwrap();
let summary = ev.finalize().unwrap();
assert_eq!(summary.lines.len(), 12);
}
#[test]
fn duplicate_image_id_across_updates_is_rejected() {
let ds = tiny_dataset();
let mut ev = StreamingEvaluator::new(
ds,
BboxIou,
default_params(),
ParityMode::Strict,
MemoryBudget::auto_default(),
)
.unwrap();
let batch1 =
br#"[{"image_id": 1, "category_id": 1, "score": 0.9, "bbox": [0, 0, 10, 10]}]"#;
ev.update(batch1).unwrap();
assert_eq!(ev.images_seen(), 1);
let batch2 =
br#"[{"image_id": 1, "category_id": 1, "score": 0.8, "bbox": [50, 50, 10, 10]}]"#;
let err = ev.update(batch2).unwrap_err();
assert!(matches!(err, EvalError::InvalidAnnotation { .. }));
assert_eq!(ev.images_seen(), 1);
assert_eq!(ev.detections_seen(), 1);
}
#[test]
fn out_of_budget_does_not_mutate_state() {
let ds = tiny_dataset();
let tiny_budget = MemoryBudget {
bytes: 1, soft_warn_fraction: 0.80,
};
let mut ev = StreamingEvaluator::new(
ds,
BboxIou,
default_params(),
ParityMode::Strict,
tiny_budget,
)
.unwrap();
let batch = br#"[{"image_id": 1, "category_id": 1, "score": 0.9, "bbox": [0, 0, 10, 10]}]"#;
let err = ev.update(batch).unwrap_err();
match err {
EvalError::OutOfBudget {
used_bytes,
budget_bytes,
breakdown,
} => {
assert!(used_bytes > budget_bytes);
assert_eq!(budget_bytes, 1);
assert!(breakdown.contains_key("cells_store"));
assert!(breakdown.contains_key("scores"));
assert!(breakdown.contains_key("match_flags"));
}
other => panic!("expected OutOfBudget, got {other:?}"),
}
assert_eq!(ev.images_seen(), 0);
assert_eq!(ev.detections_seen(), 0);
assert_eq!(ev.memory_used_bytes(), 0);
}
#[test]
fn checkpoint_and_restore_return_not_implemented() {
let ds = tiny_dataset();
let ev = StreamingEvaluator::new(
ds,
BboxIou,
default_params(),
ParityMode::Strict,
MemoryBudget::auto_default(),
)
.unwrap();
let err = ev.checkpoint().unwrap_err();
assert!(matches!(err, EvalError::NotImplemented { .. }));
let err = StreamingEvaluator::<BboxIou>::restore(&[]).unwrap_err();
assert!(matches!(err, EvalError::NotImplemented { .. }));
}
#[test]
fn flatten_round_trips_to_dense_layout() {
let mut store = PerImageEvalStore::new();
let cell = PerImageEval {
dt_scores: vec![0.5],
dt_matched: ndarray::Array2::default((1, 1)),
dt_ignore: ndarray::Array2::default((1, 1)),
gt_ignore: vec![false],
};
store.insert(0, 0, 0, cell);
let meta = EvalGridMeta {
n_categories: 1,
n_area_ranges: 1,
n_images: 2,
category_id_to_idx: HashMap::new(),
image_id_to_idx: HashMap::new(),
};
let dense = store.flatten(&meta);
assert_eq!(dense.len(), 2);
assert!(dense[0].is_some());
assert!(dense[1].is_none());
}
}