1use crate::{CameraIntrinsics, MeshBoundsMetadata, RenderHealth, RenderOutput, TargetingPolicy};
20use bevy::prelude::*;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::fs;
24use std::path::{Path, PathBuf};
25
26#[derive(Debug)]
28pub enum FixtureError {
29 NotFound(String),
31 InvalidMetadata(String),
33 RenderNotFound {
35 object_id: String,
36 rotation: usize,
37 viewpoint: usize,
38 },
39 IoError(std::io::Error),
41 JsonError(serde_json::Error),
43}
44
45impl std::fmt::Display for FixtureError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 FixtureError::NotFound(path) => write!(f, "Fixture directory not found: {}", path),
49 FixtureError::InvalidMetadata(msg) => write!(f, "Invalid metadata: {}", msg),
50 FixtureError::RenderNotFound {
51 object_id,
52 rotation,
53 viewpoint,
54 } => write!(
55 f,
56 "Render not found: {} r{} v{}",
57 object_id, rotation, viewpoint
58 ),
59 FixtureError::IoError(e) => write!(f, "IO error: {}", e),
60 FixtureError::JsonError(e) => write!(f, "JSON error: {}", e),
61 }
62 }
63}
64
65impl std::error::Error for FixtureError {}
66
67impl From<std::io::Error> for FixtureError {
68 fn from(e: std::io::Error) -> Self {
69 FixtureError::IoError(e)
70 }
71}
72
73impl From<serde_json::Error> for FixtureError {
74 fn from(e: serde_json::Error) -> Self {
75 FixtureError::JsonError(e)
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DatasetMetadata {
82 pub version: String,
83 #[serde(default)]
84 pub crate_version: Option<String>,
85 #[serde(default)]
86 pub renderer_policy_version: Option<String>,
87 pub objects: Vec<String>,
88 pub viewpoints_per_rotation: usize,
89 pub rotations_per_object: usize,
90 pub renders_per_object: usize,
91 pub resolution: [u32; 2],
92 #[serde(default)]
93 pub resolution_width: Option<u32>,
94 #[serde(default)]
95 pub resolution_height: Option<u32>,
96 #[serde(default)]
97 pub targeting_policy: Option<TargetingPolicy>,
98 pub intrinsics: IntrinsicsMetadata,
99 pub viewpoint_config: ViewpointConfigMetadata,
100 pub rotations: Vec<[f32; 3]>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct IntrinsicsMetadata {
105 pub focal_length: [f32; 2],
106 pub principal_point: [f32; 2],
107 pub image_size: [u32; 2],
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ViewpointConfigMetadata {
112 pub radius: f32,
113 pub yaw_count: usize,
114 pub pitch_angles_deg: Vec<f32>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct RenderMetadata {
120 pub object_id: String,
121 pub rotation_index: usize,
122 pub viewpoint_index: usize,
123 pub rotation_euler: [f32; 3],
124 pub camera_position: [f32; 3],
125 #[serde(default)]
126 pub camera_rotation_xyzw: Option<[f32; 4]>,
127 #[serde(default)]
128 pub object_translation: Option<[f32; 3]>,
129 #[serde(default)]
130 pub object_scale: Option<[f32; 3]>,
131 #[serde(default)]
132 pub target_point: Option<[f32; 3]>,
133 #[serde(default)]
134 pub targeting_policy: Option<TargetingPolicy>,
135 #[serde(default)]
136 pub mesh_bounds: Option<MeshBoundsMetadata>,
137 #[serde(default)]
138 pub health: Option<RenderHealth>,
139 pub rgba_file: String,
140 pub depth_file: String,
141}
142
143pub struct TestFixtures {
145 root: PathBuf,
147 pub metadata: DatasetMetadata,
149 indices: HashMap<String, Vec<RenderMetadata>>,
151}
152
153impl TestFixtures {
154 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, FixtureError> {
163 let root = path.as_ref().to_path_buf();
164
165 if !root.exists() {
166 return Err(FixtureError::NotFound(root.display().to_string()));
167 }
168
169 let metadata_path = root.join("metadata.json");
171 if !metadata_path.exists() {
172 return Err(FixtureError::InvalidMetadata(
173 "metadata.json not found".to_string(),
174 ));
175 }
176
177 let metadata_json = fs::read_to_string(&metadata_path)?;
178 let metadata: DatasetMetadata = serde_json::from_str(&metadata_json)?;
179
180 let mut indices = HashMap::new();
182 for object_id in &metadata.objects {
183 let index_path = root.join(object_id).join("index.json");
184 if index_path.exists() {
185 let index_json = fs::read_to_string(&index_path)?;
186 let renders: Vec<RenderMetadata> = serde_json::from_str(&index_json)?;
187 indices.insert(object_id.clone(), renders);
188 }
189 }
190
191 Ok(Self {
192 root,
193 metadata,
194 indices,
195 })
196 }
197
198 pub fn exists<P: AsRef<Path>>(path: P) -> bool {
200 let root = path.as_ref();
201 root.exists() && root.join("metadata.json").exists()
202 }
203
204 pub fn objects(&self) -> &[String] {
206 &self.metadata.objects
207 }
208
209 pub fn viewpoints_per_rotation(&self) -> usize {
211 self.metadata.viewpoints_per_rotation
212 }
213
214 pub fn rotations_per_object(&self) -> usize {
216 self.metadata.rotations_per_object
217 }
218
219 pub fn renders_for_object(&self, object_id: &str) -> usize {
221 self.indices.get(object_id).map(|v| v.len()).unwrap_or(0)
222 }
223
224 pub fn intrinsics(&self) -> CameraIntrinsics {
226 CameraIntrinsics {
227 focal_length: [
228 self.metadata.intrinsics.focal_length[0] as f64,
229 self.metadata.intrinsics.focal_length[1] as f64,
230 ],
231 principal_point: [
232 self.metadata.intrinsics.principal_point[0] as f64,
233 self.metadata.intrinsics.principal_point[1] as f64,
234 ],
235 image_size: self.metadata.intrinsics.image_size,
236 }
237 }
238
239 pub fn get_render(
246 &self,
247 object_id: &str,
248 rotation_idx: usize,
249 viewpoint_idx: usize,
250 ) -> Result<RenderOutput, FixtureError> {
251 let renders = self
253 .indices
254 .get(object_id)
255 .ok_or_else(|| FixtureError::RenderNotFound {
256 object_id: object_id.to_string(),
257 rotation: rotation_idx,
258 viewpoint: viewpoint_idx,
259 })?;
260
261 let render_meta = renders
262 .iter()
263 .find(|r| r.rotation_index == rotation_idx && r.viewpoint_index == viewpoint_idx)
264 .ok_or_else(|| FixtureError::RenderNotFound {
265 object_id: object_id.to_string(),
266 rotation: rotation_idx,
267 viewpoint: viewpoint_idx,
268 })?;
269
270 let rgba_path = self.root.join(object_id).join(&render_meta.rgba_file);
272 let rgba = load_rgba_png(&rgba_path)?;
273
274 let depth_path = self.root.join(object_id).join(&render_meta.depth_file);
276 let expected_depth_values =
277 (self.metadata.resolution[0] as usize) * (self.metadata.resolution[1] as usize);
278 let depth = load_depth_binary(&depth_path, expected_depth_values)?;
279
280 let pos = render_meta.camera_position;
284 let translation = Vec3::new(pos[0], pos[1], pos[2]);
285 let camera_transform = if let Some(q) = render_meta.camera_rotation_xyzw {
286 Transform {
287 translation,
288 rotation: Quat::from_xyzw(q[0], q[1], q[2], q[3]),
289 ..Default::default()
290 }
291 } else {
292 let target = render_meta.target_point.unwrap_or([0.0, 0.0, 0.0]);
293 Transform::from_translation(translation)
294 .looking_at(Vec3::new(target[0], target[1], target[2]), Vec3::Y)
295 };
296 let target_point = Vec3::from_array(render_meta.target_point.unwrap_or([0.0, 0.0, 0.0]));
297 let targeting_policy = render_meta
298 .targeting_policy
299 .clone()
300 .or_else(|| self.metadata.targeting_policy.clone())
301 .unwrap_or(TargetingPolicy::Origin);
302
303 let rot = render_meta.rotation_euler;
305 let object_rotation =
306 crate::ObjectRotation::new(rot[0] as f64, rot[1] as f64, rot[2] as f64);
307 let object_translation =
308 Vec3::from_array(render_meta.object_translation.unwrap_or([0.0, 0.0, 0.0]));
309 let object_scale = Vec3::from_array(render_meta.object_scale.unwrap_or([1.0, 1.0, 1.0]));
310
311 Ok(RenderOutput {
312 rgba,
313 depth,
314 width: self.metadata.resolution[0],
315 height: self.metadata.resolution[1],
316 intrinsics: self.intrinsics(),
317 camera_transform,
318 object_rotation,
319 object_translation,
320 object_scale,
321 target_point,
322 targeting_policy,
323 })
324 }
325
326 pub fn get_all_renders(&self, object_id: &str) -> Result<Vec<RenderOutput>, FixtureError> {
328 let renders = self
329 .indices
330 .get(object_id)
331 .ok_or_else(|| FixtureError::RenderNotFound {
332 object_id: object_id.to_string(),
333 rotation: 0,
334 viewpoint: 0,
335 })?;
336
337 let mut outputs = Vec::with_capacity(renders.len());
338 for meta in renders {
339 let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
340 outputs.push(output);
341 }
342
343 Ok(outputs)
344 }
345
346 pub fn iter_renders<'a>(
348 &'a self,
349 object_id: &'a str,
350 ) -> impl Iterator<Item = Result<(usize, usize, RenderOutput), FixtureError>> + 'a {
351 let renders = self.indices.get(object_id);
352
353 renders.into_iter().flat_map(|v| v.iter()).map(move |meta| {
354 let output = self.get_render(object_id, meta.rotation_index, meta.viewpoint_index)?;
355 Ok((meta.rotation_index, meta.viewpoint_index, output))
356 })
357 }
358}
359
360fn load_rgba_png(path: &Path) -> Result<Vec<u8>, FixtureError> {
362 let img = image::open(path).map_err(|e| FixtureError::IoError(std::io::Error::other(e)))?;
363
364 let rgba = img.to_rgba8();
365 Ok(rgba.into_raw())
366}
367
368fn load_depth_binary(path: &Path, expected_values: usize) -> Result<Vec<f64>, FixtureError> {
370 let bytes = fs::read(path)?;
371
372 if bytes.len() == expected_values * std::mem::size_of::<f64>() {
373 return Ok(bytes
374 .chunks_exact(8)
375 .map(|chunk| {
376 let arr: [u8; 8] = chunk.try_into().unwrap();
377 f64::from_le_bytes(arr)
378 })
379 .collect());
380 }
381
382 if bytes.len() == expected_values * std::mem::size_of::<f32>() {
383 return Ok(bytes
384 .chunks_exact(4)
385 .map(|chunk| {
386 let arr: [u8; 4] = chunk.try_into().unwrap();
387 f32::from_le_bytes(arr) as f64
388 })
389 .collect());
390 }
391
392 Err(FixtureError::InvalidMetadata(format!(
393 "Depth file {} has {} bytes, expected {} f32 values or {} f64 values",
394 path.display(),
395 bytes.len(),
396 expected_values,
397 expected_values
398 )))
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use tempfile::TempDir;
405
406 #[test]
407 fn test_fixture_not_found() {
408 let result = TestFixtures::load("/nonexistent/path");
409 assert!(matches!(result, Err(FixtureError::NotFound(_))));
410 }
411
412 #[test]
413 fn test_fixtures_exists() {
414 assert!(!TestFixtures::exists("/nonexistent/path"));
415 }
416
417 #[test]
418 fn test_fixture_error_display() {
419 let errors = vec![
420 FixtureError::NotFound("/path".to_string()),
421 FixtureError::InvalidMetadata("bad json".to_string()),
422 FixtureError::RenderNotFound {
423 object_id: "obj".to_string(),
424 rotation: 0,
425 viewpoint: 5,
426 },
427 FixtureError::IoError(std::io::Error::new(
428 std::io::ErrorKind::NotFound,
429 "file not found",
430 )),
431 FixtureError::JsonError(serde_json::from_str::<String>("invalid").unwrap_err()),
432 ];
433
434 for err in errors {
435 let msg = err.to_string();
436 assert!(!msg.is_empty());
437 }
438 }
439
440 #[test]
441 fn test_fixture_missing_metadata() {
442 let temp_dir = TempDir::new().unwrap();
443 let result = TestFixtures::load(temp_dir.path());
444 assert!(matches!(result, Err(FixtureError::InvalidMetadata(_))));
445 }
446
447 #[test]
448 fn test_fixture_load_metadata() {
449 let temp_dir = TempDir::new().unwrap();
450
451 let metadata = DatasetMetadata {
453 version: "1.0".to_string(),
454 crate_version: None,
455 renderer_policy_version: None,
456 objects: vec!["test_object".to_string()],
457 viewpoints_per_rotation: 24,
458 rotations_per_object: 3,
459 renders_per_object: 72,
460 resolution: [64, 64],
461 resolution_width: None,
462 resolution_height: None,
463 targeting_policy: None,
464 intrinsics: IntrinsicsMetadata {
465 focal_length: [55.4, 55.4],
466 principal_point: [32.0, 32.0],
467 image_size: [64, 64],
468 },
469 viewpoint_config: ViewpointConfigMetadata {
470 radius: 0.5,
471 yaw_count: 8,
472 pitch_angles_deg: vec![-30.0, 0.0, 30.0],
473 },
474 rotations: vec![[0.0, 0.0, 0.0], [0.0, 90.0, 0.0], [0.0, 180.0, 0.0]],
475 };
476
477 let metadata_json = serde_json::to_string_pretty(&metadata).unwrap();
478 let metadata_path = temp_dir.path().join("metadata.json");
479 fs::write(&metadata_path, &metadata_json).unwrap();
480
481 let obj_dir = temp_dir.path().join("test_object");
483 fs::create_dir_all(&obj_dir).unwrap();
484 fs::write(obj_dir.join("index.json"), "[]").unwrap();
485
486 let fixtures = TestFixtures::load(temp_dir.path()).unwrap();
488
489 assert_eq!(fixtures.objects(), &["test_object"]);
490 assert_eq!(fixtures.viewpoints_per_rotation(), 24);
491 assert_eq!(fixtures.rotations_per_object(), 3);
492 assert_eq!(fixtures.renders_for_object("test_object"), 0);
493 assert_eq!(fixtures.renders_for_object("nonexistent"), 0);
494
495 let intrinsics = fixtures.intrinsics();
496 assert_eq!(intrinsics.image_size, [64, 64]);
497 }
498
499 #[test]
500 fn test_load_depth_binary_f32() {
501 let temp_dir = TempDir::new().unwrap();
502 let depth_path = temp_dir.path().join("test.depth");
503
504 let depths: Vec<f32> = vec![0.5, 1.0, 2.0, 10.0];
506 let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
507 fs::write(&depth_path, &bytes).unwrap();
508
509 let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
511 assert_eq!(loaded.len(), 4);
512 assert!((loaded[0] - 0.5).abs() < 0.001);
513 assert!((loaded[1] - 1.0).abs() < 0.001);
514 assert!((loaded[2] - 2.0).abs() < 0.001);
515 assert!((loaded[3] - 10.0).abs() < 0.001);
516 }
517
518 #[test]
519 fn test_load_depth_binary_f64() {
520 let temp_dir = TempDir::new().unwrap();
521 let depth_path = temp_dir.path().join("test.depth");
522
523 let depths: Vec<f64> = vec![0.5, 1.0, 2.0, 10.0];
524 let bytes: Vec<u8> = depths.iter().flat_map(|f| f.to_le_bytes()).collect();
525 fs::write(&depth_path, &bytes).unwrap();
526
527 let loaded = load_depth_binary(&depth_path, depths.len()).unwrap();
528 assert_eq!(loaded, depths);
529 }
530
531 #[test]
532 fn test_metadata_serialization_roundtrip() {
533 let metadata = DatasetMetadata {
534 version: "1.0".to_string(),
535 crate_version: Some("0.5.5".to_string()),
536 renderer_policy_version: Some(crate::RENDERER_POLICY_VERSION.to_string()),
537 objects: vec!["obj1".to_string(), "obj2".to_string()],
538 viewpoints_per_rotation: 24,
539 rotations_per_object: 3,
540 renders_per_object: 72,
541 resolution: [64, 64],
542 resolution_width: Some(64),
543 resolution_height: Some(64),
544 targeting_policy: Some(TargetingPolicy::MeshCenter),
545 intrinsics: IntrinsicsMetadata {
546 focal_length: [55.4, 55.4],
547 principal_point: [32.0, 32.0],
548 image_size: [64, 64],
549 },
550 viewpoint_config: ViewpointConfigMetadata {
551 radius: 0.5,
552 yaw_count: 8,
553 pitch_angles_deg: vec![-30.0, 0.0, 30.0],
554 },
555 rotations: vec![[0.0, 0.0, 0.0]],
556 };
557
558 let json = serde_json::to_string(&metadata).unwrap();
559 let loaded: DatasetMetadata = serde_json::from_str(&json).unwrap();
560
561 assert_eq!(loaded.version, metadata.version);
562 assert_eq!(loaded.objects, metadata.objects);
563 assert_eq!(loaded.resolution, metadata.resolution);
564 }
565
566 #[test]
567 fn test_render_metadata_serialization() {
568 let meta = RenderMetadata {
569 object_id: "003_cracker_box".to_string(),
570 rotation_index: 1,
571 viewpoint_index: 5,
572 rotation_euler: [0.0, 90.0, 0.0],
573 camera_position: [0.5, 0.0, 0.0],
574 camera_rotation_xyzw: Some([0.0, 0.0, 0.0, 1.0]),
575 object_translation: Some([0.1, 0.2, 0.3]),
576 object_scale: Some([1.0, 1.25, 0.75]),
577 target_point: Some([0.0, 0.0, 0.0]),
578 targeting_policy: Some(TargetingPolicy::Origin),
579 mesh_bounds: None,
580 health: Some(RenderHealth {
581 center_pixel: Some([32, 32]),
582 center_depth: Some(0.25),
583 center_foreground: true,
584 foreground_pixel_count: 1,
585 foreground_coverage: 1.0 / 4096.0,
586 center_5x5_foreground_count: 1,
587 nearest_foreground_pixel: Some([32, 32]),
588 nearest_foreground_depth: Some(0.25),
589 nearest_foreground_distance_px: Some(0.0),
590 }),
591 rgba_file: "r1_v05.png".to_string(),
592 depth_file: "r1_v05.depth".to_string(),
593 };
594
595 let json = serde_json::to_string(&meta).unwrap();
596 let loaded: RenderMetadata = serde_json::from_str(&json).unwrap();
597
598 assert_eq!(loaded.object_id, meta.object_id);
599 assert_eq!(loaded.rotation_index, meta.rotation_index);
600 assert_eq!(loaded.viewpoint_index, meta.viewpoint_index);
601 }
602}