pub mod error;
pub mod model;
use arcstr::ArcStr;
use error::YoloError;
use image::{DynamicImage, GenericImageView, Rgba, imageops::FilterType};
use model::YoloModelSession;
use ndarray::{Array4, ArrayBase, ArrayView4, Axis, s};
use ort::inputs;
#[derive(Debug, Clone, Copy)]
pub struct BoundingBox {
pub x1: f32,
pub y1: f32,
pub x2: f32,
pub y2: f32,
}
#[derive(Debug, Clone)]
pub struct YoloInput {
pub tensor: Array4<f32>, pub raw_width: u32,
pub raw_height: u32,
}
impl YoloInput {
pub fn view(&self) -> YoloInputView {
YoloInputView {
tensor_view: self.tensor.view(),
raw_width: self.raw_width,
raw_height: self.raw_height,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct YoloInputView<'a> {
pub tensor_view: ArrayView4<'a, f32>,
pub raw_width: u32,
pub raw_height: u32,
}
#[derive(Debug, Clone)]
pub struct YoloEntityOutput {
pub bounding_box: BoundingBox,
pub label: ArcStr,
pub confidence: f32,
}
pub fn image_to_yolo_input_tensor(original_image: &DynamicImage) -> YoloInput {
let mut input = ArrayBase::zeros((1, 3, 640, 640));
let image = original_image.resize_exact(640, 640, FilterType::CatmullRom);
for (x, y, Rgba([r, g, b, _])) in image.pixels() {
let x = x as usize;
let y = y as usize;
input[[0, 0, y, x]] = (r as f32) / 255.;
input[[0, 1, y, x]] = (g as f32) / 255.;
input[[0, 2, y, x]] = (b as f32) / 255.;
}
YoloInput {
tensor: input,
raw_width: original_image.width(),
raw_height: original_image.height(),
}
}
pub fn inference(
model: &YoloModelSession,
YoloInputView {
tensor_view,
raw_width,
raw_height,
}: YoloInputView,
) -> Result<Vec<YoloEntityOutput>, YoloError> {
fn intersection(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
(box1.x2.min(box2.x2) - box1.x1.max(box2.x1))
* (box1.y2.min(box2.y2) - box1.y1.max(box2.y1))
}
fn union(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
((box1.x2 - box1.x1) * (box1.y2 - box1.y1)) + ((box2.x2 - box2.x1) * (box2.y2 - box2.y1))
- intersection(box1, box2)
}
fn non_maximum_suppression(
mut boxes: Vec<YoloEntityOutput>,
iou_threshold: f32,
) -> Vec<YoloEntityOutput> {
if boxes.is_empty() {
return Vec::new();
}
boxes.sort_unstable_by(|a, b| b.confidence.total_cmp(&a.confidence));
let mut result = Vec::with_capacity(boxes.len());
for current in boxes.into_iter() {
if result.iter().all(|selected: &YoloEntityOutput| {
let iou = intersection(&selected.bounding_box, ¤t.bounding_box)
/ union(&selected.bounding_box, ¤t.bounding_box);
iou < iou_threshold
}) {
result.push(current);
}
}
result.shrink_to_fit();
result
}
let inputs = inputs!["images" => tensor_view].map_err(YoloError::OrtInputError)?;
let outputs = model
.as_ref()
.run(inputs)
.map_err(YoloError::OrtInferenceError)?;
let output = outputs["output0"]
.try_extract_tensor::<f32>()
.map_err(YoloError::OrtExtractSensorError)?
.reversed_axes();
let output = output.slice(s![.., .., 0]);
let boxes = output
.axis_iter(Axis(0))
.filter_map(|row| {
let (class_id, prob) = row
.iter()
.skip(4) .enumerate()
.map(|(index, value)| (index, *value))
.reduce(|accum, row| if row.1 > accum.1 { row } else { accum })
.filter(|(_, prob)| *prob >= model.get_probability_threshold())?;
let label = model.labels[class_id].clone();
let xc = row[0_usize] / 640. * (raw_width as f32);
let yc = row[1_usize] / 640. * (raw_height as f32);
let w = row[2_usize] / 640. * (raw_width as f32);
let h = row[3_usize] / 640. * (raw_height as f32);
Some(YoloEntityOutput {
bounding_box: BoundingBox {
x1: xc - w / 2.,
y1: yc - h / 2.,
x2: xc + w / 2.,
y2: yc + h / 2.,
},
label,
confidence: prob,
})
})
.collect::<Vec<YoloEntityOutput>>();
Ok(non_maximum_suppression(boxes, model.get_iou_threshold()))
}