1use crate::{CameraIntrinsics, RenderOutput};
20use bevy::prelude::*;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::fs;
24use std::path::{Path, PathBuf};
25
26#[derive(Debug)]
28pub enum FixtureError {
29 NotFound(String),
31 InvalidMetadata(String),
33 RenderNotFound {
35 object_id: String,
36 rotation: usize,
37 viewpoint: usize,
38 },
39 IoError(std::io::Error),
41 JsonError(serde_json::Error),
43}
44
45impl std::fmt::Display for FixtureError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 FixtureError::NotFound(path) => write!(f, "Fixture directory not found: {}", path),
49 FixtureError::InvalidMetadata(msg) => write!(f, "Invalid metadata: {}", msg),
50 FixtureError::RenderNotFound {
51 object_id,
52 rotation,
53 viewpoint,
54 } => write!(
55 f,
56 "Render not found: {} r{} v{}",
57 object_id, rotation, viewpoint
58 ),
59 FixtureError::IoError(e) => write!(f, "IO error: {}", e),
60 FixtureError::JsonError(e) => write!(f, "JSON error: {}", e),
61 }
62 }
63}
64
65impl std::error::Error for FixtureError {}
66
67impl From<std::io::Error> for FixtureError {
68 fn from(e: std::io::Error) -> Self {
69 FixtureError::IoError(e)
70 }
71}
72
73impl From<serde_json::Error> for FixtureError {
74 fn from(e: serde_json::Error) -> Self {
75 FixtureError::JsonError(e)
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DatasetMetadata {
82 pub version: String,
83 pub objects: Vec<String>,
84 pub viewpoints_per_rotation: usize,
85 pub rotations_per_object: usize,
86 pub renders_per_object: usize,
87 pub resolution: [u32; 2],
88 pub intrinsics: IntrinsicsMetadata,
89 pub viewpoint_config: ViewpointConfigMetadata,
90 pub rotations: Vec<[f32; 3]>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct IntrinsicsMetadata {
95 pub focal_length: [f32; 2],
96 pub principal_point: [f32; 2],
97 pub image_size: [u32; 2],
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ViewpointConfigMetadata {
102 pub radius: f32,
103 pub yaw_count: usize,
104 pub pitch_angles_deg: Vec<f32>,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct RenderMetadata {
110 pub object_id: String,
111 pub rotation_index: usize,
112 pub viewpoint_index: usize,
113 pub rotation_euler: [f32; 3],
114 pub camera_position: [f32; 3],
115 pub rgba_file: String,
116 pub depth_file: String,
117}
118
119pub struct TestFixtures {
121 root: PathBuf,
123 pub metadata: DatasetMetadata,
125 indices: HashMap<String, Vec<RenderMetadata>>,
127}
128
129impl TestFixtures {
130 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, FixtureError> {
139 let root = path.as_ref().to_path_buf();
140
141 if !root.exists() {
142 return Err(FixtureError::NotFound(root.display().to_string()));
143 }
144
145 let metadata_path = root.join("metadata.json");
147 if !metadata_path.exists() {
148 return Err(FixtureError::InvalidMetadata(
149 "metadata.json not found".to_string(),
150 ));
151 }
152
153 let metadata_json = fs::read_to_string(&metadata_path)?;
154 let metadata: DatasetMetadata = serde_json::from_str(&metadata_json)?;
155
156 let mut indices = HashMap::new();
158 for object_id in &metadata.objects {
159 let index_path = root.join(object_id).join("index.json");
160 if index_path.exists() {
161 let index_json = fs::read_to_string(&index_path)?;
162 let renders: Vec<RenderMetadata> = serde_json::from_str(&index_json)?;
163 indices.insert(object_id.clone(), renders);
164 }
165 }
166
167 Ok(Self {
168 root,
169 metadata,
170 indices,
171 })
172 }
173
174 pub fn exists<P: AsRef<Path>>(path: P) -> bool {
176 let root = path.as_ref();
177 root.exists() && root.join("metadata.json").exists()
178 }
179
180 pub fn objects(&self) -> &[String] {
182 &self.metadata.objects
183 }
184
185 pub fn viewpoints_per_rotation(&self) -> usize {
187 self.metadata.viewpoints_per_rotation
188 }
189
190 pub fn rotations_per_object(&self) -> usize {
192 self.metadata.rotations_per_object
193 }
194
195 pub fn renders_for_object(&self, object_id: &str) -> usize {
197 self.indices.get(object_id).map(|v| v.len()).unwrap_or(0)
198 }
199
200 pub fn intrinsics(&self) -> CameraIntrinsics {
202 CameraIntrinsics {
203 focal_length: [
204 self.metadata.intrinsics.focal_length[0] as f64,
205 self.metadata.intrinsics.focal_length[1] as f64,
206 ],
207 principal_point: [
208 self.metadata.intrinsics.principal_point[0] as f64,
209 self.metadata.intrinsics.principal_point[1] as f64,
210 ],
211 image_size: self.metadata.intrinsics.image_size,
212 }
213 }
214
215 pub fn get_render(
222 &self,
223 object_id: &str,
224 rotation_idx: usize,
225 viewpoint_idx: usize,
226 ) -> Result<RenderOutput, FixtureError> {
227 let renders = self
229 .indices
230 .get(object_id)
231 .ok_or_else(|| FixtureError::RenderNotFound {
232 object_id: object_id.to_string(),
233 rotation: rotation_idx,
234 viewpoint: viewpoint_idx,
235 })?;
236
237 let render_meta = renders
238 .iter()
239 .find(|r| r.rotation_index == rotation_idx && r.viewpoint_index == viewpoint_idx)
240 .ok_or_else(|| FixtureError::RenderNotFound {
241 object_id: object_id.to_string(),
242 rotation: rotation_idx,
243 viewpoint: viewpoint_idx,
244 })?;
245
246 let rgba_path = self.root.join(object_id).join(&render_meta.rgba_file);
248 let rgba = load_rgba_png(&rgba_path)?;
249
250 let depth_path = self.root.join(object_id).join(&render_meta.depth_file);
252 let expected_depth_values =
253 (self.metadata.resolution[0] as usize) * (self.metadata.resolution[1] as usize);
254 let depth = load_depth_binary(&depth_path, expected_depth_values)?;
255
256 let pos = render_meta.camera_position;
258 let camera_transform =
259 Transform::from_xyz(pos[0], pos[1], pos[2]).looking_at(Vec3::ZERO, Vec3::Y);
260
261 let rot = render_meta.rotation_euler;
263 let object_rotation =
264 crate::ObjectRotation::new(rot[0] as f64, rot[1] as f64, rot[2] as f64);
265
266 Ok(RenderOutput {
267 rgba,
268 depth,
269 width: self.metadata.resolution[0],
270 height: self.metadata.resolution[1],
271 intrinsics: self.intrinsics(),
272 camera_transform,
273 object_rotation,
274 })
275 }
276
277 pub fn get_all_renders(&self, object_id: &str) -> Result<Vec<RenderOutput>, FixtureError> {
279 let renders = self
280 .indices
281 .get(object_id)
282 .ok_or_else(|| FixtureError::RenderNotFound {
283 object_id: object_id.to_string(),
284 rotation: 0,
285 viewpoint: 0,
286 })?;
287
288 let mut outputs = Vec::with_capacity(renders.len());
289 for meta in renders {
290 let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
291 outputs.push(output);
292 }
293
294 Ok(outputs)
295 }
296
297 pub fn iter_renders<'a>(
299 &'a self,
300 object_id: &'a str,
301 ) -> impl Iterator<Item = Result<(usize, usize, RenderOutput), FixtureError>> + 'a {
302 let renders = self.indices.get(object_id);
303
304 renders.into_iter().flat_map(|v| v.iter()).map(move |meta| {
305 let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
306 Ok((meta.rotation_index, meta.viewpoint_index, output))
307 })
308 }
309}
310
311fn load_rgba_png(path: &Path) -> Result<Vec<u8>, FixtureError> {
313 let img = image::open(path).map_err(|e| FixtureError::IoError(std::io::Error::other(e)))?;
314
315 let rgba = img.to_rgba8();
316 Ok(rgba.into_raw())
317}
318
319fn load_depth_binary(path: &Path, expected_values: usize) -> Result<Vec<f64>, FixtureError> {
321 let bytes = fs::read(path)?;
322
323 if bytes.len() == expected_values * std::mem::size_of::<f64>() {
324 return Ok(bytes
325 .chunks_exact(8)
326 .map(|chunk| {
327 let arr: [u8; 8] = chunk.try_into().unwrap();
328 f64::from_le_bytes(arr)
329 })
330 .collect());
331 }
332
333 if bytes.len() == expected_values * std::mem::size_of::<f32>() {
334 return Ok(bytes
335 .chunks_exact(4)
336 .map(|chunk| {
337 let arr: [u8; 4] = chunk.try_into().unwrap();
338 f32::from_le_bytes(arr) as f64
339 })
340 .collect());
341 }
342
343 Err(FixtureError::InvalidMetadata(format!(
344 "Depth file {} has {} bytes, expected {} f32 values or {} f64 values",
345 path.display(),
346 bytes.len(),
347 expected_values,
348 expected_values
349 )))
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use tempfile::TempDir;
356
357 #[test]
358 fn test_fixture_not_found() {
359 let result = TestFixtures::load("/nonexistent/path");
360 assert!(matches!(result, Err(FixtureError::NotFound(_))));
361 }
362
363 #[test]
364 fn test_fixtures_exists() {
365 assert!(!TestFixtures::exists("/nonexistent/path"));
366 }
367
368 #[test]
369 fn test_fixture_error_display() {
370 let errors = vec![
371 FixtureError::NotFound("/path".to_string()),
372 FixtureError::InvalidMetadata("bad json".to_string()),
373 FixtureError::RenderNotFound {
374 object_id: "obj".to_string(),
375 rotation: 0,
376 viewpoint: 5,
377 },
378 FixtureError::IoError(std::io::Error::new(
379 std::io::ErrorKind::NotFound,
380 "file not found",
381 )),
382 FixtureError::JsonError(serde_json::from_str::<String>("invalid").unwrap_err()),
383 ];
384
385 for err in errors {
386 let msg = err.to_string();
387 assert!(!msg.is_empty());
388 }
389 }
390
391 #[test]
392 fn test_fixture_missing_metadata() {
393 let temp_dir = TempDir::new().unwrap();
394 let result = TestFixtures::load(temp_dir.path());
395 assert!(matches!(result, Err(FixtureError::InvalidMetadata(_))));
396 }
397
398 #[test]
399 fn test_fixture_load_metadata() {
400 let temp_dir = TempDir::new().unwrap();
401
402 let metadata = DatasetMetadata {
404 version: "1.0".to_string(),
405 objects: vec!["test_object".to_string()],
406 viewpoints_per_rotation: 24,
407 rotations_per_object: 3,
408 renders_per_object: 72,
409 resolution: [64, 64],
410 intrinsics: IntrinsicsMetadata {
411 focal_length: [55.4, 55.4],
412 principal_point: [32.0, 32.0],
413 image_size: [64, 64],
414 },
415 viewpoint_config: ViewpointConfigMetadata {
416 radius: 0.5,
417 yaw_count: 8,
418 pitch_angles_deg: vec![-30.0, 0.0, 30.0],
419 },
420 rotations: vec![[0.0, 0.0, 0.0], [0.0, 90.0, 0.0], [0.0, 180.0, 0.0]],
421 };
422
423 let metadata_json = serde_json::to_string_pretty(&metadata).unwrap();
424 let metadata_path = temp_dir.path().join("metadata.json");
425 fs::write(&metadata_path, &metadata_json).unwrap();
426
427 let obj_dir = temp_dir.path().join("test_object");
429 fs::create_dir_all(&obj_dir).unwrap();
430 fs::write(obj_dir.join("index.json"), "[]").unwrap();
431
432 let fixtures = TestFixtures::load(temp_dir.path()).unwrap();
434
435 assert_eq!(fixtures.objects(), &["test_object"]);
436 assert_eq!(fixtures.viewpoints_per_rotation(), 24);
437 assert_eq!(fixtures.rotations_per_object(), 3);
438 assert_eq!(fixtures.renders_for_object("test_object"), 0);
439 assert_eq!(fixtures.renders_for_object("nonexistent"), 0);
440
441 let intrinsics = fixtures.intrinsics();
442 assert_eq!(intrinsics.image_size, [64, 64]);
443 }
444
445 #[test]
446 fn test_load_depth_binary_f32() {
447 let temp_dir = TempDir::new().unwrap();
448 let depth_path = temp_dir.path().join("test.depth");
449
450 let depths: Vec<f32> = vec![0.5, 1.0, 2.0, 10.0];
452 let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
453 fs::write(&depth_path, &bytes).unwrap();
454
455 let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
457 assert_eq!(loaded.len(), 4);
458 assert!((loaded[0] - 0.5).abs() < 0.001);
459 assert!((loaded[1] - 1.0).abs() < 0.001);
460 assert!((loaded[2] - 2.0).abs() < 0.001);
461 assert!((loaded[3] - 10.0).abs() < 0.001);
462 }
463
464 #[test]
465 fn test_load_depth_binary_f64() {
466 let temp_dir = TempDir::new().unwrap();
467 let depth_path = temp_dir.path().join("test.depth");
468
469 let depths: Vec<f64> = vec![0.5, 1.0, 2.0, 10.0];
470 let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
471 fs::write(&depth_path, &bytes).unwrap();
472
473 let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
474 assert_eq!(loaded, depths);
475 }
476
477 #[test]
478 fn test_metadata_serialization_roundtrip() {
479 let metadata = DatasetMetadata {
480 version: "1.0".to_string(),
481 objects: vec!["obj1".to_string(), "obj2".to_string()],
482 viewpoints_per_rotation: 24,
483 rotations_per_object: 3,
484 renders_per_object: 72,
485 resolution: [64, 64],
486 intrinsics: IntrinsicsMetadata {
487 focal_length: [55.4, 55.4],
488 principal_point: [32.0, 32.0],
489 image_size: [64, 64],
490 },
491 viewpoint_config: ViewpointConfigMetadata {
492 radius: 0.5,
493 yaw_count: 8,
494 pitch_angles_deg: vec![-30.0, 0.0, 30.0],
495 },
496 rotations: vec![[0.0, 0.0, 0.0]],
497 };
498
499 let json = serde_json::to_string(&metadata).unwrap();
500 let loaded: DatasetMetadata = serde_json::from_str(&json).unwrap();
501
502 assert_eq!(loaded.version, metadata.version);
503 assert_eq!(loaded.objects, metadata.objects);
504 assert_eq!(loaded.resolution, metadata.resolution);
505 }
506
507 #[test]
508 fn test_render_metadata_serialization() {
509 let meta = RenderMetadata {
510 object_id: "003_cracker_box".to_string(),
511 rotation_index: 1,
512 viewpoint_index: 5,
513 rotation_euler: [0.0, 90.0, 0.0],
514 camera_position: [0.5, 0.0, 0.0],
515 rgba_file: "r1_v05.png".to_string(),
516 depth_file: "r1_v05.depth".to_string(),
517 };
518
519 let json = serde_json::to_string(&meta).unwrap();
520 let loaded: RenderMetadata = serde_json::from_str(&json).unwrap();
521
522 assert_eq!(loaded.object_id, meta.object_id);
523 assert_eq!(loaded.rotation_index, meta.rotation_index);
524 assert_eq!(loaded.viewpoint_index, meta.viewpoint_index);
525 }
526}