inference/
factory.rs

1use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
2use burn::module::Module;
3use burn::tensor::TensorData;
4#[cfg(feature = "bigdet")]
5use data_contracts::preprocess::{stats_from_rgba_u8, ImageStats};
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8use vision_core::interfaces::{DetectionResult, Detector, Frame};
9
10/// Thresholds for inference (objectness + IoU).
11#[derive(Debug, Clone, Copy)]
12pub struct InferenceThresholds {
13    pub obj_thresh: f32,
14    pub iou_thresh: f32,
15}
16
17impl Default for InferenceThresholds {
18    fn default() -> Self {
19        Self {
20            obj_thresh: 0.3,
21            iou_thresh: 0.5,
22        }
23    }
24}
25
26/// Heuristic detector placeholder; used when Burn weights are unavailable.
27struct HeuristicDetector {
28    obj_thresh: f32,
29}
30
31impl Detector for HeuristicDetector {
32    fn detect(&mut self, frame: &Frame) -> DetectionResult {
33        let confidence = self.obj_thresh;
34        DetectionResult {
35            frame_id: frame.id,
36            positive: confidence >= self.obj_thresh,
37            confidence,
38            boxes: Vec::new(),
39            scores: Vec::new(),
40        }
41    }
42}
43
44struct BurnTinyDetDetector {
45    model: Arc<Mutex<InferenceModel<InferenceBackend>>>,
46    obj_thresh: f32,
47    #[allow(dead_code)]
48    iou_thresh: f32,
49}
50
51impl BurnTinyDetDetector {
52    fn frame_to_tensor(&self, _frame: &Frame) -> TensorData {
53        #[cfg(feature = "bigdet")]
54        {
55            let (w, h) = _frame.size;
56            let stats = if let Some(rgba) = &_frame.rgba {
57                stats_from_rgba_u8(w, h, rgba).unwrap_or_else(|_| ImageStats {
58                    mean: [0.0; 3],
59                    std: [0.0; 3],
60                    aspect: w as f32 / h as f32,
61                })
62            } else {
63                ImageStats {
64                    mean: [0.0; 3],
65                    std: [0.0; 3],
66                    aspect: w as f32 / h as f32,
67                }
68            };
69
70            let mut input = Vec::with_capacity(12);
71            input.extend_from_slice(&[0.0, 0.0, 1.0, 1.0]);
72            input.extend_from_slice(&stats.feature_vector(0.0));
73            TensorData::new(input, [1, 12])
74        }
75        #[cfg(not(feature = "bigdet"))]
76        {
77            TensorData::new(vec![0.0, 0.0, 1.0, 1.0], [1, 4])
78        }
79    }
80}
81
82impl Detector for BurnTinyDetDetector {
83    fn detect(&mut self, frame: &Frame) -> DetectionResult {
84        let input = self.frame_to_tensor(frame);
85        let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
86        let model = self.model.lock().expect("model mutex poisoned");
87        let logits = model.forward(burn::tensor::Tensor::<InferenceBackend, 2>::from_data(
88            input, &device,
89        ));
90        let scores = logits.into_data().to_vec::<f32>().unwrap_or_default();
91        let confidence = scores.first().copied().unwrap_or(0.0);
92        DetectionResult {
93            frame_id: frame.id,
94            positive: confidence >= self.obj_thresh,
95            confidence,
96            boxes: Vec::new(),
97            scores,
98        }
99    }
100}
101
102/// Factory that will load Burn checkpoints when available; currently returns heuristic.
103pub struct InferenceFactory;
104
105impl InferenceFactory {
106    pub fn build(
107        &self,
108        thresh: InferenceThresholds,
109        weights: Option<&Path>,
110    ) -> Box<dyn vision_core::interfaces::Detector + Send + Sync> {
111        if let Some(det) = self.try_load_burn_detector(thresh, weights) {
112            return det;
113        }
114        eprintln!("InferenceFactory: no valid checkpoint provided; using heuristic detector.");
115        Box::new(HeuristicDetector {
116            obj_thresh: thresh.obj_thresh,
117        })
118    }
119
120    fn try_load_burn_detector(
121        &self,
122        thresh: InferenceThresholds,
123        weights: Option<&Path>,
124    ) -> Option<Box<dyn vision_core::interfaces::Detector + Send + Sync>> {
125        let path = weights?;
126        if !path.exists() {
127            return None;
128        }
129        let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
130        let recorder = burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new();
131        #[cfg(feature = "bigdet")]
132        let config = InferenceModelConfig {
133            input_dim: Some(4 + 8),
134            ..Default::default()
135        };
136        #[cfg(not(feature = "bigdet"))]
137        let config = InferenceModelConfig::default();
138
139        match InferenceModel::<InferenceBackend>::new(config, &device)
140            .load_file(path, &recorder, &device)
141        {
142            Ok(model) => Some(Box::new(BurnTinyDetDetector {
143                model: Arc::new(Mutex::new(model)),
144                obj_thresh: thresh.obj_thresh,
145                iou_thresh: thresh.iou_thresh,
146            })),
147            Err(err) => {
148                eprintln!(
149                    "Failed to load detector checkpoint {:?}: {err}. Falling back to heuristic.",
150                    path
151                );
152                None
153            }
154        }
155    }
156}