1pub mod error;
7pub mod model;
8
9use arcstr::ArcStr;
10use error::YoloError;
11use image::{DynamicImage, GenericImageView, Rgba, imageops::FilterType};
12use model::YoloModelSession;
13use ndarray::{Array4, ArrayBase, ArrayView4, Axis, s};
14use ort::{inputs, value::TensorRef};
15
16#[derive(Debug, Clone, Copy)]
17pub struct BoundingBox {
18 pub x1: f32,
19 pub y1: f32,
20 pub x2: f32,
21 pub y2: f32,
22}
23
24#[derive(Debug, Clone)]
25pub struct YoloInput {
26 pub tensor: Array4<f32>, pub raw_width: u32,
28 pub raw_height: u32,
29}
30
31impl YoloInput {
32 pub fn view(&self) -> YoloInputView<'_> {
33 YoloInputView {
34 tensor_view: self.tensor.view(),
35 raw_width: self.raw_width,
36 raw_height: self.raw_height,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
42pub struct YoloInputView<'a> {
43 pub tensor_view: ArrayView4<'a, f32>,
44 pub raw_width: u32,
45 pub raw_height: u32,
46}
47
48#[derive(Debug, Clone)]
49pub struct YoloEntityOutput {
50 pub bounding_box: BoundingBox,
51 pub label: ArcStr,
56 pub confidence: f32,
58}
59
60pub fn image_to_yolo_input_tensor(original_image: &DynamicImage) -> YoloInput {
68 let mut input = ArrayBase::zeros((1, 3, 640, 640));
69
70 let image = original_image.resize_exact(640, 640, FilterType::CatmullRom);
71 for (x, y, Rgba([r, g, b, _])) in image.pixels() {
72 let x = x as usize;
73 let y = y as usize;
74
75 input[[0, 0, y, x]] = (r as f32) / 255.;
76 input[[0, 1, y, x]] = (g as f32) / 255.;
77 input[[0, 2, y, x]] = (b as f32) / 255.;
78 }
79
80 YoloInput {
81 tensor: input,
82 raw_width: original_image.width(),
83 raw_height: original_image.height(),
84 }
85}
86
87pub fn inference(
92 model: &mut YoloModelSession,
93 YoloInputView {
94 tensor_view,
95 raw_width,
96 raw_height,
97 }: YoloInputView,
98) -> Result<Vec<YoloEntityOutput>, YoloError> {
99 fn intersection(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
100 (box1.x2.min(box2.x2) - box1.x1.max(box2.x1))
101 * (box1.y2.min(box2.y2) - box1.y1.max(box2.y1))
102 }
103
104 fn union(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
105 ((box1.x2 - box1.x1) * (box1.y2 - box1.y1)) + ((box2.x2 - box2.x1) * (box2.y2 - box2.y1))
106 - intersection(box1, box2)
107 }
108
109 fn non_maximum_suppression(
110 mut boxes: Vec<YoloEntityOutput>,
111 iou_threshold: f32,
112 ) -> Vec<YoloEntityOutput> {
113 if boxes.is_empty() {
115 return Vec::new();
116 }
117
118 boxes.sort_unstable_by(|a, b| b.confidence.total_cmp(&a.confidence));
120
121 let mut result = Vec::with_capacity(boxes.len());
122
123 for current in boxes.into_iter() {
125 if result.iter().all(|selected: &YoloEntityOutput| {
128 let iou = intersection(&selected.bounding_box, ¤t.bounding_box)
129 / union(&selected.bounding_box, ¤t.bounding_box);
130 iou < iou_threshold
131 }) {
132 result.push(current);
133 }
134 }
135
136 result.shrink_to_fit();
137
138 result
139 }
140
141 let iou_threshold = model.get_iou_threshold();
144 let probability_threshold = model.get_probability_threshold();
145 let labels = model.get_labels().to_vec();
146
147 let inputs = inputs!["images" => TensorRef::from_array_view(tensor_view).map_err(YoloError::OrtInputError)?];
149 let outputs = model
150 .as_mut()
151 .run(inputs)
152 .map_err(YoloError::OrtInferenceError)?;
153 let output = outputs["output0"]
154 .try_extract_array::<f32>()
155 .map_err(YoloError::OrtExtractSensorError)?
156 .reversed_axes();
157 let output = output.slice(s![.., .., 0]);
158
159 let boxes = output
161 .axis_iter(Axis(0))
162 .filter_map(|row| {
163 let (class_id, prob) = row
164 .iter()
165 .skip(4) .enumerate()
167 .map(|(index, value)| (index, *value))
168 .reduce(|accum, row| if row.1 > accum.1 { row } else { accum })
169 .filter(|(_, prob)| *prob >= probability_threshold)?;
170
171 let label = labels[class_id].clone();
172
173 let xc = row[0_usize] / 640. * (raw_width as f32);
174 let yc = row[1_usize] / 640. * (raw_height as f32);
175 let w = row[2_usize] / 640. * (raw_width as f32);
176 let h = row[3_usize] / 640. * (raw_height as f32);
177
178 Some(YoloEntityOutput {
179 bounding_box: BoundingBox {
180 x1: xc - w / 2.,
181 y1: yc - h / 2.,
182 x2: xc + w / 2.,
183 y2: yc + h / 2.,
184 },
185 label,
186 confidence: prob,
187 })
188 })
189 .collect::<Vec<YoloEntityOutput>>();
190
191 Ok(non_maximum_suppression(boxes, iou_threshold))
193}