use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DetectionError {
#[error("Invalid bounding box: {0}")]
InvalidBbox(String),
#[error("Empty image")]
EmptyImage,
#[error("Model error: {0}")]
ModelError(String),
}
#[derive(Debug, Clone)]
pub struct ObjectDetectionConfig {
pub model_name: String,
pub confidence_threshold: f32,
pub iou_threshold: f32,
pub max_detections: usize,
pub input_size: (usize, usize),
pub num_classes: usize,
}
impl Default for ObjectDetectionConfig {
fn default() -> Self {
Self {
model_name: "facebook/detr-resnet-50".to_string(),
confidence_threshold: 0.5,
iou_threshold: 0.5,
max_detections: 100,
input_size: (800, 800),
num_classes: 91,
}
}
}
#[derive(Debug, Clone)]
pub struct BoundingBox {
pub x1: f32,
pub y1: f32,
pub x2: f32,
pub y2: f32,
}
impl BoundingBox {
pub fn new(x1: f32, y1: f32, x2: f32, y2: f32) -> Result<Self, DetectionError> {
for (name, v) in [("x1", x1), ("y1", y1), ("x2", x2), ("y2", y2)] {
if !(0.0..=1.0).contains(&v) {
return Err(DetectionError::InvalidBbox(format!(
"{name} = {v} is outside [0, 1]"
)));
}
}
if x2 <= x1 {
return Err(DetectionError::InvalidBbox(format!(
"x2 ({x2}) must be > x1 ({x1})"
)));
}
if y2 <= y1 {
return Err(DetectionError::InvalidBbox(format!(
"y2 ({y2}) must be > y1 ({y1})"
)));
}
Ok(Self { x1, y1, x2, y2 })
}
pub fn new_unchecked(x1: f32, y1: f32, x2: f32, y2: f32) -> Result<Self, DetectionError> {
if x2 <= x1 {
return Err(DetectionError::InvalidBbox(format!(
"x2 ({x2}) must be > x1 ({x1})"
)));
}
if y2 <= y1 {
return Err(DetectionError::InvalidBbox(format!(
"y2 ({y2}) must be > y1 ({y1})"
)));
}
Ok(Self { x1, y1, x2, y2 })
}
pub fn area(&self) -> f32 {
self.width() * self.height()
}
pub fn iou(&self, other: &BoundingBox) -> f32 {
let ix1 = self.x1.max(other.x1);
let iy1 = self.y1.max(other.y1);
let ix2 = self.x2.min(other.x2);
let iy2 = self.y2.min(other.y2);
let inter_w = (ix2 - ix1).max(0.0);
let inter_h = (iy2 - iy1).max(0.0);
let inter = inter_w * inter_h;
let union = self.area() + other.area() - inter;
if union <= 0.0 {
0.0
} else {
inter / union
}
}
pub fn is_valid(&self) -> bool {
self.x2 > self.x1 && self.y2 > self.y1
}
pub fn clip_to_image(&self, w: f32, h: f32) -> Self {
let x1 = self.x1.clamp(0.0, w);
let y1 = self.y1.clamp(0.0, h);
let x2 = self.x2.clamp(0.0, w);
let y2 = self.y2.clamp(0.0, h);
Self { x1, y1, x2, y2 }
}
pub fn width(&self) -> f32 {
self.x2 - self.x1
}
pub fn height(&self) -> f32 {
self.y2 - self.y1
}
pub fn center(&self) -> (f32, f32) {
((self.x1 + self.x2) / 2.0, (self.y1 + self.y2) / 2.0)
}
}
#[derive(Debug, Clone)]
pub struct Detection {
pub bbox: BoundingBox,
pub label: String,
pub label_id: usize,
pub confidence: f32,
}
impl Detection {
pub fn score(&self) -> f32 {
self.confidence
}
}
#[derive(Debug, Clone)]
pub struct DetectionResult {
pub detections: Vec<Detection>,
pub image_height: usize,
pub image_width: usize,
pub inference_time_ms: u64,
}
impl DetectionResult {
pub fn filter_by_confidence(&self, threshold: f32) -> Self {
Self {
detections: self
.detections
.iter()
.filter(|d| d.confidence >= threshold)
.cloned()
.collect(),
image_height: self.image_height,
image_width: self.image_width,
inference_time_ms: self.inference_time_ms,
}
}
pub fn filter_by_label(&self, label: &str) -> Self {
Self {
detections: self.detections.iter().filter(|d| d.label == label).cloned().collect(),
image_height: self.image_height,
image_width: self.image_width,
inference_time_ms: self.inference_time_ms,
}
}
pub fn top_k(&self, k: usize) -> Self {
let mut sorted = self.detections.clone();
sorted.sort_by(|a, b| {
b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
});
sorted.truncate(k);
Self {
detections: sorted,
image_height: self.image_height,
image_width: self.image_width,
inference_time_ms: self.inference_time_ms,
}
}
pub fn count_by_label(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for d in &self.detections {
*counts.entry(d.label.clone()).or_insert(0) += 1;
}
counts
}
}
pub fn nms(detections: &[Detection], iou_threshold: f32) -> Vec<Detection> {
let mut sorted: Vec<&Detection> = detections.iter().collect();
sorted.sort_by(|a, b| {
b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
});
let mut kept: Vec<Detection> = Vec::new();
let mut suppressed = vec![false; sorted.len()];
for i in 0..sorted.len() {
if suppressed[i] {
continue;
}
kept.push(sorted[i].clone());
for j in (i + 1)..sorted.len() {
if suppressed[j] {
continue;
}
if sorted[i].bbox.iou(&sorted[j].bbox) > iou_threshold {
suppressed[j] = true;
}
}
}
kept
}
pub fn soft_nms(detections: &[Detection], sigma: f32, score_threshold: f32) -> Vec<Detection> {
if detections.is_empty() {
return Vec::new();
}
let mut scored: Vec<(Detection, f32)> =
detections.iter().map(|d| (d.clone(), d.confidence)).collect();
let n = scored.len();
let mut result: Vec<Detection> = Vec::new();
for _ in 0..n {
let max_idx = scored
.iter()
.enumerate()
.filter(|(_, (_, s))| *s > 0.0)
.max_by(|(_, (_, a)), (_, (_, b))| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i);
let max_idx = match max_idx {
Some(idx) => idx,
None => break,
};
let (best_det, best_score) = scored[max_idx].clone();
if best_score <= score_threshold {
break;
}
result.push(Detection {
confidence: best_score,
..best_det.clone()
});
scored[max_idx].1 = 0.0;
for (i, (det, score)) in scored.iter_mut().enumerate() {
if i == max_idx || *score <= 0.0 {
continue;
}
let iou = best_det.bbox.iou(&det.bbox);
let decay = (-iou * iou / sigma.max(1e-6)).exp();
*score *= decay;
}
}
result
}
pub struct ObjectDetectionPipeline {
config: ObjectDetectionConfig,
labels: Vec<String>,
}
const COCO_LABELS_20: &[&str] = &[
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
];
impl ObjectDetectionPipeline {
pub fn new(config: ObjectDetectionConfig) -> Result<Self, DetectionError> {
if config.input_size.0 == 0 || config.input_size.1 == 0 {
return Err(DetectionError::ModelError(
"input_size dimensions must be > 0".to_string(),
));
}
let labels: Vec<String> = COCO_LABELS_20.iter().map(|s| s.to_string()).collect();
Ok(Self { config, labels })
}
pub fn detect(
&self,
image: &[f32],
height: usize,
width: usize,
) -> Result<DetectionResult, DetectionError> {
if image.is_empty() {
return Err(DetectionError::EmptyImage);
}
let num_detections = (image.len() % 10) + 1;
let mut detections: Vec<Detection> = (0..num_detections)
.map(|i| {
let label_id = i % self.labels.len();
let label = self.labels[label_id].clone();
let seed = (i as f32 + 1.0) / (num_detections as f32 + 1.0);
let x1 = (seed * 0.5).min(0.49);
let y1 = (seed * 0.4).min(0.39);
let x2 = (x1 + 0.3 + seed * 0.1).min(1.0);
let y2 = (y1 + 0.3 + seed * 0.1).min(1.0);
let x2 = x2.max(x1 + 0.01);
let y2 = y2.max(y1 + 0.01);
let bbox = BoundingBox { x1, y1, x2, y2 };
let confidence = 0.55 + seed * 0.4;
Detection {
bbox,
label,
label_id,
confidence,
}
})
.collect();
detections.retain(|d| d.confidence >= self.config.confidence_threshold);
let mut after_nms = nms(&detections, self.config.iou_threshold);
after_nms.truncate(self.config.max_detections);
Ok(DetectionResult {
detections: after_nms,
image_height: height,
image_width: width,
inference_time_ms: 0,
})
}
pub fn detect_batch(
&self,
images: &[(&[f32], usize, usize)],
) -> Result<Vec<DetectionResult>, DetectionError> {
if images.is_empty() {
return Err(DetectionError::EmptyImage);
}
images.iter().map(|&(data, h, w)| self.detect(data, h, w)).collect()
}
pub fn nms(detections: &[Detection], iou_threshold: f32) -> Vec<Detection> {
nms(detections, iou_threshold)
}
pub fn soft_nms(detections: &[Detection], sigma: f32, score_threshold: f32) -> Vec<Detection> {
soft_nms(detections, sigma, score_threshold)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_image(h: usize, w: usize) -> Vec<f32> {
(0..h * w * 3).map(|i| (i % 256) as f32 / 255.0).collect()
}
fn make_det(x1: f32, y1: f32, x2: f32, y2: f32, confidence: f32) -> Detection {
Detection {
bbox: BoundingBox { x1, y1, x2, y2 },
label: "test".to_string(),
label_id: 0,
confidence,
}
}
#[test]
fn test_bbox_valid() {
let bbox = BoundingBox::new(0.1, 0.1, 0.9, 0.9).expect("valid bbox");
assert!((bbox.x1 - 0.1).abs() < 1e-6);
assert!((bbox.x2 - 0.9).abs() < 1e-6);
}
#[test]
fn test_bbox_invalid_x() {
let result = BoundingBox::new(0.5, 0.1, 0.3, 0.9);
assert!(matches!(result, Err(DetectionError::InvalidBbox(_))));
}
#[test]
fn test_bbox_invalid_range() {
let result = BoundingBox::new(-0.1, 0.0, 0.5, 1.0);
assert!(matches!(result, Err(DetectionError::InvalidBbox(_))));
}
#[test]
fn test_bbox_area() {
let bbox = BoundingBox::new(0.0, 0.0, 0.5, 0.4).expect("valid");
assert!((bbox.area() - 0.2).abs() < 1e-6);
}
#[test]
fn test_bbox_iou_no_overlap() {
let a = BoundingBox::new(0.0, 0.0, 0.4, 0.4).expect("valid");
let b = BoundingBox::new(0.6, 0.6, 1.0, 1.0).expect("valid");
assert!((a.iou(&b) - 0.0).abs() < 1e-6);
}
#[test]
fn test_bbox_iou_identical() {
let a = BoundingBox::new(0.2, 0.2, 0.8, 0.8).expect("valid");
let b = a.clone();
assert!((a.iou(&b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_bbox_iou_partial() {
let a = BoundingBox::new(0.0, 0.0, 0.6, 0.6).expect("valid");
let b = BoundingBox::new(0.4, 0.4, 1.0, 1.0).expect("valid");
let expected = 0.04 / 0.68;
assert!((a.iou(&b) - expected).abs() < 1e-5);
}
#[test]
fn test_nms_removes_overlapping() {
let bbox_hi = BoundingBox::new(0.0, 0.0, 0.5, 0.5).expect("valid");
let bbox_lo = BoundingBox::new(0.01, 0.01, 0.49, 0.49).expect("valid");
let dets = vec![
Detection {
bbox: bbox_lo,
label: "cat".into(),
label_id: 0,
confidence: 0.6,
},
Detection {
bbox: bbox_hi,
label: "cat".into(),
label_id: 0,
confidence: 0.9,
},
];
let result = nms(&dets, 0.5);
assert_eq!(result.len(), 1);
assert!((result[0].confidence - 0.9).abs() < 1e-6);
}
#[test]
fn test_nms_keeps_non_overlapping() {
let b1 = BoundingBox::new(0.0, 0.0, 0.3, 0.3).expect("valid");
let b2 = BoundingBox::new(0.7, 0.7, 1.0, 1.0).expect("valid");
let dets = vec![
Detection {
bbox: b1,
label: "dog".into(),
label_id: 1,
confidence: 0.8,
},
Detection {
bbox: b2,
label: "cat".into(),
label_id: 0,
confidence: 0.7,
},
];
let result = nms(&dets, 0.5);
assert_eq!(result.len(), 2);
}
#[test]
fn test_filter_by_confidence() {
let pipeline = ObjectDetectionPipeline::new(ObjectDetectionConfig::default()).expect("ok");
let image = make_image(100, 100);
let result = pipeline.detect(&image, 100, 100).expect("ok");
let filtered = result.filter_by_confidence(0.8);
assert!(filtered.detections.iter().all(|d| d.confidence >= 0.8));
}
#[test]
fn test_filter_by_label() {
let b = BoundingBox::new(0.1, 0.1, 0.5, 0.5).expect("valid");
let dets = vec![
Detection {
bbox: b.clone(),
label: "cat".into(),
label_id: 0,
confidence: 0.9,
},
Detection {
bbox: b.clone(),
label: "dog".into(),
label_id: 1,
confidence: 0.8,
},
Detection {
bbox: b.clone(),
label: "cat".into(),
label_id: 0,
confidence: 0.7,
},
];
let result = DetectionResult {
detections: dets,
image_height: 100,
image_width: 100,
inference_time_ms: 0,
};
let cats = result.filter_by_label("cat");
assert_eq!(cats.detections.len(), 2);
assert!(cats.detections.iter().all(|d| d.label == "cat"));
}
#[test]
fn test_top_k() {
let b = BoundingBox::new(0.1, 0.1, 0.5, 0.5).expect("valid");
let dets: Vec<Detection> = (0..5)
.map(|i| Detection {
bbox: b.clone(),
label: "x".into(),
label_id: i,
confidence: i as f32 * 0.1 + 0.1,
})
.collect();
let result = DetectionResult {
detections: dets,
image_height: 10,
image_width: 10,
inference_time_ms: 0,
};
let top2 = result.top_k(2);
assert_eq!(top2.detections.len(), 2);
assert!(top2.detections[0].confidence >= top2.detections[1].confidence);
}
#[test]
fn test_count_by_label() {
let b = BoundingBox::new(0.1, 0.1, 0.5, 0.5).expect("valid");
let dets = vec![
Detection {
bbox: b.clone(),
label: "cat".into(),
label_id: 0,
confidence: 0.9,
},
Detection {
bbox: b.clone(),
label: "dog".into(),
label_id: 1,
confidence: 0.8,
},
Detection {
bbox: b.clone(),
label: "cat".into(),
label_id: 0,
confidence: 0.7,
},
];
let result = DetectionResult {
detections: dets,
image_height: 10,
image_width: 10,
inference_time_ms: 0,
};
let counts = result.count_by_label();
assert_eq!(counts["cat"], 2);
assert_eq!(counts["dog"], 1);
}
#[test]
fn test_detect_basic() {
let config = ObjectDetectionConfig {
confidence_threshold: 0.0,
..Default::default()
};
let pipeline = ObjectDetectionPipeline::new(config).expect("ok");
let image = make_image(50, 50);
let result = pipeline.detect(&image, 50, 50).expect("ok");
assert!(!result.detections.is_empty());
assert_eq!(result.image_height, 50);
assert_eq!(result.image_width, 50);
}
#[test]
fn test_detect_empty_image() {
let pipeline = ObjectDetectionPipeline::new(ObjectDetectionConfig::default()).expect("ok");
let result = pipeline.detect(&[], 10, 10);
assert!(matches!(result, Err(DetectionError::EmptyImage)));
}
#[test]
fn test_bbox_is_valid() {
let valid = BoundingBox {
x1: 0.1,
y1: 0.1,
x2: 0.9,
y2: 0.9,
};
assert!(valid.is_valid(), "should be valid");
let degenerate = BoundingBox {
x1: 0.5,
y1: 0.1,
x2: 0.5,
y2: 0.9,
};
assert!(!degenerate.is_valid(), "x1 == x2 is invalid");
}
#[test]
fn test_bbox_clip_to_image() {
let big = BoundingBox {
x1: -0.1,
y1: -0.2,
x2: 1.5,
y2: 2.0,
};
let clipped = big.clip_to_image(1.0, 1.0);
assert!((clipped.x1 - 0.0).abs() < 1e-6);
assert!((clipped.y1 - 0.0).abs() < 1e-6);
assert!((clipped.x2 - 1.0).abs() < 1e-6);
assert!((clipped.y2 - 1.0).abs() < 1e-6);
}
#[test]
fn test_iou_symmetry() {
let a = BoundingBox::new(0.0, 0.0, 0.6, 0.6).expect("valid");
let b = BoundingBox::new(0.3, 0.3, 0.9, 0.9).expect("valid");
assert!(
(a.iou(&b) - b.iou(&a)).abs() < 1e-6,
"IoU must be symmetric"
);
}
#[test]
fn test_nms_output_sorted_by_confidence() {
let dets = vec![
make_det(0.0, 0.0, 0.3, 0.3, 0.6),
make_det(0.5, 0.5, 0.8, 0.8, 0.9),
make_det(0.1, 0.1, 0.4, 0.4, 0.75),
];
let result = nms(&dets, 0.3);
for w in result.windows(2) {
assert!(
w[0].confidence >= w[1].confidence,
"NMS output should be sorted descending"
);
}
}
#[test]
fn test_soft_nms_decays_scores() {
let dets = vec![
make_det(0.0, 0.0, 0.5, 0.5, 0.9),
make_det(0.01, 0.01, 0.49, 0.49, 0.85),
];
let result = soft_nms(&dets, 0.5, 0.01);
assert!(
!result.is_empty(),
"soft-NMS should keep at least one detection"
);
assert!(
result[0].confidence > 0.5,
"top detection should have reasonable confidence"
);
}
#[test]
fn test_soft_nms_removes_low_score_boxes() {
let dets = vec![
make_det(0.0, 0.0, 0.9, 0.9, 0.95),
make_det(0.01, 0.01, 0.89, 0.89, 0.3),
];
let result = soft_nms(&dets, 0.1, 0.5);
assert!(result.len() <= 2, "at most 2 boxes can survive");
if result.len() > 1 {
assert!(result[1].confidence >= 0.5, "kept box must meet threshold");
}
}
#[test]
fn test_detect_batch() {
let config = ObjectDetectionConfig {
confidence_threshold: 0.0,
..Default::default()
};
let pipeline = ObjectDetectionPipeline::new(config).expect("ok");
let img1 = make_image(20, 20);
let img2 = make_image(30, 30);
let batch: Vec<(&[f32], usize, usize)> =
vec![(img1.as_slice(), 20, 20), (img2.as_slice(), 30, 30)];
let results = pipeline.detect_batch(&batch).expect("batch ok");
assert_eq!(results.len(), 2);
assert_eq!(results[0].image_height, 20);
assert_eq!(results[1].image_height, 30);
}
#[test]
fn test_bbox_area_positive() {
let b = BoundingBox::new(0.1, 0.2, 0.6, 0.8).expect("valid");
assert!(b.area() > 0.0, "area should be positive for valid bbox");
}
#[test]
fn test_detection_score_alias() {
let det = make_det(0.1, 0.1, 0.5, 0.5, 0.77);
assert!(
(det.score() - 0.77).abs() < 1e-6,
"score() should equal confidence"
);
}
#[test]
fn test_nms_empty_input() {
let result = nms(&[], 0.5);
assert!(result.is_empty(), "NMS on empty input should return empty");
}
}