Skip to main content

easy_yolo/
lib.rs

1/* 
2https://linzichun.com/posts/rust-opencv-onnx-yolov8-detect/
3
4TODOS:
5- batch processing multiple images at once (maybe, not sure if I want it)
6*/
7
8#![allow(clippy::manual_retain)]
9
10mod nms;
11
12use std::path::Path;
13use ndarray::{Array, ArrayView, Axis};
14use ort::{
15    execution_providers::{CoreMLExecutionProvider, coreml::CoreMLComputeUnits},
16    inputs,
17    session::{InMemorySession, Session, SessionOutputs},
18    value::TensorRef
19};
20use std::sync::OnceLock;
21
22/************************** ERROR **************************/
23
24#[derive(Debug,Clone,Copy)]
25pub enum Error {
26    OrtError,
27    IoError,
28    InvalidInput,
29    InvalidModel,
30    UnsupportedModelOutputFormat,
31    LibraryError,
32}
33
34impl From<ort::Error> for Error {
35    fn from(_: ort::Error) -> Self {
36        Error::OrtError
37    }
38}
39
40pub type Result<T> = std::result::Result<T, Error>;
41
42pub const YOLO_CLASS_LABELS: [&str; 80] = [
43    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
44    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
45    "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
46    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
47    "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
48    "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
49    "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
50    "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
51];
52
53pub enum OutputType {
54    UltralyticsUnprocessedV8V11,
55    SuperGradientsProcessedBatchFormat,
56    SuperGradientsProcessedFlatFormat,
57    UnprocessedSuperGradients, // Unprocessed supergradients output for YoloNAS or anything really
58    // YoloV5
59}
60
61#[derive(Debug, Clone, Copy)]
62pub struct BoundingBox {
63    pub x1: f32,
64    pub y1: f32,
65    pub x2: f32,
66    pub y2: f32
67}
68
69#[derive(Debug, Clone)]
70pub struct YoloResult {
71    pub bbox: BoundingBox,
72    pub class_id: usize,
73    pub confidence: f32,
74}
75
76static INIT_ORT_ENVIRONMENT: OnceLock<()> = OnceLock::new();
77
78/* Does setup for GPU stuff for ORT. TODO: Use CUDA */
79pub fn init_ort_env() {
80    INIT_ORT_ENVIRONMENT.get_or_init(|| {
81        // TODO: maybe record if coreML EP fails to build and inform about cpu fallback
82        ort::init()
83            .with_execution_providers([
84                CoreMLExecutionProvider::default()
85                    .with_compute_units(CoreMLComputeUnits::All)
86                    .build()
87                    // .error_on_failure(), // exit the program with an error if the Execution Provider fails to register, better to fail silently and fallback to CPU
88            ])
89            .commit()
90            .expect("Failed to initialize ONNX Runtime environment");
91    });
92}
93
94pub struct YoloModel<'a> {
95    session_ref: Option<InMemorySession<'a>>,
96    session_own: Option<Session>,
97    output_format: OutputType,
98}
99
100impl YoloModel<'static> {
101    pub fn new_from_bytes(bytes: &[u8], output_format: OutputType) -> Result<Self> {
102        Ok(Self {
103            session_own: Some(Session::builder()?.commit_from_memory(bytes)?),
104            session_ref: None,
105            output_format, // TODO: autodetect yolo-nas vs v8/11
106        })
107    }
108}
109
110impl<'a> YoloModel<'a> {
111    pub fn new_from_bytes_borrowed(bytes: &'a [u8], output_format: OutputType) -> Result<Self> {
112        Ok(Self {
113            session_ref: Some(Session::builder()?.commit_from_memory_directly(bytes)?),
114            session_own: None,
115            output_format, // TODO: autodetect yolo-nas vs v8/11
116        })
117    }
118
119    pub fn get_model_input_image_dims(&self) -> (u32, u32) {
120        (640, 640) // TODO: detect this from model
121    }
122
123    pub fn print_input_and_output_info(&mut self) {
124        let model = if let Some(session) = &mut self.session_ref { session }
125            else if let Some(session) = &mut self.session_own { session }
126            else { unreachable!("Shouldn't happen") };
127
128        // get model input output names
129        let input_name = model.inputs[0].name.clone();
130        // println!("Using input name: {}", input_name);
131        let num_outputs = model.outputs.len();
132        println!("Model has {} outputs", num_outputs);
133        for i in 0..num_outputs {
134        	println!("\n------ Output {}: {:#?}", i, model.outputs[i].name.clone());
135        	println!("Type: {:#?}", model.outputs[i].output_type.clone());
136        }
137    }
138
139    /*
140     * Input: flattened f32 array
141     * Returns unfiltered results (x1, y1, x2, y2, confidence, class probabilities) from the model.
142     * Of the ONNX is a filtered model type, the results will be fake (1.0 probability)
143    */
144    pub fn run(&mut self, image_data: &[f32], min_confidence: f32) -> Result<Vec<YoloResult>> {
145        let (width, height) = self.get_model_input_image_dims();
146        if image_data.len() != (width * height * 3) as usize {
147            return Err(Error::InvalidInput); // incorrect image size
148        }
149
150        let model = if let Some(session) = &mut self.session_ref { session }
151            else if let Some(session) = &mut self.session_own { session }
152            else { return Err(Error::LibraryError); };
153
154        let input_name = model.inputs[0].name.clone();
155
156        let mut input = ArrayView::from_shape([1, 3, height as usize, width as usize], image_data)
157            .map_err(|_| Error::InvalidInput)?;
158
159        let mut out = vec![];
160
161        match self.output_format {
162            OutputType::SuperGradientsProcessedBatchFormat => {
163                let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
164                let num_predictions = outputs["graph2_num_predictions"].try_extract_array::<i64>()?.iter().next().copied().unwrap_or(0);
165                let boxes_output = outputs["graph2_pred_boxes"].try_extract_array::<f32>()?.t().into_owned();
166                let pred_scores = outputs["graph2_pred_scores"].try_extract_array::<f32>()?.t().into_owned();
167                let pred_classes = outputs["graph2_pred_classes"].try_extract_array::<i64>()?.t().into_owned();
168
169                for i in 0..(num_predictions.min(boxes_output.shape()[1] as i64) as usize) {
170                    let score = *pred_scores.get([i, 0]).unwrap() as f32;
171                    // let label = YOLO_CLASS_LABELS[*pred_classes.get([i, 0]).unwrap() as usize];
172                    let data = [
173                        *boxes_output.get([0, i, 0]).unwrap(),
174                        *boxes_output.get([1, i, 0]).unwrap(),
175                        *boxes_output.get([2, i, 0]).unwrap(),
176                        *boxes_output.get([3, i, 0]).unwrap()
177                    ];
178                    out.push(YoloResult {
179                        bbox: BoundingBox {
180                            x1: data[0] as f32,
181                            y1: data[1] as f32,
182                            x2: data[2] as f32,
183                            y2: data[3] as f32
184                        },
185                        class_id: *pred_classes.get([i, 0]).unwrap() as usize,
186                        confidence: score
187                    });
188                }
189            },
190            OutputType::UnprocessedSuperGradients => {
191                if model.outputs.len() != 2 {
192                    return Err(Error::InvalidModel);
193                }
194                let (name1, name2) = (model.outputs[0].name.clone(), model.outputs[1].name.clone());
195                let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
196                let bounding_boxes = outputs[name1.as_str()].try_extract_array::<f32>()?.t().into_owned();
197                let class_scores = outputs[name2.as_str()].try_extract_array::<f32>()?.t().into_owned();
198            },
199            OutputType::UltralyticsUnprocessedV8V11 => {
200                /* TODO!!!!!!!!!!!!!!!! */
201                // todo!("Unprocessed YOLOv8/11 output format not implemented yet");
202                let name = model.outputs[0].name.clone();
203                let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
204                let output = outputs[name.as_str()].try_extract_array::<f32>()?.t().into_owned();
205
206                return Ok(nms::yolo_nms(output.view(), min_confidence, 0.45))
207            }
208            _ => {
209                return Err(Error::UnsupportedModelOutputFormat);
210            }
211        }
212        
213        return Ok(out);
214    }
215
216    #[cfg(feature = "image")]
217    pub fn run_on_image_from_path(&mut self, path: impl AsRef<Path>, min_confidence: f32) -> Result<Vec<YoloResult>> {
218        use image::{GenericImageView, imageops::FilterType};
219        let original_img = image::open(path).unwrap();
220        let (img_width, img_height) = (original_img.width(), original_img.height());
221        let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
222
223        let channel = |i: usize| img.pixels().map(move |(_,_,c)| c[i] as f32 / 255.0);
224        let data = channel(0).chain(channel(1)).chain(channel(2)).collect::<Vec<_>>();
225
226        let start = std::time::Instant::now();
227        let res = self.run(&data, min_confidence).map(|mut results| {
228            // Scale the bounding boxes back to the original image size
229            for result in &mut results {
230                // println!("Result: {:#?}", result);
231                result.bbox.x1 = result.bbox.x1 * (img_width as f32 / 640.0);
232                result.bbox.y1 = result.bbox.y1 * (img_height as f32 / 640.0);
233                result.bbox.x2 = result.bbox.x2 * (img_width as f32 / 640.0);
234                result.bbox.y2 = result.bbox.y2 * (img_height as f32 / 640.0);
235            }
236            results
237        });
238        println!("Inference took: {:?}", start.elapsed());
239        return res;
240    }
241}
242
243
244
245#[cfg(feature = "weights")]
246pub fn pretrained_v12n() -> YoloModel<'static> {
247    fn extract_file_from_zip_bytes(zip_bytes: &[u8]) -> Vec<u8> {
248        const EXPECT_MESSAGE: &str = "Should work, has been tested on this data";
249        let reader = std::io::Cursor::new(zip_bytes);
250        let mut archive = zip::ZipArchive::new(reader).expect(EXPECT_MESSAGE);
251        let mut file = archive.by_index(0).expect(EXPECT_MESSAGE);
252        let mut contents = Vec::new();
253        std::io::Read::read_to_end(&mut file, &mut contents).expect(EXPECT_MESSAGE);
254        return contents;
255    }
256    let bytes = extract_file_from_zip_bytes(include_bytes!("../yolov12n.onnx.zip"));
257    // let bytes = extract_file_from_zip_bytes(include_bytes!("../yolov12n.onnx.zip"));
258    println!("Extracted yolov12n file, {} bytes", bytes.len());
259    return YoloModel::new_from_bytes(&bytes, OutputType::UltralyticsUnprocessedV8V11).expect("Should work")
260}