Skip to main content

bevy_sensor/
batch.rs

1//! Batch rendering API for multiple viewpoints and objects.
2//!
3//! Today this module is a queue-oriented wrapper around sequential `render_to_buffer()`
4//! calls. It does not yet keep a persistent Bevy app alive across renders; that follow-up
5//! remains tracked work. The API is still useful for consumers that want ordered request
6//! management and structured batch outputs without promising reuse semantics that do not
7//! exist yet.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use bevy_sensor::{
13//!     create_batch_renderer, queue_render_request, render_next_in_batch,
14//!     batch::BatchRenderRequest, BatchRenderConfig, RenderConfig, ObjectRotation,
15//!     TargetingPolicy, Vec3,
16//! };
17//! use std::path::PathBuf;
18//!
19//! // Create a batch helper
20//! let config = BatchRenderConfig::default();
21//! let mut renderer = create_batch_renderer(&config)?;
22//!
23//! // Queue multiple renders
24//! for rotation in rotations {
25//!     for viewpoint in viewpoints {
26//!         queue_render_request(&mut renderer, BatchRenderRequest {
27//!             object_dir: "/tmp/ycb/003_cracker_box".into(),
28//!             viewpoint,
29//!             object_rotation: rotation.clone(),
30//!             object_translation: Vec3::ZERO,
31//!             object_scale: Vec3::ONE,
32//!             render_config: RenderConfig::tbp_default(),
33//!             target_point: Vec3::ZERO,
34//!             targeting_policy: TargetingPolicy::Origin,
35//!         })?;
36//!     }
37//! }
38//!
39//! // Execute and collect results
40//! let mut results = Vec::new();
41//! loop {
42//!     match render_next_in_batch(&mut renderer, 500)? {
43//!         Some(output) => results.push(output),
44//!         None => break,
45//!     }
46//! }
47//! ```
48
49use crate::{
50    semantic_3d_from_depth, CameraIntrinsics, ObjectRotation, RenderConfig, RenderHealth,
51    RenderOutput, TargetingPolicy,
52};
53use bevy::prelude::{Transform, Vec3};
54use std::collections::VecDeque;
55use std::path::PathBuf;
56
57/// Configuration for batch rendering.
58#[derive(Clone, Debug)]
59pub struct BatchRenderConfig {
60    /// Maximum number of renders to queue before automatic cleanup
61    pub max_batch_size: usize,
62    /// Timeout in milliseconds per individual render
63    pub frame_timeout_ms: u32,
64    /// Enable depth buffer readback
65    pub enable_depth_readback: bool,
66    /// Enable asset caching for repeated objects
67    pub enable_asset_caching: bool,
68    /// Number of renders before triggering resource cleanup
69    pub resource_cleanup_interval: u32,
70}
71
72impl Default for BatchRenderConfig {
73    fn default() -> Self {
74        Self {
75            max_batch_size: 256,
76            frame_timeout_ms: 500,
77            enable_depth_readback: true,
78            enable_asset_caching: true,
79            resource_cleanup_interval: 32,
80        }
81    }
82}
83
84/// A single render request in a batch.
85#[derive(Clone, Debug)]
86pub struct BatchRenderRequest {
87    /// Path to YCB object directory (e.g., "/tmp/ycb/003_cracker_box")
88    pub object_dir: PathBuf,
89    /// Camera transform (position and orientation)
90    pub viewpoint: Transform,
91    /// Object rotation to apply
92    pub object_rotation: ObjectRotation,
93    /// Object world translation to apply
94    pub object_translation: Vec3,
95    /// Object scale to apply
96    pub object_scale: Vec3,
97    /// Render configuration (resolution, lighting, etc.)
98    pub render_config: RenderConfig,
99    /// Point the camera was intended to target for this render.
100    pub target_point: Vec3,
101    /// Policy used to derive `target_point`.
102    pub targeting_policy: TargetingPolicy,
103}
104
105impl BatchRenderRequest {
106    /// Build a request with the current default object transform: origin translation and unit scale.
107    pub fn new(
108        object_dir: PathBuf,
109        viewpoint: Transform,
110        object_rotation: ObjectRotation,
111        render_config: RenderConfig,
112    ) -> Self {
113        Self {
114            object_dir,
115            viewpoint,
116            object_rotation,
117            object_translation: Vec3::ZERO,
118            object_scale: Vec3::ONE,
119            render_config,
120            target_point: Vec3::ZERO,
121            targeting_policy: TargetingPolicy::Origin,
122        }
123    }
124
125    /// Attach explicit object translation and scale.
126    pub fn with_object_transform(mut self, object_translation: Vec3, object_scale: Vec3) -> Self {
127        self.object_translation = object_translation;
128        self.object_scale = object_scale;
129        self
130    }
131
132    /// Attach camera-target metadata used to create the request viewpoint.
133    pub fn with_targeting(mut self, target_point: Vec3, targeting_policy: TargetingPolicy) -> Self {
134        self.target_point = target_point;
135        self.targeting_policy = targeting_policy;
136        self
137    }
138}
139
140/// Status of a single render in a batch.
141#[derive(Clone, Debug, Copy, PartialEq, Eq)]
142pub enum RenderStatus {
143    /// Render completed successfully with RGBA and depth
144    Success,
145    /// Render completed but depth extraction failed
146    PartialFailure,
147    /// Render failed completely
148    Failed,
149}
150
151/// Output from a single render in a batch.
152#[derive(Clone, Debug)]
153pub struct BatchRenderOutput {
154    /// Original request for this render
155    pub request: BatchRenderRequest,
156    /// RGBA pixel data (width * height * 4 bytes, row-major)
157    pub rgba: Vec<u8>,
158    /// Depth data in meters (width * height f64s)
159    pub depth: Vec<f64>,
160    /// Image width in pixels
161    pub width: u32,
162    /// Image height in pixels
163    pub height: u32,
164    /// Camera intrinsics used
165    pub intrinsics: CameraIntrinsics,
166    /// Camera transform used for world-space depth unprojection.
167    pub camera_transform: Transform,
168    /// Object world translation applied during render.
169    pub object_translation: Vec3,
170    /// Object scale applied during render.
171    pub object_scale: Vec3,
172    /// Point the camera was intended to target for this render.
173    pub target_point: Vec3,
174    /// Policy used to derive `target_point`.
175    pub targeting_policy: TargetingPolicy,
176    /// Cheap diagnostics derived from the rendered depth buffer
177    pub health: RenderHealth,
178    /// Status of this render
179    pub status: RenderStatus,
180    /// Error message if status is Failed or PartialFailure
181    pub error_message: Option<String>,
182}
183
184impl BatchRenderOutput {
185    /// Convert to neocortx-compatible RGB format: Vec<Vec<[u8; 3]>>
186    pub fn to_rgb_image(&self) -> Vec<Vec<[u8; 3]>> {
187        let mut image = Vec::with_capacity(self.height as usize);
188        for y in 0..self.height {
189            let mut row = Vec::with_capacity(self.width as usize);
190            for x in 0..self.width {
191                let idx = ((y * self.width + x) * 4) as usize;
192                if idx + 2 < self.rgba.len() {
193                    row.push([self.rgba[idx], self.rgba[idx + 1], self.rgba[idx + 2]]);
194                } else {
195                    row.push([0, 0, 0]);
196                }
197            }
198            image.push(row);
199        }
200        image
201    }
202
203    /// Convert depth to neocortx-compatible format: Vec<Vec<f64>>
204    pub fn to_depth_image(&self) -> Vec<Vec<f64>> {
205        let mut image = Vec::with_capacity(self.height as usize);
206        for y in 0..self.height {
207            let mut row = Vec::with_capacity(self.width as usize);
208            for x in 0..self.width {
209                let idx = (y * self.width + x) as usize;
210                if idx < self.depth.len() {
211                    row.push(self.depth[idx]);
212                } else {
213                    row.push(0.0);
214                }
215            }
216            image.push(row);
217        }
218        image
219    }
220
221    /// Build TBP-style `semantic_3d` rows using this request's far plane.
222    ///
223    /// The returned vector is row-major with one `[x, y, z, semantic_id]` row
224    /// per pixel. Foreground pixels are unprojected into world space and use
225    /// `object_semantic_id`; background/far pixels are `[0, 0, 0, 0]`.
226    pub fn semantic_3d(&self, object_semantic_id: u32) -> Vec<[f64; 4]> {
227        self.semantic_3d_with_far_plane(
228            object_semantic_id,
229            self.request.render_config.far_plane as f64,
230        )
231    }
232
233    /// Build TBP-style `semantic_3d` rows using a caller-provided far plane.
234    pub fn semantic_3d_with_far_plane(
235        &self,
236        object_semantic_id: u32,
237        far_plane: f64,
238    ) -> Vec<[f64; 4]> {
239        semantic_3d_from_depth(
240            &self.depth,
241            self.width,
242            self.height,
243            &self.intrinsics,
244            self.camera_transform,
245            object_semantic_id,
246            far_plane,
247        )
248    }
249
250    /// Convert from RenderOutput, carrying request-level target metadata.
251    pub fn from_render_output(request: BatchRenderRequest, output: RenderOutput) -> Self {
252        let health = output.health_with_far_plane(request.render_config.far_plane as f64);
253        let camera_transform = output.camera_transform;
254        let object_translation = output.object_translation;
255        let object_scale = output.object_scale;
256        let target_point = request.target_point;
257        let targeting_policy = request.targeting_policy.clone();
258        Self {
259            request,
260            rgba: output.rgba,
261            depth: output.depth,
262            width: output.width,
263            height: output.height,
264            intrinsics: output.intrinsics,
265            camera_transform,
266            object_translation,
267            object_scale,
268            target_point,
269            targeting_policy,
270            health,
271            status: RenderStatus::Success,
272            error_message: None,
273        }
274    }
275}
276
277/// Error types for batch rendering.
278#[derive(Debug, Clone)]
279pub enum BatchRenderError {
280    /// Some renders succeeded, others failed
281    PartialFailure { successful: usize, failed: usize },
282    /// All renders failed
283    TotalFailure(String),
284    /// Invalid configuration
285    InvalidConfig(String),
286    /// Queue is full
287    QueueFull,
288    /// No renders queued
289    EmptyQueue,
290    /// The wgpu device was lost mid-render. The current `RenderSession::render()`
291    /// call produced no output; any outputs returned by earlier calls remain valid.
292    /// Recovery: drop the session and construct a new one.
293    ///
294    /// `reason` is a string form of `wgpu::DeviceLostReason` so callers can branch
295    /// on recoverable vs. adapter-evicted without taking a direct wgpu dependency.
296    /// Phase 1 ships the string form; a typed variant may follow once the Bevy
297    /// re-export surface is clearer.
298    DeviceLost { reason: String, message: String },
299}
300
301impl std::fmt::Display for BatchRenderError {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        match self {
304            BatchRenderError::PartialFailure { successful, failed } => {
305                write!(
306                    f,
307                    "Batch render partial failure: {} succeeded, {} failed",
308                    successful, failed
309                )
310            }
311            BatchRenderError::TotalFailure(msg) => write!(f, "Batch render total failure: {}", msg),
312            BatchRenderError::InvalidConfig(msg) => write!(f, "Invalid batch config: {}", msg),
313            BatchRenderError::QueueFull => write!(f, "Batch queue is full"),
314            BatchRenderError::EmptyQueue => write!(f, "No renders queued"),
315            BatchRenderError::DeviceLost { reason, message } => {
316                write!(f, "wgpu device lost ({}): {}", reason, message)
317            }
318        }
319    }
320}
321
322impl std::error::Error for BatchRenderError {}
323
324/// State machine for batch rendering lifecycle.
325#[derive(Clone, Copy, Debug, PartialEq, Eq)]
326pub enum BatchState {
327    /// Idle, waiting for requests to queue
328    Idle,
329    /// Loading object assets (mesh, texture)
330    LoadingAssets,
331    /// Rendering frame to GPU buffer
332    RenderingFrame,
333    /// Extracting RGBA and depth from GPU
334    ExtractingResults,
335    /// Cleaning up resources
336    Cleanup,
337    /// Shutting down
338    Shutdown,
339}
340
341/// Manages queued render requests and completed outputs for batch-style workflows.
342pub struct BatchRenderer {
343    /// Queued render requests
344    pub pending_requests: VecDeque<BatchRenderRequest>,
345    /// Completed results
346    pub completed_results: Vec<BatchRenderOutput>,
347    /// Current request being processed
348    pub current_request: Option<BatchRenderRequest>,
349    /// Current render output being built
350    pub current_output: Option<BatchRenderOutput>,
351    /// Frame counter for timeout management
352    pub frame_count: u32,
353    /// Current state
354    pub state: BatchState,
355    /// Configuration
356    pub config: BatchRenderConfig,
357    /// Total renders processed
358    pub renders_processed: usize,
359}
360
361impl BatchRenderer {
362    /// Create a new batch renderer with default configuration.
363    pub fn new(config: BatchRenderConfig) -> Self {
364        Self {
365            pending_requests: VecDeque::new(),
366            completed_results: Vec::new(),
367            current_request: None,
368            current_output: None,
369            frame_count: 0,
370            state: BatchState::Idle,
371            config,
372            renders_processed: 0,
373        }
374    }
375
376    /// Queue a render request for batch processing.
377    pub fn queue_request(&mut self, request: BatchRenderRequest) -> Result<(), BatchRenderError> {
378        if self.pending_requests.len() >= self.config.max_batch_size {
379            return Err(BatchRenderError::QueueFull);
380        }
381        self.pending_requests.push_back(request);
382        Ok(())
383    }
384
385    /// Get the number of pending requests.
386    pub fn pending_count(&self) -> usize {
387        self.pending_requests.len()
388    }
389
390    /// Get the number of completed results.
391    pub fn completed_count(&self) -> usize {
392        self.completed_results.len()
393    }
394
395    /// Get all completed results and clear the internal list.
396    pub fn take_completed(&mut self) -> Vec<BatchRenderOutput> {
397        std::mem::take(&mut self.completed_results)
398    }
399
400    /// Check if all work is done (no pending requests and not currently rendering).
401    pub fn is_finished(&self) -> bool {
402        self.pending_requests.is_empty() && self.current_request.is_none()
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_batch_config_defaults() {
412        let config = BatchRenderConfig::default();
413        assert_eq!(config.max_batch_size, 256);
414        assert_eq!(config.frame_timeout_ms, 500);
415        assert!(config.enable_depth_readback);
416        assert!(config.enable_asset_caching);
417    }
418
419    #[test]
420    fn test_batch_renderer_creation() {
421        let config = BatchRenderConfig::default();
422        let renderer = BatchRenderer::new(config);
423        assert_eq!(renderer.state, BatchState::Idle);
424        assert_eq!(renderer.pending_count(), 0);
425        assert_eq!(renderer.completed_count(), 0);
426        assert!(renderer.is_finished());
427    }
428
429    #[test]
430    fn test_queue_request() {
431        let mut renderer = BatchRenderer::new(BatchRenderConfig::default());
432        let request = BatchRenderRequest {
433            object_dir: "/tmp/test".into(),
434            viewpoint: Transform::default(),
435            object_rotation: ObjectRotation::identity(),
436            object_translation: Vec3::ZERO,
437            object_scale: Vec3::ONE,
438            render_config: RenderConfig::tbp_default(),
439            target_point: Vec3::ZERO,
440            targeting_policy: TargetingPolicy::Origin,
441        };
442        assert!(renderer.queue_request(request).is_ok());
443        assert_eq!(renderer.pending_count(), 1);
444    }
445
446    #[test]
447    fn test_queue_full() {
448        let config = BatchRenderConfig {
449            max_batch_size: 1,
450            ..BatchRenderConfig::default()
451        };
452        let mut renderer = BatchRenderer::new(config);
453
454        let request = BatchRenderRequest {
455            object_dir: "/tmp/test".into(),
456            viewpoint: Transform::default(),
457            object_rotation: ObjectRotation::identity(),
458            object_translation: Vec3::ZERO,
459            object_scale: Vec3::ONE,
460            render_config: RenderConfig::tbp_default(),
461            target_point: Vec3::ZERO,
462            targeting_policy: TargetingPolicy::Origin,
463        };
464
465        assert!(renderer.queue_request(request.clone()).is_ok());
466        assert!(matches!(
467            renderer.queue_request(request),
468            Err(BatchRenderError::QueueFull)
469        ));
470    }
471
472    #[test]
473    fn test_batch_render_output_rgb_conversion() {
474        let request = BatchRenderRequest {
475            object_dir: "/tmp/test".into(),
476            viewpoint: Transform::default(),
477            object_rotation: ObjectRotation::identity(),
478            object_translation: Vec3::ZERO,
479            object_scale: Vec3::ONE,
480            render_config: RenderConfig::tbp_default(),
481            target_point: Vec3::ZERO,
482            targeting_policy: TargetingPolicy::Origin,
483        };
484
485        // Create minimal output: 2x2 image
486        let mut rgba = vec![0u8; 2 * 2 * 4];
487        // Pixel (0,0) = red
488        rgba[0] = 255;
489        rgba[1] = 0;
490        rgba[2] = 0;
491        rgba[3] = 255;
492
493        let output = BatchRenderOutput {
494            request,
495            rgba,
496            depth: vec![1.0; 4],
497            width: 2,
498            height: 2,
499            intrinsics: RenderConfig::tbp_default().intrinsics(),
500            camera_transform: Transform::default(),
501            object_translation: Vec3::ZERO,
502            object_scale: Vec3::ONE,
503            target_point: Vec3::ZERO,
504            targeting_policy: TargetingPolicy::Origin,
505            health: RenderHealth {
506                center_pixel: Some([1, 1]),
507                center_depth: Some(1.0),
508                center_foreground: true,
509                foreground_pixel_count: 4,
510                foreground_coverage: 1.0,
511                center_5x5_foreground_count: 4,
512                nearest_foreground_pixel: Some([1, 1]),
513                nearest_foreground_depth: Some(1.0),
514                nearest_foreground_distance_px: Some(0.0),
515            },
516            status: RenderStatus::Success,
517            error_message: None,
518        };
519
520        let rgb = output.to_rgb_image();
521        assert_eq!(rgb.len(), 2); // 2 rows
522        assert_eq!(rgb[0].len(), 2); // 2 cols
523        assert_eq!(rgb[0][0], [255, 0, 0]); // Red
524    }
525
526    #[test]
527    fn test_batch_render_output_carries_request_target_metadata() {
528        let target_point = Vec3::new(0.25, -0.125, 0.5);
529        let camera_transform = Transform::from_xyz(0.0, 0.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y);
530        let request = BatchRenderRequest {
531            object_dir: "/tmp/test".into(),
532            viewpoint: camera_transform,
533            object_rotation: ObjectRotation::identity(),
534            object_translation: Vec3::new(0.125, 0.25, -0.5),
535            object_scale: Vec3::splat(1.25),
536            render_config: RenderConfig::tbp_default(),
537            target_point,
538            targeting_policy: TargetingPolicy::MeshCenter,
539        };
540        let output = RenderOutput {
541            rgba: vec![0u8; 4],
542            depth: vec![1.0],
543            width: 1,
544            height: 1,
545            intrinsics: RenderConfig::tbp_default().intrinsics(),
546            camera_transform,
547            object_rotation: ObjectRotation::identity(),
548            object_translation: Vec3::new(0.125, 0.25, -0.5),
549            object_scale: Vec3::splat(1.25),
550            target_point: Vec3::ZERO,
551            targeting_policy: TargetingPolicy::Origin,
552        };
553
554        let batch_output = BatchRenderOutput::from_render_output(request, output);
555
556        assert_eq!(batch_output.target_point, target_point);
557        assert_eq!(batch_output.targeting_policy, TargetingPolicy::MeshCenter);
558        assert_eq!(batch_output.camera_transform, camera_transform);
559        assert_eq!(
560            batch_output.object_translation,
561            Vec3::new(0.125, 0.25, -0.5)
562        );
563        assert_eq!(batch_output.object_scale, Vec3::splat(1.25));
564        assert_eq!(batch_output.request.target_point, target_point);
565        assert_eq!(
566            batch_output.request.object_translation,
567            Vec3::new(0.125, 0.25, -0.5)
568        );
569        assert_eq!(batch_output.request.object_scale, Vec3::splat(1.25));
570        assert_eq!(
571            batch_output.request.targeting_policy,
572            TargetingPolicy::MeshCenter
573        );
574    }
575
576    #[test]
577    fn test_batch_render_output_semantic_3d_uses_camera_transform() {
578        let camera_transform = Transform::from_xyz(0.0, 0.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y);
579        let request = BatchRenderRequest {
580            object_dir: "/tmp/test".into(),
581            viewpoint: camera_transform,
582            object_rotation: ObjectRotation::identity(),
583            object_translation: Vec3::ZERO,
584            object_scale: Vec3::ONE,
585            render_config: RenderConfig::tbp_default(),
586            target_point: Vec3::ZERO,
587            targeting_policy: TargetingPolicy::Origin,
588        };
589        let output = RenderOutput {
590            rgba: vec![0u8; 4],
591            depth: vec![1.5],
592            width: 1,
593            height: 1,
594            intrinsics: CameraIntrinsics {
595                focal_length: [100.0, 100.0],
596                principal_point: [0.0, 0.0],
597                image_size: [1, 1],
598            },
599            camera_transform,
600            object_rotation: ObjectRotation::identity(),
601            object_translation: Vec3::ZERO,
602            object_scale: Vec3::ONE,
603            target_point: Vec3::ZERO,
604            targeting_policy: TargetingPolicy::Origin,
605        };
606
607        let batch_output = BatchRenderOutput::from_render_output(request, output);
608        let rows = batch_output.semantic_3d(7);
609
610        assert_eq!(rows.len(), 1);
611        assert!((rows[0][0]).abs() < 1e-6);
612        assert!((rows[0][1]).abs() < 1e-6);
613        assert!((rows[0][2] - 0.5).abs() < 1e-6);
614        assert_eq!(rows[0][3], 7.0);
615    }
616}