1use crate::{InferenceBackend, InferenceModel, InferenceModelConfig};
2use burn::module::Module;
3use burn::tensor::TensorData;
4use std::path::Path;
5use std::sync::{Arc, Mutex};
6use vision_core::interfaces::{DetectionResult, Detector, Frame};
7
8#[derive(Debug, Clone, Copy)]
10pub struct InferenceThresholds {
11 pub obj_thresh: f32,
12 pub iou_thresh: f32,
13}
14
15impl Default for InferenceThresholds {
16 fn default() -> Self {
17 Self {
18 obj_thresh: 0.3,
19 iou_thresh: 0.5,
20 }
21 }
22}
23
24struct HeuristicDetector {
26 obj_thresh: f32,
27}
28
29impl Detector for HeuristicDetector {
30 fn detect(&mut self, frame: &Frame) -> DetectionResult {
31 let confidence = self.obj_thresh;
32 DetectionResult {
33 frame_id: frame.id,
34 positive: confidence >= self.obj_thresh,
35 confidence,
36 boxes: Vec::new(),
37 scores: Vec::new(),
38 }
39 }
40}
41
42struct BurnTinyDetDetector {
43 model: Arc<Mutex<InferenceModel<InferenceBackend>>>,
44 obj_thresh: f32,
45 #[allow(dead_code)]
46 iou_thresh: f32,
47}
48
49impl BurnTinyDetDetector {
50 fn frame_to_tensor(&self, frame: &Frame) -> TensorData {
51 let (w, h) = frame.size;
52 if let Some(rgba) = &frame.rgba {
53 let mut mean = [0f32; 3];
54 let mut count = 0usize;
55 for chunk in rgba.chunks_exact(4) {
56 mean[0] += chunk[0] as f32;
57 mean[1] += chunk[1] as f32;
58 mean[2] += chunk[2] as f32;
59 count += 1;
60 }
61 if count > 0 {
62 mean[0] /= count as f32 * 255.0;
63 mean[1] /= count as f32 * 255.0;
64 mean[2] /= count as f32 * 255.0;
65 }
66 TensorData::new(vec![mean[0], mean[1], mean[2], w as f32 / h as f32], [1, 4])
67 } else {
68 TensorData::new(vec![0.0, 0.0, 0.0, w as f32 / h as f32], [1, 4])
69 }
70 }
71}
72
73impl Detector for BurnTinyDetDetector {
74 fn detect(&mut self, frame: &Frame) -> DetectionResult {
75 let input = self.frame_to_tensor(frame);
76 let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
77 let model = self.model.lock().expect("model mutex poisoned");
78 let logits = model.forward(burn::tensor::Tensor::<InferenceBackend, 2>::from_data(
79 input, &device,
80 ));
81 let scores = logits.into_data().to_vec::<f32>().unwrap_or_default();
82 let confidence = scores.first().copied().unwrap_or(0.0);
83 DetectionResult {
84 frame_id: frame.id,
85 positive: confidence >= self.obj_thresh,
86 confidence,
87 boxes: Vec::new(),
88 scores,
89 }
90 }
91}
92
93pub struct InferenceFactory;
95
96impl InferenceFactory {
97 pub fn build(
98 &self,
99 thresh: InferenceThresholds,
100 weights: Option<&Path>,
101 ) -> Box<dyn vision_core::interfaces::Detector + Send + Sync> {
102 if let Some(det) = self.try_load_burn_detector(thresh, weights) {
103 return det;
104 }
105 eprintln!("InferenceFactory: no valid checkpoint provided; using heuristic detector.");
106 Box::new(HeuristicDetector {
107 obj_thresh: thresh.obj_thresh,
108 })
109 }
110
111 fn try_load_burn_detector(
112 &self,
113 thresh: InferenceThresholds,
114 weights: Option<&Path>,
115 ) -> Option<Box<dyn vision_core::interfaces::Detector + Send + Sync>> {
116 let path = weights?;
117 if !path.exists() {
118 return None;
119 }
120 let device = <InferenceBackend as burn::tensor::backend::Backend>::Device::default();
121 let recorder = burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new();
122 match InferenceModel::<InferenceBackend>::new(InferenceModelConfig::default(), &device)
123 .load_file(path, &recorder, &device)
124 {
125 Ok(model) => Some(Box::new(BurnTinyDetDetector {
126 model: Arc::new(Mutex::new(model)),
127 obj_thresh: thresh.obj_thresh,
128 iou_thresh: thresh.iou_thresh,
129 })),
130 Err(err) => {
131 eprintln!(
132 "Failed to load detector checkpoint {:?}: {err}. Falling back to heuristic.",
133 path
134 );
135 None
136 }
137 }
138 }
139}