ultralytics_inference/inference.rs
1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Inference configuration and common types.
4//!
5//! This module defines the [`InferenceConfig`] struct, which controls various parameters
6//! for YOLO model inference, such as confidence thresholds, Non-Maximum Suppression (NMS),
7//! input image sizing, and hardware execution options.
8
9/// Configuration for YOLO inference.
10///
11/// This struct is used to customize the behavior of the inference engine.
12/// It uses a builder pattern for convenient construction.
13///
14/// # Examples
15///
16/// Basic configuration:
17/// ```rust
18/// use ultralytics_inference::InferenceConfig;
19///
20/// let config = InferenceConfig::new()
21/// .with_confidence(0.5)
22/// .with_iou(0.45)
23/// .with_max_det(300)
24/// .with_imgsz(640, 640);
25/// ```
26///
27/// With specific hardware device:
28/// ```rust
29/// use ultralytics_inference::{InferenceConfig, Device};
30///
31/// let config = InferenceConfig::new()
32/// .with_confidence(0.5)
33/// .with_device(Device::Cuda(0));
34/// ```
35#[derive(Debug, Clone)]
36#[allow(clippy::struct_excessive_bools)]
37pub struct InferenceConfig {
38 /// Confidence threshold for detections (0.0 to 1.0).
39 /// Detections with confidence scores lower than this value will be discarded.
40 pub confidence_threshold: f32,
41 /// Intersection over Union (`IoU`) threshold for Non-Maximum Suppression (NMS) (0.0 to 1.0).
42 /// Used to merge overlapping boxes. Lower values filter more duplicates.
43 pub iou_threshold: f32,
44 /// Maximum number of detections to return per image.
45 /// The top-k detections sorted by confidence will be returned.
46 pub max_det: usize,
47 /// Explicit input image size (height, width).
48 /// If `None`, the model's metadata will be used to determine input size.
49 pub imgsz: Option<(usize, usize)>,
50 /// Batch size for inference when using [`BatchProcessor`](crate::batch::BatchProcessor).
51 /// If `None`, defaults to 1 (single-image inference).
52 pub batch: Option<usize>,
53 /// Number of intra-op threads for ONNX Runtime.
54 /// Setting this to `0` allows ONNX Runtime to choose the optimal number.
55 pub num_threads: usize,
56 /// Whether to use FP16 (half-precision) inference.
57 /// This can improve performance on compatible hardware (e.g., GPUs) but may
58 /// result in slight precision loss.
59 pub half: bool,
60 /// Hardware device to use for inference.
61 /// If `None`, the best available device will be automatically selected.
62 pub device: Option<crate::Device>,
63 /// Whether to save annotated results.
64 /// Defaults to `true`.
65 pub save: bool,
66 /// Whether to save individual frames instead of a video file when input is video.
67 /// Defaults to `false` (save as video).
68 pub save_frames: bool,
69 /// Whether to use minimal padding (rectangular inference). Defaults to `true`.
70 pub rect: bool,
71 /// Class IDs to filter predictions. If `None`, all classes are returned.
72 /// Useful for focusing on specific objects in multi-class detection tasks.
73 pub classes: Option<Vec<usize>>,
74}
75
76impl Default for InferenceConfig {
77 fn default() -> Self {
78 Self {
79 confidence_threshold: Self::DEFAULT_CONF,
80 iou_threshold: Self::DEFAULT_IOU,
81 max_det: Self::DEFAULT_MAX_DET,
82 imgsz: None,
83 batch: None,
84 num_threads: 0, // 0 = let ONNX Runtime decide (typically uses all cores efficiently)
85 half: Self::DEFAULT_HALF,
86 device: None,
87 save: Self::DEFAULT_SAVE,
88 save_frames: Self::DEFAULT_SAVE_FRAMES,
89 rect: Self::DEFAULT_RECT,
90 classes: None,
91 }
92 }
93}
94
95impl InferenceConfig {
96 /// Default confidence threshold (0.0 to 1.0).
97 pub const DEFAULT_CONF: f32 = 0.25;
98 /// Default `IoU` threshold for NMS (0.0 to 1.0).
99 pub const DEFAULT_IOU: f32 = 0.7;
100 /// Default maximum number of detections per image.
101 pub const DEFAULT_MAX_DET: usize = 300;
102 /// Default for FP16 half-precision inference.
103 pub const DEFAULT_HALF: bool = false;
104 /// Default for saving annotated results.
105 pub const DEFAULT_SAVE: bool = true;
106 /// Default for saving individual frames (vs video).
107 pub const DEFAULT_SAVE_FRAMES: bool = false;
108 /// Default for rectangular (minimal padding) inference.
109 pub const DEFAULT_RECT: bool = true;
110
111 /// Create a new configuration with default values.
112 ///
113 /// # Returns
114 ///
115 /// * A new `InferenceConfig` instance with default settings.
116 #[must_use]
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 /// Set the batch size.
122 ///
123 /// # Arguments
124 ///
125 /// * `batch` - The batch size.
126 ///
127 /// # Returns
128 ///
129 /// * The modified `InferenceConfig`.
130 #[must_use]
131 pub const fn with_batch(mut self, batch: usize) -> Self {
132 self.batch = Some(batch);
133 self
134 }
135
136 /// Set the confidence threshold.
137 ///
138 /// Detections with a confidence score below this threshold will be filtered out.
139 ///
140 /// # Arguments
141 ///
142 /// * `threshold` - The minimum confidence score (0.0 to 1.0).
143 ///
144 /// # Returns
145 ///
146 /// * The modified `InferenceConfig`.
147 #[must_use]
148 pub const fn with_confidence(mut self, threshold: f32) -> Self {
149 self.confidence_threshold = threshold;
150 self
151 }
152
153 /// Set the `IoU` threshold for Non-Maximum Suppression (NMS).
154 ///
155 /// NMS suppresses overlapping bounding boxes. This threshold determines how much overlap
156 /// is allowed before boxes are considered duplicates.
157 ///
158 /// # Arguments
159 ///
160 /// * `threshold` - The `IoU` threshold (0.0 to 1.0).
161 ///
162 /// # Returns
163 ///
164 /// * The modified `InferenceConfig`.
165 #[must_use]
166 pub const fn with_iou(mut self, threshold: f32) -> Self {
167 self.iou_threshold = threshold;
168 self
169 }
170
171 /// Set the maximum number of detections to return.
172 ///
173 /// Only the top `max` detections (sorted by confidence) will be kept after NMS.
174 ///
175 /// # Arguments
176 ///
177 /// * `max` - The maximum number of detections.
178 ///
179 /// # Returns
180 ///
181 /// * The modified `InferenceConfig`.
182 #[must_use]
183 pub const fn with_max_det(mut self, max: usize) -> Self {
184 self.max_det = max;
185 self
186 }
187
188 /// Set the input image size.
189 ///
190 /// This explicitly sets the size to resize images to before inference.
191 /// If not set, the model's internal metadata size will be used.
192 ///
193 /// # Arguments
194 ///
195 /// * `height` - The target image height.
196 /// * `width` - The target image width.
197 ///
198 /// # Returns
199 ///
200 /// * The modified `InferenceConfig`.
201 #[must_use]
202 pub const fn with_imgsz(mut self, height: usize, width: usize) -> Self {
203 self.imgsz = Some((height, width));
204 self
205 }
206
207 /// Set the number of threads for inference.
208 ///
209 /// # Arguments
210 ///
211 /// * `threads` - The number of intra-op threads. Set to `0` for auto-configuration.
212 ///
213 /// # Returns
214 ///
215 /// * The modified `InferenceConfig`.
216 #[must_use]
217 pub const fn with_threads(mut self, threads: usize) -> Self {
218 self.num_threads = threads;
219 self
220 }
221
222 /// Enable or disable FP16 (half-precision) inference.
223 ///
224 /// Using FP16 can significantly speed up inference on GPUs and some CPUS,
225 /// at the cost of potential minor precision loss.
226 ///
227 /// # Arguments
228 ///
229 /// * `half` - `true` to enable FP16, `false` for FP32.
230 ///
231 /// # Returns
232 ///
233 /// * The modified `InferenceConfig`.
234 #[must_use]
235 pub const fn with_half(mut self, half: bool) -> Self {
236 self.half = half;
237 self
238 }
239
240 /// Set the hardware device for inference.
241 ///
242 /// # Arguments
243 ///
244 /// * `device` - The device to use (e.g., CPU, CUDA, MPS).
245 ///
246 /// # Example
247 ///
248 /// ```rust
249 /// use ultralytics_inference::{InferenceConfig, Device};
250 ///
251 /// let config = InferenceConfig::new()
252 /// .with_device(Device::Mps); // Use Apple Metal Performance Shaders
253 /// ```
254 ///
255 /// # Returns
256 ///
257 /// * The modified `InferenceConfig`.
258 #[must_use]
259 pub const fn with_device(mut self, device: crate::Device) -> Self {
260 self.device = Some(device);
261 self
262 }
263
264 /// Set whether to save annotated results.
265 ///
266 /// # Arguments
267 ///
268 /// * `save` - `true` to save results, `false` to skip saving.
269 ///
270 /// # Returns
271 ///
272 /// * The modified `InferenceConfig`.
273 #[must_use]
274 pub const fn with_save(mut self, save: bool) -> Self {
275 self.save = save;
276 self
277 }
278
279 /// Set whether to save individual frames for video inputs.
280 ///
281 /// # Arguments
282 ///
283 /// * `save_frames` - `true` to save frames, `false` to save as video.
284 ///
285 /// # Returns
286 ///
287 /// * The modified `InferenceConfig`.
288 #[must_use]
289 pub const fn with_save_frames(mut self, save_frames: bool) -> Self {
290 self.save_frames = save_frames;
291 self
292 }
293
294 /// Set whether to use minimal padding (rectangular inference).
295 ///
296 /// # Arguments
297 ///
298 /// * `rect` - `true` to enable, `false` to disable.
299 ///
300 /// # Returns
301 ///
302 /// * The modified `InferenceConfig`.
303 #[must_use]
304 pub const fn with_rect(mut self, rect: bool) -> Self {
305 self.rect = rect;
306 self
307 }
308
309 /// Set the class IDs to filter predictions.
310 ///
311 /// Only detections belonging to the specified classes will be returned.
312 ///
313 /// # Arguments
314 ///
315 /// * `classes` - A vector of class IDs to keep.
316 ///
317 /// # Example
318 ///
319 /// ```rust
320 /// use ultralytics_inference::InferenceConfig;
321 ///
322 /// // Only detect persons (class 0) and cars (class 2)
323 /// let config = InferenceConfig::new()
324 /// .with_classes(vec![0, 2]);
325 /// ```
326 ///
327 /// # Returns
328 ///
329 /// * The modified `InferenceConfig`.
330 #[must_use]
331 pub fn with_classes(mut self, classes: Vec<usize>) -> Self {
332 self.classes = Some(classes);
333 self
334 }
335 /// Check if a class should be included in the results.
336 ///
337 /// # Arguments
338 ///
339 /// * `class_id` - The class index to check.
340 ///
341 /// # Returns
342 ///
343 /// * `true` if the class should be kept.
344 /// * `false` if the class should be filtered out.
345 #[must_use]
346 pub fn keep_class(&self, class_id: usize) -> bool {
347 self.classes.as_ref().is_none_or(|c| c.contains(&class_id))
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_config_default() {
357 let config = InferenceConfig::default();
358 assert!((config.confidence_threshold - InferenceConfig::DEFAULT_CONF).abs() < f32::EPSILON);
359 assert!((config.iou_threshold - InferenceConfig::DEFAULT_IOU).abs() < f32::EPSILON);
360 assert_eq!(config.max_det, 300);
361 }
362
363 #[test]
364 fn test_config_builder() {
365 let config = InferenceConfig::new()
366 .with_confidence(0.5)
367 .with_iou(0.6)
368 .with_max_det(300)
369 .with_imgsz(640, 640)
370 .with_threads(8);
371
372 assert!((config.confidence_threshold - 0.5).abs() < f32::EPSILON);
373 assert!((config.iou_threshold - 0.6).abs() < f32::EPSILON);
374 assert_eq!(config.max_det, 300);
375 assert_eq!(config.imgsz, Some((640, 640)));
376 assert_eq!(config.num_threads, 8);
377 }
378
379 #[test]
380 fn test_keep_class() {
381 let config = InferenceConfig::default();
382 assert!(config.keep_class(0));
383 assert!(config.keep_class(100));
384
385 let config_filtered = InferenceConfig::new().with_classes(vec![1, 3]);
386 assert!(config_filtered.keep_class(1));
387 assert!(config_filtered.keep_class(3));
388 assert!(!config_filtered.keep_class(0));
389 assert!(!config_filtered.keep_class(2));
390 }
391}