use std::os::raw::c_void;
use crate::bbox::BBox;
use crate::image_buffer::ImageBuffer;
use crate::postprocess::{Detection, nms, filter_by_class, detections_to_vecs, argmax};
use crate::preprocessing::{preprocess_into, PreprocessMeta};
use tensorrt_infer::{TrtEngine, TrtContext, CudaBuffer, CudaStream, BindingInfo, TrtError};
#[derive(Debug)]
pub enum TrtModelError {
Trt(String),
InvalidOutputShape(String),
Io(std::io::Error),
}
impl std::fmt::Display for TrtModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrtModelError::Trt(e) => write!(f, "TensorRT error: {}", e),
TrtModelError::InvalidOutputShape(s) => write!(f, "Invalid output shape: {}", s),
TrtModelError::Io(e) => write!(f, "IO error: {}", e),
}
}
}
impl std::error::Error for TrtModelError {}
impl From<std::io::Error> for TrtModelError {
fn from(e: std::io::Error) -> Self {
TrtModelError::Io(e)
}
}
impl From<TrtError> for TrtModelError {
fn from(e: TrtError) -> Self {
TrtModelError::Trt(e.to_string())
}
}
pub struct ModelUltralyticsRt {
gpu_buffers: Vec<CudaBuffer>,
context: TrtContext,
stream: CudaStream,
engine: TrtEngine,
input_width: u32,
input_height: u32,
class_filters: Vec<usize>,
use_letterbox: bool,
tensor_buf: ndarray::Array4<f32>,
output_host_buf: Vec<f32>,
bindings: Vec<BindingInfo>,
input_binding_idx: usize,
output_binding_idx: usize,
output_shape: Vec<i32>,
}
impl ModelUltralyticsRt {
pub fn new_from_file(
engine_path: &str,
class_filters: Vec<usize>,
) -> Result<Self, TrtModelError> {
let engine = TrtEngine::from_file(engine_path)?;
let bindings = engine.bindings();
let input_binding_idx = bindings.iter()
.position(|b| b.is_input)
.ok_or_else(|| TrtModelError::Trt("No input binding found".into()))?;
let output_binding_idx = bindings.iter()
.position(|b| !b.is_input)
.ok_or_else(|| TrtModelError::Trt("No output binding found".into()))?;
let input_dims = &bindings[input_binding_idx].dims;
if input_dims.len() != 4 || input_dims[0] != 1 || input_dims[1] != 3 {
return Err(TrtModelError::InvalidOutputShape(
format!("Expected input [1, 3, H, W], got {:?}", input_dims),
));
}
let input_width = input_dims[3] as u32;
let input_height = input_dims[2] as u32;
let output_shape = bindings[output_binding_idx].dims.clone();
let mut gpu_buffers = Vec::with_capacity(bindings.len());
for binding in &bindings {
gpu_buffers.push(CudaBuffer::new(binding.byte_size)?);
}
let output_num_floats = bindings[output_binding_idx].byte_size / 4;
let output_host_buf = vec![0.0f32; output_num_floats];
let context = engine.create_context()?;
let stream = CudaStream::new()?;
Ok(Self {
engine,
context,
stream,
input_width,
input_height,
class_filters,
#[cfg(feature = "letterbox")]
use_letterbox: true,
#[cfg(not(feature = "letterbox"))]
use_letterbox: false,
tensor_buf: ndarray::Array4::zeros((
1, 3, input_height as usize, input_width as usize,
)),
output_host_buf,
gpu_buffers,
bindings,
input_binding_idx,
output_binding_idx,
output_shape,
})
}
pub fn set_letterbox(&mut self, enabled: bool) {
self.use_letterbox = enabled;
}
pub fn input_size(&self) -> (u32, u32) {
(self.input_width, self.input_height)
}
pub fn forward(
&mut self,
image: &ImageBuffer,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), TrtModelError> {
let meta = preprocess_into(image, &mut self.tensor_buf, self.use_letterbox);
let input_bytes = unsafe {
std::slice::from_raw_parts(
self.tensor_buf.as_ptr() as *const u8,
self.bindings[self.input_binding_idx].byte_size,
)
};
self.gpu_buffers[self.input_binding_idx].copy_from_host(input_bytes, &self.stream)?;
let mut binding_ptrs: Vec<*mut c_void> = self.gpu_buffers
.iter()
.map(|buf| buf.as_ptr())
.collect();
self.context.enqueue(&mut binding_ptrs, &self.stream)?;
let output_bytes = unsafe {
std::slice::from_raw_parts_mut(
self.output_host_buf.as_mut_ptr() as *mut u8,
self.bindings[self.output_binding_idx].byte_size,
)
};
self.gpu_buffers[self.output_binding_idx].copy_to_host(output_bytes, &self.stream)?;
self.stream.synchronize()?;
let output_shape_usize: Vec<usize> = self.output_shape.iter()
.map(|&d| d as usize)
.collect();
let output_view = ndarray::ArrayViewD::from_shape(
output_shape_usize.as_slice(),
&self.output_host_buf,
).map_err(|e| TrtModelError::InvalidOutputShape(format!("{}", e)))?;
let detections = Self::parse_output_array_static(&output_view, conf_threshold, &meta)?;
let class_filters = self.class_filters.clone();
let filtered = filter_by_class(&detections, &class_filters);
let final_detections = nms(&filtered, nms_threshold);
Ok(detections_to_vecs(final_detections))
}
fn parse_output_array_static(
output: &ndarray::ArrayViewD<f32>,
conf_threshold: f32,
meta: &PreprocessMeta,
) -> Result<Vec<Detection>, TrtModelError> {
let shape = output.shape();
if shape.len() != 3 || shape[0] != 1 {
return Err(TrtModelError::InvalidOutputShape(format!(
"Expected shape [1, C, N], got {:?}",
shape,
)));
}
let num_features = shape[1];
let num_predictions = shape[2];
let mut detections = Vec::new();
for i in 0..num_predictions {
let cx = output[[0, 0, i]];
let cy = output[[0, 1, i]];
let w = output[[0, 2, i]];
let h = output[[0, 3, i]];
let class_scores: Vec<f32> = (4..num_features)
.map(|j| output[[0, j, i]])
.collect();
if let Some((class_idx, max_score)) = argmax(&class_scores) {
if max_score >= conf_threshold {
let (x_orig, y_orig, w_orig, h_orig) = meta.inverse_transform(cx, cy, w, h);
let bbox = BBox::from_center(x_orig, y_orig, w_orig, h_orig);
detections.push(Detection::new(bbox, class_idx, max_score));
}
}
}
Ok(detections)
}
}
#[cfg(feature = "tensorrt-opencv-compat")]
mod opencv_compat_impl {
use super::*;
use crate::image_buffer::ImageBuffer;
use opencv::core::{Mat, Rect};
use opencv::Error as OpenCvError;
impl ModelUltralyticsRt {
pub fn forward_mat(
&mut self,
image: &Mat,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<Rect>, Vec<usize>, Vec<f32>), OpenCvError> {
let img_buf = ImageBuffer::from_mat(image)?;
let (bboxes, class_ids, confidences) = self
.forward(&img_buf, conf_threshold, nms_threshold)
.map_err(|e| {
OpenCvError::new(opencv::core::StsError, format!("TensorRT error: {}", e))
})?;
let rects: Vec<Rect> = bboxes
.into_iter()
.map(|bbox| Rect::new(bbox.x, bbox.y, bbox.width, bbox.height))
.collect();
Ok((rects, class_ids, confidences))
}
}
impl crate::opencv_compat::ModelTrait for ModelUltralyticsRt {
fn forward(
&mut self,
image: &Mat,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<Rect>, Vec<usize>, Vec<f32>), OpenCvError> {
self.forward_mat(image, conf_threshold, nms_threshold)
}
}
}
impl crate::ObjectDetector for ModelUltralyticsRt {
type Input = ImageBuffer;
type Error = TrtModelError;
fn detect(
&mut self,
input: &Self::Input,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), Self::Error> {
self.forward(input, conf_threshold, nms_threshold)
}
}