use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::inputs;
use ort::value::TensorRef;
#[cfg(feature = "ort-cuda-backend")]
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(feature = "ort-tensorrt-backend")]
use ort::execution_providers::TensorRTExecutionProvider;
use crate::bbox::BBox;
use crate::image_buffer::ImageBuffer;
use crate::postprocess::{Detection, nms, filter_by_class, detections_to_vecs, argmax};
use crate::preprocessing::{preprocess_into, PreprocessMeta};
use super::OrtModelError;
pub struct ModelYOLOv5Ort {
session: Session,
input_width: u32,
input_height: u32,
class_filters: Vec<usize>,
use_letterbox: bool,
tensor_buf: ndarray::Array4<f32>,
}
impl ModelYOLOv5Ort {
pub fn new_from_file(
model_path: &str,
input_size: (u32, u32),
class_filters: Vec<usize>,
) -> Result<Self, OrtModelError> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_file(model_path)?;
Ok(Self {
session,
input_width: input_size.0,
input_height: input_size.1,
class_filters,
#[cfg(feature = "letterbox")]
use_letterbox: true,
#[cfg(not(feature = "letterbox"))]
use_letterbox: false,
tensor_buf: ndarray::Array4::zeros((1, 3, input_size.1 as usize, input_size.0 as usize)),
})
}
#[cfg(feature = "ort-cuda-backend")]
pub fn new_from_file_cuda(
model_path: &str,
input_size: (u32, u32),
class_filters: Vec<usize>,
) -> Result<Self, OrtModelError> {
let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_file(model_path)?;
Ok(Self {
session,
input_width: input_size.0,
input_height: input_size.1,
class_filters,
#[cfg(feature = "letterbox")]
use_letterbox: true,
#[cfg(not(feature = "letterbox"))]
use_letterbox: false,
tensor_buf: ndarray::Array4::zeros((1, 3, input_size.1 as usize, input_size.0 as usize)),
})
}
#[cfg(feature = "ort-tensorrt-backend")]
pub fn new_from_file_tensorrt(
model_path: &str,
input_size: (u32, u32),
class_filters: Vec<usize>,
) -> Result<Self, OrtModelError> {
let session = Session::builder()?
.with_execution_providers([TensorRTExecutionProvider::default().build()])?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_file(model_path)?;
Ok(Self {
session,
input_width: input_size.0,
input_height: input_size.1,
class_filters,
#[cfg(feature = "letterbox")]
use_letterbox: true,
#[cfg(not(feature = "letterbox"))]
use_letterbox: false,
tensor_buf: ndarray::Array4::zeros((1, 3, input_size.1 as usize, input_size.0 as usize)),
})
}
pub fn from_session(
session: Session,
input_size: (u32, u32),
class_filters: Vec<usize>,
) -> Self {
Self {
session,
input_width: input_size.0,
input_height: input_size.1,
class_filters,
#[cfg(feature = "letterbox")]
use_letterbox: true,
#[cfg(not(feature = "letterbox"))]
use_letterbox: false,
tensor_buf: ndarray::Array4::zeros((1, 3, input_size.1 as usize, input_size.0 as usize)),
}
}
pub fn set_letterbox(&mut self, enabled: bool) {
self.use_letterbox = enabled;
}
pub fn input_size(&self) -> (u32, u32) {
(self.input_width, self.input_height)
}
pub fn forward(
&mut self,
image: &ImageBuffer,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), OrtModelError> {
let meta = preprocess_into(
image,
&mut self.tensor_buf,
self.use_letterbox,
);
let outputs = self.session.run(
inputs!["images" => TensorRef::from_array_view(&self.tensor_buf)?]
)?;
let output = outputs["output0"]
.try_extract_array::<f32>()?
.into_owned();
let class_filters = self.class_filters.clone();
let detections = Self::parse_output_array_static(
&output.view(),
conf_threshold,
&meta,
self.input_width,
self.input_height,
)?;
let filtered = filter_by_class(&detections, &class_filters);
let final_detections = nms(&filtered, nms_threshold);
Ok(detections_to_vecs(final_detections))
}
fn parse_output_array_static(
output: &ndarray::ArrayViewD<f32>,
conf_threshold: f32,
meta: &PreprocessMeta,
input_width: u32,
input_height: u32,
) -> Result<Vec<Detection>, OrtModelError> {
let shape = output.shape();
if shape.len() != 3 || shape[0] != 1 {
return Err(OrtModelError::InvalidOutputShape(format!(
"Expected shape [1, N, C], got {:?}",
shape
)));
}
let num_predictions = shape[1];
let num_features = shape[2];
if num_features < 6 {
return Err(OrtModelError::InvalidOutputShape(format!(
"Expected at least 6 features (4 bbox + 1 obj + 1 class), got {}",
num_features
)));
}
let mut detections = Vec::new();
for i in 0..num_predictions {
let objectness = output[[0, i, 4]];
if objectness < conf_threshold {
continue;
}
let class_scores: Vec<f32> = (5..num_features)
.map(|j| output[[0, i, j]])
.collect();
if let Some((class_idx, max_class_score)) = argmax(&class_scores) {
let confidence = objectness * max_class_score;
if confidence >= conf_threshold {
let mut cx = output[[0, i, 0]];
let mut cy = output[[0, i, 1]];
let mut w = output[[0, i, 2]];
let mut h = output[[0, i, 3]];
if cx < 2.0 && cy < 2.0 && w < 2.0 && h < 2.0 {
cx *= input_width as f32;
cy *= input_height as f32;
w *= input_width as f32;
h *= input_height as f32;
}
let (x_orig, y_orig, w_orig, h_orig) = meta.inverse_transform(cx, cy, w, h);
let bbox = BBox::from_center(x_orig, y_orig, w_orig, h_orig);
detections.push(Detection::new(bbox, class_idx, confidence));
}
}
}
Ok(detections)
}
}
impl crate::ObjectDetector for ModelYOLOv5Ort {
type Input = ImageBuffer;
type Error = OrtModelError;
fn detect(
&mut self,
input: &Self::Input,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), Self::Error> {
self.forward(input, conf_threshold, nms_threshold)
}
}
#[cfg(feature = "ort-opencv-compat")]
mod opencv_compat_impl {
use super::*;
use opencv::core::{Mat, Rect};
use opencv::Error as OpenCvError;
impl ModelYOLOv5Ort {
pub fn forward_mat(
&mut self,
image: &Mat,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<Rect>, Vec<usize>, Vec<f32>), OpenCvError> {
let (tensor, meta) = crate::opencv_compat::preprocess_mat(
image,
self.input_width,
self.input_height,
self.use_letterbox,
)?;
let outputs = self.session.run(
inputs!["images" => TensorRef::from_array_view(&tensor).map_err(|e| {
OpenCvError::new(opencv::core::StsError, format!("ORT error: {}", e))
})?]
).map_err(|e| {
OpenCvError::new(opencv::core::StsError, format!("ORT inference error: {}", e))
})?;
let output = outputs["output0"]
.try_extract_array::<f32>()
.map_err(|e| {
OpenCvError::new(opencv::core::StsError, format!("Output extraction error: {}", e))
})?
.into_owned();
let detections = Self::parse_output_array_static(
&output.view(),
conf_threshold,
&meta,
self.input_width,
self.input_height,
).map_err(|e| {
OpenCvError::new(opencv::core::StsError, format!("Parse error: {}", e))
})?;
let class_filters = self.class_filters.clone();
let filtered = filter_by_class(&detections, &class_filters);
let final_detections = nms(&filtered, nms_threshold);
let (bboxes, class_ids, confidences) = detections_to_vecs(final_detections);
let rects: Vec<Rect> = bboxes
.into_iter()
.map(|bbox| Rect::new(bbox.x, bbox.y, bbox.width, bbox.height))
.collect();
Ok((rects, class_ids, confidences))
}
}
impl crate::opencv_compat::ModelTrait for ModelYOLOv5Ort {
fn forward(
&mut self,
image: &Mat,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<Rect>, Vec<usize>, Vec<f32>), OpenCvError> {
self.forward_mat(image, conf_threshold, nms_threshold)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_creation_error() {
let result = ModelYOLOv5Ort::new_from_file(
"nonexistent.onnx",
(640, 640),
vec![],
);
assert!(result.is_err());
}
}