Skip to main content

bevy_sensor/
fixtures.rs

1//! Test fixtures for pre-rendered YCB images
2//!
3//! This module provides utilities for loading pre-rendered images from disk,
4//! enabling testing without GPU access.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use bevy_sensor::fixtures::TestFixtures;
10//!
11//! let fixtures = TestFixtures::load("test_fixtures/renders")?;
12//!
13//! // Get a specific render
14//! let render = fixtures.get_render("003_cracker_box", 0, 5)?;
15//! let rgb_image = render.to_rgb_image();
16//! let depth_image = render.to_depth_image();
17//! ```
18
19use crate::{CameraIntrinsics, MeshBoundsMetadata, RenderHealth, RenderOutput, TargetingPolicy};
20use bevy::prelude::*;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::fs;
24use std::path::{Path, PathBuf};
25
26/// Error type for fixture loading
27#[derive(Debug)]
28pub enum FixtureError {
29    /// Directory not found
30    NotFound(String),
31    /// Metadata file missing or invalid
32    InvalidMetadata(String),
33    /// Render file missing
34    RenderNotFound {
35        object_id: String,
36        rotation: usize,
37        viewpoint: usize,
38    },
39    /// IO error
40    IoError(std::io::Error),
41    /// JSON parsing error
42    JsonError(serde_json::Error),
43}
44
45impl std::fmt::Display for FixtureError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            FixtureError::NotFound(path) => write!(f, "Fixture directory not found: {}", path),
49            FixtureError::InvalidMetadata(msg) => write!(f, "Invalid metadata: {}", msg),
50            FixtureError::RenderNotFound {
51                object_id,
52                rotation,
53                viewpoint,
54            } => write!(
55                f,
56                "Render not found: {} r{} v{}",
57                object_id, rotation, viewpoint
58            ),
59            FixtureError::IoError(e) => write!(f, "IO error: {}", e),
60            FixtureError::JsonError(e) => write!(f, "JSON error: {}", e),
61        }
62    }
63}
64
65impl std::error::Error for FixtureError {}
66
67impl From<std::io::Error> for FixtureError {
68    fn from(e: std::io::Error) -> Self {
69        FixtureError::IoError(e)
70    }
71}
72
73impl From<serde_json::Error> for FixtureError {
74    fn from(e: serde_json::Error) -> Self {
75        FixtureError::JsonError(e)
76    }
77}
78
79/// Dataset metadata from pre-rendering
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DatasetMetadata {
82    pub version: String,
83    #[serde(default)]
84    pub crate_version: Option<String>,
85    #[serde(default)]
86    pub renderer_policy_version: Option<String>,
87    pub objects: Vec<String>,
88    pub viewpoints_per_rotation: usize,
89    pub rotations_per_object: usize,
90    pub renders_per_object: usize,
91    pub resolution: [u32; 2],
92    #[serde(default)]
93    pub resolution_width: Option<u32>,
94    #[serde(default)]
95    pub resolution_height: Option<u32>,
96    #[serde(default)]
97    pub targeting_policy: Option<TargetingPolicy>,
98    pub intrinsics: IntrinsicsMetadata,
99    pub viewpoint_config: ViewpointConfigMetadata,
100    pub rotations: Vec<[f32; 3]>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct IntrinsicsMetadata {
105    pub focal_length: [f32; 2],
106    pub principal_point: [f32; 2],
107    pub image_size: [u32; 2],
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ViewpointConfigMetadata {
112    pub radius: f32,
113    pub yaw_count: usize,
114    pub pitch_angles_deg: Vec<f32>,
115}
116
117/// Metadata for a single render
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct RenderMetadata {
120    pub object_id: String,
121    pub rotation_index: usize,
122    pub viewpoint_index: usize,
123    pub rotation_euler: [f32; 3],
124    pub camera_position: [f32; 3],
125    #[serde(default)]
126    pub camera_rotation_xyzw: Option<[f32; 4]>,
127    #[serde(default)]
128    pub target_point: Option<[f32; 3]>,
129    #[serde(default)]
130    pub targeting_policy: Option<TargetingPolicy>,
131    #[serde(default)]
132    pub mesh_bounds: Option<MeshBoundsMetadata>,
133    #[serde(default)]
134    pub health: Option<RenderHealth>,
135    pub rgba_file: String,
136    pub depth_file: String,
137}
138
139/// Pre-rendered test fixtures loaded from disk
140pub struct TestFixtures {
141    /// Root directory containing fixtures
142    root: PathBuf,
143    /// Dataset metadata
144    pub metadata: DatasetMetadata,
145    /// Per-object render indices
146    indices: HashMap<String, Vec<RenderMetadata>>,
147}
148
149impl TestFixtures {
150    /// Load test fixtures from a directory
151    ///
152    /// # Arguments
153    /// * `path` - Path to the fixtures directory (e.g., "test_fixtures/renders")
154    ///
155    /// # Returns
156    /// * `Ok(TestFixtures)` if loaded successfully
157    /// * `Err(FixtureError)` if loading fails
158    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, FixtureError> {
159        let root = path.as_ref().to_path_buf();
160
161        if !root.exists() {
162            return Err(FixtureError::NotFound(root.display().to_string()));
163        }
164
165        // Load metadata
166        let metadata_path = root.join("metadata.json");
167        if !metadata_path.exists() {
168            return Err(FixtureError::InvalidMetadata(
169                "metadata.json not found".to_string(),
170            ));
171        }
172
173        let metadata_json = fs::read_to_string(&metadata_path)?;
174        let metadata: DatasetMetadata = serde_json::from_str(&metadata_json)?;
175
176        // Load per-object indices
177        let mut indices = HashMap::new();
178        for object_id in &metadata.objects {
179            let index_path = root.join(object_id).join("index.json");
180            if index_path.exists() {
181                let index_json = fs::read_to_string(&index_path)?;
182                let renders: Vec<RenderMetadata> = serde_json::from_str(&index_json)?;
183                indices.insert(object_id.clone(), renders);
184            }
185        }
186
187        Ok(Self {
188            root,
189            metadata,
190            indices,
191        })
192    }
193
194    /// Check if fixtures exist at the given path
195    pub fn exists<P: AsRef<Path>>(path: P) -> bool {
196        let root = path.as_ref();
197        root.exists() && root.join("metadata.json").exists()
198    }
199
200    /// Get list of available objects
201    pub fn objects(&self) -> &[String] {
202        &self.metadata.objects
203    }
204
205    /// Get number of viewpoints per rotation
206    pub fn viewpoints_per_rotation(&self) -> usize {
207        self.metadata.viewpoints_per_rotation
208    }
209
210    /// Get number of rotations per object
211    pub fn rotations_per_object(&self) -> usize {
212        self.metadata.rotations_per_object
213    }
214
215    /// Get total renders available for an object
216    pub fn renders_for_object(&self, object_id: &str) -> usize {
217        self.indices.get(object_id).map(|v| v.len()).unwrap_or(0)
218    }
219
220    /// Get camera intrinsics (converts from f32 metadata to f64 for TBP precision)
221    pub fn intrinsics(&self) -> CameraIntrinsics {
222        CameraIntrinsics {
223            focal_length: [
224                self.metadata.intrinsics.focal_length[0] as f64,
225                self.metadata.intrinsics.focal_length[1] as f64,
226            ],
227            principal_point: [
228                self.metadata.intrinsics.principal_point[0] as f64,
229                self.metadata.intrinsics.principal_point[1] as f64,
230            ],
231            image_size: self.metadata.intrinsics.image_size,
232        }
233    }
234
235    /// Load a specific render by object, rotation index, and viewpoint index
236    ///
237    /// # Arguments
238    /// * `object_id` - YCB object ID (e.g., "003_cracker_box")
239    /// * `rotation_idx` - Rotation index (0-2 for benchmark rotations)
240    /// * `viewpoint_idx` - Viewpoint index (0-23 for default config)
241    pub fn get_render(
242        &self,
243        object_id: &str,
244        rotation_idx: usize,
245        viewpoint_idx: usize,
246    ) -> Result<RenderOutput, FixtureError> {
247        // Find the render metadata
248        let renders = self
249            .indices
250            .get(object_id)
251            .ok_or_else(|| FixtureError::RenderNotFound {
252                object_id: object_id.to_string(),
253                rotation: rotation_idx,
254                viewpoint: viewpoint_idx,
255            })?;
256
257        let render_meta = renders
258            .iter()
259            .find(|r| r.rotation_index == rotation_idx && r.viewpoint_index == viewpoint_idx)
260            .ok_or_else(|| FixtureError::RenderNotFound {
261                object_id: object_id.to_string(),
262                rotation: rotation_idx,
263                viewpoint: viewpoint_idx,
264            })?;
265
266        // Load RGBA from PNG
267        let rgba_path = self.root.join(object_id).join(&render_meta.rgba_file);
268        let rgba = load_rgba_png(&rgba_path)?;
269
270        // Load depth from binary
271        let depth_path = self.root.join(object_id).join(&render_meta.depth_file);
272        let expected_depth_values =
273            (self.metadata.resolution[0] as usize) * (self.metadata.resolution[1] as usize);
274        let depth = load_depth_binary(&depth_path, expected_depth_values)?;
275
276        // Build camera transform from position. New manifests carry the exact
277        // render rotation; older manifests only recorded origin-targeted
278        // camera positions, so reconstruct with a target fallback.
279        let pos = render_meta.camera_position;
280        let translation = Vec3::new(pos[0], pos[1], pos[2]);
281        let camera_transform = if let Some(q) = render_meta.camera_rotation_xyzw {
282            Transform {
283                translation,
284                rotation: Quat::from_xyzw(q[0], q[1], q[2], q[3]),
285                ..Default::default()
286            }
287        } else {
288            let target = render_meta.target_point.unwrap_or([0.0, 0.0, 0.0]);
289            Transform::from_translation(translation)
290                .looking_at(Vec3::new(target[0], target[1], target[2]), Vec3::Y)
291        };
292        let target_point = Vec3::from_array(render_meta.target_point.unwrap_or([0.0, 0.0, 0.0]));
293        let targeting_policy = render_meta
294            .targeting_policy
295            .clone()
296            .or_else(|| self.metadata.targeting_policy.clone())
297            .unwrap_or(TargetingPolicy::Origin);
298
299        // Build object rotation (convert from f32 metadata to f64)
300        let rot = render_meta.rotation_euler;
301        let object_rotation =
302            crate::ObjectRotation::new(rot[0] as f64, rot[1] as f64, rot[2] as f64);
303
304        Ok(RenderOutput {
305            rgba,
306            depth,
307            width: self.metadata.resolution[0],
308            height: self.metadata.resolution[1],
309            intrinsics: self.intrinsics(),
310            camera_transform,
311            object_rotation,
312            target_point,
313            targeting_policy,
314        })
315    }
316
317    /// Load all renders for an object
318    pub fn get_all_renders(&self, object_id: &str) -> Result<Vec<RenderOutput>, FixtureError> {
319        let renders = self
320            .indices
321            .get(object_id)
322            .ok_or_else(|| FixtureError::RenderNotFound {
323                object_id: object_id.to_string(),
324                rotation: 0,
325                viewpoint: 0,
326            })?;
327
328        let mut outputs = Vec::with_capacity(renders.len());
329        for meta in renders {
330            let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
331            outputs.push(output);
332        }
333
334        Ok(outputs)
335    }
336
337    /// Iterate over all renders for an object
338    pub fn iter_renders<'a>(
339        &'a self,
340        object_id: &'a str,
341    ) -> impl Iterator<Item = Result<(usize, usize, RenderOutput), FixtureError>> + 'a {
342        let renders = self.indices.get(object_id);
343
344        renders.into_iter().flat_map(|v| v.iter()).map(move |meta| {
345            let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
346            Ok((meta.rotation_index, meta.viewpoint_index, output))
347        })
348    }
349}
350
351/// Load RGBA data from a PNG file
352fn load_rgba_png(path: &Path) -> Result<Vec<u8>, FixtureError> {
353    let img = image::open(path).map_err(|e| FixtureError::IoError(std::io::Error::other(e)))?;
354
355    let rgba = img.to_rgba8();
356    Ok(rgba.into_raw())
357}
358
359/// Load depth data from binary f32 or f64 file and normalize to f64 for TBP precision.
360fn load_depth_binary(path: &Path, expected_values: usize) -> Result<Vec<f64>, FixtureError> {
361    let bytes = fs::read(path)?;
362
363    if bytes.len() == expected_values * std::mem::size_of::<f64>() {
364        return Ok(bytes
365            .chunks_exact(8)
366            .map(|chunk| {
367                let arr: [u8; 8] = chunk.try_into().unwrap();
368                f64::from_le_bytes(arr)
369            })
370            .collect());
371    }
372
373    if bytes.len() == expected_values * std::mem::size_of::<f32>() {
374        return Ok(bytes
375            .chunks_exact(4)
376            .map(|chunk| {
377                let arr: [u8; 4] = chunk.try_into().unwrap();
378                f32::from_le_bytes(arr) as f64
379            })
380            .collect());
381    }
382
383    Err(FixtureError::InvalidMetadata(format!(
384        "Depth file {} has {} bytes, expected {} f32 values or {} f64 values",
385        path.display(),
386        bytes.len(),
387        expected_values,
388        expected_values
389    )))
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use tempfile::TempDir;
396
397    #[test]
398    fn test_fixture_not_found() {
399        let result = TestFixtures::load("/nonexistent/path");
400        assert!(matches!(result, Err(FixtureError::NotFound(_))));
401    }
402
403    #[test]
404    fn test_fixtures_exists() {
405        assert!(!TestFixtures::exists("/nonexistent/path"));
406    }
407
408    #[test]
409    fn test_fixture_error_display() {
410        let errors = vec![
411            FixtureError::NotFound("/path".to_string()),
412            FixtureError::InvalidMetadata("bad json".to_string()),
413            FixtureError::RenderNotFound {
414                object_id: "obj".to_string(),
415                rotation: 0,
416                viewpoint: 5,
417            },
418            FixtureError::IoError(std::io::Error::new(
419                std::io::ErrorKind::NotFound,
420                "file not found",
421            )),
422            FixtureError::JsonError(serde_json::from_str::<String>("invalid").unwrap_err()),
423        ];
424
425        for err in errors {
426            let msg = err.to_string();
427            assert!(!msg.is_empty());
428        }
429    }
430
431    #[test]
432    fn test_fixture_missing_metadata() {
433        let temp_dir = TempDir::new().unwrap();
434        let result = TestFixtures::load(temp_dir.path());
435        assert!(matches!(result, Err(FixtureError::InvalidMetadata(_))));
436    }
437
438    #[test]
439    fn test_fixture_load_metadata() {
440        let temp_dir = TempDir::new().unwrap();
441
442        // Create minimal metadata
443        let metadata = DatasetMetadata {
444            version: "1.0".to_string(),
445            crate_version: None,
446            renderer_policy_version: None,
447            objects: vec!["test_object".to_string()],
448            viewpoints_per_rotation: 24,
449            rotations_per_object: 3,
450            renders_per_object: 72,
451            resolution: [64, 64],
452            resolution_width: None,
453            resolution_height: None,
454            targeting_policy: None,
455            intrinsics: IntrinsicsMetadata {
456                focal_length: [55.4, 55.4],
457                principal_point: [32.0, 32.0],
458                image_size: [64, 64],
459            },
460            viewpoint_config: ViewpointConfigMetadata {
461                radius: 0.5,
462                yaw_count: 8,
463                pitch_angles_deg: vec![-30.0, 0.0, 30.0],
464            },
465            rotations: vec![[0.0, 0.0, 0.0], [0.0, 90.0, 0.0], [0.0, 180.0, 0.0]],
466        };
467
468        let metadata_json = serde_json::to_string_pretty(&metadata).unwrap();
469        let metadata_path = temp_dir.path().join("metadata.json");
470        fs::write(&metadata_path, &metadata_json).unwrap();
471
472        // Create object directory with empty index
473        let obj_dir = temp_dir.path().join("test_object");
474        fs::create_dir_all(&obj_dir).unwrap();
475        fs::write(obj_dir.join("index.json"), "[]").unwrap();
476
477        // Load fixtures
478        let fixtures = TestFixtures::load(temp_dir.path()).unwrap();
479
480        assert_eq!(fixtures.objects(), &["test_object"]);
481        assert_eq!(fixtures.viewpoints_per_rotation(), 24);
482        assert_eq!(fixtures.rotations_per_object(), 3);
483        assert_eq!(fixtures.renders_for_object("test_object"), 0);
484        assert_eq!(fixtures.renders_for_object("nonexistent"), 0);
485
486        let intrinsics = fixtures.intrinsics();
487        assert_eq!(intrinsics.image_size, [64, 64]);
488    }
489
490    #[test]
491    fn test_load_depth_binary_f32() {
492        let temp_dir = TempDir::new().unwrap();
493        let depth_path = temp_dir.path().join("test.depth");
494
495        // Write test depth values
496        let depths: Vec<f32> = vec![0.5, 1.0, 2.0, 10.0];
497        let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
498        fs::write(&depth_path, &bytes).unwrap();
499
500        // Load and verify
501        let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
502        assert_eq!(loaded.len(), 4);
503        assert!((loaded[0] - 0.5).abs() < 0.001);
504        assert!((loaded[1] - 1.0).abs() < 0.001);
505        assert!((loaded[2] - 2.0).abs() < 0.001);
506        assert!((loaded[3] - 10.0).abs() < 0.001);
507    }
508
509    #[test]
510    fn test_load_depth_binary_f64() {
511        let temp_dir = TempDir::new().unwrap();
512        let depth_path = temp_dir.path().join("test.depth");
513
514        let depths: Vec<f64> = vec![0.5, 1.0, 2.0, 10.0];
515        let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
516        fs::write(&depth_path, &bytes).unwrap();
517
518        let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
519        assert_eq!(loaded, depths);
520    }
521
522    #[test]
523    fn test_metadata_serialization_roundtrip() {
524        let metadata = DatasetMetadata {
525            version: "1.0".to_string(),
526            crate_version: Some("0.5.5".to_string()),
527            renderer_policy_version: Some(crate::RENDERER_POLICY_VERSION.to_string()),
528            objects: vec!["obj1".to_string(), "obj2".to_string()],
529            viewpoints_per_rotation: 24,
530            rotations_per_object: 3,
531            renders_per_object: 72,
532            resolution: [64, 64],
533            resolution_width: Some(64),
534            resolution_height: Some(64),
535            targeting_policy: Some(TargetingPolicy::MeshCenter),
536            intrinsics: IntrinsicsMetadata {
537                focal_length: [55.4, 55.4],
538                principal_point: [32.0, 32.0],
539                image_size: [64, 64],
540            },
541            viewpoint_config: ViewpointConfigMetadata {
542                radius: 0.5,
543                yaw_count: 8,
544                pitch_angles_deg: vec![-30.0, 0.0, 30.0],
545            },
546            rotations: vec![[0.0, 0.0, 0.0]],
547        };
548
549        let json = serde_json::to_string(&metadata).unwrap();
550        let loaded: DatasetMetadata = serde_json::from_str(&json).unwrap();
551
552        assert_eq!(loaded.version, metadata.version);
553        assert_eq!(loaded.objects, metadata.objects);
554        assert_eq!(loaded.resolution, metadata.resolution);
555    }
556
557    #[test]
558    fn test_render_metadata_serialization() {
559        let meta = RenderMetadata {
560            object_id: "003_cracker_box".to_string(),
561            rotation_index: 1,
562            viewpoint_index: 5,
563            rotation_euler: [0.0, 90.0, 0.0],
564            camera_position: [0.5, 0.0, 0.0],
565            camera_rotation_xyzw: Some([0.0, 0.0, 0.0, 1.0]),
566            target_point: Some([0.0, 0.0, 0.0]),
567            targeting_policy: Some(TargetingPolicy::Origin),
568            mesh_bounds: None,
569            health: Some(RenderHealth {
570                center_pixel: Some([32, 32]),
571                center_depth: Some(0.25),
572                center_foreground: true,
573                foreground_pixel_count: 1,
574                foreground_coverage: 1.0 / 4096.0,
575                center_5x5_foreground_count: 1,
576                nearest_foreground_pixel: Some([32, 32]),
577                nearest_foreground_depth: Some(0.25),
578                nearest_foreground_distance_px: Some(0.0),
579            }),
580            rgba_file: "r1_v05.png".to_string(),
581            depth_file: "r1_v05.depth".to_string(),
582        };
583
584        let json = serde_json::to_string(&meta).unwrap();
585        let loaded: RenderMetadata = serde_json::from_str(&json).unwrap();
586
587        assert_eq!(loaded.object_id, meta.object_id);
588        assert_eq!(loaded.rotation_index, meta.rotation_index);
589        assert_eq!(loaded.viewpoint_index, meta.viewpoint_index);
590    }
591}