easy-yolo 0.1.2

Easy to use library for YOLO inference in rust requiring no additional setup, weights included.
Documentation
use ndarray::{ArrayView, IxDyn, s};
use crate::{YoloResult, BoundingBox};
use core::cmp::Ordering;

fn iou(a: &YoloResult, b: &YoloResult) -> f32 {
    let xx1 = a.bbox.x1.max(b.bbox.x1);
    let yy1 = a.bbox.y1.max(b.bbox.y1);
    let xx2 = a.bbox.x2.min(b.bbox.x2);
    let yy2 = a.bbox.y2.min(b.bbox.y2);
    let w = (xx2 - xx1).max(0.0);
    let h = (yy2 - yy1).max(0.0);
    let inter = w * h;
    let area_a = (a.bbox.x2 - a.bbox.x1) * (a.bbox.y2 - a.bbox.y1);
    let area_b = (b.bbox.x2 - b.bbox.x1) * (b.bbox.y2 - b.bbox.y1);
    inter / (area_a + area_b - inter + 1e-6)
}

fn nms(mut dets: Vec<YoloResult>, iou_thresh: f32) -> Vec<YoloResult> {
    dets.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap_or(Ordering::Equal));
    let mut keep = Vec::new();
    while let Some(current) = dets.pop() {
        keep.push(current.clone());
        dets.retain(|d| iou(&current, d) < iou_thresh);
    }
    keep
}

/* NMS for YoloV8/V11/V12 ultralytics format (generally 84 valeus - box, classes) */
pub fn yolo_nms(output: ArrayView<f32, IxDyn>, conf_thresh: f32, iou_thresh: f32) -> Vec<YoloResult> {
    let shape = output.shape();
    let (n, c) = (shape[0], shape[1]);
    let mut dets = Vec::new();

    for i in 0..n {
        let cx = output[[i, 0, 0]];
        let cy = output[[i, 1, 0]];
        let w  = output[[i, 2, 0]];
        let h  = output[[i, 3, 0]];
        // let obj = output[[i, 4, 0]];

        let class_slice = output.slice(s![i, 4..c, 0]);

        if let Some((cls, cls_score)) = class_slice
            .iter()
            .enumerate()
            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
        {
            // let confidence = obj * cls_score;
            let confidence = *cls_score;
            if confidence > conf_thresh {
                dets.push(YoloResult {
                    bbox: BoundingBox {
                        x1: cx - w / 2.0,
                        y1: cy - h / 2.0,
                        x2: cx + w / 2.0,
                        y2: cy + h / 2.0,
                    },
                    confidence,
                    class_id: cls,
                });
            }
        }
    }

    nms(dets, iou_thresh)
}