use yscv_tensor::Tensor;
use crate::{DetectError, Detection};
#[derive(Debug, Clone)]
pub struct ModelDetectorConfig {
pub score_threshold: f32,
pub nms_iou_threshold: f32,
pub max_detections: usize,
pub input_height: usize,
pub input_width: usize,
}
impl Default for ModelDetectorConfig {
fn default() -> Self {
Self {
score_threshold: 0.5,
nms_iou_threshold: 0.45,
max_detections: 100,
input_height: 640,
input_width: 640,
}
}
}
pub trait ModelDetector {
fn detect_tensor(&self, input: &Tensor) -> Result<Vec<Detection>, DetectError>;
fn class_labels(&self) -> &[&str];
fn input_shape(&self) -> [usize; 3];
}
pub fn postprocess_detections(raw: &[Detection], config: &ModelDetectorConfig) -> Vec<Detection> {
let mut filtered: Vec<Detection> = raw
.iter()
.copied()
.filter(|d| d.score >= config.score_threshold)
.collect();
filtered.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut result = Vec::new();
let mut suppressed = vec![false; filtered.len()];
for i in 0..filtered.len() {
if suppressed[i] {
continue;
}
result.push(filtered[i]);
if result.len() >= config.max_detections {
break;
}
for j in i + 1..filtered.len() {
if suppressed[j] || filtered[j].class_id != filtered[i].class_id {
continue;
}
if crate::iou(filtered[i].bbox, filtered[j].bbox) > config.nms_iou_threshold {
suppressed[j] = true;
}
}
}
result
}
pub fn preprocess_rgb8_for_model(
rgb8: &[u8],
width: usize,
height: usize,
target_h: usize,
target_w: usize,
) -> Result<Tensor, DetectError> {
if rgb8.len() < width * height * 3 {
return Err(DetectError::InvalidRgb8BufferSize {
expected: width * height * 3,
got: rgb8.len(),
});
}
let mut data = Vec::with_capacity(target_h * target_w * 3);
let scale_y = height as f32 / target_h as f32;
let scale_x = width as f32 / target_w as f32;
for row in 0..target_h {
let src_y = (row as f32 * scale_y).min((height - 1) as f32);
let y0 = src_y as usize;
let y1 = (y0 + 1).min(height - 1);
let fy = src_y - y0 as f32;
for col in 0..target_w {
let src_x = (col as f32 * scale_x).min((width - 1) as f32);
let x0 = src_x as usize;
let x1 = (x0 + 1).min(width - 1);
let fx = src_x - x0 as f32;
for ch in 0..3 {
let v00 = rgb8[(y0 * width + x0) * 3 + ch] as f32;
let v01 = rgb8[(y0 * width + x1) * 3 + ch] as f32;
let v10 = rgb8[(y1 * width + x0) * 3 + ch] as f32;
let v11 = rgb8[(y1 * width + x1) * 3 + ch] as f32;
let v = v00 * (1.0 - fx) * (1.0 - fy)
+ v01 * fx * (1.0 - fy)
+ v10 * (1.0 - fx) * fy
+ v11 * fx * fy;
data.push(v / 255.0);
}
}
}
Tensor::from_vec(vec![1, target_h, target_w, 3], data).map_err(DetectError::Tensor)
}