easy-yolo 0.1.3

Easy to use library for YOLO inference in rust requiring no additional setup, weights included.
Documentation
/* 
https://linzichun.com/posts/rust-opencv-onnx-yolov8-detect/

TODOS:
- batch processing multiple images at once (maybe, not sure if I want it)
*/

#![allow(clippy::manual_retain)]

mod nms;

use std::path::Path;
use ndarray::{Array, ArrayView, Axis};
use ort::{
    execution_providers::{CoreMLExecutionProvider, coreml::CoreMLComputeUnits},
    inputs,
    session::{InMemorySession, Session, SessionOutputs},
    value::TensorRef
};
use std::sync::OnceLock;

/************************** ERROR **************************/

#[derive(Debug,Clone,Copy)]
pub enum Error {
    OrtError,
    IoError,
    InvalidInput,
    InvalidModel,
    UnsupportedModelOutputFormat,
    LibraryError,
}

impl From<ort::Error> for Error {
    fn from(_: ort::Error) -> Self {
        Error::OrtError
    }
}

pub type Result<T> = std::result::Result<T, Error>;

pub const YOLO_CLASS_LABELS: [&str; 80] = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
    "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
    "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
    "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
    "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
    "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
];

pub enum OutputType {
    UltralyticsUnprocessedV8V11,
    SuperGradientsProcessedBatchFormat,
    SuperGradientsProcessedFlatFormat,
    UnprocessedSuperGradients, // Unprocessed supergradients output for YoloNAS or anything really
    // YoloV5
}

#[derive(Debug, Clone, Copy)]
pub struct BoundingBox {
    pub x1: f32,
    pub y1: f32,
    pub x2: f32,
    pub y2: f32
}

#[derive(Debug, Clone)]
pub struct YoloResult {
    pub bbox: BoundingBox,
    pub class_id: usize,
    pub confidence: f32,
}

static INIT_ORT_ENVIRONMENT: OnceLock<()> = OnceLock::new();

/* Does setup for GPU stuff for ORT. TODO: Use CUDA */
pub fn init_ort_env() {
    INIT_ORT_ENVIRONMENT.get_or_init(|| {
        // TODO: maybe record if coreML EP fails to build and inform about cpu fallback
        ort::init()
            .with_execution_providers([
                CoreMLExecutionProvider::default()
                    .with_compute_units(CoreMLComputeUnits::All)
                    .build()
                    // .error_on_failure(), // exit the program with an error if the Execution Provider fails to register, better to fail silently and fallback to CPU
            ])
            .commit()
            .expect("Failed to initialize ONNX Runtime environment");
    });
}

pub struct YoloModel<'a> {
    session_ref: Option<InMemorySession<'a>>,
    session_own: Option<Session>,
    output_format: OutputType,
}

impl YoloModel<'static> {
    pub fn new_from_bytes(bytes: &[u8], output_format: OutputType) -> Result<Self> {
        Ok(Self {
            session_own: Some(Session::builder()?.commit_from_memory(bytes)?),
            session_ref: None,
            output_format, // TODO: autodetect yolo-nas vs v8/11
        })
    }
}

impl<'a> YoloModel<'a> {
    pub fn new_from_bytes_borrowed(bytes: &'a [u8], output_format: OutputType) -> Result<Self> {
        Ok(Self {
            session_ref: Some(Session::builder()?.commit_from_memory_directly(bytes)?),
            session_own: None,
            output_format, // TODO: autodetect yolo-nas vs v8/11
        })
    }

    pub fn get_model_input_image_dims(&self) -> (u32, u32) {
        (640, 640) // TODO: detect this from model
    }

    pub fn print_input_and_output_info(&mut self) {
        let model = if let Some(session) = &mut self.session_ref { session }
            else if let Some(session) = &mut self.session_own { session }
            else { unreachable!("Shouldn't happen") };

        // get model input output names
        let input_name = model.inputs[0].name.clone();
        // println!("Using input name: {}", input_name);
        let num_outputs = model.outputs.len();
        println!("Model has {} outputs", num_outputs);
        for i in 0..num_outputs {
        	println!("\n------ Output {}: {:#?}", i, model.outputs[i].name.clone());
        	println!("Type: {:#?}", model.outputs[i].output_type.clone());
        }
    }

    /*
     * Input: flattened f32 array
     * Returns unfiltered results (x1, y1, x2, y2, confidence, class probabilities) from the model.
     * Of the ONNX is a filtered model type, the results will be fake (1.0 probability)
    */
    pub fn run(&mut self, image_data: &[f32], min_confidence: f32) -> Result<Vec<YoloResult>> {
        let (width, height) = self.get_model_input_image_dims();
        if image_data.len() != (width * height * 3) as usize {
            return Err(Error::InvalidInput); // incorrect image size
        }

        let model = if let Some(session) = &mut self.session_ref { session }
            else if let Some(session) = &mut self.session_own { session }
            else { return Err(Error::LibraryError); };

        let input_name = model.inputs[0].name.clone();

        let mut input = ArrayView::from_shape([1, 3, height as usize, width as usize], image_data)
            .map_err(|_| Error::InvalidInput)?;

        let mut out = vec![];

        match self.output_format {
            OutputType::SuperGradientsProcessedBatchFormat => {
                let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
                let num_predictions = outputs["graph2_num_predictions"].try_extract_array::<i64>()?.iter().next().copied().unwrap_or(0);
                let boxes_output = outputs["graph2_pred_boxes"].try_extract_array::<f32>()?.t().into_owned();
                let pred_scores = outputs["graph2_pred_scores"].try_extract_array::<f32>()?.t().into_owned();
                let pred_classes = outputs["graph2_pred_classes"].try_extract_array::<i64>()?.t().into_owned();

                for i in 0..(num_predictions.min(boxes_output.shape()[1] as i64) as usize) {
                    let score = *pred_scores.get([i, 0]).unwrap() as f32;
                    // let label = YOLO_CLASS_LABELS[*pred_classes.get([i, 0]).unwrap() as usize];
                    let data = [
                        *boxes_output.get([0, i, 0]).unwrap(),
                        *boxes_output.get([1, i, 0]).unwrap(),
                        *boxes_output.get([2, i, 0]).unwrap(),
                        *boxes_output.get([3, i, 0]).unwrap()
                    ];
                    out.push(YoloResult {
                        bbox: BoundingBox {
                            x1: data[0] as f32,
                            y1: data[1] as f32,
                            x2: data[2] as f32,
                            y2: data[3] as f32
                        },
                        class_id: *pred_classes.get([i, 0]).unwrap() as usize,
                        confidence: score
                    });
                }
            },
            OutputType::UnprocessedSuperGradients => {
                if model.outputs.len() != 2 {
                    return Err(Error::InvalidModel);
                }
                let (name1, name2) = (model.outputs[0].name.clone(), model.outputs[1].name.clone());
                let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
                let bounding_boxes = outputs[name1.as_str()].try_extract_array::<f32>()?.t().into_owned();
                let class_scores = outputs[name2.as_str()].try_extract_array::<f32>()?.t().into_owned();
            },
            OutputType::UltralyticsUnprocessedV8V11 => {
                /* TODO!!!!!!!!!!!!!!!! */
                // todo!("Unprocessed YOLOv8/11 output format not implemented yet");
                let name = model.outputs[0].name.clone();
                let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
                let output = outputs[name.as_str()].try_extract_array::<f32>()?.t().into_owned();

                return Ok(nms::yolo_nms(output.view(), min_confidence, 0.45))
            }
            _ => {
                return Err(Error::UnsupportedModelOutputFormat);
            }
        }
        
        return Ok(out);
    }

    #[cfg(feature = "image")]
    pub fn run_on_image_from_path(&mut self, path: impl AsRef<Path>, min_confidence: f32) -> Result<Vec<YoloResult>> {
        use image::{GenericImageView, imageops::FilterType};
        let original_img = image::open(path).unwrap();
        let (img_width, img_height) = (original_img.width(), original_img.height());
        let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);

        let channel = |i: usize| img.pixels().map(move |(_,_,c)| c[i] as f32 / 255.0);
        let data = channel(0).chain(channel(1)).chain(channel(2)).collect::<Vec<_>>();

        let start = std::time::Instant::now();
        let res = self.run(&data, min_confidence).map(|mut results| {
            // Scale the bounding boxes back to the original image size
            for result in &mut results {
                // println!("Result: {:#?}", result);
                result.bbox.x1 = result.bbox.x1 * (img_width as f32 / 640.0);
                result.bbox.y1 = result.bbox.y1 * (img_height as f32 / 640.0);
                result.bbox.x2 = result.bbox.x2 * (img_width as f32 / 640.0);
                result.bbox.y2 = result.bbox.y2 * (img_height as f32 / 640.0);
            }
            results
        });
        println!("Inference took: {:?}", start.elapsed());
        return res;
    }
}



#[cfg(feature = "weights")]
pub fn pretrained_v12n() -> YoloModel<'static> {
    fn extract_file_from_zip_bytes(zip_bytes: &[u8]) -> Vec<u8> {
        const EXPECT_MESSAGE: &str = "Should work, has been tested on this data";
        let reader = std::io::Cursor::new(zip_bytes);
        let mut archive = zip::ZipArchive::new(reader).expect(EXPECT_MESSAGE);
        let mut file = archive.by_index(0).expect(EXPECT_MESSAGE);
        let mut contents = Vec::new();
        std::io::Read::read_to_end(&mut file, &mut contents).expect(EXPECT_MESSAGE);
        return contents;
    }
    let bytes = extract_file_from_zip_bytes(include_bytes!("../yolov12n.onnx.zip"));
    // let bytes = extract_file_from_zip_bytes(include_bytes!("../yolov12n.onnx.zip"));
    println!("Extracted yolov12n file, {} bytes", bytes.len());
    return YoloModel::new_from_bytes(&bytes, OutputType::UltralyticsUnprocessedV8V11).expect("Should work")
}