use rknn_runtime::{RknnModel, Nc1hwc2Layout};
use crate::bbox::BBox;
use crate::image_buffer::ImageBuffer;
use crate::postprocess::{Detection, nms, filter_by_class, detections_to_vecs};
use crate::preprocessing::{LetterboxMeta, PreprocessMeta, StretchMeta};
#[derive(Debug)]
pub enum RknnModelError {
Rknn(rknn_runtime::Error),
InvalidOutputShape(String),
}
impl std::fmt::Display for RknnModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RknnModelError::Rknn(e) => write!(f, "RKNN error: {}", e),
RknnModelError::InvalidOutputShape(s) => write!(f, "Invalid output shape: {}", s),
}
}
}
impl std::error::Error for RknnModelError {}
impl From<rknn_runtime::Error> for RknnModelError {
fn from(e: rknn_runtime::Error) -> Self {
RknnModelError::Rknn(e)
}
}
pub struct ModelUltralyticsRknn {
model: RknnModel,
input_width: u32,
input_height: u32,
class_filters: Vec<usize>,
use_letterbox: bool,
resize_buf: Vec<u8>,
layout: Nc1hwc2Layout,
class_raw_offsets: Vec<usize>,
}
impl ModelUltralyticsRknn {
pub fn new_from_file(
model_path: &str,
num_classes: usize,
class_filters: Vec<usize>,
) -> Result<Self, RknnModelError> {
let model = RknnModel::load(model_path)?;
Self::from_model(model, num_classes, class_filters)
}
pub fn new_with_lib(
model_path: &str,
lib_path: &str,
num_classes: usize,
class_filters: Vec<usize>,
) -> Result<Self, RknnModelError> {
let model = RknnModel::load_with_lib(model_path, lib_path)?;
Self::from_model(model, num_classes, class_filters)
}
fn from_model(
model: RknnModel,
num_classes: usize,
class_filters: Vec<usize>,
) -> Result<Self, RknnModelError> {
let input_shape = &model.input_attr().shape;
let input_height = input_shape[1];
let input_width = input_shape[2];
let layout = model.output_nc1hwc2_layout(0)?;
if layout.c2() < 4 {
return Err(RknnModelError::InvalidOutputShape(
format!("NC1HWC2 c2={} < 4: bbox channels must fit in one block", layout.c2()),
));
}
let class_raw_offsets = layout.precompute_channel_offsets(4, num_classes);
Ok(Self {
model,
input_width,
input_height,
class_filters,
#[cfg(feature = "letterbox")]
use_letterbox: true,
#[cfg(not(feature = "letterbox"))]
use_letterbox: false,
resize_buf: vec![0u8; input_width as usize * input_height as usize * 3],
layout,
class_raw_offsets,
})
}
pub fn set_letterbox(&mut self, enabled: bool) {
self.use_letterbox = enabled;
}
pub fn input_size(&self) -> (u32, u32) {
(self.input_width, self.input_height)
}
pub fn forward(
&mut self,
image: &ImageBuffer,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), RknnModelError> {
let (orig_h, orig_w, _) = image.shape();
let dst_w = self.input_width as usize;
let dst_h = self.input_height as usize;
let already_correct_size = orig_w == dst_w && orig_h == dst_h;
let meta = if already_correct_size {
PreprocessMeta::Stretch(StretchMeta {
scale_x: 1.0,
scale_y: 1.0,
original_width: orig_w as i32,
original_height: orig_h as i32,
})
} else {
let src = image.as_slice().expect("ImageBuffer not contiguous");
if self.use_letterbox {
let lm = resize_letterbox_nearest_into(
src, orig_w, orig_h,
&mut self.resize_buf, dst_w, dst_h,
);
PreprocessMeta::Letterbox(lm)
} else {
resize_nearest_rgb_into(
src, orig_w, orig_h,
&mut self.resize_buf, dst_w, dst_h,
);
PreprocessMeta::Stretch(StretchMeta {
scale_x: orig_w as f32 / dst_w as f32,
scale_y: orig_h as f32 / dst_h as f32,
original_width: orig_w as i32,
original_height: orig_h as i32,
})
}
};
let input_bytes = if already_correct_size {
image.as_slice().expect("ImageBuffer not contiguous")
} else {
&self.resize_buf
};
self.model.run(input_bytes)?;
let raw = self.model.output_raw(0)?;
let detections = parse_nc1hwc2_direct(
raw,
&self.class_raw_offsets,
&self.layout,
conf_threshold,
self.input_width as f32,
self.input_height as f32,
&meta,
);
let filtered = filter_by_class(&detections, &self.class_filters);
let final_detections = nms(&filtered, nms_threshold);
Ok(detections_to_vecs(final_detections))
}
}
impl crate::ObjectDetector for ModelUltralyticsRknn {
type Input = ImageBuffer;
type Error = RknnModelError;
fn detect(
&mut self,
input: &Self::Input,
conf_threshold: f32,
nms_threshold: f32,
) -> Result<(Vec<BBox>, Vec<usize>, Vec<f32>), Self::Error> {
self.forward(input, conf_threshold, nms_threshold)
}
}
#[inline(never)]
fn resize_nearest_rgb_into(
src: &[u8], src_w: usize, src_h: usize,
dst: &mut [u8], dst_w: usize, dst_h: usize,
) {
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_mut_ptr();
for y in 0..dst_h {
let src_y = (y * src_h) / dst_h;
let dst_row = y * dst_w * 3;
let src_row = src_y * src_w * 3;
for x in 0..dst_w {
let src_x = (x * src_w) / dst_w;
let si = src_row + src_x * 3;
let di = dst_row + x * 3;
unsafe {
std::ptr::copy_nonoverlapping(src_ptr.add(si), dst_ptr.add(di), 3);
}
}
}
}
#[inline(never)]
fn resize_letterbox_nearest_into(
src: &[u8], src_w: usize, src_h: usize,
dst: &mut [u8], dst_w: usize, dst_h: usize,
) -> LetterboxMeta {
let scale = f32::min(dst_w as f32 / src_w as f32, dst_h as f32 / src_h as f32);
let new_w = (src_w as f32 * scale).round() as usize;
let new_h = (src_h as f32 * scale).round() as usize;
let pad_left = (dst_w - new_w) / 2;
let pad_top = (dst_h - new_h) / 2;
dst.fill(114);
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_mut_ptr();
for y in 0..new_h {
let src_y = (y * src_h) / new_h;
let dst_row = (y + pad_top) * dst_w * 3;
let src_row = src_y * src_w * 3;
for x in 0..new_w {
let src_x = (x * src_w) / new_w;
let si = src_row + src_x * 3;
let di = dst_row + (x + pad_left) * 3;
unsafe {
std::ptr::copy_nonoverlapping(src_ptr.add(si), dst_ptr.add(di), 3);
}
}
}
LetterboxMeta {
scale,
pad_left: pad_left as i32,
pad_top: pad_top as i32,
original_width: src_w as i32,
original_height: src_h as i32,
}
}
#[inline(never)]
fn parse_nc1hwc2_direct(
raw: &[i8],
class_raw_offsets: &[usize],
layout: &Nc1hwc2Layout,
conf_threshold: f32,
input_width: f32,
input_height: f32,
meta: &PreprocessMeta,
) -> Vec<Detection> {
let num_classes = class_raw_offsets.len();
let threshold_i8 = layout.threshold_i8(conf_threshold);
let stride = layout.prediction_stride();
let mut detections = Vec::new();
let mut p_offset = 0usize;
for _p in 0..layout.num_predictions() {
let mut best_raw = i8::MIN;
let mut best_cls = 0usize;
for (c, &off) in class_raw_offsets[..num_classes].iter().enumerate() {
let v = unsafe { *raw.get_unchecked(off + p_offset) };
if v > best_raw {
best_raw = v;
best_cls = c;
}
}
if best_raw >= threshold_i8 {
let best_conf = layout.dequant(best_raw);
let cx = layout.dequant(unsafe { *raw.get_unchecked(p_offset) }) * input_width;
let cy = layout.dequant(unsafe { *raw.get_unchecked(p_offset + 1) }) * input_height;
let bw = layout.dequant(unsafe { *raw.get_unchecked(p_offset + 2) }) * input_width;
let bh = layout.dequant(unsafe { *raw.get_unchecked(p_offset + 3) }) * input_height;
if bw > 0.0 && bh > 0.0 {
let (x, y, w_out, h_out) = meta.inverse_transform(cx, cy, bw, bh);
detections.push(Detection::new(
BBox::from_center(x, y, w_out, h_out),
best_cls,
best_conf,
));
}
}
p_offset += stride;
}
detections
}