inference/
lib.rs

1//! Detector factory and runtime configuration for CortenForge inference.
2//!
3//! This crate bridges the `models` crate (Burn Modules) and the `vision_core::interfaces::Detector`
4//! trait. The `InferenceFactory` loads weights and constructs detector implementations that can be
5//! used in vision pipelines.
6//!
7//! ## Backend Selection
8//! - `backend-wgpu`: Uses WGPU for GPU acceleration (recommended for production).
9//! - Default: Falls back to NdArray CPU backend.
10//!
11//! ## Model Selection
12//! - `convolutional_detector`: Uses `MultiboxModel` for multi-box detection.
13//! - Default: Uses `LinearClassifier` for binary classification.
14//!
15//! Type aliases `InferenceModel` and `InferenceModelConfig` adapt to the selected features.
16
17#![recursion_limit = "256"]
18
19pub mod factory;
20
21#[cfg(feature = "backend-wgpu")]
22pub type InferenceBackend = burn_wgpu::Wgpu<f32>;
23#[cfg(not(feature = "backend-wgpu"))]
24pub type InferenceBackend = burn_ndarray::NdArray<f32>;
25
26#[cfg(feature = "convolutional_detector")]
27pub type InferenceModel<B> = models::MultiboxModel<B>;
28#[cfg(feature = "convolutional_detector")]
29pub type InferenceModelConfig = models::MultiboxModelConfig;
30#[cfg(not(feature = "convolutional_detector"))]
31pub type InferenceModel<B> = models::LinearClassifier<B>;
32#[cfg(not(feature = "convolutional_detector"))]
33pub type InferenceModelConfig = models::LinearClassifierConfig;
34
35pub use factory::{InferenceFactory, InferenceThresholds};
36
37pub mod prelude {
38    pub use crate::factory::{InferenceFactory, InferenceThresholds};
39    pub use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45    #[test]
46    fn inference_factory_falls_back_without_weights() {
47        let factory = InferenceFactory;
48        let mut detector = factory.build(InferenceThresholds::default(), None);
49        // Should not panic and should produce a detector.
50        assert!(
51            detector
52                .detect(&vision_core::interfaces::Frame {
53                    id: 0,
54                    timestamp: 0.0,
55                    rgba: None,
56                    size: (1, 1),
57                    path: None,
58                })
59                .frame_id
60                == 0
61        );
62    }
63}