1use bevy::asset::RenderAssetUsages;
16use bevy::prelude::*;
17use bevy::render::gpu_readback::{Readback, ReadbackComplete};
18use bevy::render::render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages};
19use bevy::tasks::{AsyncComputeTaskPool, Task};
20use bevy_camera::{ImageRenderTarget, RenderTarget};
21use futures_lite::future::{block_on, poll_once};
22use image::RgbaImage;
23use inference::InferenceThresholds;
24use sim_core::{ModeSet, SimRunMode};
25use vision_core::capture::{PrimaryCaptureCamera, PrimaryCaptureReadback, PrimaryCaptureTarget};
26use vision_core::interfaces::{self, Frame};
27use vision_core::overlay::draw_rect;
28
29#[derive(Resource, Debug, Clone, Copy, Default)]
35pub struct InferenceThresholdsResource(pub InferenceThresholds);
36
37impl std::ops::Deref for InferenceThresholdsResource {
38 type Target = InferenceThresholds;
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44impl std::ops::DerefMut for InferenceThresholdsResource {
45 fn deref_mut(&mut self) -> &mut Self::Target {
46 &mut self.0
47 }
48}
49
50type InferenceJobResult = (
51 Box<dyn interfaces::Detector + Send + Sync>,
52 DetectorKind,
53 interfaces::DetectionResult,
54 f32,
55 (u32, u32),
56);
57
58#[derive(Clone)]
59pub struct PrimaryCameraFrame {
60 pub id: u64,
61 pub transform: GlobalTransform,
62 pub captured_at: f64,
63}
64
65#[derive(Resource, Default)]
66pub struct PrimaryCameraState {
67 pub active: bool,
68 pub last_transform: Option<GlobalTransform>,
69 pub frame_counter: u64,
70}
71
72#[derive(Resource, Default)]
73pub struct PrimaryCameraFrameBuffer {
74 pub latest: Option<PrimaryCameraFrame>,
75}
76
77#[derive(Resource, Default)]
79pub struct ModelLoadedFlag {
80 pub model_loaded: bool,
81}
82
83#[derive(Resource, Default, Clone)]
84pub struct DetectionOverlayState {
85 pub boxes: Vec<[f32; 4]>,
86 pub scores: Vec<f32>,
87 pub size: (u32, u32),
88 pub fallback: Option<String>,
89 pub inference_ms: Option<f32>,
90}
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq, Resource)]
93pub enum DetectorKind {
94 Burn,
95 Heuristic,
96}
97
98#[derive(Clone)]
103pub struct RuntimeDetectionResult {
104 pub frame_id: u64,
105 pub positive: bool,
106 pub confidence: f32,
107 pub boxes: Vec<[f32; 4]>,
108 pub scores: Vec<f32>,
109}
110
111#[derive(Resource)]
115pub struct AsyncInferenceState {
116 pub pending: Option<Task<InferenceJobResult>>,
117 pub last_result: Option<RuntimeDetectionResult>,
118 pub debounce: Timer,
119}
120
121impl Default for AsyncInferenceState {
122 fn default() -> Self {
123 Self {
124 pending: None,
125 last_result: None,
126 debounce: Timer::from_seconds(0.18, TimerMode::Repeating),
127 }
128 }
129}
130
131#[derive(Resource)]
132pub struct DetectorHandle {
133 pub detector: Box<dyn interfaces::Detector + Send + Sync>,
134 pub kind: DetectorKind,
135}
136
137struct DefaultTestDetector;
138
139impl interfaces::Detector for DefaultTestDetector {
140 fn detect(&mut self, frame: &Frame) -> interfaces::DetectionResult {
141 interfaces::DetectionResult {
142 frame_id: frame.id,
143 positive: true,
144 confidence: 0.8,
145 boxes: Vec::new(),
146 scores: Vec::new(),
147 }
148 }
149}
150
151pub fn setup_primary_capture(
154 mut commands: Commands,
155 mut images: ResMut<Assets<Image>>,
156 mut state: ResMut<PrimaryCameraState>,
157 mut target: ResMut<PrimaryCaptureTarget>,
158) {
159 if target.size != UVec2::ZERO {
161 return;
162 }
163
164 let size = UVec2::new(1280, 720);
165 let mut image = Image::new_fill(
166 Extent3d {
167 width: size.x,
168 height: size.y,
169 ..default()
170 },
171 TextureDimension::D2,
172 &[0, 0, 0, 255],
173 TextureFormat::Rgba8UnormSrgb,
174 RenderAssetUsages::default(),
175 );
176 image.texture_descriptor.usage =
177 TextureUsages::COPY_SRC | TextureUsages::TEXTURE_BINDING | TextureUsages::RENDER_ATTACHMENT;
178 let handle = images.add(image);
179
180 let cam_entity = commands
181 .spawn((
182 Camera3d::default(),
183 Camera {
184 order: -10,
185 is_active: true,
186 target: RenderTarget::Image(ImageRenderTarget::from(handle.clone())),
187 ..default()
188 },
189 Projection::from(PerspectiveProjection {
190 fov: 20.0f32.to_radians(),
191 ..default()
192 }),
193 Transform::from_translation(Vec3::ZERO),
194 GlobalTransform::default(),
195 Visibility::default(),
196 InheritedVisibility::default(),
197 ViewVisibility::default(),
198 PrimaryCaptureCamera,
199 Name::new("PrimaryCaptureCamera"),
200 ))
201 .id();
202
203 target.size = size;
204 target.handle = handle;
205 target.entity = cam_entity;
206 state.active = true;
207}
208
209pub fn track_primary_camera_state(
210 target: Res<PrimaryCaptureTarget>,
211 mut state: ResMut<PrimaryCameraState>,
212 mut buffer: ResMut<PrimaryCameraFrameBuffer>,
213 cameras: Query<&GlobalTransform, With<PrimaryCaptureCamera>>,
214 time: Res<Time>,
215) {
216 let Ok(transform) = cameras.get(target.entity) else {
217 return;
218 };
219 state.last_transform = Some(*transform);
220 state.frame_counter = state.frame_counter.wrapping_add(1);
221 buffer.latest = Some(PrimaryCameraFrame {
222 id: state.frame_counter,
223 transform: *transform,
224 captured_at: time.elapsed_secs_f64(),
225 });
226}
227
228pub fn capture_primary_camera_frame(
229 mode: Res<SimRunMode>,
230 mut commands: Commands,
231 target: Res<PrimaryCaptureTarget>,
232) {
233 if !matches!(*mode, SimRunMode::Datagen | SimRunMode::Inference) {
234 return;
235 }
236 commands
237 .entity(target.entity)
238 .insert(Readback::texture(target.handle.clone()));
239}
240
241pub fn on_primary_capture_readback(
242 ev: On<ReadbackComplete>,
243 target: Res<PrimaryCaptureTarget>,
244 mut readback: ResMut<PrimaryCaptureReadback>,
245) {
246 let expected_len = (target.size.x * target.size.y * 4) as usize;
247 let ev = ev.event();
248 if ev.entity != target.entity {
249 return;
250 }
251 if ev.data.len() == expected_len {
252 readback.latest = Some(ev.data.clone());
253 }
254}
255
256pub struct CapturePlugin;
257
258impl Plugin for CapturePlugin {
259 fn build(&self, app: &mut App) {
260 app.insert_resource(PrimaryCaptureTarget {
261 handle: Handle::default(),
262 size: UVec2::ZERO,
263 entity: Entity::PLACEHOLDER,
264 })
265 .init_resource::<PrimaryCaptureReadback>()
266 .init_resource::<PrimaryCameraState>()
267 .init_resource::<PrimaryCameraFrameBuffer>()
268 .add_systems(Startup, setup_primary_capture)
269 .add_systems(Update, track_primary_camera_state.in_set(ModeSet::Common))
270 .add_systems(Update, capture_primary_camera_frame.in_set(ModeSet::Common))
271 .add_observer(on_primary_capture_readback);
272 }
273}
274
275pub fn schedule_burn_inference(
278 mode: Res<SimRunMode>,
279 time: Res<Time>,
280 mut jobs: ResMut<AsyncInferenceState>,
281 mut buffer: ResMut<PrimaryCameraFrameBuffer>,
282 handle: Option<ResMut<DetectorHandle>>,
283 target: Res<PrimaryCaptureTarget>,
284 mut readback: ResMut<PrimaryCaptureReadback>,
285) {
286 if !matches!(*mode, SimRunMode::Inference) {
287 return;
288 }
289 let Some(mut handle) = handle else {
290 return;
291 };
292
293 jobs.debounce.tick(time.delta());
294 if jobs.pending.is_some() || !jobs.debounce.is_finished() {
295 return;
296 }
297 let Some(frame) = buffer.latest.take() else {
298 return;
299 };
300
301 let rgba = readback.latest.take();
302 let start = std::time::Instant::now();
303 let f = Frame {
304 id: frame.id,
305 timestamp: frame.captured_at,
306 rgba,
307 size: (target.size.x, target.size.y),
308 path: None,
309 };
310 let mut detector = std::mem::replace(&mut handle.detector, Box::new(DefaultTestDetector));
311 let kind = handle.kind;
312 let size = (target.size.x, target.size.y);
313 let task = AsyncComputeTaskPool::get().spawn(async move {
314 let result = detector.detect(&f);
315 let infer_ms = start.elapsed().as_secs_f32() * 1000.0;
316 (detector, kind, result, infer_ms, size)
317 });
318 jobs.pending = Some(task);
319}
320
321pub fn threshold_hotkeys(
322 mode: Res<SimRunMode>,
323 keys: Res<ButtonInput<KeyCode>>,
324 thresh: Option<ResMut<InferenceThresholdsResource>>,
325 handle: Option<ResMut<DetectorHandle>>,
326 burn_loaded: Option<ResMut<ModelLoadedFlag>>,
327) {
328 if !matches!(*mode, SimRunMode::Inference) {
329 return;
330 }
331 let (Some(mut thresh), Some(mut handle)) = (thresh, handle) else {
332 return;
333 };
334 let Some(mut burn_loaded) = burn_loaded else {
335 return;
336 };
337
338 let mut changed = false;
339 if keys.just_pressed(KeyCode::Minus) {
340 thresh.objectness_threshold = (thresh.objectness_threshold - 0.05).clamp(0.0, 1.0);
341 changed = true;
342 }
343 if keys.just_pressed(KeyCode::Equal) {
344 thresh.objectness_threshold = (thresh.objectness_threshold + 0.05).clamp(0.0, 1.0);
345 changed = true;
346 }
347 if keys.just_pressed(KeyCode::BracketLeft) {
348 thresh.iou_threshold = (thresh.iou_threshold - 0.05).clamp(0.1, 0.95);
349 changed = true;
350 }
351 if keys.just_pressed(KeyCode::BracketRight) {
352 thresh.iou_threshold = (thresh.iou_threshold + 0.05).clamp(0.1, 0.95);
353 changed = true;
354 }
355
356 if keys.just_pressed(KeyCode::Digit0) {
357 handle.detector = Box::new(DefaultTestDetector);
358 handle.kind = DetectorKind::Heuristic;
359 burn_loaded.model_loaded = false;
360 changed = true;
361 }
362
363 if changed {
364 info!(
365 "Updated inference thresholds: obj {:.2}, iou {:.2}",
366 thresh.objectness_threshold, thresh.iou_threshold
367 );
368 }
369}
370
371pub struct InferenceRuntimePlugin;
377
378impl Plugin for InferenceRuntimePlugin {
379 fn build(&self, app: &mut App) {
380 app.init_resource::<AsyncInferenceState>()
381 .init_resource::<ModelLoadedFlag>()
382 .init_resource::<DetectionOverlayState>()
383 .add_systems(
384 Update,
385 (
386 schedule_burn_inference,
387 poll_inference_task,
388 threshold_hotkeys,
389 )
390 .in_set(ModeSet::Inference),
391 );
392 }
393}
394
395pub fn recorder_draw_rect(
398 img: &mut RgbaImage,
399 bbox_px: [u32; 4],
400 color: image::Rgba<u8>,
401 thickness: u32,
402) {
403 draw_rect(img, bbox_px, color, thickness);
404}
405
406pub mod prelude {
407 pub use super::{
408 AsyncInferenceState, CapturePlugin, DetectionOverlayState, DetectorHandle, DetectorKind,
409 InferenceRuntimePlugin, InferenceThresholdsResource, ModelLoadedFlag, PrimaryCameraFrame,
410 PrimaryCameraFrameBuffer, PrimaryCameraState, RuntimeDetectionResult,
411 };
412}
413pub fn poll_inference_task(
414 mut jobs: ResMut<AsyncInferenceState>,
415 mut overlay: ResMut<DetectionOverlayState>,
416 handle: Option<ResMut<DetectorHandle>>,
417 mut burn_detector: ResMut<ModelLoadedFlag>,
418) {
419 let Some(mut task) = jobs.pending.take() else {
420 return;
421 };
422 if let Some((detector, kind, result, infer_ms, size)) = block_on(poll_once(&mut task)) {
423 if let Some(mut handle) = handle {
424 handle.detector = detector;
425 handle.kind = kind;
426 }
427 burn_detector.model_loaded = matches!(kind, DetectorKind::Burn);
428 if matches!(kind, DetectorKind::Heuristic) {
429 overlay.fallback = Some("Heuristic detector active (Burn unavailable)".into());
430 } else {
431 overlay.fallback = None;
432 }
433 overlay.inference_ms = Some(infer_ms);
434 overlay.boxes = result.boxes.clone();
435 overlay.scores = result.scores.clone();
436 overlay.size = size;
437 jobs.last_result = Some(RuntimeDetectionResult {
438 frame_id: result.frame_id,
439 positive: result.positive,
440 confidence: result.confidence,
441 boxes: result.boxes,
442 scores: result.scores,
443 });
444 } else {
445 jobs.pending = Some(task);
447 }
448}