Skip to main content

bevy_sensor/
fixtures.rs

1//! Test fixtures for pre-rendered YCB images
2//!
3//! This module provides utilities for loading pre-rendered images from disk,
4//! enabling testing without GPU access.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use bevy_sensor::fixtures::TestFixtures;
10//!
11//! let fixtures = TestFixtures::load("test_fixtures/renders")?;
12//!
13//! // Get a specific render
14//! let render = fixtures.get_render("003_cracker_box", 0, 5)?;
15//! let rgb_image = render.to_rgb_image();
16//! let depth_image = render.to_depth_image();
17//! ```
18
19use crate::{CameraIntrinsics, RenderOutput};
20use bevy::prelude::*;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::fs;
24use std::path::{Path, PathBuf};
25
26/// Error type for fixture loading
27#[derive(Debug)]
28pub enum FixtureError {
29    /// Directory not found
30    NotFound(String),
31    /// Metadata file missing or invalid
32    InvalidMetadata(String),
33    /// Render file missing
34    RenderNotFound {
35        object_id: String,
36        rotation: usize,
37        viewpoint: usize,
38    },
39    /// IO error
40    IoError(std::io::Error),
41    /// JSON parsing error
42    JsonError(serde_json::Error),
43}
44
45impl std::fmt::Display for FixtureError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            FixtureError::NotFound(path) => write!(f, "Fixture directory not found: {}", path),
49            FixtureError::InvalidMetadata(msg) => write!(f, "Invalid metadata: {}", msg),
50            FixtureError::RenderNotFound {
51                object_id,
52                rotation,
53                viewpoint,
54            } => write!(
55                f,
56                "Render not found: {} r{} v{}",
57                object_id, rotation, viewpoint
58            ),
59            FixtureError::IoError(e) => write!(f, "IO error: {}", e),
60            FixtureError::JsonError(e) => write!(f, "JSON error: {}", e),
61        }
62    }
63}
64
65impl std::error::Error for FixtureError {}
66
67impl From<std::io::Error> for FixtureError {
68    fn from(e: std::io::Error) -> Self {
69        FixtureError::IoError(e)
70    }
71}
72
73impl From<serde_json::Error> for FixtureError {
74    fn from(e: serde_json::Error) -> Self {
75        FixtureError::JsonError(e)
76    }
77}
78
79/// Dataset metadata from pre-rendering
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DatasetMetadata {
82    pub version: String,
83    pub objects: Vec<String>,
84    pub viewpoints_per_rotation: usize,
85    pub rotations_per_object: usize,
86    pub renders_per_object: usize,
87    pub resolution: [u32; 2],
88    pub intrinsics: IntrinsicsMetadata,
89    pub viewpoint_config: ViewpointConfigMetadata,
90    pub rotations: Vec<[f32; 3]>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct IntrinsicsMetadata {
95    pub focal_length: [f32; 2],
96    pub principal_point: [f32; 2],
97    pub image_size: [u32; 2],
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ViewpointConfigMetadata {
102    pub radius: f32,
103    pub yaw_count: usize,
104    pub pitch_angles_deg: Vec<f32>,
105}
106
107/// Metadata for a single render
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct RenderMetadata {
110    pub object_id: String,
111    pub rotation_index: usize,
112    pub viewpoint_index: usize,
113    pub rotation_euler: [f32; 3],
114    pub camera_position: [f32; 3],
115    pub rgba_file: String,
116    pub depth_file: String,
117}
118
119/// Pre-rendered test fixtures loaded from disk
120pub struct TestFixtures {
121    /// Root directory containing fixtures
122    root: PathBuf,
123    /// Dataset metadata
124    pub metadata: DatasetMetadata,
125    /// Per-object render indices
126    indices: HashMap<String, Vec<RenderMetadata>>,
127}
128
129impl TestFixtures {
130    /// Load test fixtures from a directory
131    ///
132    /// # Arguments
133    /// * `path` - Path to the fixtures directory (e.g., "test_fixtures/renders")
134    ///
135    /// # Returns
136    /// * `Ok(TestFixtures)` if loaded successfully
137    /// * `Err(FixtureError)` if loading fails
138    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, FixtureError> {
139        let root = path.as_ref().to_path_buf();
140
141        if !root.exists() {
142            return Err(FixtureError::NotFound(root.display().to_string()));
143        }
144
145        // Load metadata
146        let metadata_path = root.join("metadata.json");
147        if !metadata_path.exists() {
148            return Err(FixtureError::InvalidMetadata(
149                "metadata.json not found".to_string(),
150            ));
151        }
152
153        let metadata_json = fs::read_to_string(&metadata_path)?;
154        let metadata: DatasetMetadata = serde_json::from_str(&metadata_json)?;
155
156        // Load per-object indices
157        let mut indices = HashMap::new();
158        for object_id in &metadata.objects {
159            let index_path = root.join(object_id).join("index.json");
160            if index_path.exists() {
161                let index_json = fs::read_to_string(&index_path)?;
162                let renders: Vec<RenderMetadata> = serde_json::from_str(&index_json)?;
163                indices.insert(object_id.clone(), renders);
164            }
165        }
166
167        Ok(Self {
168            root,
169            metadata,
170            indices,
171        })
172    }
173
174    /// Check if fixtures exist at the given path
175    pub fn exists<P: AsRef<Path>>(path: P) -> bool {
176        let root = path.as_ref();
177        root.exists() && root.join("metadata.json").exists()
178    }
179
180    /// Get list of available objects
181    pub fn objects(&self) -> &[String] {
182        &self.metadata.objects
183    }
184
185    /// Get number of viewpoints per rotation
186    pub fn viewpoints_per_rotation(&self) -> usize {
187        self.metadata.viewpoints_per_rotation
188    }
189
190    /// Get number of rotations per object
191    pub fn rotations_per_object(&self) -> usize {
192        self.metadata.rotations_per_object
193    }
194
195    /// Get total renders available for an object
196    pub fn renders_for_object(&self, object_id: &str) -> usize {
197        self.indices.get(object_id).map(|v| v.len()).unwrap_or(0)
198    }
199
200    /// Get camera intrinsics (converts from f32 metadata to f64 for TBP precision)
201    pub fn intrinsics(&self) -> CameraIntrinsics {
202        CameraIntrinsics {
203            focal_length: [
204                self.metadata.intrinsics.focal_length[0] as f64,
205                self.metadata.intrinsics.focal_length[1] as f64,
206            ],
207            principal_point: [
208                self.metadata.intrinsics.principal_point[0] as f64,
209                self.metadata.intrinsics.principal_point[1] as f64,
210            ],
211            image_size: self.metadata.intrinsics.image_size,
212        }
213    }
214
215    /// Load a specific render by object, rotation index, and viewpoint index
216    ///
217    /// # Arguments
218    /// * `object_id` - YCB object ID (e.g., "003_cracker_box")
219    /// * `rotation_idx` - Rotation index (0-2 for benchmark rotations)
220    /// * `viewpoint_idx` - Viewpoint index (0-23 for default config)
221    pub fn get_render(
222        &self,
223        object_id: &str,
224        rotation_idx: usize,
225        viewpoint_idx: usize,
226    ) -> Result<RenderOutput, FixtureError> {
227        // Find the render metadata
228        let renders = self
229            .indices
230            .get(object_id)
231            .ok_or_else(|| FixtureError::RenderNotFound {
232                object_id: object_id.to_string(),
233                rotation: rotation_idx,
234                viewpoint: viewpoint_idx,
235            })?;
236
237        let render_meta = renders
238            .iter()
239            .find(|r| r.rotation_index == rotation_idx && r.viewpoint_index == viewpoint_idx)
240            .ok_or_else(|| FixtureError::RenderNotFound {
241                object_id: object_id.to_string(),
242                rotation: rotation_idx,
243                viewpoint: viewpoint_idx,
244            })?;
245
246        // Load RGBA from PNG
247        let rgba_path = self.root.join(object_id).join(&render_meta.rgba_file);
248        let rgba = load_rgba_png(&rgba_path)?;
249
250        // Load depth from binary
251        let depth_path = self.root.join(object_id).join(&render_meta.depth_file);
252        let expected_depth_values =
253            (self.metadata.resolution[0] as usize) * (self.metadata.resolution[1] as usize);
254        let depth = load_depth_binary(&depth_path, expected_depth_values)?;
255
256        // Build camera transform from position (looking at origin)
257        let pos = render_meta.camera_position;
258        let camera_transform =
259            Transform::from_xyz(pos[0], pos[1], pos[2]).looking_at(Vec3::ZERO, Vec3::Y);
260
261        // Build object rotation (convert from f32 metadata to f64)
262        let rot = render_meta.rotation_euler;
263        let object_rotation =
264            crate::ObjectRotation::new(rot[0] as f64, rot[1] as f64, rot[2] as f64);
265
266        Ok(RenderOutput {
267            rgba,
268            depth,
269            width: self.metadata.resolution[0],
270            height: self.metadata.resolution[1],
271            intrinsics: self.intrinsics(),
272            camera_transform,
273            object_rotation,
274        })
275    }
276
277    /// Load all renders for an object
278    pub fn get_all_renders(&self, object_id: &str) -> Result<Vec<RenderOutput>, FixtureError> {
279        let renders = self
280            .indices
281            .get(object_id)
282            .ok_or_else(|| FixtureError::RenderNotFound {
283                object_id: object_id.to_string(),
284                rotation: 0,
285                viewpoint: 0,
286            })?;
287
288        let mut outputs = Vec::with_capacity(renders.len());
289        for meta in renders {
290            let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
291            outputs.push(output);
292        }
293
294        Ok(outputs)
295    }
296
297    /// Iterate over all renders for an object
298    pub fn iter_renders<'a>(
299        &'a self,
300        object_id: &'a str,
301    ) -> impl Iterator<Item = Result<(usize, usize, RenderOutput), FixtureError>> + 'a {
302        let renders = self.indices.get(object_id);
303
304        renders.into_iter().flat_map(|v| v.iter()).map(move |meta| {
305            let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
306            Ok((meta.rotation_index, meta.viewpoint_index, output))
307        })
308    }
309}
310
311/// Load RGBA data from a PNG file
312fn load_rgba_png(path: &Path) -> Result<Vec<u8>, FixtureError> {
313    let img = image::open(path).map_err(|e| FixtureError::IoError(std::io::Error::other(e)))?;
314
315    let rgba = img.to_rgba8();
316    Ok(rgba.into_raw())
317}
318
319/// Load depth data from binary f32 or f64 file and normalize to f64 for TBP precision.
320fn load_depth_binary(path: &Path, expected_values: usize) -> Result<Vec<f64>, FixtureError> {
321    let bytes = fs::read(path)?;
322
323    if bytes.len() == expected_values * std::mem::size_of::<f64>() {
324        return Ok(bytes
325            .chunks_exact(8)
326            .map(|chunk| {
327                let arr: [u8; 8] = chunk.try_into().unwrap();
328                f64::from_le_bytes(arr)
329            })
330            .collect());
331    }
332
333    if bytes.len() == expected_values * std::mem::size_of::<f32>() {
334        return Ok(bytes
335            .chunks_exact(4)
336            .map(|chunk| {
337                let arr: [u8; 4] = chunk.try_into().unwrap();
338                f32::from_le_bytes(arr) as f64
339            })
340            .collect());
341    }
342
343    Err(FixtureError::InvalidMetadata(format!(
344        "Depth file {} has {} bytes, expected {} f32 values or {} f64 values",
345        path.display(),
346        bytes.len(),
347        expected_values,
348        expected_values
349    )))
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use tempfile::TempDir;
356
357    #[test]
358    fn test_fixture_not_found() {
359        let result = TestFixtures::load("/nonexistent/path");
360        assert!(matches!(result, Err(FixtureError::NotFound(_))));
361    }
362
363    #[test]
364    fn test_fixtures_exists() {
365        assert!(!TestFixtures::exists("/nonexistent/path"));
366    }
367
368    #[test]
369    fn test_fixture_error_display() {
370        let errors = vec![
371            FixtureError::NotFound("/path".to_string()),
372            FixtureError::InvalidMetadata("bad json".to_string()),
373            FixtureError::RenderNotFound {
374                object_id: "obj".to_string(),
375                rotation: 0,
376                viewpoint: 5,
377            },
378            FixtureError::IoError(std::io::Error::new(
379                std::io::ErrorKind::NotFound,
380                "file not found",
381            )),
382            FixtureError::JsonError(serde_json::from_str::<String>("invalid").unwrap_err()),
383        ];
384
385        for err in errors {
386            let msg = err.to_string();
387            assert!(!msg.is_empty());
388        }
389    }
390
391    #[test]
392    fn test_fixture_missing_metadata() {
393        let temp_dir = TempDir::new().unwrap();
394        let result = TestFixtures::load(temp_dir.path());
395        assert!(matches!(result, Err(FixtureError::InvalidMetadata(_))));
396    }
397
398    #[test]
399    fn test_fixture_load_metadata() {
400        let temp_dir = TempDir::new().unwrap();
401
402        // Create minimal metadata
403        let metadata = DatasetMetadata {
404            version: "1.0".to_string(),
405            objects: vec!["test_object".to_string()],
406            viewpoints_per_rotation: 24,
407            rotations_per_object: 3,
408            renders_per_object: 72,
409            resolution: [64, 64],
410            intrinsics: IntrinsicsMetadata {
411                focal_length: [55.4, 55.4],
412                principal_point: [32.0, 32.0],
413                image_size: [64, 64],
414            },
415            viewpoint_config: ViewpointConfigMetadata {
416                radius: 0.5,
417                yaw_count: 8,
418                pitch_angles_deg: vec![-30.0, 0.0, 30.0],
419            },
420            rotations: vec![[0.0, 0.0, 0.0], [0.0, 90.0, 0.0], [0.0, 180.0, 0.0]],
421        };
422
423        let metadata_json = serde_json::to_string_pretty(&metadata).unwrap();
424        let metadata_path = temp_dir.path().join("metadata.json");
425        fs::write(&metadata_path, &metadata_json).unwrap();
426
427        // Create object directory with empty index
428        let obj_dir = temp_dir.path().join("test_object");
429        fs::create_dir_all(&obj_dir).unwrap();
430        fs::write(obj_dir.join("index.json"), "[]").unwrap();
431
432        // Load fixtures
433        let fixtures = TestFixtures::load(temp_dir.path()).unwrap();
434
435        assert_eq!(fixtures.objects(), &["test_object"]);
436        assert_eq!(fixtures.viewpoints_per_rotation(), 24);
437        assert_eq!(fixtures.rotations_per_object(), 3);
438        assert_eq!(fixtures.renders_for_object("test_object"), 0);
439        assert_eq!(fixtures.renders_for_object("nonexistent"), 0);
440
441        let intrinsics = fixtures.intrinsics();
442        assert_eq!(intrinsics.image_size, [64, 64]);
443    }
444
445    #[test]
446    fn test_load_depth_binary_f32() {
447        let temp_dir = TempDir::new().unwrap();
448        let depth_path = temp_dir.path().join("test.depth");
449
450        // Write test depth values
451        let depths: Vec<f32> = vec![0.5, 1.0, 2.0, 10.0];
452        let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
453        fs::write(&depth_path, &bytes).unwrap();
454
455        // Load and verify
456        let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
457        assert_eq!(loaded.len(), 4);
458        assert!((loaded[0] - 0.5).abs() < 0.001);
459        assert!((loaded[1] - 1.0).abs() < 0.001);
460        assert!((loaded[2] - 2.0).abs() < 0.001);
461        assert!((loaded[3] - 10.0).abs() < 0.001);
462    }
463
464    #[test]
465    fn test_load_depth_binary_f64() {
466        let temp_dir = TempDir::new().unwrap();
467        let depth_path = temp_dir.path().join("test.depth");
468
469        let depths: Vec<f64> = vec![0.5, 1.0, 2.0, 10.0];
470        let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
471        fs::write(&depth_path, &bytes).unwrap();
472
473        let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
474        assert_eq!(loaded, depths);
475    }
476
477    #[test]
478    fn test_metadata_serialization_roundtrip() {
479        let metadata = DatasetMetadata {
480            version: "1.0".to_string(),
481            objects: vec!["obj1".to_string(), "obj2".to_string()],
482            viewpoints_per_rotation: 24,
483            rotations_per_object: 3,
484            renders_per_object: 72,
485            resolution: [64, 64],
486            intrinsics: IntrinsicsMetadata {
487                focal_length: [55.4, 55.4],
488                principal_point: [32.0, 32.0],
489                image_size: [64, 64],
490            },
491            viewpoint_config: ViewpointConfigMetadata {
492                radius: 0.5,
493                yaw_count: 8,
494                pitch_angles_deg: vec![-30.0, 0.0, 30.0],
495            },
496            rotations: vec![[0.0, 0.0, 0.0]],
497        };
498
499        let json = serde_json::to_string(&metadata).unwrap();
500        let loaded: DatasetMetadata = serde_json::from_str(&json).unwrap();
501
502        assert_eq!(loaded.version, metadata.version);
503        assert_eq!(loaded.objects, metadata.objects);
504        assert_eq!(loaded.resolution, metadata.resolution);
505    }
506
507    #[test]
508    fn test_render_metadata_serialization() {
509        let meta = RenderMetadata {
510            object_id: "003_cracker_box".to_string(),
511            rotation_index: 1,
512            viewpoint_index: 5,
513            rotation_euler: [0.0, 90.0, 0.0],
514            camera_position: [0.5, 0.0, 0.0],
515            rgba_file: "r1_v05.png".to_string(),
516            depth_file: "r1_v05.depth".to_string(),
517        };
518
519        let json = serde_json::to_string(&meta).unwrap();
520        let loaded: RenderMetadata = serde_json::from_str(&json).unwrap();
521
522        assert_eq!(loaded.object_id, meta.object_id);
523        assert_eq!(loaded.rotation_index, meta.rotation_index);
524        assert_eq!(loaded.viewpoint_index, meta.viewpoint_index);
525    }
526}