Skip to main content

bevy_sensor/
batch.rs

1//! Batch rendering API for multiple viewpoints and objects.
2//!
3//! Today this module is a queue-oriented wrapper around sequential `render_to_buffer()`
4//! calls. It does not yet keep a persistent Bevy app alive across renders; that follow-up
5//! remains tracked work. The API is still useful for consumers that want ordered request
6//! management and structured batch outputs without promising reuse semantics that do not
7//! exist yet.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use bevy_sensor::{
13//!     create_batch_renderer, queue_render_request, render_next_in_batch,
14//!     batch::BatchRenderRequest, BatchRenderConfig, RenderConfig, ObjectRotation,
15//!     TargetingPolicy, Vec3,
16//! };
17//! use std::path::PathBuf;
18//!
19//! // Create a batch helper
20//! let config = BatchRenderConfig::default();
21//! let mut renderer = create_batch_renderer(&config)?;
22//!
23//! // Queue multiple renders
24//! for rotation in rotations {
25//!     for viewpoint in viewpoints {
26//!         queue_render_request(&mut renderer, BatchRenderRequest {
27//!             object_dir: "/tmp/ycb/003_cracker_box".into(),
28//!             viewpoint,
29//!             object_rotation: rotation.clone(),
30//!             render_config: RenderConfig::tbp_default(),
31//!             target_point: Vec3::ZERO,
32//!             targeting_policy: TargetingPolicy::Origin,
33//!         })?;
34//!     }
35//! }
36//!
37//! // Execute and collect results
38//! let mut results = Vec::new();
39//! loop {
40//!     match render_next_in_batch(&mut renderer, 500)? {
41//!         Some(output) => results.push(output),
42//!         None => break,
43//!     }
44//! }
45//! ```
46
47use crate::{
48    semantic_3d_from_depth, CameraIntrinsics, ObjectRotation, RenderConfig, RenderHealth,
49    RenderOutput, TargetingPolicy,
50};
51use bevy::prelude::{Transform, Vec3};
52use std::collections::VecDeque;
53use std::path::PathBuf;
54
55/// Configuration for batch rendering.
56#[derive(Clone, Debug)]
57pub struct BatchRenderConfig {
58    /// Maximum number of renders to queue before automatic cleanup
59    pub max_batch_size: usize,
60    /// Timeout in milliseconds per individual render
61    pub frame_timeout_ms: u32,
62    /// Enable depth buffer readback
63    pub enable_depth_readback: bool,
64    /// Enable asset caching for repeated objects
65    pub enable_asset_caching: bool,
66    /// Number of renders before triggering resource cleanup
67    pub resource_cleanup_interval: u32,
68}
69
70impl Default for BatchRenderConfig {
71    fn default() -> Self {
72        Self {
73            max_batch_size: 256,
74            frame_timeout_ms: 500,
75            enable_depth_readback: true,
76            enable_asset_caching: true,
77            resource_cleanup_interval: 32,
78        }
79    }
80}
81
82/// A single render request in a batch.
83#[derive(Clone, Debug)]
84pub struct BatchRenderRequest {
85    /// Path to YCB object directory (e.g., "/tmp/ycb/003_cracker_box")
86    pub object_dir: PathBuf,
87    /// Camera transform (position and orientation)
88    pub viewpoint: Transform,
89    /// Object rotation to apply
90    pub object_rotation: ObjectRotation,
91    /// Render configuration (resolution, lighting, etc.)
92    pub render_config: RenderConfig,
93    /// Point the camera was intended to target for this render.
94    pub target_point: Vec3,
95    /// Policy used to derive `target_point`.
96    pub targeting_policy: TargetingPolicy,
97}
98
99/// Status of a single render in a batch.
100#[derive(Clone, Debug, Copy, PartialEq, Eq)]
101pub enum RenderStatus {
102    /// Render completed successfully with RGBA and depth
103    Success,
104    /// Render completed but depth extraction failed
105    PartialFailure,
106    /// Render failed completely
107    Failed,
108}
109
110/// Output from a single render in a batch.
111#[derive(Clone, Debug)]
112pub struct BatchRenderOutput {
113    /// Original request for this render
114    pub request: BatchRenderRequest,
115    /// RGBA pixel data (width * height * 4 bytes, row-major)
116    pub rgba: Vec<u8>,
117    /// Depth data in meters (width * height f64s)
118    pub depth: Vec<f64>,
119    /// Image width in pixels
120    pub width: u32,
121    /// Image height in pixels
122    pub height: u32,
123    /// Camera intrinsics used
124    pub intrinsics: CameraIntrinsics,
125    /// Camera transform used for world-space depth unprojection.
126    pub camera_transform: Transform,
127    /// Point the camera was intended to target for this render.
128    pub target_point: Vec3,
129    /// Policy used to derive `target_point`.
130    pub targeting_policy: TargetingPolicy,
131    /// Cheap diagnostics derived from the rendered depth buffer
132    pub health: RenderHealth,
133    /// Status of this render
134    pub status: RenderStatus,
135    /// Error message if status is Failed or PartialFailure
136    pub error_message: Option<String>,
137}
138
139impl BatchRenderOutput {
140    /// Convert to neocortx-compatible RGB format: Vec<Vec<[u8; 3]>>
141    pub fn to_rgb_image(&self) -> Vec<Vec<[u8; 3]>> {
142        let mut image = Vec::with_capacity(self.height as usize);
143        for y in 0..self.height {
144            let mut row = Vec::with_capacity(self.width as usize);
145            for x in 0..self.width {
146                let idx = ((y * self.width + x) * 4) as usize;
147                if idx + 2 < self.rgba.len() {
148                    row.push([self.rgba[idx], self.rgba[idx + 1], self.rgba[idx + 2]]);
149                } else {
150                    row.push([0, 0, 0]);
151                }
152            }
153            image.push(row);
154        }
155        image
156    }
157
158    /// Convert depth to neocortx-compatible format: Vec<Vec<f64>>
159    pub fn to_depth_image(&self) -> Vec<Vec<f64>> {
160        let mut image = Vec::with_capacity(self.height as usize);
161        for y in 0..self.height {
162            let mut row = Vec::with_capacity(self.width as usize);
163            for x in 0..self.width {
164                let idx = (y * self.width + x) as usize;
165                if idx < self.depth.len() {
166                    row.push(self.depth[idx]);
167                } else {
168                    row.push(0.0);
169                }
170            }
171            image.push(row);
172        }
173        image
174    }
175
176    /// Build TBP-style `semantic_3d` rows using this request's far plane.
177    ///
178    /// The returned vector is row-major with one `[x, y, z, semantic_id]` row
179    /// per pixel. Foreground pixels are unprojected into world space and use
180    /// `object_semantic_id`; background/far pixels are `[0, 0, 0, 0]`.
181    pub fn semantic_3d(&self, object_semantic_id: u32) -> Vec<[f64; 4]> {
182        self.semantic_3d_with_far_plane(
183            object_semantic_id,
184            self.request.render_config.far_plane as f64,
185        )
186    }
187
188    /// Build TBP-style `semantic_3d` rows using a caller-provided far plane.
189    pub fn semantic_3d_with_far_plane(
190        &self,
191        object_semantic_id: u32,
192        far_plane: f64,
193    ) -> Vec<[f64; 4]> {
194        semantic_3d_from_depth(
195            &self.depth,
196            self.width,
197            self.height,
198            &self.intrinsics,
199            self.camera_transform,
200            object_semantic_id,
201            far_plane,
202        )
203    }
204
205    /// Convert from RenderOutput, carrying request-level target metadata.
206    pub fn from_render_output(request: BatchRenderRequest, output: RenderOutput) -> Self {
207        let health = output.health_with_far_plane(request.render_config.far_plane as f64);
208        let camera_transform = output.camera_transform;
209        let target_point = request.target_point;
210        let targeting_policy = request.targeting_policy.clone();
211        Self {
212            request,
213            rgba: output.rgba,
214            depth: output.depth,
215            width: output.width,
216            height: output.height,
217            intrinsics: output.intrinsics,
218            camera_transform,
219            target_point,
220            targeting_policy,
221            health,
222            status: RenderStatus::Success,
223            error_message: None,
224        }
225    }
226}
227
228/// Error types for batch rendering.
229#[derive(Debug, Clone)]
230pub enum BatchRenderError {
231    /// Some renders succeeded, others failed
232    PartialFailure { successful: usize, failed: usize },
233    /// All renders failed
234    TotalFailure(String),
235    /// Invalid configuration
236    InvalidConfig(String),
237    /// Queue is full
238    QueueFull,
239    /// No renders queued
240    EmptyQueue,
241    /// The wgpu device was lost mid-render. The current `RenderSession::render()`
242    /// call produced no output; any outputs returned by earlier calls remain valid.
243    /// Recovery: drop the session and construct a new one.
244    ///
245    /// `reason` is a string form of `wgpu::DeviceLostReason` so callers can branch
246    /// on recoverable vs. adapter-evicted without taking a direct wgpu dependency.
247    /// Phase 1 ships the string form; a typed variant may follow once the Bevy
248    /// re-export surface is clearer.
249    DeviceLost { reason: String, message: String },
250}
251
252impl std::fmt::Display for BatchRenderError {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        match self {
255            BatchRenderError::PartialFailure { successful, failed } => {
256                write!(
257                    f,
258                    "Batch render partial failure: {} succeeded, {} failed",
259                    successful, failed
260                )
261            }
262            BatchRenderError::TotalFailure(msg) => write!(f, "Batch render total failure: {}", msg),
263            BatchRenderError::InvalidConfig(msg) => write!(f, "Invalid batch config: {}", msg),
264            BatchRenderError::QueueFull => write!(f, "Batch queue is full"),
265            BatchRenderError::EmptyQueue => write!(f, "No renders queued"),
266            BatchRenderError::DeviceLost { reason, message } => {
267                write!(f, "wgpu device lost ({}): {}", reason, message)
268            }
269        }
270    }
271}
272
273impl std::error::Error for BatchRenderError {}
274
275/// State machine for batch rendering lifecycle.
276#[derive(Clone, Copy, Debug, PartialEq, Eq)]
277pub enum BatchState {
278    /// Idle, waiting for requests to queue
279    Idle,
280    /// Loading object assets (mesh, texture)
281    LoadingAssets,
282    /// Rendering frame to GPU buffer
283    RenderingFrame,
284    /// Extracting RGBA and depth from GPU
285    ExtractingResults,
286    /// Cleaning up resources
287    Cleanup,
288    /// Shutting down
289    Shutdown,
290}
291
292/// Manages queued render requests and completed outputs for batch-style workflows.
293pub struct BatchRenderer {
294    /// Queued render requests
295    pub pending_requests: VecDeque<BatchRenderRequest>,
296    /// Completed results
297    pub completed_results: Vec<BatchRenderOutput>,
298    /// Current request being processed
299    pub current_request: Option<BatchRenderRequest>,
300    /// Current render output being built
301    pub current_output: Option<BatchRenderOutput>,
302    /// Frame counter for timeout management
303    pub frame_count: u32,
304    /// Current state
305    pub state: BatchState,
306    /// Configuration
307    pub config: BatchRenderConfig,
308    /// Total renders processed
309    pub renders_processed: usize,
310}
311
312impl BatchRenderer {
313    /// Create a new batch renderer with default configuration.
314    pub fn new(config: BatchRenderConfig) -> Self {
315        Self {
316            pending_requests: VecDeque::new(),
317            completed_results: Vec::new(),
318            current_request: None,
319            current_output: None,
320            frame_count: 0,
321            state: BatchState::Idle,
322            config,
323            renders_processed: 0,
324        }
325    }
326
327    /// Queue a render request for batch processing.
328    pub fn queue_request(&mut self, request: BatchRenderRequest) -> Result<(), BatchRenderError> {
329        if self.pending_requests.len() >= self.config.max_batch_size {
330            return Err(BatchRenderError::QueueFull);
331        }
332        self.pending_requests.push_back(request);
333        Ok(())
334    }
335
336    /// Get the number of pending requests.
337    pub fn pending_count(&self) -> usize {
338        self.pending_requests.len()
339    }
340
341    /// Get the number of completed results.
342    pub fn completed_count(&self) -> usize {
343        self.completed_results.len()
344    }
345
346    /// Get all completed results and clear the internal list.
347    pub fn take_completed(&mut self) -> Vec<BatchRenderOutput> {
348        std::mem::take(&mut self.completed_results)
349    }
350
351    /// Check if all work is done (no pending requests and not currently rendering).
352    pub fn is_finished(&self) -> bool {
353        self.pending_requests.is_empty() && self.current_request.is_none()
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_batch_config_defaults() {
363        let config = BatchRenderConfig::default();
364        assert_eq!(config.max_batch_size, 256);
365        assert_eq!(config.frame_timeout_ms, 500);
366        assert!(config.enable_depth_readback);
367        assert!(config.enable_asset_caching);
368    }
369
370    #[test]
371    fn test_batch_renderer_creation() {
372        let config = BatchRenderConfig::default();
373        let renderer = BatchRenderer::new(config);
374        assert_eq!(renderer.state, BatchState::Idle);
375        assert_eq!(renderer.pending_count(), 0);
376        assert_eq!(renderer.completed_count(), 0);
377        assert!(renderer.is_finished());
378    }
379
380    #[test]
381    fn test_queue_request() {
382        let mut renderer = BatchRenderer::new(BatchRenderConfig::default());
383        let request = BatchRenderRequest {
384            object_dir: "/tmp/test".into(),
385            viewpoint: Transform::default(),
386            object_rotation: ObjectRotation::identity(),
387            render_config: RenderConfig::tbp_default(),
388            target_point: Vec3::ZERO,
389            targeting_policy: TargetingPolicy::Origin,
390        };
391        assert!(renderer.queue_request(request).is_ok());
392        assert_eq!(renderer.pending_count(), 1);
393    }
394
395    #[test]
396    fn test_queue_full() {
397        let config = BatchRenderConfig {
398            max_batch_size: 1,
399            ..BatchRenderConfig::default()
400        };
401        let mut renderer = BatchRenderer::new(config);
402
403        let request = BatchRenderRequest {
404            object_dir: "/tmp/test".into(),
405            viewpoint: Transform::default(),
406            object_rotation: ObjectRotation::identity(),
407            render_config: RenderConfig::tbp_default(),
408            target_point: Vec3::ZERO,
409            targeting_policy: TargetingPolicy::Origin,
410        };
411
412        assert!(renderer.queue_request(request.clone()).is_ok());
413        assert!(matches!(
414            renderer.queue_request(request),
415            Err(BatchRenderError::QueueFull)
416        ));
417    }
418
419    #[test]
420    fn test_batch_render_output_rgb_conversion() {
421        let request = BatchRenderRequest {
422            object_dir: "/tmp/test".into(),
423            viewpoint: Transform::default(),
424            object_rotation: ObjectRotation::identity(),
425            render_config: RenderConfig::tbp_default(),
426            target_point: Vec3::ZERO,
427            targeting_policy: TargetingPolicy::Origin,
428        };
429
430        // Create minimal output: 2x2 image
431        let mut rgba = vec![0u8; 2 * 2 * 4];
432        // Pixel (0,0) = red
433        rgba[0] = 255;
434        rgba[1] = 0;
435        rgba[2] = 0;
436        rgba[3] = 255;
437
438        let output = BatchRenderOutput {
439            request,
440            rgba,
441            depth: vec![1.0; 4],
442            width: 2,
443            height: 2,
444            intrinsics: RenderConfig::tbp_default().intrinsics(),
445            camera_transform: Transform::default(),
446            target_point: Vec3::ZERO,
447            targeting_policy: TargetingPolicy::Origin,
448            health: RenderHealth {
449                center_pixel: Some([1, 1]),
450                center_depth: Some(1.0),
451                center_foreground: true,
452                foreground_pixel_count: 4,
453                foreground_coverage: 1.0,
454                center_5x5_foreground_count: 4,
455                nearest_foreground_pixel: Some([1, 1]),
456                nearest_foreground_depth: Some(1.0),
457                nearest_foreground_distance_px: Some(0.0),
458            },
459            status: RenderStatus::Success,
460            error_message: None,
461        };
462
463        let rgb = output.to_rgb_image();
464        assert_eq!(rgb.len(), 2); // 2 rows
465        assert_eq!(rgb[0].len(), 2); // 2 cols
466        assert_eq!(rgb[0][0], [255, 0, 0]); // Red
467    }
468
469    #[test]
470    fn test_batch_render_output_carries_request_target_metadata() {
471        let target_point = Vec3::new(0.25, -0.125, 0.5);
472        let camera_transform = Transform::from_xyz(0.0, 0.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y);
473        let request = BatchRenderRequest {
474            object_dir: "/tmp/test".into(),
475            viewpoint: camera_transform,
476            object_rotation: ObjectRotation::identity(),
477            render_config: RenderConfig::tbp_default(),
478            target_point,
479            targeting_policy: TargetingPolicy::MeshCenter,
480        };
481        let output = RenderOutput {
482            rgba: vec![0u8; 4],
483            depth: vec![1.0],
484            width: 1,
485            height: 1,
486            intrinsics: RenderConfig::tbp_default().intrinsics(),
487            camera_transform,
488            object_rotation: ObjectRotation::identity(),
489            target_point: Vec3::ZERO,
490            targeting_policy: TargetingPolicy::Origin,
491        };
492
493        let batch_output = BatchRenderOutput::from_render_output(request, output);
494
495        assert_eq!(batch_output.target_point, target_point);
496        assert_eq!(batch_output.targeting_policy, TargetingPolicy::MeshCenter);
497        assert_eq!(batch_output.camera_transform, camera_transform);
498        assert_eq!(batch_output.request.target_point, target_point);
499        assert_eq!(
500            batch_output.request.targeting_policy,
501            TargetingPolicy::MeshCenter
502        );
503    }
504
505    #[test]
506    fn test_batch_render_output_semantic_3d_uses_camera_transform() {
507        let camera_transform = Transform::from_xyz(0.0, 0.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y);
508        let request = BatchRenderRequest {
509            object_dir: "/tmp/test".into(),
510            viewpoint: camera_transform,
511            object_rotation: ObjectRotation::identity(),
512            render_config: RenderConfig::tbp_default(),
513            target_point: Vec3::ZERO,
514            targeting_policy: TargetingPolicy::Origin,
515        };
516        let output = RenderOutput {
517            rgba: vec![0u8; 4],
518            depth: vec![1.5],
519            width: 1,
520            height: 1,
521            intrinsics: CameraIntrinsics {
522                focal_length: [100.0, 100.0],
523                principal_point: [0.0, 0.0],
524                image_size: [1, 1],
525            },
526            camera_transform,
527            object_rotation: ObjectRotation::identity(),
528            target_point: Vec3::ZERO,
529            targeting_policy: TargetingPolicy::Origin,
530        };
531
532        let batch_output = BatchRenderOutput::from_render_output(request, output);
533        let rows = batch_output.semantic_3d(7);
534
535        assert_eq!(rows.len(), 1);
536        assert!((rows[0][0]).abs() < 1e-6);
537        assert!((rows[0][1]).abs() < 1e-6);
538        assert!((rows[0][2] - 0.5).abs() < 1e-6);
539        assert_eq!(rows[0][3], 7.0);
540    }
541}