1use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
2use burn::module::Module;
3use burn::tensor::TensorData;
4#[cfg(feature = "convolutional_detector")]
5use data_contracts::preprocess::{stats_from_rgba_u8, ImageStats};
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8use vision_core::interfaces::{DetectionResult, Detector, Frame};
9
10#[derive(Debug, Clone, Copy)]
15pub struct InferenceThresholds {
16 pub objectness_threshold: f32,
17 pub iou_threshold: f32,
18}
19
20impl Default for InferenceThresholds {
21 fn default() -> Self {
22 Self {
23 objectness_threshold: 0.3,
24 iou_threshold: 0.5,
25 }
26 }
27}
28
29struct HeuristicDetector {
31 obj_thresh: f32,
32}
33
34impl Detector for HeuristicDetector {
35 fn detect(&mut self, frame: &Frame) -> DetectionResult {
36 let confidence = self.obj_thresh;
37 DetectionResult {
38 frame_id: frame.id,
39 positive: confidence >= self.obj_thresh,
40 confidence,
41 boxes: Vec::new(),
42 scores: Vec::new(),
43 }
44 }
45}
46
47struct BurnDetector {
48 model: Arc<Mutex<InferenceModel<InferenceBackend>>>,
49 obj_thresh: f32,
50 #[allow(dead_code)]
51 iou_thresh: f32,
52}
53
54impl BurnDetector {
55 fn frame_to_tensor(&self, _frame: &Frame) -> TensorData {
56 #[cfg(feature = "convolutional_detector")]
57 {
58 let (w, h) = _frame.size;
59 let stats = if let Some(rgba) = &_frame.rgba {
60 stats_from_rgba_u8(w, h, rgba).unwrap_or_else(|_| ImageStats {
61 mean: [0.0; 3],
62 std: [0.0; 3],
63 aspect: w as f32 / h as f32,
64 })
65 } else {
66 ImageStats {
67 mean: [0.0; 3],
68 std: [0.0; 3],
69 aspect: w as f32 / h as f32,
70 }
71 };
72
73 let mut input = Vec::with_capacity(12);
74 input.extend_from_slice(&[0.0, 0.0, 1.0, 1.0]);
75 input.extend_from_slice(&stats.feature_vector(0.0));
76 TensorData::new(input, [1, 12])
77 }
78 #[cfg(not(feature = "convolutional_detector"))]
79 {
80 TensorData::new(vec![0.0, 0.0, 1.0, 1.0], [1, 4])
81 }
82 }
83}
84
85impl Detector for BurnDetector {
86 fn detect(&mut self, frame: &Frame) -> DetectionResult {
87 let input = self.frame_to_tensor(frame);
88 let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
89 let model = self.model.lock().expect("model mutex poisoned");
90 let logits = model.forward(burn::tensor::Tensor::<InferenceBackend, 2>::from_data(
91 input, &device,
92 ));
93 let scores = logits.into_data().to_vec::<f32>().unwrap_or_default();
94 let confidence = scores.first().copied().unwrap_or(0.0);
95 DetectionResult {
96 frame_id: frame.id,
97 positive: confidence >= self.obj_thresh,
98 confidence,
99 boxes: Vec::new(),
100 scores,
101 }
102 }
103}
104
105pub struct InferenceFactory;
107
108impl InferenceFactory {
109 pub fn build(
110 &self,
111 thresh: InferenceThresholds,
112 weights: Option<&Path>,
113 ) -> Box<dyn vision_core::interfaces::Detector + Send + Sync> {
114 if let Some(det) = self.try_load_burn_detector(thresh, weights) {
115 return det;
116 }
117 eprintln!("InferenceFactory: no valid checkpoint provided; using heuristic detector.");
118 Box::new(HeuristicDetector {
119 obj_thresh: thresh.objectness_threshold,
120 })
121 }
122
123 fn try_load_burn_detector(
124 &self,
125 thresh: InferenceThresholds,
126 weights: Option<&Path>,
127 ) -> Option<Box<dyn vision_core::interfaces::Detector + Send + Sync>> {
128 let path = weights?;
129 if !path.exists() {
130 return None;
131 }
132 let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
133 let recorder = burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new();
134 #[cfg(feature = "convolutional_detector")]
135 let config = InferenceModelConfig {
136 input_dim: Some(4 + 8),
137 ..Default::default()
138 };
139 #[cfg(not(feature = "convolutional_detector"))]
140 let config = InferenceModelConfig::default();
141
142 match InferenceModel::<InferenceBackend>::new(config, &device)
143 .load_file(path, &recorder, &device)
144 {
145 Ok(model) => Some(Box::new(BurnDetector {
146 model: Arc::new(Mutex::new(model)),
147 obj_thresh: thresh.objectness_threshold,
148 iou_thresh: thresh.iou_threshold,
149 })),
150 Err(err) => {
151 eprintln!(
152 "Failed to load detector checkpoint {:?}: {err}. Falling back to heuristic.",
153 path
154 );
155 None
156 }
157 }
158 }
159}