Skip to main content

yolo_rs/
lib.rs

1//! A Rust library for the YOLO v11 object detection model.
2//!
3//! This library provides a high-level API for running the YOLO v11 object detection model.
4//! Currently, it supports only the inference.
5
6pub mod error;
7pub mod model;
8
9use arcstr::ArcStr;
10use error::YoloError;
11use image::{DynamicImage, GenericImageView, Rgba, imageops::FilterType};
12use model::YoloModelSession;
13use ndarray::{Array4, ArrayBase, ArrayView4, Axis, s};
14use ort::{inputs, value::TensorRef};
15
16#[derive(Debug, Clone, Copy)]
17pub struct BoundingBox {
18    pub x1: f32,
19    pub y1: f32,
20    pub x2: f32,
21    pub y2: f32,
22}
23
24#[derive(Debug, Clone)]
25pub struct YoloInput {
26    pub tensor: Array4<f32>, // 640x640
27    pub raw_width: u32,
28    pub raw_height: u32,
29}
30
31impl YoloInput {
32    pub fn view(&self) -> YoloInputView<'_> {
33        YoloInputView {
34            tensor_view: self.tensor.view(),
35            raw_width: self.raw_width,
36            raw_height: self.raw_height,
37        }
38    }
39}
40
41#[derive(Debug, Clone, Copy)]
42pub struct YoloInputView<'a> {
43    pub tensor_view: ArrayView4<'a, f32>,
44    pub raw_width: u32,
45    pub raw_height: u32,
46}
47
48#[derive(Debug, Clone)]
49pub struct YoloEntityOutput {
50    pub bounding_box: BoundingBox,
51    /// The label of the detected entity.
52    ///
53    /// You can check the metadata of the model with
54    /// [Netron](https://netron.app) to get the labels.
55    pub label: ArcStr,
56    /// The confidence of the detected entity.
57    pub confidence: f32,
58}
59
60/// Convert an image to a YOLO input tensor.
61///
62/// The input image is resized to 640x640 and normalized to the range [0, 1].
63/// The tensor has the shape (1, 3, 640, 640) and the layout is (R, G, B).
64///
65/// You can pass the resulting tensor to the [`inference`] function.
66/// Note that you might need to call [`YoloInput::view`] to get a view of the tensor.
67pub fn image_to_yolo_input_tensor(original_image: &DynamicImage) -> YoloInput {
68    let mut input = ArrayBase::zeros((1, 3, 640, 640));
69
70    let image = original_image.resize_exact(640, 640, FilterType::CatmullRom);
71    for (x, y, Rgba([r, g, b, _])) in image.pixels() {
72        let x = x as usize;
73        let y = y as usize;
74
75        input[[0, 0, y, x]] = (r as f32) / 255.;
76        input[[0, 1, y, x]] = (g as f32) / 255.;
77        input[[0, 2, y, x]] = (b as f32) / 255.;
78    }
79
80    YoloInput {
81        tensor: input,
82        raw_width: original_image.width(),
83        raw_height: original_image.height(),
84    }
85}
86
87/// Inference on the YOLO model, returning the detected entities.
88///
89/// The input tensor should be obtained from the [`image_to_yolo_input_tensor`] function.
90/// The [`YoloModelSession`] can be obtained from the [`YoloModelSession::from_filename_v8`] method.
91pub fn inference(
92    model: &mut YoloModelSession,
93    YoloInputView {
94        tensor_view,
95        raw_width,
96        raw_height,
97    }: YoloInputView,
98) -> Result<Vec<YoloEntityOutput>, YoloError> {
99    fn intersection(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
100        (box1.x2.min(box2.x2) - box1.x1.max(box2.x1))
101            * (box1.y2.min(box2.y2) - box1.y1.max(box2.y1))
102    }
103
104    fn union(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
105        ((box1.x2 - box1.x1) * (box1.y2 - box1.y1)) + ((box2.x2 - box2.x1) * (box2.y2 - box2.y1))
106            - intersection(box1, box2)
107    }
108
109    fn non_maximum_suppression(
110        mut boxes: Vec<YoloEntityOutput>,
111        iou_threshold: f32,
112    ) -> Vec<YoloEntityOutput> {
113        // Early return if no boxes are provided
114        if boxes.is_empty() {
115            return Vec::new();
116        }
117
118        // Sort boxes by confidence descending using sort_unstable_by for better performance
119        boxes.sort_unstable_by(|a, b| b.confidence.total_cmp(&a.confidence));
120
121        let mut result = Vec::with_capacity(boxes.len());
122
123        // Iterate through each box and select it if it doesn't overlap significantly with already selected boxes
124        for current in boxes.into_iter() {
125            // Check if the current box has a high IoU with any box in the result
126            // Using `iter().all()` ensures we short-circuit on the first overlap found
127            if result.iter().all(|selected: &YoloEntityOutput| {
128                let iou = intersection(&selected.bounding_box, &current.bounding_box)
129                    / union(&selected.bounding_box, &current.bounding_box);
130                iou < iou_threshold
131            }) {
132                result.push(current);
133            }
134        }
135
136        result.shrink_to_fit();
137
138        result
139    }
140
141    // Due to the lifetime of the model, we need to clone the
142    // labels and thresholds early.
143    let iou_threshold = model.get_iou_threshold();
144    let probability_threshold = model.get_probability_threshold();
145    let labels = model.get_labels().to_vec();
146
147    // Run YOLOv8 inference
148    let inputs = inputs!["images" => TensorRef::from_array_view(tensor_view).map_err(YoloError::OrtInputError)?];
149    let outputs = model
150        .as_mut()
151        .run(inputs)
152        .map_err(YoloError::OrtInferenceError)?;
153    let output = outputs["output0"]
154        .try_extract_array::<f32>()
155        .map_err(YoloError::OrtExtractSensorError)?
156        .reversed_axes();
157    let output = output.slice(s![.., .., 0]);
158
159    // Turn the output tensor into bounding boxes
160    let boxes = output
161        .axis_iter(Axis(0))
162        .filter_map(|row| {
163            let (class_id, prob) = row
164                .iter()
165                .skip(4) // skip bounding box coordinates
166                .enumerate()
167                .map(|(index, value)| (index, *value))
168                .reduce(|accum, row| if row.1 > accum.1 { row } else { accum })
169                .filter(|(_, prob)| *prob >= probability_threshold)?;
170
171            let label = labels[class_id].clone();
172
173            let xc = row[0_usize] / 640. * (raw_width as f32);
174            let yc = row[1_usize] / 640. * (raw_height as f32);
175            let w = row[2_usize] / 640. * (raw_width as f32);
176            let h = row[3_usize] / 640. * (raw_height as f32);
177
178            Some(YoloEntityOutput {
179                bounding_box: BoundingBox {
180                    x1: xc - w / 2.,
181                    y1: yc - h / 2.,
182                    x2: xc + w / 2.,
183                    y2: yc + h / 2.,
184                },
185                label,
186                confidence: prob,
187            })
188        })
189        .collect::<Vec<YoloEntityOutput>>();
190
191    // Perform non-maximum suppression (NMS)
192    Ok(non_maximum_suppression(boxes, iou_threshold))
193}