Skip to main content

yolo_rs/
model.rs

1//! Load the YOLO model.
2
3use std::path::Path;
4
5use arcstr::ArcStr;
6
7use crate::error::YoloError;
8
9/// The YOLO model.
10///
11/// It is a wrapper around the ONNX runtime session and the YOLO labels.
12#[derive(Debug)]
13pub struct YoloModelSession {
14    pub session: ort::session::Session,
15    pub labels: Vec<ArcStr>,
16
17    pub probability_threshold: Option<f32>, // default = 0.5
18    pub iou_threshold: Option<f32>,         // default = 0.7
19}
20
21impl YoloModelSession {
22    /// Wrap a ONNX session to a [`YoloModelSession`].
23    ///
24    /// The `session` is the ONNX runtime session, and the `labels` are the YOLO labels.
25    pub fn new(
26        session: ort::session::Session,
27        labels: impl Iterator<Item = impl Into<ArcStr>>,
28    ) -> Self {
29        Self {
30            session,
31            labels: labels.map(Into::into).collect(),
32            probability_threshold: None,
33            iou_threshold: None,
34        }
35    }
36
37    /// Wrap a ONNX session to a [`YoloModelSession`] based on the labels of YOLO v8 (v11).
38    pub fn new_v8(session: ort::session::Session) -> Self {
39        const LABELS: &[ArcStr] = &[
40            arcstr::literal!("person"),
41            arcstr::literal!("bicycle"),
42            arcstr::literal!("car"),
43            arcstr::literal!("motorcycle"),
44            arcstr::literal!("airplane"),
45            arcstr::literal!("bus"),
46            arcstr::literal!("train"),
47            arcstr::literal!("truck"),
48            arcstr::literal!("boat"),
49            arcstr::literal!("traffic light"),
50            arcstr::literal!("fire hydrant"),
51            arcstr::literal!("stop sign"),
52            arcstr::literal!("parking meter"),
53            arcstr::literal!("bench"),
54            arcstr::literal!("bird"),
55            arcstr::literal!("cat"),
56            arcstr::literal!("dog"),
57            arcstr::literal!("horse"),
58            arcstr::literal!("sheep"),
59            arcstr::literal!("cow"),
60            arcstr::literal!("elephant"),
61            arcstr::literal!("bear"),
62            arcstr::literal!("zebra"),
63            arcstr::literal!("giraffe"),
64            arcstr::literal!("backpack"),
65            arcstr::literal!("umbrella"),
66            arcstr::literal!("handbag"),
67            arcstr::literal!("tie"),
68            arcstr::literal!("suitcase"),
69            arcstr::literal!("frisbee"),
70            arcstr::literal!("skis"),
71            arcstr::literal!("snowboard"),
72            arcstr::literal!("sports ball"),
73            arcstr::literal!("kite"),
74            arcstr::literal!("baseball bat"),
75            arcstr::literal!("baseball glove"),
76            arcstr::literal!("skateboard"),
77            arcstr::literal!("surfboard"),
78            arcstr::literal!("tennis racket"),
79            arcstr::literal!("bottle"),
80            arcstr::literal!("wine glass"),
81            arcstr::literal!("cup"),
82            arcstr::literal!("fork"),
83            arcstr::literal!("knife"),
84            arcstr::literal!("spoon"),
85            arcstr::literal!("bowl"),
86            arcstr::literal!("banana"),
87            arcstr::literal!("apple"),
88            arcstr::literal!("sandwich"),
89            arcstr::literal!("orange"),
90            arcstr::literal!("broccoli"),
91            arcstr::literal!("carrot"),
92            arcstr::literal!("hot dog"),
93            arcstr::literal!("pizza"),
94            arcstr::literal!("donut"),
95            arcstr::literal!("cake"),
96            arcstr::literal!("chair"),
97            arcstr::literal!("couch"),
98            arcstr::literal!("potted plant"),
99            arcstr::literal!("bed"),
100            arcstr::literal!("dining table"),
101            arcstr::literal!("toilet"),
102            arcstr::literal!("tv"),
103            arcstr::literal!("laptop"),
104            arcstr::literal!("mouse"),
105            arcstr::literal!("remote"),
106            arcstr::literal!("keyboard"),
107            arcstr::literal!("cell phone"),
108            arcstr::literal!("microwave"),
109            arcstr::literal!("oven"),
110            arcstr::literal!("toaster"),
111            arcstr::literal!("sink"),
112            arcstr::literal!("refrigerator"),
113            arcstr::literal!("book"),
114            arcstr::literal!("clock"),
115            arcstr::literal!("vase"),
116            arcstr::literal!("scissors"),
117            arcstr::literal!("teddy bear"),
118            arcstr::literal!("hair drier"),
119            arcstr::literal!("toothbrush"),
120        ];
121
122        Self {
123            session,
124            labels: LABELS.to_vec(),
125            probability_threshold: None,
126            iou_threshold: None,
127        }
128    }
129
130    /// Load the YOLO ONNX model from a filename.
131    ///
132    /// You can use this function to load a YOLO v8 (v11) model from a file.
133    /// The `filename` is the path to the ONNX model file.
134    ///
135    /// You can export the ONNX model file according to
136    /// [Ultralytics' manual](https://docs.ultralytics.com/integrations/onnx/).
137    pub fn from_filename_v8(filename: impl AsRef<Path>) -> Result<Self, YoloError> {
138        let session = ort::session::Session::builder()
139            .map_err(YoloError::OrtSessionBuildError)?
140            .commit_from_file(filename)
141            .map_err(YoloError::OrtSessionLoadError)?;
142
143        Ok(Self::new_v8(session))
144    }
145
146    pub fn get_labels(&self) -> &[ArcStr] {
147        &self.labels
148    }
149
150    pub fn get_probability_threshold(&self) -> f32 {
151        self.probability_threshold.unwrap_or(0.5)
152    }
153
154    pub fn get_iou_threshold(&self) -> f32 {
155        self.iou_threshold.unwrap_or(0.7)
156    }
157}
158
159impl AsRef<ort::session::Session> for YoloModelSession {
160    fn as_ref(&self) -> &ort::session::Session {
161        &self.session
162    }
163}
164
165impl AsMut<ort::session::Session> for YoloModelSession {
166    fn as_mut(&mut self) -> &mut ort::session::Session {
167        &mut self.session
168    }
169}