inference/
factory.rs

1use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
2use burn::module::Module;
3use burn::tensor::TensorData;
4#[cfg(feature = "convolutional_detector")]
5use data_contracts::preprocess::{stats_from_rgba_u8, ImageStats};
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8use vision_core::interfaces::{DetectionResult, Detector, Frame};
9
10/// Thresholds for inference (objectness + IoU).
11///
12/// This is a framework-agnostic type. For Bevy applications, use
13/// `InferenceThresholdsResource` from the `vision_runtime` crate.
14#[derive(Debug, Clone, Copy)]
15pub struct InferenceThresholds {
16    pub objectness_threshold: f32,
17    pub iou_threshold: f32,
18}
19
20impl Default for InferenceThresholds {
21    fn default() -> Self {
22        Self {
23            objectness_threshold: 0.3,
24            iou_threshold: 0.5,
25        }
26    }
27}
28
29/// Heuristic detector placeholder; used when Burn weights are unavailable.
30struct HeuristicDetector {
31    obj_thresh: f32,
32}
33
34impl Detector for HeuristicDetector {
35    fn detect(&mut self, frame: &Frame) -> DetectionResult {
36        let confidence = self.obj_thresh;
37        DetectionResult {
38            frame_id: frame.id,
39            positive: confidence >= self.obj_thresh,
40            confidence,
41            boxes: Vec::new(),
42            scores: Vec::new(),
43        }
44    }
45}
46
47struct BurnDetector {
48    model: Arc<Mutex<InferenceModel<InferenceBackend>>>,
49    obj_thresh: f32,
50    #[allow(dead_code)]
51    iou_thresh: f32,
52}
53
54impl BurnDetector {
55    fn frame_to_tensor(&self, _frame: &Frame) -> TensorData {
56        #[cfg(feature = "convolutional_detector")]
57        {
58            let (w, h) = _frame.size;
59            let stats = if let Some(rgba) = &_frame.rgba {
60                stats_from_rgba_u8(w, h, rgba).unwrap_or_else(|_| ImageStats {
61                    mean: [0.0; 3],
62                    std: [0.0; 3],
63                    aspect: w as f32 / h as f32,
64                })
65            } else {
66                ImageStats {
67                    mean: [0.0; 3],
68                    std: [0.0; 3],
69                    aspect: w as f32 / h as f32,
70                }
71            };
72
73            let mut input = Vec::with_capacity(12);
74            input.extend_from_slice(&[0.0, 0.0, 1.0, 1.0]);
75            input.extend_from_slice(&stats.feature_vector(0.0));
76            TensorData::new(input, [1, 12])
77        }
78        #[cfg(not(feature = "convolutional_detector"))]
79        {
80            TensorData::new(vec![0.0, 0.0, 1.0, 1.0], [1, 4])
81        }
82    }
83}
84
85impl Detector for BurnDetector {
86    fn detect(&mut self, frame: &Frame) -> DetectionResult {
87        let input = self.frame_to_tensor(frame);
88        let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
89        let model = self.model.lock().expect("model mutex poisoned");
90        let logits = model.forward(burn::tensor::Tensor::<InferenceBackend, 2>::from_data(
91            input, &device,
92        ));
93        let scores = logits.into_data().to_vec::<f32>().unwrap_or_default();
94        let confidence = scores.first().copied().unwrap_or(0.0);
95        DetectionResult {
96            frame_id: frame.id,
97            positive: confidence >= self.obj_thresh,
98            confidence,
99            boxes: Vec::new(),
100            scores,
101        }
102    }
103}
104
105/// Factory that will load Burn checkpoints when available; currently returns heuristic.
106pub struct InferenceFactory;
107
108impl InferenceFactory {
109    pub fn build(
110        &self,
111        thresh: InferenceThresholds,
112        weights: Option<&Path>,
113    ) -> Box<dyn vision_core::interfaces::Detector + Send + Sync> {
114        if let Some(det) = self.try_load_burn_detector(thresh, weights) {
115            return det;
116        }
117        eprintln!("InferenceFactory: no valid checkpoint provided; using heuristic detector.");
118        Box::new(HeuristicDetector {
119            obj_thresh: thresh.objectness_threshold,
120        })
121    }
122
123    fn try_load_burn_detector(
124        &self,
125        thresh: InferenceThresholds,
126        weights: Option<&Path>,
127    ) -> Option<Box<dyn vision_core::interfaces::Detector + Send + Sync>> {
128        let path = weights?;
129        if !path.exists() {
130            return None;
131        }
132        let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
133        let recorder = burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new();
134        #[cfg(feature = "convolutional_detector")]
135        let config = InferenceModelConfig {
136            input_dim: Some(4 + 8),
137            ..Default::default()
138        };
139        #[cfg(not(feature = "convolutional_detector"))]
140        let config = InferenceModelConfig::default();
141
142        match InferenceModel::<InferenceBackend>::new(config, &device)
143            .load_file(path, &recorder, &device)
144        {
145            Ok(model) => Some(Box::new(BurnDetector {
146                model: Arc::new(Mutex::new(model)),
147                obj_thresh: thresh.objectness_threshold,
148                iou_thresh: thresh.iou_threshold,
149            })),
150            Err(err) => {
151                eprintln!(
152                    "Failed to load detector checkpoint {:?}: {err}. Falling back to heuristic.",
153                    path
154                );
155                None
156            }
157        }
158    }
159}