inference/
factory.rs

1use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
2use burn::module::Module;
3use burn::tensor::TensorData;
4use std::path::Path;
5use std::sync::{Arc, Mutex};
6use vision_core::interfaces::{DetectionResult, Detector, Frame};
7
8/// Thresholds for inference (objectness + IoU).
9#[derive(Debug, Clone, Copy)]
10pub struct InferenceThresholds {
11    pub obj_thresh: f32,
12    pub iou_thresh: f32,
13}
14
15impl Default for InferenceThresholds {
16    fn default() -> Self {
17        Self {
18            obj_thresh: 0.3,
19            iou_thresh: 0.5,
20        }
21    }
22}
23
24/// Heuristic detector placeholder; used when Burn weights are unavailable.
25struct HeuristicDetector {
26    obj_thresh: f32,
27}
28
29impl Detector for HeuristicDetector {
30    fn detect(&mut self, frame: &Frame) -> DetectionResult {
31        let confidence = self.obj_thresh;
32        DetectionResult {
33            frame_id: frame.id,
34            positive: confidence >= self.obj_thresh,
35            confidence,
36            boxes: Vec::new(),
37            scores: Vec::new(),
38        }
39    }
40}
41
42struct BurnTinyDetDetector {
43    model: Arc<Mutex<InferenceModel<InferenceBackend>>>,
44    obj_thresh: f32,
45    #[allow(dead_code)]
46    iou_thresh: f32,
47}
48
49impl BurnTinyDetDetector {
50    fn frame_to_tensor(&self, frame: &Frame) -> TensorData {
51        let (w, h) = frame.size;
52        if let Some(rgba) = &frame.rgba {
53            let mut mean = [0f32; 3];
54            let mut count = 0usize;
55            for chunk in rgba.chunks_exact(4) {
56                mean[0] += chunk[0] as f32;
57                mean[1] += chunk[1] as f32;
58                mean[2] += chunk[2] as f32;
59                count += 1;
60            }
61            if count > 0 {
62                mean[0] /= count as f32 * 255.0;
63                mean[1] /= count as f32 * 255.0;
64                mean[2] /= count as f32 * 255.0;
65            }
66            TensorData::new(vec![mean[0], mean[1], mean[2], w as f32 / h as f32], [1, 4])
67        } else {
68            TensorData::new(vec![0.0, 0.0, 0.0, w as f32 / h as f32], [1, 4])
69        }
70    }
71}
72
73impl Detector for BurnTinyDetDetector {
74    fn detect(&mut self, frame: &Frame) -> DetectionResult {
75        let input = self.frame_to_tensor(frame);
76        let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
77        let model = self.model.lock().expect("model mutex poisoned");
78        let logits = model.forward(burn::tensor::Tensor::<InferenceBackend, 2>::from_data(
79            input, &device,
80        ));
81        let scores = logits.into_data().to_vec::<f32>().unwrap_or_default();
82        let confidence = scores.first().copied().unwrap_or(0.0);
83        DetectionResult {
84            frame_id: frame.id,
85            positive: confidence >= self.obj_thresh,
86            confidence,
87            boxes: Vec::new(),
88            scores,
89        }
90    }
91}
92
93/// Factory that will load Burn checkpoints when available; currently returns heuristic.
94pub struct InferenceFactory;
95
96impl InferenceFactory {
97    pub fn build(
98        &self,
99        thresh: InferenceThresholds,
100        weights: Option<&Path>,
101    ) -> Box<dyn vision_core::interfaces::Detector + Send + Sync> {
102        if let Some(det) = self.try_load_burn_detector(thresh, weights) {
103            return det;
104        }
105        eprintln!("InferenceFactory: no valid checkpoint provided; using heuristic detector.");
106        Box::new(HeuristicDetector {
107            obj_thresh: thresh.obj_thresh,
108        })
109    }
110
111    fn try_load_burn_detector(
112        &self,
113        thresh: InferenceThresholds,
114        weights: Option<&Path>,
115    ) -> Option<Box<dyn vision_core::interfaces::Detector + Send + Sync>> {
116        let path = weights?;
117        if !path.exists() {
118            return None;
119        }
120        let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
121        let recorder = burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new();
122        match InferenceModel::<InferenceBackend>::new(InferenceModelConfig::default(), &device)
123            .load_file(path, &recorder, &device)
124        {
125            Ok(model) => Some(Box::new(BurnTinyDetDetector {
126                model: Arc::new(Mutex::new(model)),
127                obj_thresh: thresh.obj_thresh,
128                iou_thresh: thresh.iou_thresh,
129            })),
130            Err(err) => {
131                eprintln!(
132                    "Failed to load detector checkpoint {:?}: {err}. Falling back to heuristic.",
133                    path
134                );
135                None
136            }
137        }
138    }
139}