1#![recursion_limit = "256"]
18
19pub mod factory;
20
21#[cfg(feature = "backend-wgpu")]
22pub type InferenceBackend = burn_wgpu::Wgpu<f32>;
23#[cfg(not(feature = "backend-wgpu"))]
24pub type InferenceBackend = burn_ndarray::NdArray<f32>;
25
26#[cfg(feature = "convolutional_detector")]
27pub type InferenceModel<B> = models::MultiboxModel<B>;
28#[cfg(feature = "convolutional_detector")]
29pub type InferenceModelConfig = models::MultiboxModelConfig;
30#[cfg(not(feature = "convolutional_detector"))]
31pub type InferenceModel<B> = models::LinearClassifier<B>;
32#[cfg(not(feature = "convolutional_detector"))]
33pub type InferenceModelConfig = models::LinearClassifierConfig;
34
35pub use factory::{InferenceFactory, InferenceThresholds};
36
37pub mod prelude {
38 pub use crate::factory::{InferenceFactory, InferenceThresholds};
39 pub use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45 #[test]
46 fn inference_factory_falls_back_without_weights() {
47 let factory = InferenceFactory;
48 let mut detector = factory.build(InferenceThresholds::default(), None);
49 assert!(
51 detector
52 .detect(&vision_core::interfaces::Frame {
53 id: 0,
54 timestamp: 0.0,
55 rgba: None,
56 size: (1, 1),
57 path: None,
58 })
59 .frame_id
60 == 0
61 );
62 }
63}