1use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
2use burn::module::Module;
3use burn::tensor::TensorData;
4#[cfg(feature = "bigdet")]
5use data_contracts::preprocess::{stats_from_rgba_u8, ImageStats};
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8use vision_core::interfaces::{DetectionResult, Detector, Frame};
9
10#[derive(Debug, Clone, Copy)]
12pub struct InferenceThresholds {
13 pub obj_thresh: f32,
14 pub iou_thresh: f32,
15}
16
17impl Default for InferenceThresholds {
18 fn default() -> Self {
19 Self {
20 obj_thresh: 0.3,
21 iou_thresh: 0.5,
22 }
23 }
24}
25
26struct HeuristicDetector {
28 obj_thresh: f32,
29}
30
31impl Detector for HeuristicDetector {
32 fn detect(&mut self, frame: &Frame) -> DetectionResult {
33 let confidence = self.obj_thresh;
34 DetectionResult {
35 frame_id: frame.id,
36 positive: confidence >= self.obj_thresh,
37 confidence,
38 boxes: Vec::new(),
39 scores: Vec::new(),
40 }
41 }
42}
43
44struct BurnTinyDetDetector {
45 model: Arc<Mutex<InferenceModel<InferenceBackend>>>,
46 obj_thresh: f32,
47 #[allow(dead_code)]
48 iou_thresh: f32,
49}
50
51impl BurnTinyDetDetector {
52 fn frame_to_tensor(&self, _frame: &Frame) -> TensorData {
53 #[cfg(feature = "bigdet")]
54 {
55 let (w, h) = _frame.size;
56 let stats = if let Some(rgba) = &_frame.rgba {
57 stats_from_rgba_u8(w, h, rgba).unwrap_or_else(|_| ImageStats {
58 mean: [0.0; 3],
59 std: [0.0; 3],
60 aspect: w as f32 / h as f32,
61 })
62 } else {
63 ImageStats {
64 mean: [0.0; 3],
65 std: [0.0; 3],
66 aspect: w as f32 / h as f32,
67 }
68 };
69
70 let mut input = Vec::with_capacity(12);
71 input.extend_from_slice(&[0.0, 0.0, 1.0, 1.0]);
72 input.extend_from_slice(&stats.feature_vector(0.0));
73 TensorData::new(input, [1, 12])
74 }
75 #[cfg(not(feature = "bigdet"))]
76 {
77 TensorData::new(vec![0.0, 0.0, 1.0, 1.0], [1, 4])
78 }
79 }
80}
81
82impl Detector for BurnTinyDetDetector {
83 fn detect(&mut self, frame: &Frame) -> DetectionResult {
84 let input = self.frame_to_tensor(frame);
85 let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
86 let model = self.model.lock().expect("model mutex poisoned");
87 let logits = model.forward(burn::tensor::Tensor::<InferenceBackend, 2>::from_data(
88 input, &device,
89 ));
90 let scores = logits.into_data().to_vec::<f32>().unwrap_or_default();
91 let confidence = scores.first().copied().unwrap_or(0.0);
92 DetectionResult {
93 frame_id: frame.id,
94 positive: confidence >= self.obj_thresh,
95 confidence,
96 boxes: Vec::new(),
97 scores,
98 }
99 }
100}
101
102pub struct InferenceFactory;
104
105impl InferenceFactory {
106 pub fn build(
107 &self,
108 thresh: InferenceThresholds,
109 weights: Option<&Path>,
110 ) -> Box<dyn vision_core::interfaces::Detector + Send + Sync> {
111 if let Some(det) = self.try_load_burn_detector(thresh, weights) {
112 return det;
113 }
114 eprintln!("InferenceFactory: no valid checkpoint provided; using heuristic detector.");
115 Box::new(HeuristicDetector {
116 obj_thresh: thresh.obj_thresh,
117 })
118 }
119
120 fn try_load_burn_detector(
121 &self,
122 thresh: InferenceThresholds,
123 weights: Option<&Path>,
124 ) -> Option<Box<dyn vision_core::interfaces::Detector + Send + Sync>> {
125 let path = weights?;
126 if !path.exists() {
127 return None;
128 }
129 let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
130 let recorder = burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new();
131 #[cfg(feature = "bigdet")]
132 let config = InferenceModelConfig {
133 input_dim: Some(4 + 8),
134 ..Default::default()
135 };
136 #[cfg(not(feature = "bigdet"))]
137 let config = InferenceModelConfig::default();
138
139 match InferenceModel::<InferenceBackend>::new(config, &device)
140 .load_file(path, &recorder, &device)
141 {
142 Ok(model) => Some(Box::new(BurnTinyDetDetector {
143 model: Arc::new(Mutex::new(model)),
144 obj_thresh: thresh.obj_thresh,
145 iou_thresh: thresh.iou_thresh,
146 })),
147 Err(err) => {
148 eprintln!(
149 "Failed to load detector checkpoint {:?}: {err}. Falling back to heuristic.",
150 path
151 );
152 None
153 }
154 }
155 }
156}