vision_runtime/
lib.rs

1//! Bevy-integrated vision runtime: camera capture, async inference, and detection overlays.
2//!
3//! This crate provides Bevy plugins and systems for:
4//! - Off-screen camera capture with GPU readback (`CapturePlugin`).
5//! - Async inference scheduling and result polling (`InferenceRuntimePlugin`).
6//! - Detection overlay rendering and threshold hotkeys.
7//!
8//! It bridges the framework-agnostic `vision_core` interfaces with Bevy ECS resources and
9//! systems, enabling real-time detector integration in Bevy apps.
10//!
11//! ## Architecture Note
12//! This is the *runtime* layer. Core detector interfaces live in `vision_core`, while
13//! detector implementations are in `inference`. See the architecture docs for the full split.
14
15use bevy::asset::RenderAssetUsages;
16use bevy::prelude::*;
17use bevy::render::gpu_readback::{Readback, ReadbackComplete};
18use bevy::render::render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages};
19use bevy::tasks::{AsyncComputeTaskPool, Task};
20use bevy_camera::{ImageRenderTarget, RenderTarget};
21use futures_lite::future::{block_on, poll_once};
22use image::RgbaImage;
23use inference::InferenceThresholds;
24use sim_core::{ModeSet, SimRunMode};
25use vision_core::capture::{PrimaryCaptureCamera, PrimaryCaptureReadback, PrimaryCaptureTarget};
26use vision_core::interfaces::{self, Frame};
27use vision_core::overlay::draw_rect;
28
29/// Bevy resource wrapper for inference thresholds.
30///
31/// This bridges the framework-agnostic `inference` crate with Bevy ECS.
32/// The inner `InferenceThresholds` type can be used in non-Bevy contexts
33/// (CLI tools, web services, etc.) without pulling in Bevy dependencies.
34#[derive(Resource, Debug, Clone, Copy, Default)]
35pub struct InferenceThresholdsResource(pub InferenceThresholds);
36
37impl std::ops::Deref for InferenceThresholdsResource {
38    type Target = InferenceThresholds;
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44impl std::ops::DerefMut for InferenceThresholdsResource {
45    fn deref_mut(&mut self) -> &mut Self::Target {
46        &mut self.0
47    }
48}
49
50type InferenceJobResult = (
51    Box<dyn interfaces::Detector + Send + Sync>,
52    DetectorKind,
53    interfaces::DetectionResult,
54    f32,
55    (u32, u32),
56);
57
58#[derive(Clone)]
59pub struct PrimaryCameraFrame {
60    pub id: u64,
61    pub transform: GlobalTransform,
62    pub captured_at: f64,
63}
64
65#[derive(Resource, Default)]
66pub struct PrimaryCameraState {
67    pub active: bool,
68    pub last_transform: Option<GlobalTransform>,
69    pub frame_counter: u64,
70}
71
72#[derive(Resource, Default)]
73pub struct PrimaryCameraFrameBuffer {
74    pub latest: Option<PrimaryCameraFrame>,
75}
76
77/// Resource tracking whether a model detector is loaded (vs. heuristic fallback).
78#[derive(Resource, Default)]
79pub struct ModelLoadedFlag {
80    pub model_loaded: bool,
81}
82
83#[derive(Resource, Default, Clone)]
84pub struct DetectionOverlayState {
85    pub boxes: Vec<[f32; 4]>,
86    pub scores: Vec<f32>,
87    pub size: (u32, u32),
88    pub fallback: Option<String>,
89    pub inference_ms: Option<f32>,
90}
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq, Resource)]
93pub enum DetectorKind {
94    Burn,
95    Heuristic,
96}
97
98/// Detection result from vision_runtime async inference.
99///
100/// This is a runtime-specific result type that aggregates detection output.
101/// Not to be confused with `vision_core::interfaces::DetectionResult`.
102#[derive(Clone)]
103pub struct RuntimeDetectionResult {
104    pub frame_id: u64,
105    pub positive: bool,
106    pub confidence: f32,
107    pub boxes: Vec<[f32; 4]>,
108    pub scores: Vec<f32>,
109}
110
111/// Resource managing async inference task state.
112///
113/// Tracks pending async inference jobs, debouncing, and the most recent result.
114#[derive(Resource)]
115pub struct AsyncInferenceState {
116    pub pending: Option<Task<InferenceJobResult>>,
117    pub last_result: Option<RuntimeDetectionResult>,
118    pub debounce: Timer,
119}
120
121impl Default for AsyncInferenceState {
122    fn default() -> Self {
123        Self {
124            pending: None,
125            last_result: None,
126            debounce: Timer::from_seconds(0.18, TimerMode::Repeating),
127        }
128    }
129}
130
131#[derive(Resource)]
132pub struct DetectorHandle {
133    pub detector: Box<dyn interfaces::Detector + Send + Sync>,
134    pub kind: DetectorKind,
135}
136
137struct DefaultTestDetector;
138
139impl interfaces::Detector for DefaultTestDetector {
140    fn detect(&mut self, frame: &Frame) -> interfaces::DetectionResult {
141        interfaces::DetectionResult {
142            frame_id: frame.id,
143            positive: true,
144            confidence: 0.8,
145            boxes: Vec::new(),
146            scores: Vec::new(),
147        }
148    }
149}
150
151// Capture setup/readback -----------------------------------------------------
152
153pub fn setup_primary_capture(
154    mut commands: Commands,
155    mut images: ResMut<Assets<Image>>,
156    mut state: ResMut<PrimaryCameraState>,
157    mut target: ResMut<PrimaryCaptureTarget>,
158) {
159    // Only set up once.
160    if target.size != UVec2::ZERO {
161        return;
162    }
163
164    let size = UVec2::new(1280, 720);
165    let mut image = Image::new_fill(
166        Extent3d {
167            width: size.x,
168            height: size.y,
169            ..default()
170        },
171        TextureDimension::D2,
172        &[0, 0, 0, 255],
173        TextureFormat::Rgba8UnormSrgb,
174        RenderAssetUsages::default(),
175    );
176    image.texture_descriptor.usage =
177        TextureUsages::COPY_SRC | TextureUsages::TEXTURE_BINDING | TextureUsages::RENDER_ATTACHMENT;
178    let handle = images.add(image);
179
180    let cam_entity = commands
181        .spawn((
182            Camera3d::default(),
183            Camera {
184                order: -10,
185                is_active: true,
186                target: RenderTarget::Image(ImageRenderTarget::from(handle.clone())),
187                ..default()
188            },
189            Projection::from(PerspectiveProjection {
190                fov: 20.0f32.to_radians(),
191                ..default()
192            }),
193            Transform::from_translation(Vec3::ZERO),
194            GlobalTransform::default(),
195            Visibility::default(),
196            InheritedVisibility::default(),
197            ViewVisibility::default(),
198            PrimaryCaptureCamera,
199            Name::new("PrimaryCaptureCamera"),
200        ))
201        .id();
202
203    target.size = size;
204    target.handle = handle;
205    target.entity = cam_entity;
206    state.active = true;
207}
208
209pub fn track_primary_camera_state(
210    target: Res<PrimaryCaptureTarget>,
211    mut state: ResMut<PrimaryCameraState>,
212    mut buffer: ResMut<PrimaryCameraFrameBuffer>,
213    cameras: Query<&GlobalTransform, With<PrimaryCaptureCamera>>,
214    time: Res<Time>,
215) {
216    let Ok(transform) = cameras.get(target.entity) else {
217        return;
218    };
219    state.last_transform = Some(*transform);
220    state.frame_counter = state.frame_counter.wrapping_add(1);
221    buffer.latest = Some(PrimaryCameraFrame {
222        id: state.frame_counter,
223        transform: *transform,
224        captured_at: time.elapsed_secs_f64(),
225    });
226}
227
228pub fn capture_primary_camera_frame(
229    mode: Res<SimRunMode>,
230    mut commands: Commands,
231    target: Res<PrimaryCaptureTarget>,
232) {
233    if !matches!(*mode, SimRunMode::Datagen | SimRunMode::Inference) {
234        return;
235    }
236    commands
237        .entity(target.entity)
238        .insert(Readback::texture(target.handle.clone()));
239}
240
241pub fn on_primary_capture_readback(
242    ev: On<ReadbackComplete>,
243    target: Res<PrimaryCaptureTarget>,
244    mut readback: ResMut<PrimaryCaptureReadback>,
245) {
246    let expected_len = (target.size.x * target.size.y * 4) as usize;
247    let ev = ev.event();
248    if ev.entity != target.entity {
249        return;
250    }
251    if ev.data.len() == expected_len {
252        readback.latest = Some(ev.data.clone());
253    }
254}
255
256pub struct CapturePlugin;
257
258impl Plugin for CapturePlugin {
259    fn build(&self, app: &mut App) {
260        app.insert_resource(PrimaryCaptureTarget {
261            handle: Handle::default(),
262            size: UVec2::ZERO,
263            entity: Entity::PLACEHOLDER,
264        })
265        .init_resource::<PrimaryCaptureReadback>()
266        .init_resource::<PrimaryCameraState>()
267        .init_resource::<PrimaryCameraFrameBuffer>()
268        .add_systems(Startup, setup_primary_capture)
269        .add_systems(Update, track_primary_camera_state.in_set(ModeSet::Common))
270        .add_systems(Update, capture_primary_camera_frame.in_set(ModeSet::Common))
271        .add_observer(on_primary_capture_readback);
272    }
273}
274
275// Inference ---------------------------------------------------------------
276
277pub fn schedule_burn_inference(
278    mode: Res<SimRunMode>,
279    time: Res<Time>,
280    mut jobs: ResMut<AsyncInferenceState>,
281    mut buffer: ResMut<PrimaryCameraFrameBuffer>,
282    handle: Option<ResMut<DetectorHandle>>,
283    target: Res<PrimaryCaptureTarget>,
284    mut readback: ResMut<PrimaryCaptureReadback>,
285) {
286    if !matches!(*mode, SimRunMode::Inference) {
287        return;
288    }
289    let Some(mut handle) = handle else {
290        return;
291    };
292
293    jobs.debounce.tick(time.delta());
294    if jobs.pending.is_some() || !jobs.debounce.is_finished() {
295        return;
296    }
297    let Some(frame) = buffer.latest.take() else {
298        return;
299    };
300
301    let rgba = readback.latest.take();
302    let start = std::time::Instant::now();
303    let f = Frame {
304        id: frame.id,
305        timestamp: frame.captured_at,
306        rgba,
307        size: (target.size.x, target.size.y),
308        path: None,
309    };
310    let mut detector = std::mem::replace(&mut handle.detector, Box::new(DefaultTestDetector));
311    let kind = handle.kind;
312    let size = (target.size.x, target.size.y);
313    let task = AsyncComputeTaskPool::get().spawn(async move {
314        let result = detector.detect(&f);
315        let infer_ms = start.elapsed().as_secs_f32() * 1000.0;
316        (detector, kind, result, infer_ms, size)
317    });
318    jobs.pending = Some(task);
319}
320
321pub fn threshold_hotkeys(
322    mode: Res<SimRunMode>,
323    keys: Res<ButtonInput<KeyCode>>,
324    thresh: Option<ResMut<InferenceThresholdsResource>>,
325    handle: Option<ResMut<DetectorHandle>>,
326    burn_loaded: Option<ResMut<ModelLoadedFlag>>,
327) {
328    if !matches!(*mode, SimRunMode::Inference) {
329        return;
330    }
331    let (Some(mut thresh), Some(mut handle)) = (thresh, handle) else {
332        return;
333    };
334    let Some(mut burn_loaded) = burn_loaded else {
335        return;
336    };
337
338    let mut changed = false;
339    if keys.just_pressed(KeyCode::Minus) {
340        thresh.objectness_threshold = (thresh.objectness_threshold - 0.05).clamp(0.0, 1.0);
341        changed = true;
342    }
343    if keys.just_pressed(KeyCode::Equal) {
344        thresh.objectness_threshold = (thresh.objectness_threshold + 0.05).clamp(0.0, 1.0);
345        changed = true;
346    }
347    if keys.just_pressed(KeyCode::BracketLeft) {
348        thresh.iou_threshold = (thresh.iou_threshold - 0.05).clamp(0.1, 0.95);
349        changed = true;
350    }
351    if keys.just_pressed(KeyCode::BracketRight) {
352        thresh.iou_threshold = (thresh.iou_threshold + 0.05).clamp(0.1, 0.95);
353        changed = true;
354    }
355
356    if keys.just_pressed(KeyCode::Digit0) {
357        handle.detector = Box::new(DefaultTestDetector);
358        handle.kind = DetectorKind::Heuristic;
359        burn_loaded.model_loaded = false;
360        changed = true;
361    }
362
363    if changed {
364        info!(
365            "Updated inference thresholds: obj {:.2}, iou {:.2}",
366            thresh.objectness_threshold, thresh.iou_threshold
367        );
368    }
369}
370
371/// Bevy plugin managing runtime inference coordination.
372///
373/// Handles async inference scheduling, model state tracking, detection overlays,
374/// and threshold adjustment hotkeys. This is the runtime/visualization layer;
375/// for core inference logic, see the `inference` crate.
376pub struct InferenceRuntimePlugin;
377
378impl Plugin for InferenceRuntimePlugin {
379    fn build(&self, app: &mut App) {
380        app.init_resource::<AsyncInferenceState>()
381            .init_resource::<ModelLoadedFlag>()
382            .init_resource::<DetectionOverlayState>()
383            .add_systems(
384                Update,
385                (
386                    schedule_burn_inference,
387                    poll_inference_task,
388                    threshold_hotkeys,
389                )
390                    .in_set(ModeSet::Inference),
391            );
392    }
393}
394
395// Overlay helpers (draw run overlays)
396
397pub fn recorder_draw_rect(
398    img: &mut RgbaImage,
399    bbox_px: [u32; 4],
400    color: image::Rgba<u8>,
401    thickness: u32,
402) {
403    draw_rect(img, bbox_px, color, thickness);
404}
405
406pub mod prelude {
407    pub use super::{
408        AsyncInferenceState, CapturePlugin, DetectionOverlayState, DetectorHandle, DetectorKind,
409        InferenceRuntimePlugin, InferenceThresholdsResource, ModelLoadedFlag, PrimaryCameraFrame,
410        PrimaryCameraFrameBuffer, PrimaryCameraState, RuntimeDetectionResult,
411    };
412}
413pub fn poll_inference_task(
414    mut jobs: ResMut<AsyncInferenceState>,
415    mut overlay: ResMut<DetectionOverlayState>,
416    handle: Option<ResMut<DetectorHandle>>,
417    mut burn_detector: ResMut<ModelLoadedFlag>,
418) {
419    let Some(mut task) = jobs.pending.take() else {
420        return;
421    };
422    if let Some((detector, kind, result, infer_ms, size)) = block_on(poll_once(&mut task)) {
423        if let Some(mut handle) = handle {
424            handle.detector = detector;
425            handle.kind = kind;
426        }
427        burn_detector.model_loaded = matches!(kind, DetectorKind::Burn);
428        if matches!(kind, DetectorKind::Heuristic) {
429            overlay.fallback = Some("Heuristic detector active (Burn unavailable)".into());
430        } else {
431            overlay.fallback = None;
432        }
433        overlay.inference_ms = Some(infer_ms);
434        overlay.boxes = result.boxes.clone();
435        overlay.scores = result.scores.clone();
436        overlay.size = size;
437        jobs.last_result = Some(RuntimeDetectionResult {
438            frame_id: result.frame_id,
439            positive: result.positive,
440            confidence: result.confidence,
441            boxes: result.boxes,
442            scores: result.scores,
443        });
444    } else {
445        // Task not finished; put it back.
446        jobs.pending = Some(task);
447    }
448}