1use bevy::asset::RenderAssetUsages;
2use bevy::prelude::*;
3use bevy::render::gpu_readback::{Readback, ReadbackComplete};
4use bevy::render::render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages};
5use bevy::tasks::{AsyncComputeTaskPool, Task};
6use bevy_camera::{ImageRenderTarget, RenderTarget};
7use futures_lite::future::{block_on, poll_once};
8use image::RgbaImage;
9use sim_core::{ModeSet, SimRunMode};
10use vision_core::capture::{
11 FrontCamera, FrontCaptureCamera, FrontCaptureReadback, FrontCaptureTarget,
12};
13use vision_core::interfaces::{self, Frame};
14use vision_core::overlay::draw_rect;
15
16type InferenceJobResult = (
17 Box<dyn interfaces::Detector + Send + Sync>,
18 DetectorKind,
19 interfaces::DetectionResult,
20 f32,
21 (u32, u32),
22);
23
24#[derive(Clone)]
25pub struct FrontCameraFrame {
26 pub id: u64,
27 pub transform: GlobalTransform,
28 pub captured_at: f64,
29}
30
31#[derive(Resource, Default)]
32pub struct FrontCameraState {
33 pub active: bool,
34 pub last_transform: Option<GlobalTransform>,
35 pub frame_counter: u64,
36}
37
38#[derive(Resource, Default)]
39pub struct FrontCameraFrameBuffer {
40 pub latest: Option<FrontCameraFrame>,
41}
42
43#[derive(Resource, Default)]
44pub struct BurnDetector {
45 pub model_loaded: bool,
46}
47
48#[derive(Resource, Default, Clone)]
49pub struct DetectionOverlayState {
50 pub boxes: Vec<[f32; 4]>,
51 pub scores: Vec<f32>,
52 pub size: (u32, u32),
53 pub fallback: Option<String>,
54 pub inference_ms: Option<f32>,
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq, Resource)]
58pub enum DetectorKind {
59 Burn,
60 Heuristic,
61}
62
63#[derive(Resource, Debug, Clone, Copy)]
64pub struct InferenceThresholds {
65 pub obj_thresh: f32,
66 pub iou_thresh: f32,
67}
68
69#[derive(Clone)]
70pub struct BurnDetectionResult {
71 pub frame_id: u64,
72 pub positive: bool,
73 pub confidence: f32,
74 pub boxes: Vec<[f32; 4]>,
75 pub scores: Vec<f32>,
76}
77
78#[derive(Resource)]
79pub struct BurnInferenceState {
80 pub pending: Option<Task<InferenceJobResult>>,
81 pub last_result: Option<BurnDetectionResult>,
82 pub debounce: Timer,
83}
84
85impl Default for BurnInferenceState {
86 fn default() -> Self {
87 Self {
88 pending: None,
89 last_result: None,
90 debounce: Timer::from_seconds(0.18, TimerMode::Repeating),
91 }
92 }
93}
94
95#[derive(Resource)]
96pub struct DetectorHandle {
97 pub detector: Box<dyn interfaces::Detector + Send + Sync>,
98 pub kind: DetectorKind,
99}
100
101struct HeuristicDetector;
102
103impl interfaces::Detector for HeuristicDetector {
104 fn detect(&mut self, frame: &Frame) -> interfaces::DetectionResult {
105 interfaces::DetectionResult {
106 frame_id: frame.id,
107 positive: true,
108 confidence: 0.8,
109 boxes: Vec::new(),
110 scores: Vec::new(),
111 }
112 }
113}
114
115pub fn setup_front_capture(
118 mut commands: Commands,
119 mut images: ResMut<Assets<Image>>,
120 mut state: ResMut<FrontCameraState>,
121 mut target: ResMut<FrontCaptureTarget>,
122) {
123 if target.size != UVec2::ZERO {
125 return;
126 }
127
128 let size = UVec2::new(1280, 720);
129 let mut image = Image::new_fill(
130 Extent3d {
131 width: size.x,
132 height: size.y,
133 ..default()
134 },
135 TextureDimension::D2,
136 &[0, 0, 0, 255],
137 TextureFormat::Rgba8UnormSrgb,
138 RenderAssetUsages::default(),
139 );
140 image.texture_descriptor.usage =
141 TextureUsages::COPY_SRC | TextureUsages::TEXTURE_BINDING | TextureUsages::RENDER_ATTACHMENT;
142 let handle = images.add(image);
143
144 let cam_entity = commands
145 .spawn((
146 Camera3d::default(),
147 Camera {
148 order: -10,
149 is_active: true,
150 target: RenderTarget::Image(ImageRenderTarget::from(handle.clone())),
151 ..default()
152 },
153 Projection::from(PerspectiveProjection {
154 fov: 20.0f32.to_radians(),
155 ..default()
156 }),
157 Transform::from_translation(Vec3::ZERO),
158 GlobalTransform::default(),
159 Visibility::default(),
160 InheritedVisibility::default(),
161 ViewVisibility::default(),
162 FrontCamera,
163 FrontCaptureCamera,
164 Name::new("FrontCaptureCamera"),
165 ))
166 .id();
167
168 target.size = size;
169 target.handle = handle;
170 target.entity = cam_entity;
171 state.active = true;
172}
173
174pub fn track_front_camera_state(
175 target: Res<FrontCaptureTarget>,
176 mut state: ResMut<FrontCameraState>,
177 mut buffer: ResMut<FrontCameraFrameBuffer>,
178 cameras: Query<&GlobalTransform, With<FrontCaptureCamera>>,
179 time: Res<Time>,
180) {
181 let Ok(transform) = cameras.get(target.entity) else {
182 return;
183 };
184 state.last_transform = Some(*transform);
185 state.frame_counter = state.frame_counter.wrapping_add(1);
186 buffer.latest = Some(FrontCameraFrame {
187 id: state.frame_counter,
188 transform: *transform,
189 captured_at: time.elapsed_secs_f64(),
190 });
191}
192
193pub fn capture_front_camera_frame(
194 mode: Res<SimRunMode>,
195 mut commands: Commands,
196 target: Res<FrontCaptureTarget>,
197) {
198 if !matches!(*mode, SimRunMode::Datagen | SimRunMode::Inference) {
199 return;
200 }
201 commands
202 .entity(target.entity)
203 .insert(Readback::texture(target.handle.clone()));
204}
205
206pub fn on_front_capture_readback(
207 ev: On<ReadbackComplete>,
208 target: Res<FrontCaptureTarget>,
209 mut readback: ResMut<FrontCaptureReadback>,
210) {
211 let expected_len = (target.size.x * target.size.y * 4) as usize;
212 let ev = ev.event();
213 if ev.entity != target.entity {
214 return;
215 }
216 if ev.data.len() == expected_len {
217 readback.latest = Some(ev.data.clone());
218 }
219}
220
221pub struct CapturePlugin;
222
223impl Plugin for CapturePlugin {
224 fn build(&self, app: &mut App) {
225 app.insert_resource(FrontCaptureTarget {
226 handle: Handle::default(),
227 size: UVec2::ZERO,
228 entity: Entity::PLACEHOLDER,
229 })
230 .init_resource::<FrontCaptureReadback>()
231 .init_resource::<FrontCameraState>()
232 .init_resource::<FrontCameraFrameBuffer>()
233 .add_systems(Startup, setup_front_capture)
234 .add_systems(Update, track_front_camera_state.in_set(ModeSet::Common))
235 .add_systems(Update, capture_front_camera_frame.in_set(ModeSet::Common))
236 .add_observer(on_front_capture_readback);
237 }
238}
239
240pub fn schedule_burn_inference(
243 mode: Res<SimRunMode>,
244 time: Res<Time>,
245 mut jobs: ResMut<BurnInferenceState>,
246 mut buffer: ResMut<FrontCameraFrameBuffer>,
247 handle: Option<ResMut<DetectorHandle>>,
248 target: Res<FrontCaptureTarget>,
249 mut readback: ResMut<FrontCaptureReadback>,
250) {
251 if !matches!(*mode, SimRunMode::Inference) {
252 return;
253 }
254 let Some(mut handle) = handle else {
255 return;
256 };
257
258 jobs.debounce.tick(time.delta());
259 if jobs.pending.is_some() || !jobs.debounce.is_finished() {
260 return;
261 }
262 let Some(frame) = buffer.latest.take() else {
263 return;
264 };
265
266 let rgba = readback.latest.take();
267 let start = std::time::Instant::now();
268 let f = Frame {
269 id: frame.id,
270 timestamp: frame.captured_at,
271 rgba,
272 size: (target.size.x, target.size.y),
273 path: None,
274 };
275 let mut detector = std::mem::replace(&mut handle.detector, Box::new(HeuristicDetector));
276 let kind = handle.kind;
277 let size = (target.size.x, target.size.y);
278 let task = AsyncComputeTaskPool::get().spawn(async move {
279 let result = detector.detect(&f);
280 let infer_ms = start.elapsed().as_secs_f32() * 1000.0;
281 (detector, kind, result, infer_ms, size)
282 });
283 jobs.pending = Some(task);
284}
285
286pub fn threshold_hotkeys(
287 mode: Res<SimRunMode>,
288 keys: Res<ButtonInput<KeyCode>>,
289 thresh: Option<ResMut<InferenceThresholds>>,
290 handle: Option<ResMut<DetectorHandle>>,
291 burn_loaded: Option<ResMut<BurnDetector>>,
292) {
293 if !matches!(*mode, SimRunMode::Inference) {
294 return;
295 }
296 let (Some(mut thresh), Some(mut handle)) = (thresh, handle) else {
297 return;
298 };
299 let Some(mut burn_loaded) = burn_loaded else {
300 return;
301 };
302
303 let mut changed = false;
304 if keys.just_pressed(KeyCode::Minus) {
305 thresh.obj_thresh = (thresh.obj_thresh - 0.05).clamp(0.0, 1.0);
306 changed = true;
307 }
308 if keys.just_pressed(KeyCode::Equal) {
309 thresh.obj_thresh = (thresh.obj_thresh + 0.05).clamp(0.0, 1.0);
310 changed = true;
311 }
312 if keys.just_pressed(KeyCode::BracketLeft) {
313 thresh.iou_thresh = (thresh.iou_thresh - 0.05).clamp(0.1, 0.95);
314 changed = true;
315 }
316 if keys.just_pressed(KeyCode::BracketRight) {
317 thresh.iou_thresh = (thresh.iou_thresh + 0.05).clamp(0.1, 0.95);
318 changed = true;
319 }
320
321 if keys.just_pressed(KeyCode::Digit0) {
322 handle.detector = Box::new(HeuristicDetector);
323 handle.kind = DetectorKind::Heuristic;
324 burn_loaded.model_loaded = false;
325 changed = true;
326 }
327
328 if changed {
329 info!(
330 "Updated inference thresholds: obj {:.2}, iou {:.2}",
331 thresh.obj_thresh, thresh.iou_thresh
332 );
333 }
334}
335
336pub struct InferencePlugin;
337
338impl Plugin for InferencePlugin {
339 fn build(&self, app: &mut App) {
340 app.init_resource::<BurnInferenceState>()
341 .init_resource::<BurnDetector>()
342 .init_resource::<DetectionOverlayState>()
343 .add_systems(
344 Update,
345 (
346 schedule_burn_inference,
347 poll_inference_task,
348 threshold_hotkeys,
349 )
350 .in_set(ModeSet::Inference),
351 );
352 }
353}
354
355pub fn recorder_draw_rect(
358 img: &mut RgbaImage,
359 bbox_px: [u32; 4],
360 color: image::Rgba<u8>,
361 thickness: u32,
362) {
363 draw_rect(img, bbox_px, color, thickness);
364}
365
366pub mod prelude {
367 pub use super::{
368 BurnDetectionResult, BurnDetector, BurnInferenceState, CapturePlugin,
369 DetectionOverlayState, DetectorHandle, DetectorKind, FrontCameraFrame,
370 FrontCameraFrameBuffer, FrontCameraState, InferencePlugin, InferenceThresholds,
371 };
372}
373pub fn poll_inference_task(
374 mut jobs: ResMut<BurnInferenceState>,
375 mut overlay: ResMut<DetectionOverlayState>,
376 handle: Option<ResMut<DetectorHandle>>,
377 mut burn_detector: ResMut<BurnDetector>,
378) {
379 let Some(mut task) = jobs.pending.take() else {
380 return;
381 };
382 if let Some((detector, kind, result, infer_ms, size)) = block_on(poll_once(&mut task)) {
383 if let Some(mut handle) = handle {
384 handle.detector = detector;
385 handle.kind = kind;
386 }
387 burn_detector.model_loaded = matches!(kind, DetectorKind::Burn);
388 if matches!(kind, DetectorKind::Heuristic) {
389 overlay.fallback = Some("Heuristic detector active (Burn unavailable)".into());
390 } else {
391 overlay.fallback = None;
392 }
393 overlay.inference_ms = Some(infer_ms);
394 overlay.boxes = result.boxes.clone();
395 overlay.scores = result.scores.clone();
396 overlay.size = size;
397 jobs.last_result = Some(BurnDetectionResult {
398 frame_id: result.frame_id,
399 positive: result.positive,
400 confidence: result.confidence,
401 boxes: result.boxes,
402 scores: result.scores,
403 });
404 } else {
405 jobs.pending = Some(task);
407 }
408}