#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct InferenceConfig {
pub confidence_threshold: f32,
pub iou_threshold: f32,
pub max_det: usize,
pub imgsz: Option<(usize, usize)>,
pub batch: Option<usize>,
pub num_threads: usize,
pub half: bool,
pub device: Option<crate::Device>,
pub save: bool,
pub save_frames: bool,
pub rect: bool,
pub classes: Option<Vec<usize>>,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
confidence_threshold: Self::DEFAULT_CONF,
iou_threshold: Self::DEFAULT_IOU,
max_det: Self::DEFAULT_MAX_DET,
imgsz: None,
batch: None,
num_threads: 0, half: Self::DEFAULT_HALF,
device: None,
save: Self::DEFAULT_SAVE,
save_frames: Self::DEFAULT_SAVE_FRAMES,
rect: Self::DEFAULT_RECT,
classes: None,
}
}
}
impl InferenceConfig {
pub const DEFAULT_CONF: f32 = 0.25;
pub const DEFAULT_IOU: f32 = 0.7;
pub const DEFAULT_MAX_DET: usize = 300;
pub const DEFAULT_HALF: bool = false;
pub const DEFAULT_SAVE: bool = true;
pub const DEFAULT_SAVE_FRAMES: bool = false;
pub const DEFAULT_RECT: bool = true;
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_batch(mut self, batch: usize) -> Self {
self.batch = Some(batch);
self
}
#[must_use]
pub const fn with_confidence(mut self, threshold: f32) -> Self {
self.confidence_threshold = threshold;
self
}
#[must_use]
pub const fn with_iou(mut self, threshold: f32) -> Self {
self.iou_threshold = threshold;
self
}
#[must_use]
pub const fn with_max_det(mut self, max: usize) -> Self {
self.max_det = max;
self
}
#[must_use]
pub const fn with_imgsz(mut self, height: usize, width: usize) -> Self {
self.imgsz = Some((height, width));
self
}
#[must_use]
pub const fn with_threads(mut self, threads: usize) -> Self {
self.num_threads = threads;
self
}
#[must_use]
pub const fn with_half(mut self, half: bool) -> Self {
self.half = half;
self
}
#[must_use]
pub const fn with_device(mut self, device: crate::Device) -> Self {
self.device = Some(device);
self
}
#[must_use]
pub const fn with_save(mut self, save: bool) -> Self {
self.save = save;
self
}
#[must_use]
pub const fn with_save_frames(mut self, save_frames: bool) -> Self {
self.save_frames = save_frames;
self
}
#[must_use]
pub const fn with_rect(mut self, rect: bool) -> Self {
self.rect = rect;
self
}
#[must_use]
pub fn with_classes(mut self, classes: Vec<usize>) -> Self {
self.classes = Some(classes);
self
}
#[must_use]
pub fn keep_class(&self, class_id: usize) -> bool {
self.classes.as_ref().is_none_or(|c| c.contains(&class_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = InferenceConfig::default();
assert!((config.confidence_threshold - InferenceConfig::DEFAULT_CONF).abs() < f32::EPSILON);
assert!((config.iou_threshold - InferenceConfig::DEFAULT_IOU).abs() < f32::EPSILON);
assert_eq!(config.max_det, 300);
}
#[test]
fn test_config_builder() {
let config = InferenceConfig::new()
.with_confidence(0.5)
.with_iou(0.6)
.with_max_det(300)
.with_imgsz(640, 640)
.with_threads(8);
assert!((config.confidence_threshold - 0.5).abs() < f32::EPSILON);
assert!((config.iou_threshold - 0.6).abs() < f32::EPSILON);
assert_eq!(config.max_det, 300);
assert_eq!(config.imgsz, Some((640, 640)));
assert_eq!(config.num_threads, 8);
}
#[test]
fn test_keep_class() {
let config = InferenceConfig::default();
assert!(config.keep_class(0));
assert!(config.keep_class(100));
let config_filtered = InferenceConfig::new().with_classes(vec![1, 3]);
assert!(config_filtered.keep_class(1));
assert!(config_filtered.keep_class(3));
assert!(!config_filtered.keep_class(0));
assert!(!config_filtered.keep_class(2));
}
}