#![allow(clippy::manual_retain)]
mod nms;
use std::path::Path;
use ndarray::{Array, ArrayView, Axis};
use ort::{
execution_providers::{CoreMLExecutionProvider, coreml::CoreMLComputeUnits},
inputs,
session::{InMemorySession, Session, SessionOutputs},
value::TensorRef
};
use std::sync::OnceLock;
#[derive(Debug,Clone,Copy)]
pub enum Error {
OrtError,
IoError,
InvalidInput,
InvalidModel,
UnsupportedModelOutputFormat,
LibraryError,
}
impl From<ort::Error> for Error {
fn from(_: ort::Error) -> Self {
Error::OrtError
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub const YOLO_CLASS_LABELS: [&str; 80] = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
"tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
];
pub enum OutputType {
UltralyticsUnprocessedV8V11,
SuperGradientsProcessedBatchFormat,
SuperGradientsProcessedFlatFormat,
UnprocessedSuperGradients, }
#[derive(Debug, Clone, Copy)]
pub struct BoundingBox {
pub x1: f32,
pub y1: f32,
pub x2: f32,
pub y2: f32
}
#[derive(Debug, Clone)]
pub struct YoloResult {
pub bbox: BoundingBox,
pub class_id: usize,
pub confidence: f32,
}
static INIT_ORT_ENVIRONMENT: OnceLock<()> = OnceLock::new();
pub fn init_ort_env() {
INIT_ORT_ENVIRONMENT.get_or_init(|| {
ort::init()
.with_execution_providers([
CoreMLExecutionProvider::default()
.with_compute_units(CoreMLComputeUnits::All)
.build()
])
.commit()
.expect("Failed to initialize ONNX Runtime environment");
});
}
pub struct YoloModel<'a> {
session_ref: Option<InMemorySession<'a>>,
session_own: Option<Session>,
output_format: OutputType,
}
impl YoloModel<'static> {
pub fn new_from_bytes(bytes: &[u8], output_format: OutputType) -> Result<Self> {
Ok(Self {
session_own: Some(Session::builder()?.commit_from_memory(bytes)?),
session_ref: None,
output_format, })
}
}
impl<'a> YoloModel<'a> {
pub fn new_from_bytes_borrowed(bytes: &'a [u8], output_format: OutputType) -> Result<Self> {
Ok(Self {
session_ref: Some(Session::builder()?.commit_from_memory_directly(bytes)?),
session_own: None,
output_format, })
}
pub fn get_model_input_image_dims(&self) -> (u32, u32) {
(640, 640) }
pub fn print_input_and_output_info(&mut self) {
let model = if let Some(session) = &mut self.session_ref { session }
else if let Some(session) = &mut self.session_own { session }
else { unreachable!("Shouldn't happen") };
let input_name = model.inputs[0].name.clone();
let num_outputs = model.outputs.len();
println!("Model has {} outputs", num_outputs);
for i in 0..num_outputs {
println!("\n------ Output {}: {:#?}", i, model.outputs[i].name.clone());
println!("Type: {:#?}", model.outputs[i].output_type.clone());
}
}
pub fn run(&mut self, image_data: &[f32], min_confidence: f32) -> Result<Vec<YoloResult>> {
let (width, height) = self.get_model_input_image_dims();
if image_data.len() != (width * height * 3) as usize {
return Err(Error::InvalidInput); }
let model = if let Some(session) = &mut self.session_ref { session }
else if let Some(session) = &mut self.session_own { session }
else { return Err(Error::LibraryError); };
let input_name = model.inputs[0].name.clone();
let mut input = ArrayView::from_shape([1, 3, height as usize, width as usize], image_data)
.map_err(|_| Error::InvalidInput)?;
let mut out = vec![];
match self.output_format {
OutputType::SuperGradientsProcessedBatchFormat => {
let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
let num_predictions = outputs["graph2_num_predictions"].try_extract_array::<i64>()?.iter().next().copied().unwrap_or(0);
let boxes_output = outputs["graph2_pred_boxes"].try_extract_array::<f32>()?.t().into_owned();
let pred_scores = outputs["graph2_pred_scores"].try_extract_array::<f32>()?.t().into_owned();
let pred_classes = outputs["graph2_pred_classes"].try_extract_array::<i64>()?.t().into_owned();
for i in 0..(num_predictions.min(boxes_output.shape()[1] as i64) as usize) {
let score = *pred_scores.get([i, 0]).unwrap() as f32;
let data = [
*boxes_output.get([0, i, 0]).unwrap(),
*boxes_output.get([1, i, 0]).unwrap(),
*boxes_output.get([2, i, 0]).unwrap(),
*boxes_output.get([3, i, 0]).unwrap()
];
out.push(YoloResult {
bbox: BoundingBox {
x1: data[0] as f32,
y1: data[1] as f32,
x2: data[2] as f32,
y2: data[3] as f32
},
class_id: *pred_classes.get([i, 0]).unwrap() as usize,
confidence: score
});
}
},
OutputType::UnprocessedSuperGradients => {
if model.outputs.len() != 2 {
return Err(Error::InvalidModel);
}
let (name1, name2) = (model.outputs[0].name.clone(), model.outputs[1].name.clone());
let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
let bounding_boxes = outputs[name1.as_str()].try_extract_array::<f32>()?.t().into_owned();
let class_scores = outputs[name2.as_str()].try_extract_array::<f32>()?.t().into_owned();
},
OutputType::UltralyticsUnprocessedV8V11 => {
let name = model.outputs[0].name.clone();
let outputs: SessionOutputs = model.run(inputs![&input_name => TensorRef::from_array_view(input)?])?;
let output = outputs[name.as_str()].try_extract_array::<f32>()?.t().into_owned();
return Ok(nms::yolo_nms(output.view(), min_confidence, 0.45))
}
_ => {
return Err(Error::UnsupportedModelOutputFormat);
}
}
return Ok(out);
}
#[cfg(feature = "image")]
pub fn run_on_image_from_path(&mut self, path: impl AsRef<Path>, min_confidence: f32) -> Result<Vec<YoloResult>> {
use image::{GenericImageView, imageops::FilterType};
let original_img = image::open(path).unwrap();
let (img_width, img_height) = (original_img.width(), original_img.height());
let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
let channel = |i: usize| img.pixels().map(move |(_,_,c)| c[i] as f32 / 255.0);
let data = channel(0).chain(channel(1)).chain(channel(2)).collect::<Vec<_>>();
let start = std::time::Instant::now();
let res = self.run(&data, min_confidence).map(|mut results| {
for result in &mut results {
result.bbox.x1 = result.bbox.x1 * (img_width as f32 / 640.0);
result.bbox.y1 = result.bbox.y1 * (img_height as f32 / 640.0);
result.bbox.x2 = result.bbox.x2 * (img_width as f32 / 640.0);
result.bbox.y2 = result.bbox.y2 * (img_height as f32 / 640.0);
}
results
});
println!("Inference took: {:?}", start.elapsed());
return res;
}
}
#[cfg(feature = "weights")]
pub fn pretrained_v12n() -> YoloModel<'static> {
fn extract_file_from_zip_bytes(zip_bytes: &[u8]) -> Vec<u8> {
const EXPECT_MESSAGE: &str = "Should work, has been tested on this data";
let reader = std::io::Cursor::new(zip_bytes);
let mut archive = zip::ZipArchive::new(reader).expect(EXPECT_MESSAGE);
let mut file = archive.by_index(0).expect(EXPECT_MESSAGE);
let mut contents = Vec::new();
std::io::Read::read_to_end(&mut file, &mut contents).expect(EXPECT_MESSAGE);
return contents;
}
let bytes = extract_file_from_zip_bytes(include_bytes!("../yolov12n.onnx.zip"));
println!("Extracted yolov12n file, {} bytes", bytes.len());
return YoloModel::new_from_bytes(&bytes, OutputType::UltralyticsUnprocessedV8V11).expect("Should work")
}