Skip to main content

bevy_sensor/
fixtures.rs

1//! Test fixtures for pre-rendered YCB images
2//!
3//! This module provides utilities for loading pre-rendered images from disk,
4//! enabling testing without GPU access.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use bevy_sensor::fixtures::TestFixtures;
10//!
11//! let fixtures = TestFixtures::load("test_fixtures/renders")?;
12//!
13//! // Get a specific render
14//! let render = fixtures.get_render("003_cracker_box", 0, 5)?;
15//! let rgb_image = render.to_rgb_image();
16//! let depth_image = render.to_depth_image();
17//! ```
18
19use crate::{CameraIntrinsics, MeshBoundsMetadata, RenderHealth, RenderOutput, TargetingPolicy};
20use bevy::prelude::*;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::fs;
24use std::path::{Path, PathBuf};
25
26/// Error type for fixture loading
27#[derive(Debug)]
28pub enum FixtureError {
29    /// Directory not found
30    NotFound(String),
31    /// Metadata file missing or invalid
32    InvalidMetadata(String),
33    /// Render file missing
34    RenderNotFound {
35        object_id: String,
36        rotation: usize,
37        viewpoint: usize,
38    },
39    /// IO error
40    IoError(std::io::Error),
41    /// JSON parsing error
42    JsonError(serde_json::Error),
43}
44
45impl std::fmt::Display for FixtureError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            FixtureError::NotFound(path) => write!(f, "Fixture directory not found: {}", path),
49            FixtureError::InvalidMetadata(msg) => write!(f, "Invalid metadata: {}", msg),
50            FixtureError::RenderNotFound {
51                object_id,
52                rotation,
53                viewpoint,
54            } => write!(
55                f,
56                "Render not found: {} r{} v{}",
57                object_id, rotation, viewpoint
58            ),
59            FixtureError::IoError(e) => write!(f, "IO error: {}", e),
60            FixtureError::JsonError(e) => write!(f, "JSON error: {}", e),
61        }
62    }
63}
64
65impl std::error::Error for FixtureError {}
66
67impl From<std::io::Error> for FixtureError {
68    fn from(e: std::io::Error) -> Self {
69        FixtureError::IoError(e)
70    }
71}
72
73impl From<serde_json::Error> for FixtureError {
74    fn from(e: serde_json::Error) -> Self {
75        FixtureError::JsonError(e)
76    }
77}
78
79/// Dataset metadata from pre-rendering
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DatasetMetadata {
82    pub version: String,
83    #[serde(default)]
84    pub crate_version: Option<String>,
85    #[serde(default)]
86    pub renderer_policy_version: Option<String>,
87    pub objects: Vec<String>,
88    pub viewpoints_per_rotation: usize,
89    pub rotations_per_object: usize,
90    pub renders_per_object: usize,
91    pub resolution: [u32; 2],
92    #[serde(default)]
93    pub resolution_width: Option<u32>,
94    #[serde(default)]
95    pub resolution_height: Option<u32>,
96    #[serde(default)]
97    pub targeting_policy: Option<TargetingPolicy>,
98    pub intrinsics: IntrinsicsMetadata,
99    pub viewpoint_config: ViewpointConfigMetadata,
100    pub rotations: Vec<[f32; 3]>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct IntrinsicsMetadata {
105    pub focal_length: [f32; 2],
106    pub principal_point: [f32; 2],
107    pub image_size: [u32; 2],
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ViewpointConfigMetadata {
112    pub radius: f32,
113    pub yaw_count: usize,
114    pub pitch_angles_deg: Vec<f32>,
115}
116
117/// Metadata for a single render
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct RenderMetadata {
120    pub object_id: String,
121    pub rotation_index: usize,
122    pub viewpoint_index: usize,
123    pub rotation_euler: [f32; 3],
124    pub camera_position: [f32; 3],
125    #[serde(default)]
126    pub camera_rotation_xyzw: Option<[f32; 4]>,
127    #[serde(default)]
128    pub object_translation: Option<[f32; 3]>,
129    #[serde(default)]
130    pub object_scale: Option<[f32; 3]>,
131    #[serde(default)]
132    pub target_point: Option<[f32; 3]>,
133    #[serde(default)]
134    pub targeting_policy: Option<TargetingPolicy>,
135    #[serde(default)]
136    pub mesh_bounds: Option<MeshBoundsMetadata>,
137    #[serde(default)]
138    pub health: Option<RenderHealth>,
139    pub rgba_file: String,
140    pub depth_file: String,
141}
142
143/// Pre-rendered test fixtures loaded from disk
144pub struct TestFixtures {
145    /// Root directory containing fixtures
146    root: PathBuf,
147    /// Dataset metadata
148    pub metadata: DatasetMetadata,
149    /// Per-object render indices
150    indices: HashMap<String, Vec<RenderMetadata>>,
151}
152
153impl TestFixtures {
154    /// Load test fixtures from a directory
155    ///
156    /// # Arguments
157    /// * `path` - Path to the fixtures directory (e.g., "test_fixtures/renders")
158    ///
159    /// # Returns
160    /// * `Ok(TestFixtures)` if loaded successfully
161    /// * `Err(FixtureError)` if loading fails
162    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, FixtureError> {
163        let root = path.as_ref().to_path_buf();
164
165        if !root.exists() {
166            return Err(FixtureError::NotFound(root.display().to_string()));
167        }
168
169        // Load metadata
170        let metadata_path = root.join("metadata.json");
171        if !metadata_path.exists() {
172            return Err(FixtureError::InvalidMetadata(
173                "metadata.json not found".to_string(),
174            ));
175        }
176
177        let metadata_json = fs::read_to_string(&metadata_path)?;
178        let metadata: DatasetMetadata = serde_json::from_str(&metadata_json)?;
179
180        // Load per-object indices
181        let mut indices = HashMap::new();
182        for object_id in &metadata.objects {
183            let index_path = root.join(object_id).join("index.json");
184            if index_path.exists() {
185                let index_json = fs::read_to_string(&index_path)?;
186                let renders: Vec<RenderMetadata> = serde_json::from_str(&index_json)?;
187                indices.insert(object_id.clone(), renders);
188            }
189        }
190
191        Ok(Self {
192            root,
193            metadata,
194            indices,
195        })
196    }
197
198    /// Check if fixtures exist at the given path
199    pub fn exists<P: AsRef<Path>>(path: P) -> bool {
200        let root = path.as_ref();
201        root.exists() && root.join("metadata.json").exists()
202    }
203
204    /// Get list of available objects
205    pub fn objects(&self) -> &[String] {
206        &self.metadata.objects
207    }
208
209    /// Get number of viewpoints per rotation
210    pub fn viewpoints_per_rotation(&self) -> usize {
211        self.metadata.viewpoints_per_rotation
212    }
213
214    /// Get number of rotations per object
215    pub fn rotations_per_object(&self) -> usize {
216        self.metadata.rotations_per_object
217    }
218
219    /// Get total renders available for an object
220    pub fn renders_for_object(&self, object_id: &str) -> usize {
221        self.indices.get(object_id).map(|v| v.len()).unwrap_or(0)
222    }
223
224    /// Get camera intrinsics (converts from f32 metadata to f64 for TBP precision)
225    pub fn intrinsics(&self) -> CameraIntrinsics {
226        CameraIntrinsics {
227            focal_length: [
228                self.metadata.intrinsics.focal_length[0] as f64,
229                self.metadata.intrinsics.focal_length[1] as f64,
230            ],
231            principal_point: [
232                self.metadata.intrinsics.principal_point[0] as f64,
233                self.metadata.intrinsics.principal_point[1] as f64,
234            ],
235            image_size: self.metadata.intrinsics.image_size,
236        }
237    }
238
239    /// Load a specific render by object, rotation index, and viewpoint index
240    ///
241    /// # Arguments
242    /// * `object_id` - YCB object ID (e.g., "003_cracker_box")
243    /// * `rotation_idx` - Rotation index (0-2 for benchmark rotations)
244    /// * `viewpoint_idx` - Viewpoint index (0-23 for default config)
245    pub fn get_render(
246        &self,
247        object_id: &str,
248        rotation_idx: usize,
249        viewpoint_idx: usize,
250    ) -> Result<RenderOutput, FixtureError> {
251        // Find the render metadata
252        let renders = self
253            .indices
254            .get(object_id)
255            .ok_or_else(|| FixtureError::RenderNotFound {
256                object_id: object_id.to_string(),
257                rotation: rotation_idx,
258                viewpoint: viewpoint_idx,
259            })?;
260
261        let render_meta = renders
262            .iter()
263            .find(|r| r.rotation_index == rotation_idx && r.viewpoint_index == viewpoint_idx)
264            .ok_or_else(|| FixtureError::RenderNotFound {
265                object_id: object_id.to_string(),
266                rotation: rotation_idx,
267                viewpoint: viewpoint_idx,
268            })?;
269
270        // Load RGBA from PNG
271        let rgba_path = self.root.join(object_id).join(&render_meta.rgba_file);
272        let rgba = load_rgba_png(&rgba_path)?;
273
274        // Load depth from binary
275        let depth_path = self.root.join(object_id).join(&render_meta.depth_file);
276        let expected_depth_values =
277            (self.metadata.resolution[0] as usize) * (self.metadata.resolution[1] as usize);
278        let depth = load_depth_binary(&depth_path, expected_depth_values)?;
279
280        // Build camera transform from position. New manifests carry the exact
281        // render rotation; older manifests only recorded origin-targeted
282        // camera positions, so reconstruct with a target fallback.
283        let pos = render_meta.camera_position;
284        let translation = Vec3::new(pos[0], pos[1], pos[2]);
285        let camera_transform = if let Some(q) = render_meta.camera_rotation_xyzw {
286            Transform {
287                translation,
288                rotation: Quat::from_xyzw(q[0], q[1], q[2], q[3]),
289                ..Default::default()
290            }
291        } else {
292            let target = render_meta.target_point.unwrap_or([0.0, 0.0, 0.0]);
293            Transform::from_translation(translation)
294                .looking_at(Vec3::new(target[0], target[1], target[2]), Vec3::Y)
295        };
296        let target_point = Vec3::from_array(render_meta.target_point.unwrap_or([0.0, 0.0, 0.0]));
297        let targeting_policy = render_meta
298            .targeting_policy
299            .clone()
300            .or_else(|| self.metadata.targeting_policy.clone())
301            .unwrap_or(TargetingPolicy::Origin);
302
303        // Build object rotation (convert from f32 metadata to f64)
304        let rot = render_meta.rotation_euler;
305        let object_rotation =
306            crate::ObjectRotation::new(rot[0] as f64, rot[1] as f64, rot[2] as f64);
307        let object_translation =
308            Vec3::from_array(render_meta.object_translation.unwrap_or([0.0, 0.0, 0.0]));
309        let object_scale = Vec3::from_array(render_meta.object_scale.unwrap_or([1.0, 1.0, 1.0]));
310
311        Ok(RenderOutput {
312            rgba,
313            depth,
314            width: self.metadata.resolution[0],
315            height: self.metadata.resolution[1],
316            intrinsics: self.intrinsics(),
317            camera_transform,
318            object_rotation,
319            object_translation,
320            object_scale,
321            target_point,
322            targeting_policy,
323        })
324    }
325
326    /// Load all renders for an object
327    pub fn get_all_renders(&self, object_id: &str) -> Result<Vec<RenderOutput>, FixtureError> {
328        let renders = self
329            .indices
330            .get(object_id)
331            .ok_or_else(|| FixtureError::RenderNotFound {
332                object_id: object_id.to_string(),
333                rotation: 0,
334                viewpoint: 0,
335            })?;
336
337        let mut outputs = Vec::with_capacity(renders.len());
338        for meta in renders {
339            let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
340            outputs.push(output);
341        }
342
343        Ok(outputs)
344    }
345
346    /// Iterate over all renders for an object
347    pub fn iter_renders<'a>(
348        &'a self,
349        object_id: &'a str,
350    ) -> impl Iterator<Item = Result<(usize, usize, RenderOutput), FixtureError>> + 'a {
351        let renders = self.indices.get(object_id);
352
353        renders.into_iter().flat_map(|v| v.iter()).map(move |meta| {
354            let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
355            Ok((meta.rotation_index, meta.viewpoint_index, output))
356        })
357    }
358}
359
360/// Load RGBA data from a PNG file
361fn load_rgba_png(path: &Path) -> Result<Vec<u8>, FixtureError> {
362    let img = image::open(path).map_err(|e| FixtureError::IoError(std::io::Error::other(e)))?;
363
364    let rgba = img.to_rgba8();
365    Ok(rgba.into_raw())
366}
367
368/// Load depth data from binary f32 or f64 file and normalize to f64 for TBP precision.
369fn load_depth_binary(path: &Path, expected_values: usize) -> Result<Vec<f64>, FixtureError> {
370    let bytes = fs::read(path)?;
371
372    if bytes.len() == expected_values * std::mem::size_of::<f64>() {
373        return Ok(bytes
374            .chunks_exact(8)
375            .map(|chunk| {
376                let arr: [u8; 8] = chunk.try_into().unwrap();
377                f64::from_le_bytes(arr)
378            })
379            .collect());
380    }
381
382    if bytes.len() == expected_values * std::mem::size_of::<f32>() {
383        return Ok(bytes
384            .chunks_exact(4)
385            .map(|chunk| {
386                let arr: [u8; 4] = chunk.try_into().unwrap();
387                f32::from_le_bytes(arr) as f64
388            })
389            .collect());
390    }
391
392    Err(FixtureError::InvalidMetadata(format!(
393        "Depth file {} has {} bytes, expected {} f32 values or {} f64 values",
394        path.display(),
395        bytes.len(),
396        expected_values,
397        expected_values
398    )))
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use tempfile::TempDir;
405
406    #[test]
407    fn test_fixture_not_found() {
408        let result = TestFixtures::load("/nonexistent/path");
409        assert!(matches!(result, Err(FixtureError::NotFound(_))));
410    }
411
412    #[test]
413    fn test_fixtures_exists() {
414        assert!(!TestFixtures::exists("/nonexistent/path"));
415    }
416
417    #[test]
418    fn test_fixture_error_display() {
419        let errors = vec![
420            FixtureError::NotFound("/path".to_string()),
421            FixtureError::InvalidMetadata("bad json".to_string()),
422            FixtureError::RenderNotFound {
423                object_id: "obj".to_string(),
424                rotation: 0,
425                viewpoint: 5,
426            },
427            FixtureError::IoError(std::io::Error::new(
428                std::io::ErrorKind::NotFound,
429                "file not found",
430            )),
431            FixtureError::JsonError(serde_json::from_str::<String>("invalid").unwrap_err()),
432        ];
433
434        for err in errors {
435            let msg = err.to_string();
436            assert!(!msg.is_empty());
437        }
438    }
439
440    #[test]
441    fn test_fixture_missing_metadata() {
442        let temp_dir = TempDir::new().unwrap();
443        let result = TestFixtures::load(temp_dir.path());
444        assert!(matches!(result, Err(FixtureError::InvalidMetadata(_))));
445    }
446
447    #[test]
448    fn test_fixture_load_metadata() {
449        let temp_dir = TempDir::new().unwrap();
450
451        // Create minimal metadata
452        let metadata = DatasetMetadata {
453            version: "1.0".to_string(),
454            crate_version: None,
455            renderer_policy_version: None,
456            objects: vec!["test_object".to_string()],
457            viewpoints_per_rotation: 24,
458            rotations_per_object: 3,
459            renders_per_object: 72,
460            resolution: [64, 64],
461            resolution_width: None,
462            resolution_height: None,
463            targeting_policy: None,
464            intrinsics: IntrinsicsMetadata {
465                focal_length: [55.4, 55.4],
466                principal_point: [32.0, 32.0],
467                image_size: [64, 64],
468            },
469            viewpoint_config: ViewpointConfigMetadata {
470                radius: 0.5,
471                yaw_count: 8,
472                pitch_angles_deg: vec![-30.0, 0.0, 30.0],
473            },
474            rotations: vec![[0.0, 0.0, 0.0], [0.0, 90.0, 0.0], [0.0, 180.0, 0.0]],
475        };
476
477        let metadata_json = serde_json::to_string_pretty(&metadata).unwrap();
478        let metadata_path = temp_dir.path().join("metadata.json");
479        fs::write(&metadata_path, &metadata_json).unwrap();
480
481        // Create object directory with empty index
482        let obj_dir = temp_dir.path().join("test_object");
483        fs::create_dir_all(&obj_dir).unwrap();
484        fs::write(obj_dir.join("index.json"), "[]").unwrap();
485
486        // Load fixtures
487        let fixtures = TestFixtures::load(temp_dir.path()).unwrap();
488
489        assert_eq!(fixtures.objects(), &["test_object"]);
490        assert_eq!(fixtures.viewpoints_per_rotation(), 24);
491        assert_eq!(fixtures.rotations_per_object(), 3);
492        assert_eq!(fixtures.renders_for_object("test_object"), 0);
493        assert_eq!(fixtures.renders_for_object("nonexistent"), 0);
494
495        let intrinsics = fixtures.intrinsics();
496        assert_eq!(intrinsics.image_size, [64, 64]);
497    }
498
499    #[test]
500    fn test_load_depth_binary_f32() {
501        let temp_dir = TempDir::new().unwrap();
502        let depth_path = temp_dir.path().join("test.depth");
503
504        // Write test depth values
505        let depths: Vec<f32> = vec![0.5, 1.0, 2.0, 10.0];
506        let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
507        fs::write(&depth_path, &bytes).unwrap();
508
509        // Load and verify
510        let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
511        assert_eq!(loaded.len(), 4);
512        assert!((loaded[0] - 0.5).abs() < 0.001);
513        assert!((loaded[1] - 1.0).abs() < 0.001);
514        assert!((loaded[2] - 2.0).abs() < 0.001);
515        assert!((loaded[3] - 10.0).abs() < 0.001);
516    }
517
518    #[test]
519    fn test_load_depth_binary_f64() {
520        let temp_dir = TempDir::new().unwrap();
521        let depth_path = temp_dir.path().join("test.depth");
522
523        let depths: Vec<f64> = vec![0.5, 1.0, 2.0, 10.0];
524        let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
525        fs::write(&depth_path, &bytes).unwrap();
526
527        let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
528        assert_eq!(loaded, depths);
529    }
530
531    #[test]
532    fn test_metadata_serialization_roundtrip() {
533        let metadata = DatasetMetadata {
534            version: "1.0".to_string(),
535            crate_version: Some("0.5.5".to_string()),
536            renderer_policy_version: Some(crate::RENDERER_POLICY_VERSION.to_string()),
537            objects: vec!["obj1".to_string(), "obj2".to_string()],
538            viewpoints_per_rotation: 24,
539            rotations_per_object: 3,
540            renders_per_object: 72,
541            resolution: [64, 64],
542            resolution_width: Some(64),
543            resolution_height: Some(64),
544            targeting_policy: Some(TargetingPolicy::MeshCenter),
545            intrinsics: IntrinsicsMetadata {
546                focal_length: [55.4, 55.4],
547                principal_point: [32.0, 32.0],
548                image_size: [64, 64],
549            },
550            viewpoint_config: ViewpointConfigMetadata {
551                radius: 0.5,
552                yaw_count: 8,
553                pitch_angles_deg: vec![-30.0, 0.0, 30.0],
554            },
555            rotations: vec![[0.0, 0.0, 0.0]],
556        };
557
558        let json = serde_json::to_string(&metadata).unwrap();
559        let loaded: DatasetMetadata = serde_json::from_str(&json).unwrap();
560
561        assert_eq!(loaded.version, metadata.version);
562        assert_eq!(loaded.objects, metadata.objects);
563        assert_eq!(loaded.resolution, metadata.resolution);
564    }
565
566    #[test]
567    fn test_render_metadata_serialization() {
568        let meta = RenderMetadata {
569            object_id: "003_cracker_box".to_string(),
570            rotation_index: 1,
571            viewpoint_index: 5,
572            rotation_euler: [0.0, 90.0, 0.0],
573            camera_position: [0.5, 0.0, 0.0],
574            camera_rotation_xyzw: Some([0.0, 0.0, 0.0, 1.0]),
575            object_translation: Some([0.1, 0.2, 0.3]),
576            object_scale: Some([1.0, 1.25, 0.75]),
577            target_point: Some([0.0, 0.0, 0.0]),
578            targeting_policy: Some(TargetingPolicy::Origin),
579            mesh_bounds: None,
580            health: Some(RenderHealth {
581                center_pixel: Some([32, 32]),
582                center_depth: Some(0.25),
583                center_foreground: true,
584                foreground_pixel_count: 1,
585                foreground_coverage: 1.0 / 4096.0,
586                center_5x5_foreground_count: 1,
587                nearest_foreground_pixel: Some([32, 32]),
588                nearest_foreground_depth: Some(0.25),
589                nearest_foreground_distance_px: Some(0.0),
590            }),
591            rgba_file: "r1_v05.png".to_string(),
592            depth_file: "r1_v05.depth".to_string(),
593        };
594
595        let json = serde_json::to_string(&meta).unwrap();
596        let loaded: RenderMetadata = serde_json::from_str(&json).unwrap();
597
598        assert_eq!(loaded.object_id, meta.object_id);
599        assert_eq!(loaded.rotation_index, meta.rotation_index);
600        assert_eq!(loaded.viewpoint_index, meta.viewpoint_index);
601    }
602}