#![allow(clippy::too_many_arguments)]
#![allow(clippy::similar_names)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::many_single_char_names)]
#![allow(clippy::struct_excessive_bools)]
#![allow(clippy::module_name_repetitions)]
use crate::detect::object::{BoundingBox, Detection, ObjectDetector};
use crate::detect::yolo_utils::{
decode_yolov5_output, decode_yolov8_output, draw_detections, letterbox_resize, LetterboxParams,
};
use crate::error::{CvError, CvResult};
use ndarray::{Array, ArrayD, IxDyn};
use oxionnx::Session;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum YoloVersion {
V5,
V8,
V9,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InputResolution {
Fixed(u32, u32),
Dynamic {
min_size: u32,
max_size: u32,
stride: u32,
},
}
impl Default for InputResolution {
fn default() -> Self {
Self::Fixed(640, 640)
}
}
impl InputResolution {
#[must_use]
pub fn compute_size(&self, img_width: u32, img_height: u32) -> (u32, u32) {
match *self {
Self::Fixed(w, h) => (w, h),
Self::Dynamic {
min_size,
max_size,
stride,
} => {
let stride = stride.max(1);
let max_dim = img_width.max(img_height);
let scale = if max_dim > max_size {
max_size as f64 / max_dim as f64
} else if max_dim < min_size {
min_size as f64 / max_dim as f64
} else {
1.0
};
let new_w = ((img_width as f64 * scale) as u32).max(stride);
let new_h = ((img_height as f64 * scale) as u32).max(stride);
let aligned_w = ((new_w + stride - 1) / stride) * stride;
let aligned_h = ((new_h + stride - 1) / stride) * stride;
(
aligned_w.clamp(min_size, max_size),
aligned_h.clamp(min_size, max_size),
)
}
}
}
}
#[derive(Debug, Clone)]
pub struct YoloConfig {
pub version: YoloVersion,
pub input_size: (u32, u32),
pub input_resolution: Option<InputResolution>,
pub confidence_threshold: f32,
pub iou_threshold: f32,
pub max_detections: usize,
pub class_names: Option<Vec<String>>,
pub per_class_nms: bool,
pub execution_providers: Vec<String>,
}
impl Default for YoloConfig {
fn default() -> Self {
Self {
version: YoloVersion::V8,
input_size: (640, 640),
input_resolution: None,
confidence_threshold: 0.25,
iou_threshold: 0.45,
max_detections: 300,
class_names: None,
per_class_nms: false,
execution_providers: vec!["CPUExecutionProvider".to_string()],
}
}
}
impl YoloConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_version(mut self, version: YoloVersion) -> Self {
self.version = version;
self
}
#[must_use]
pub fn with_input_size(mut self, width: u32, height: u32) -> Self {
self.input_size = (width, height);
self
}
#[must_use]
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
self.confidence_threshold = threshold;
self
}
#[must_use]
pub fn with_iou_threshold(mut self, threshold: f32) -> Self {
self.iou_threshold = threshold;
self
}
#[must_use]
pub fn with_max_detections(mut self, max: usize) -> Self {
self.max_detections = max;
self
}
#[must_use]
pub fn with_class_names(mut self, names: Vec<String>) -> Self {
self.class_names = Some(names);
self
}
#[must_use]
pub fn with_per_class_nms(mut self, enabled: bool) -> Self {
self.per_class_nms = enabled;
self
}
#[must_use]
pub fn with_execution_providers(mut self, providers: Vec<String>) -> Self {
self.execution_providers = providers;
self
}
#[must_use]
pub fn with_input_resolution(mut self, resolution: InputResolution) -> Self {
self.input_resolution = Some(resolution);
self
}
#[must_use]
pub fn effective_input_size(&self, img_width: u32, img_height: u32) -> (u32, u32) {
match &self.input_resolution {
Some(resolution) => resolution.compute_size(img_width, img_height),
None => self.input_size,
}
}
}
pub struct YoloDetector {
session: Session,
config: YoloConfig,
class_names: Vec<String>,
num_classes: usize,
}
impl YoloDetector {
pub fn new(model_path: impl AsRef<Path>, config: YoloConfig) -> CvResult<Self> {
let session = Session::builder()
.with_optimization_level(oxionnx::OptLevel::All)
.load(model_path.as_ref())
.map_err(|e| CvError::detection_failed(format!("Failed to load model: {e}")))?;
let class_names = config.class_names.clone().unwrap_or_else(coco_class_names);
let num_classes = class_names.len();
Ok(Self {
session,
config,
class_names,
num_classes,
})
}
pub fn from_bytes(model_bytes: &[u8], config: YoloConfig) -> CvResult<Self> {
let session = Session::builder()
.with_optimization_level(oxionnx::OptLevel::All)
.load_from_bytes(model_bytes)
.map_err(|e| {
CvError::detection_failed(format!("Failed to load model from memory: {e}"))
})?;
let class_names = config.class_names.clone().unwrap_or_else(coco_class_names);
let num_classes = class_names.len();
Ok(Self {
session,
config,
class_names,
num_classes,
})
}
pub fn detect(&mut self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<Detection>> {
if width == 0 || height == 0 {
return Err(CvError::invalid_dimensions(width, height));
}
let expected_size = (width as usize) * (height as usize) * 3;
if image.len() != expected_size {
return Err(CvError::insufficient_data(expected_size, image.len()));
}
let (input_w, input_h) = self.config.effective_input_size(width, height);
let (input_tensor, letterbox_params) =
self.preprocess(image, width, height, input_w, input_h)?;
let outputs = self.run_inference(&input_tensor)?;
let detections = self.postprocess(&outputs, &letterbox_params, width, height)?;
Ok(detections)
}
pub fn detect_with_visualization(
&mut self,
image: &[u8],
width: u32,
height: u32,
) -> CvResult<(Vec<u8>, Vec<Detection>)> {
let detections = self.detect(image, width, height)?;
let annotated = draw_detections(image, width, height, &detections)?;
Ok((annotated, detections))
}
fn preprocess(
&self,
image: &[u8],
width: u32,
height: u32,
input_w: u32,
input_h: u32,
) -> CvResult<(ArrayD<f32>, LetterboxParams)> {
let (resized, params) = letterbox_resize(image, width, height, input_w, input_h)?;
let mut input_tensor = Array::zeros(IxDyn(&[1, 3, input_h as usize, input_w as usize]));
for c in 0..3 {
for y in 0..input_h as usize {
for x in 0..input_w as usize {
let idx = (y * input_w as usize + x) * 3 + c;
let value = resized[idx] as f32 / 255.0;
input_tensor[[0, c, y, x]] = value;
}
}
}
Ok((input_tensor, params))
}
fn run_inference(&mut self, input: &ArrayD<f32>) -> CvResult<ArrayD<f32>> {
let flat: Vec<f32> = input.iter().copied().collect();
let shape: Vec<usize> = input.shape().to_vec();
let tensor = oxionnx::Tensor::new(flat, shape);
let input_name = self
.session
.input_names()
.first()
.cloned()
.unwrap_or_else(|| "images".to_string());
let mut inputs = HashMap::new();
inputs.insert(input_name.as_str(), tensor);
let outputs = self
.session
.run(&inputs)
.map_err(|e| CvError::detection_failed(format!("Inference failed: {e}")))?;
let output_name = self
.session
.output_names()
.first()
.cloned()
.unwrap_or_default();
let out_tensor = outputs
.get(&output_name)
.ok_or_else(|| CvError::detection_failed("No output tensor found".to_owned()))?;
let out_shape: Vec<usize> = out_tensor.shape.clone();
ArrayD::from_shape_vec(IxDyn(&out_shape), out_tensor.data.clone())
.map_err(|e| CvError::detection_failed(format!("Failed to create output array: {e}")))
}
fn postprocess(
&self,
output_tensor: &ArrayD<f32>,
letterbox_params: &LetterboxParams,
orig_width: u32,
orig_height: u32,
) -> CvResult<Vec<Detection>> {
let detections = match self.config.version {
YoloVersion::V5 => decode_yolov5_output(
output_tensor,
self.config.confidence_threshold,
self.config.iou_threshold,
self.num_classes,
self.config.per_class_nms,
self.config.max_detections,
)?,
YoloVersion::V8 | YoloVersion::V9 => decode_yolov8_output(
output_tensor,
self.config.confidence_threshold,
self.config.iou_threshold,
self.num_classes,
self.config.per_class_nms,
self.config.max_detections,
)?,
};
let detections =
self.transform_detections(detections, letterbox_params, orig_width, orig_height);
Ok(detections)
}
fn transform_detections(
&self,
detections: Vec<Detection>,
params: &LetterboxParams,
orig_width: u32,
orig_height: u32,
) -> Vec<Detection> {
detections
.into_iter()
.map(|mut det| {
let x = (det.bbox.x - params.pad_left as f32) / params.scale;
let y = (det.bbox.y - params.pad_top as f32) / params.scale;
let w = det.bbox.width / params.scale;
let h = det.bbox.height / params.scale;
det.bbox =
BoundingBox::new(x, y, w, h).clamp(orig_width as f32, orig_height as f32);
if det.class_id < self.class_names.len() as u32 {
det.class_name = Some(self.class_names[det.class_id as usize].clone());
}
det
})
.collect()
}
#[must_use]
pub const fn input_size(&self) -> (u32, u32) {
self.config.input_size
}
#[must_use]
pub const fn confidence_threshold(&self) -> f32 {
self.config.confidence_threshold
}
#[must_use]
pub const fn iou_threshold(&self) -> f32 {
self.config.iou_threshold
}
#[must_use]
pub const fn num_classes(&self) -> usize {
self.num_classes
}
}
impl ObjectDetector for YoloDetector {
fn detect(&mut self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<Detection>> {
YoloDetector::detect(self, image, width, height)
}
fn class_names(&self) -> &[String] {
&self.class_names
}
}
#[must_use]
pub fn coco_class_names() -> Vec<String> {
vec![
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
.into_iter()
.map(String::from)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coco_class_names() {
let names = coco_class_names();
assert_eq!(names.len(), 80);
assert_eq!(names[0], "person");
assert_eq!(names[79], "toothbrush");
}
#[test]
fn test_yolo_config_default() {
let config = YoloConfig::default();
assert_eq!(config.input_size, (640, 640));
assert_eq!(config.confidence_threshold, 0.25);
assert_eq!(config.iou_threshold, 0.45);
assert_eq!(config.max_detections, 300);
assert!(!config.per_class_nms);
}
#[test]
fn test_yolo_config_builder() {
let config = YoloConfig::new()
.with_version(YoloVersion::V5)
.with_input_size(416, 416)
.with_confidence_threshold(0.5)
.with_iou_threshold(0.4)
.with_max_detections(100)
.with_per_class_nms(true);
assert_eq!(config.version, YoloVersion::V5);
assert_eq!(config.input_size, (416, 416));
assert_eq!(config.confidence_threshold, 0.5);
assert_eq!(config.iou_threshold, 0.4);
assert_eq!(config.max_detections, 100);
assert!(config.per_class_nms);
}
#[test]
fn test_yolo_version() {
assert_eq!(YoloVersion::V5, YoloVersion::V5);
assert_ne!(YoloVersion::V5, YoloVersion::V8);
}
#[test]
fn test_invalid_dimensions() {
let config = YoloConfig::default();
assert_eq!(config.input_size, (640, 640));
}
#[test]
fn test_class_names_custom() {
let custom_names = vec!["cat".to_string(), "dog".to_string()];
let config = YoloConfig::new().with_class_names(custom_names.clone());
assert_eq!(config.class_names, Some(custom_names));
}
#[test]
fn test_coco_class_names_specific_entries() {
let names = coco_class_names();
assert_eq!(names[14], "bird");
assert_eq!(names[15], "cat");
assert_eq!(names[16], "dog");
assert_eq!(names[39], "bottle");
assert_eq!(names[56], "chair");
}
#[test]
fn test_yolo_config_execution_providers() {
let config =
YoloConfig::new().with_execution_providers(vec!["CUDAExecutionProvider".to_string()]);
assert_eq!(config.execution_providers[0], "CUDAExecutionProvider");
}
#[test]
fn test_yolo_config_default_providers() {
let config = YoloConfig::default();
assert_eq!(config.execution_providers.len(), 1);
assert_eq!(config.execution_providers[0], "CPUExecutionProvider");
}
#[test]
fn test_yolo_version_equality() {
assert_eq!(YoloVersion::V8, YoloVersion::V8);
assert_eq!(YoloVersion::V5, YoloVersion::V5);
}
#[test]
fn test_yolo_config_max_detections() {
let config = YoloConfig::new().with_max_detections(50);
assert_eq!(config.max_detections, 50);
let config2 = YoloConfig::new().with_max_detections(1000);
assert_eq!(config2.max_detections, 1000);
}
#[test]
fn test_coco_class_names_no_duplicates() {
let names = coco_class_names();
let mut seen = std::collections::HashSet::new();
for name in &names {
assert!(seen.insert(name.as_str()), "Duplicate class name: {name}");
}
}
#[test]
fn test_coco_class_names_non_empty() {
let names = coco_class_names();
for name in &names {
assert!(!name.is_empty(), "Class name should not be empty");
}
}
#[test]
fn test_yolo_version_v9() {
let config = YoloConfig::new().with_version(YoloVersion::V9);
assert_eq!(config.version, YoloVersion::V9);
assert_ne!(config.version, YoloVersion::V8);
assert_ne!(config.version, YoloVersion::V5);
}
#[test]
fn test_input_resolution_fixed() {
let res = InputResolution::Fixed(640, 640);
assert_eq!(res.compute_size(1920, 1080), (640, 640));
assert_eq!(res.compute_size(100, 100), (640, 640));
}
#[test]
fn test_input_resolution_dynamic_downscale() {
let res = InputResolution::Dynamic {
min_size: 320,
max_size: 1280,
stride: 32,
};
let (w, h) = res.compute_size(1920, 1080);
assert_eq!(w % 32, 0, "Width must be aligned to stride");
assert_eq!(h % 32, 0, "Height must be aligned to stride");
assert!(w <= 1280);
assert!(h <= 1280);
}
#[test]
fn test_input_resolution_dynamic_upscale() {
let res = InputResolution::Dynamic {
min_size: 320,
max_size: 1280,
stride: 32,
};
let (w, h) = res.compute_size(100, 100);
assert!(w >= 320);
assert!(h >= 320);
assert_eq!(w % 32, 0);
assert_eq!(h % 32, 0);
}
#[test]
fn test_input_resolution_dynamic_passthrough() {
let res = InputResolution::Dynamic {
min_size: 320,
max_size: 1280,
stride: 32,
};
let (w, h) = res.compute_size(640, 480);
assert_eq!(w % 32, 0);
assert_eq!(h % 32, 0);
assert!(w >= 320 && w <= 1280);
assert!(h >= 320 && h <= 1280);
}
#[test]
fn test_input_resolution_default() {
let res = InputResolution::default();
assert_eq!(res, InputResolution::Fixed(640, 640));
}
#[test]
fn test_config_with_input_resolution() {
let config = YoloConfig::new().with_input_resolution(InputResolution::Dynamic {
min_size: 320,
max_size: 1280,
stride: 32,
});
assert!(config.input_resolution.is_some());
let (w, h) = config.effective_input_size(800, 600);
assert_eq!(w % 32, 0);
assert_eq!(h % 32, 0);
}
#[test]
fn test_effective_input_size_without_dynamic() {
let config = YoloConfig::new().with_input_size(416, 416);
assert_eq!(config.effective_input_size(1920, 1080), (416, 416));
}
#[test]
fn test_yolo_v9_config_full() {
let config = YoloConfig::new()
.with_version(YoloVersion::V9)
.with_input_resolution(InputResolution::Dynamic {
min_size: 320,
max_size: 1280,
stride: 32,
})
.with_confidence_threshold(0.3)
.with_iou_threshold(0.5)
.with_max_detections(200)
.with_per_class_nms(true);
assert_eq!(config.version, YoloVersion::V9);
assert!(config.input_resolution.is_some());
assert_eq!(config.confidence_threshold, 0.3);
assert_eq!(config.iou_threshold, 0.5);
assert_eq!(config.max_detections, 200);
assert!(config.per_class_nms);
}
}