1#![allow(clippy::manual_retain)]
9
10mod nms;
11
12use std::path::Path;
13use ndarray::{Array, ArrayView, Axis};
14use ort::{
15 execution_providers::{CoreMLExecutionProvider, coreml::CoreMLComputeUnits},
16 inputs,
17 session::{InMemorySession, Session, SessionOutputs},
18 value::TensorRef
19};
20use std::sync::OnceLock;
21
22#[derive(Debug,Clone,Copy)]
25pub enum Error {
26 OrtError,
27 IoError,
28 InvalidInput,
29 InvalidModel,
30 UnsupportedModelOutputFormat,
31 LibraryError,
32}
33
34impl From<ort::Error> for Error {
35 fn from(_: ort::Error) -> Self {
36 Error::OrtError
37 }
38}
39
40pub type Result<T> = std::result::Result<T, Error>;
41
42pub const YOLO_CLASS_LABELS: [&str; 80] = [
43 "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
44 "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
45 "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
46 "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
47 "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
48 "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
49 "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
50 "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
51];
52
53pub enum OutputType {
54 UltralyticsUnprocessedV8V11,
55 SuperGradientsProcessedBatchFormat,
56 SuperGradientsProcessedFlatFormat,
57 UnprocessedSuperGradients, }
60
61#[derive(Debug, Clone, Copy)]
62pub struct BoundingBox {
63 pub x1: f32,
64 pub y1: f32,
65 pub x2: f32,
66 pub y2: f32
67}
68
69#[derive(Debug, Clone)]
70pub struct YoloResult {
71 pub bbox: BoundingBox,
72 pub class_id: usize,
73 pub confidence: f32,
74}
75
76static INIT_ORT_ENVIRONMENT: OnceLock<()> = OnceLock::new();
77
78pub fn init_ort_env() {
80 INIT_ORT_ENVIRONMENT.get_or_init(|| {
81 ort::init()
83 .with_execution_providers([
84 CoreMLExecutionProvider::default()
85 .with_compute_units(CoreMLComputeUnits::All)
86 .build()
87 ])
89 .commit()
90 .expect("Failed to initialize ONNX Runtime environment");
91 });
92}
93
94pub struct YoloModel<'a> {
95 session_ref: Option<InMemorySession<'a>>,
96 session_own: Option<Session>,
97 output_format: OutputType,
98}
99
100impl YoloModel<'static> {
101 pub fn new_from_bytes(bytes: &[u8], output_format: OutputType) -> Result<Self> {
102 Ok(Self {
103 session_own: Some(Session::builder()?.commit_from_memory(bytes)?),
104 session_ref: None,
105 output_format, })
107 }
108}
109
110impl<'a> YoloModel<'a> {
111 pub fn new_from_bytes_borrowed(bytes: &'a [u8], output_format: OutputType) -> Result<Self> {
112 Ok(Self {
113 session_ref: Some(Session::builder()?.commit_from_memory_directly(bytes)?),
114 session_own: None,
115 output_format, })
117 }
118
119 pub fn get_model_input_image_dims(&self) -> (u32, u32) {
120 (640, 640) }
122
123 pub fn print_input_and_output_info(&mut self) {
124 let model = if let Some(session) = &mut self.session_ref { session }
125 else if let Some(session) = &mut self.session_own { session }
126 else { unreachable!("Shouldn't happen") };
127
128 let input_name = model.inputs[0].name.clone();
130 let num_outputs = model.outputs.len();
132 println!("Model has {} outputs", num_outputs);
133 for i in 0..num_outputs {
134 println!("\n------ Output {}: {:#?}", i, model.outputs[i].name.clone());
135 println!("Type: {:#?}", model.outputs[i].output_type.clone());
136 }
137 }
138
139 pub fn run(&mut self, image_data: &[f32], min_confidence: f32) -> Result<Vec<YoloResult>> {
145 let (width, height) = self.get_model_input_image_dims();
146 if image_data.len() != (width * height * 3) as usize {
147 return Err(Error::InvalidInput); }
149
150 let model = if let Some(session) = &mut self.session_ref { session }
151 else if let Some(session) = &mut self.session_own { session }
152 else { return Err(Error::LibraryError); };
153
154 let input_name = model.inputs[0].name.clone();
155
156 let mut input = ArrayView::from_shape([1, 3, height as usize, width as usize], image_data)
157 .map_err(|_| Error::InvalidInput)?;
158
159 let mut out = vec![];
160
161 match self.output_format {
162 OutputType::SuperGradientsProcessedBatchFormat => {
163 let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
164 let num_predictions = outputs["graph2_num_predictions"].try_extract_array::<i64>()?.iter().next().copied().unwrap_or(0);
165 let boxes_output = outputs["graph2_pred_boxes"].try_extract_array::<f32>()?.t().into_owned();
166 let pred_scores = outputs["graph2_pred_scores"].try_extract_array::<f32>()?.t().into_owned();
167 let pred_classes = outputs["graph2_pred_classes"].try_extract_array::<i64>()?.t().into_owned();
168
169 for i in 0..(num_predictions.min(boxes_output.shape()[1] as i64) as usize) {
170 let score = *pred_scores.get([i, 0]).unwrap() as f32;
171 let data = [
173 *boxes_output.get([0, i, 0]).unwrap(),
174 *boxes_output.get([1, i, 0]).unwrap(),
175 *boxes_output.get([2, i, 0]).unwrap(),
176 *boxes_output.get([3, i, 0]).unwrap()
177 ];
178 out.push(YoloResult {
179 bbox: BoundingBox {
180 x1: data[0] as f32,
181 y1: data[1] as f32,
182 x2: data[2] as f32,
183 y2: data[3] as f32
184 },
185 class_id: *pred_classes.get([i, 0]).unwrap() as usize,
186 confidence: score
187 });
188 }
189 },
190 OutputType::UnprocessedSuperGradients => {
191 if model.outputs.len() != 2 {
192 return Err(Error::InvalidModel);
193 }
194 let (name1, name2) = (model.outputs[0].name.clone(), model.outputs[1].name.clone());
195 let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
196 let bounding_boxes = outputs[name1.as_str()].try_extract_array::<f32>()?.t().into_owned();
197 let class_scores = outputs[name2.as_str()].try_extract_array::<f32>()?.t().into_owned();
198 },
199 OutputType::UltralyticsUnprocessedV8V11 => {
200 let name = model.outputs[0].name.clone();
203 let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
204 let output = outputs[name.as_str()].try_extract_array::<f32>()?.t().into_owned();
205
206 return Ok(nms::yolo_nms(output.view(), min_confidence, 0.45))
207 }
208 _ => {
209 return Err(Error::UnsupportedModelOutputFormat);
210 }
211 }
212
213 return Ok(out);
214 }
215
216 #[cfg(feature = "image")]
217 pub fn run_on_image_from_path(&mut self, path: impl AsRef<Path>, min_confidence: f32) -> Result<Vec<YoloResult>> {
218 use image::{GenericImageView, imageops::FilterType};
219 let original_img = image::open(path).unwrap();
220 let (img_width, img_height) = (original_img.width(), original_img.height());
221 let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
222
223 let channel = |i: usize| img.pixels().map(move |(_,_,c)| c[i] as f32 / 255.0);
224 let data = channel(0).chain(channel(1)).chain(channel(2)).collect::<Vec<_>>();
225
226 let start = std::time::Instant::now();
227 let res = self.run(&data, min_confidence).map(|mut results| {
228 for result in &mut results {
230 result.bbox.x1 = result.bbox.x1 * (img_width as f32 / 640.0);
232 result.bbox.y1 = result.bbox.y1 * (img_height as f32 / 640.0);
233 result.bbox.x2 = result.bbox.x2 * (img_width as f32 / 640.0);
234 result.bbox.y2 = result.bbox.y2 * (img_height as f32 / 640.0);
235 }
236 results
237 });
238 println!("Inference took: {:?}", start.elapsed());
239 return res;
240 }
241}
242
243
244
245#[cfg(feature = "weights")]
246pub fn pretrained_v12n() -> YoloModel<'static> {
247 fn extract_file_from_zip_bytes(zip_bytes: &[u8]) -> Vec<u8> {
248 const EXPECT_MESSAGE: &str = "Should work, has been tested on this data";
249 let reader = std::io::Cursor::new(zip_bytes);
250 let mut archive = zip::ZipArchive::new(reader).expect(EXPECT_MESSAGE);
251 let mut file = archive.by_index(0).expect(EXPECT_MESSAGE);
252 let mut contents = Vec::new();
253 std::io::Read::read_to_end(&mut file, &mut contents).expect(EXPECT_MESSAGE);
254 return contents;
255 }
256 let bytes = extract_file_from_zip_bytes(include_bytes!("../yolov12n.onnx.zip"));
257 println!("Extracted yolov12n file, {} bytes", bytes.len());
259 return YoloModel::new_from_bytes(&bytes, OutputType::UltralyticsUnprocessedV8V11).expect("Should work")
260}