vision_runtime/
lib.rs

1use bevy::asset::RenderAssetUsages;
2use bevy::prelude::*;
3use bevy::render::gpu_readback::{Readback, ReadbackComplete};
4use bevy::render::render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages};
5use bevy::tasks::{AsyncComputeTaskPool, Task};
6use bevy_camera::{ImageRenderTarget, RenderTarget};
7use futures_lite::future::{block_on, poll_once};
8use image::RgbaImage;
9use sim_core::{ModeSet, SimRunMode};
10use vision_core::capture::{
11    FrontCamera, FrontCaptureCamera, FrontCaptureReadback, FrontCaptureTarget,
12};
13use vision_core::interfaces::{self, Frame};
14use vision_core::overlay::draw_rect;
15
16type InferenceJobResult = (
17    Box<dyn interfaces::Detector + Send + Sync>,
18    DetectorKind,
19    interfaces::DetectionResult,
20    f32,
21    (u32, u32),
22);
23
24#[derive(Clone)]
25pub struct FrontCameraFrame {
26    pub id: u64,
27    pub transform: GlobalTransform,
28    pub captured_at: f64,
29}
30
31#[derive(Resource, Default)]
32pub struct FrontCameraState {
33    pub active: bool,
34    pub last_transform: Option<GlobalTransform>,
35    pub frame_counter: u64,
36}
37
38#[derive(Resource, Default)]
39pub struct FrontCameraFrameBuffer {
40    pub latest: Option<FrontCameraFrame>,
41}
42
43#[derive(Resource, Default)]
44pub struct BurnDetector {
45    pub model_loaded: bool,
46}
47
48#[derive(Resource, Default, Clone)]
49pub struct DetectionOverlayState {
50    pub boxes: Vec<[f32; 4]>,
51    pub scores: Vec<f32>,
52    pub size: (u32, u32),
53    pub fallback: Option<String>,
54    pub inference_ms: Option<f32>,
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq, Resource)]
58pub enum DetectorKind {
59    Burn,
60    Heuristic,
61}
62
63#[derive(Resource, Debug, Clone, Copy)]
64pub struct InferenceThresholds {
65    pub obj_thresh: f32,
66    pub iou_thresh: f32,
67}
68
69#[derive(Clone)]
70pub struct BurnDetectionResult {
71    pub frame_id: u64,
72    pub positive: bool,
73    pub confidence: f32,
74    pub boxes: Vec<[f32; 4]>,
75    pub scores: Vec<f32>,
76}
77
78#[derive(Resource)]
79pub struct BurnInferenceState {
80    pub pending: Option<Task<InferenceJobResult>>,
81    pub last_result: Option<BurnDetectionResult>,
82    pub debounce: Timer,
83}
84
85impl Default for BurnInferenceState {
86    fn default() -> Self {
87        Self {
88            pending: None,
89            last_result: None,
90            debounce: Timer::from_seconds(0.18, TimerMode::Repeating),
91        }
92    }
93}
94
95#[derive(Resource)]
96pub struct DetectorHandle {
97    pub detector: Box<dyn interfaces::Detector + Send + Sync>,
98    pub kind: DetectorKind,
99}
100
101struct HeuristicDetector;
102
103impl interfaces::Detector for HeuristicDetector {
104    fn detect(&mut self, frame: &Frame) -> interfaces::DetectionResult {
105        interfaces::DetectionResult {
106            frame_id: frame.id,
107            positive: true,
108            confidence: 0.8,
109            boxes: Vec::new(),
110            scores: Vec::new(),
111        }
112    }
113}
114
115// Capture setup/readback -----------------------------------------------------
116
117pub fn setup_front_capture(
118    mut commands: Commands,
119    mut images: ResMut<Assets<Image>>,
120    mut state: ResMut<FrontCameraState>,
121    mut target: ResMut<FrontCaptureTarget>,
122) {
123    // Only set up once.
124    if target.size != UVec2::ZERO {
125        return;
126    }
127
128    let size = UVec2::new(1280, 720);
129    let mut image = Image::new_fill(
130        Extent3d {
131            width: size.x,
132            height: size.y,
133            ..default()
134        },
135        TextureDimension::D2,
136        &[0, 0, 0, 255],
137        TextureFormat::Rgba8UnormSrgb,
138        RenderAssetUsages::default(),
139    );
140    image.texture_descriptor.usage =
141        TextureUsages::COPY_SRC | TextureUsages::TEXTURE_BINDING | TextureUsages::RENDER_ATTACHMENT;
142    let handle = images.add(image);
143
144    let cam_entity = commands
145        .spawn((
146            Camera3d::default(),
147            Camera {
148                order: -10,
149                is_active: true,
150                target: RenderTarget::Image(ImageRenderTarget::from(handle.clone())),
151                ..default()
152            },
153            Projection::from(PerspectiveProjection {
154                fov: 20.0f32.to_radians(),
155                ..default()
156            }),
157            Transform::from_translation(Vec3::ZERO),
158            GlobalTransform::default(),
159            Visibility::default(),
160            InheritedVisibility::default(),
161            ViewVisibility::default(),
162            FrontCamera,
163            FrontCaptureCamera,
164            Name::new("FrontCaptureCamera"),
165        ))
166        .id();
167
168    target.size = size;
169    target.handle = handle;
170    target.entity = cam_entity;
171    state.active = true;
172}
173
174pub fn track_front_camera_state(
175    target: Res<FrontCaptureTarget>,
176    mut state: ResMut<FrontCameraState>,
177    mut buffer: ResMut<FrontCameraFrameBuffer>,
178    cameras: Query<&GlobalTransform, With<FrontCaptureCamera>>,
179    time: Res<Time>,
180) {
181    let Ok(transform) = cameras.get(target.entity) else {
182        return;
183    };
184    state.last_transform = Some(*transform);
185    state.frame_counter = state.frame_counter.wrapping_add(1);
186    buffer.latest = Some(FrontCameraFrame {
187        id: state.frame_counter,
188        transform: *transform,
189        captured_at: time.elapsed_secs_f64(),
190    });
191}
192
193pub fn capture_front_camera_frame(
194    mode: Res<SimRunMode>,
195    mut commands: Commands,
196    target: Res<FrontCaptureTarget>,
197) {
198    if !matches!(*mode, SimRunMode::Datagen | SimRunMode::Inference) {
199        return;
200    }
201    commands
202        .entity(target.entity)
203        .insert(Readback::texture(target.handle.clone()));
204}
205
206pub fn on_front_capture_readback(
207    ev: On<ReadbackComplete>,
208    target: Res<FrontCaptureTarget>,
209    mut readback: ResMut<FrontCaptureReadback>,
210) {
211    let expected_len = (target.size.x * target.size.y * 4) as usize;
212    let ev = ev.event();
213    if ev.entity != target.entity {
214        return;
215    }
216    if ev.data.len() == expected_len {
217        readback.latest = Some(ev.data.clone());
218    }
219}
220
221pub struct CapturePlugin;
222
223impl Plugin for CapturePlugin {
224    fn build(&self, app: &mut App) {
225        app.insert_resource(FrontCaptureTarget {
226            handle: Handle::default(),
227            size: UVec2::ZERO,
228            entity: Entity::PLACEHOLDER,
229        })
230        .init_resource::<FrontCaptureReadback>()
231        .init_resource::<FrontCameraState>()
232        .init_resource::<FrontCameraFrameBuffer>()
233        .add_systems(Startup, setup_front_capture)
234        .add_systems(Update, track_front_camera_state.in_set(ModeSet::Common))
235        .add_systems(Update, capture_front_camera_frame.in_set(ModeSet::Common))
236        .add_observer(on_front_capture_readback);
237    }
238}
239
240// Inference ---------------------------------------------------------------
241
242pub fn schedule_burn_inference(
243    mode: Res<SimRunMode>,
244    time: Res<Time>,
245    mut jobs: ResMut<BurnInferenceState>,
246    mut buffer: ResMut<FrontCameraFrameBuffer>,
247    handle: Option<ResMut<DetectorHandle>>,
248    target: Res<FrontCaptureTarget>,
249    mut readback: ResMut<FrontCaptureReadback>,
250) {
251    if !matches!(*mode, SimRunMode::Inference) {
252        return;
253    }
254    let Some(mut handle) = handle else {
255        return;
256    };
257
258    jobs.debounce.tick(time.delta());
259    if jobs.pending.is_some() || !jobs.debounce.is_finished() {
260        return;
261    }
262    let Some(frame) = buffer.latest.take() else {
263        return;
264    };
265
266    let rgba = readback.latest.take();
267    let start = std::time::Instant::now();
268    let f = Frame {
269        id: frame.id,
270        timestamp: frame.captured_at,
271        rgba,
272        size: (target.size.x, target.size.y),
273        path: None,
274    };
275    let mut detector = std::mem::replace(&mut handle.detector, Box::new(HeuristicDetector));
276    let kind = handle.kind;
277    let size = (target.size.x, target.size.y);
278    let task = AsyncComputeTaskPool::get().spawn(async move {
279        let result = detector.detect(&f);
280        let infer_ms = start.elapsed().as_secs_f32() * 1000.0;
281        (detector, kind, result, infer_ms, size)
282    });
283    jobs.pending = Some(task);
284}
285
286pub fn threshold_hotkeys(
287    mode: Res<SimRunMode>,
288    keys: Res<ButtonInput<KeyCode>>,
289    thresh: Option<ResMut<InferenceThresholds>>,
290    handle: Option<ResMut<DetectorHandle>>,
291    burn_loaded: Option<ResMut<BurnDetector>>,
292) {
293    if !matches!(*mode, SimRunMode::Inference) {
294        return;
295    }
296    let (Some(mut thresh), Some(mut handle)) = (thresh, handle) else {
297        return;
298    };
299    let Some(mut burn_loaded) = burn_loaded else {
300        return;
301    };
302
303    let mut changed = false;
304    if keys.just_pressed(KeyCode::Minus) {
305        thresh.obj_thresh = (thresh.obj_thresh - 0.05).clamp(0.0, 1.0);
306        changed = true;
307    }
308    if keys.just_pressed(KeyCode::Equal) {
309        thresh.obj_thresh = (thresh.obj_thresh + 0.05).clamp(0.0, 1.0);
310        changed = true;
311    }
312    if keys.just_pressed(KeyCode::BracketLeft) {
313        thresh.iou_thresh = (thresh.iou_thresh - 0.05).clamp(0.1, 0.95);
314        changed = true;
315    }
316    if keys.just_pressed(KeyCode::BracketRight) {
317        thresh.iou_thresh = (thresh.iou_thresh + 0.05).clamp(0.1, 0.95);
318        changed = true;
319    }
320
321    if keys.just_pressed(KeyCode::Digit0) {
322        handle.detector = Box::new(HeuristicDetector);
323        handle.kind = DetectorKind::Heuristic;
324        burn_loaded.model_loaded = false;
325        changed = true;
326    }
327
328    if changed {
329        info!(
330            "Updated inference thresholds: obj {:.2}, iou {:.2}",
331            thresh.obj_thresh, thresh.iou_thresh
332        );
333    }
334}
335
336pub struct InferencePlugin;
337
338impl Plugin for InferencePlugin {
339    fn build(&self, app: &mut App) {
340        app.init_resource::<BurnInferenceState>()
341            .init_resource::<BurnDetector>()
342            .init_resource::<DetectionOverlayState>()
343            .add_systems(
344                Update,
345                (
346                    schedule_burn_inference,
347                    poll_inference_task,
348                    threshold_hotkeys,
349                )
350                    .in_set(ModeSet::Inference),
351            );
352    }
353}
354
355// Overlay helpers (draw run overlays)
356
357pub fn recorder_draw_rect(
358    img: &mut RgbaImage,
359    bbox_px: [u32; 4],
360    color: image::Rgba<u8>,
361    thickness: u32,
362) {
363    draw_rect(img, bbox_px, color, thickness);
364}
365
366pub mod prelude {
367    pub use super::{
368        BurnDetectionResult, BurnDetector, BurnInferenceState, CapturePlugin,
369        DetectionOverlayState, DetectorHandle, DetectorKind, FrontCameraFrame,
370        FrontCameraFrameBuffer, FrontCameraState, InferencePlugin, InferenceThresholds,
371    };
372}
373pub fn poll_inference_task(
374    mut jobs: ResMut<BurnInferenceState>,
375    mut overlay: ResMut<DetectionOverlayState>,
376    handle: Option<ResMut<DetectorHandle>>,
377    mut burn_detector: ResMut<BurnDetector>,
378) {
379    let Some(mut task) = jobs.pending.take() else {
380        return;
381    };
382    if let Some((detector, kind, result, infer_ms, size)) = block_on(poll_once(&mut task)) {
383        if let Some(mut handle) = handle {
384            handle.detector = detector;
385            handle.kind = kind;
386        }
387        burn_detector.model_loaded = matches!(kind, DetectorKind::Burn);
388        if matches!(kind, DetectorKind::Heuristic) {
389            overlay.fallback = Some("Heuristic detector active (Burn unavailable)".into());
390        } else {
391            overlay.fallback = None;
392        }
393        overlay.inference_ms = Some(infer_ms);
394        overlay.boxes = result.boxes.clone();
395        overlay.scores = result.scores.clone();
396        overlay.size = size;
397        jobs.last_result = Some(BurnDetectionResult {
398            frame_id: result.frame_id,
399            positive: result.positive,
400            confidence: result.confidence,
401            boxes: result.boxes,
402            scores: result.scores,
403        });
404    } else {
405        // Task not finished; put it back.
406        jobs.pending = Some(task);
407    }
408}