1use bevy::prelude::*;
54use std::f32::consts::PI;
55use std::path::Path;
56
57mod render;
60
61pub mod batch;
63
64pub mod backend;
66
67pub mod cache;
69
70pub mod fixtures;
72
73#[allow(deprecated)]
75pub use ycbust::{
76 self, DownloadOptions, Subset as YcbSubset, REPRESENTATIVE_OBJECTS, TBP_STANDARD_OBJECTS,
77 TEN_OBJECTS,
78};
79
80pub mod ycb {
82 #[allow(deprecated)]
83 pub use ycbust::{
84 download_ycb, DownloadOptions, Subset, REPRESENTATIVE_OBJECTS, TBP_STANDARD_OBJECTS,
85 TEN_OBJECTS,
86 };
87
88 use reqwest::Client;
89 use std::path::Path;
90
91 pub async fn download_models<P: AsRef<Path>>(
104 output_dir: P,
105 subset: Subset,
106 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
107 let options = DownloadOptions {
108 overwrite: false,
109 full: false,
110 show_progress: true,
111 delete_archives: true,
112 };
113 download_ycb(subset, output_dir.as_ref(), options).await?;
114 Ok(())
115 }
116
117 pub async fn download_models_with_options<P: AsRef<Path>>(
119 output_dir: P,
120 subset: Subset,
121 options: DownloadOptions,
122 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
123 download_ycb(subset, output_dir.as_ref(), options).await?;
124 Ok(())
125 }
126
127 pub async fn download_objects<P: AsRef<Path>>(
129 output_dir: P,
130 object_ids: &[&str],
131 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
132 let output_dir = output_dir.as_ref();
133 let client = Client::new();
134 let options = DownloadOptions {
135 overwrite: false,
136 full: false,
137 show_progress: true,
138 delete_archives: true,
139 };
140
141 std::fs::create_dir_all(output_dir)?;
142
143 for object_id in object_ids {
144 let url = ycbust::get_tgz_url(object_id, "google_16k");
145 let archive_path = output_dir.join(format!("{object_id}_google_16k.tgz"));
146
147 if archive_path.exists() && !options.overwrite {
148 continue;
149 }
150
151 ycbust::download_file(&client, &url, &archive_path, options.show_progress).await?;
152 ycbust::extract_tgz(&archive_path, output_dir, options.delete_archives)?;
153 }
154
155 Ok(())
156 }
157
158 pub fn models_exist<P: AsRef<Path>>(output_dir: P) -> bool {
160 let path = output_dir.as_ref();
161 path.join("003_cracker_box/google_16k/textured.obj")
163 .exists()
164 }
165
166 pub fn object_mesh_path<P: AsRef<Path>>(output_dir: P, object_id: &str) -> std::path::PathBuf {
168 output_dir
169 .as_ref()
170 .join(object_id)
171 .join("google_16k")
172 .join("textured.obj")
173 }
174
175 pub fn object_texture_path<P: AsRef<Path>>(
177 output_dir: P,
178 object_id: &str,
179 ) -> std::path::PathBuf {
180 output_dir
181 .as_ref()
182 .join(object_id)
183 .join("google_16k")
184 .join("texture_map.png")
185 }
186}
187
188pub fn initialize() {
222 use std::sync::atomic::{AtomicBool, Ordering};
224 static INITIALIZED: AtomicBool = AtomicBool::new(false);
225
226 if !INITIALIZED.swap(true, Ordering::SeqCst) {
227 let config = backend::BackendConfig::new();
229 config.apply_env();
230 }
231}
232
233#[derive(Clone, Debug, PartialEq)]
236pub struct ObjectRotation {
237 pub pitch: f64,
239 pub yaw: f64,
241 pub roll: f64,
243}
244
245impl ObjectRotation {
246 pub fn new(pitch: f64, yaw: f64, roll: f64) -> Self {
248 Self { pitch, yaw, roll }
249 }
250
251 pub fn from_array(arr: [f64; 3]) -> Self {
253 Self {
254 pitch: arr[0],
255 yaw: arr[1],
256 roll: arr[2],
257 }
258 }
259
260 pub fn identity() -> Self {
262 Self::new(0.0, 0.0, 0.0)
263 }
264
265 pub fn tbp_benchmark_rotations() -> Vec<Self> {
268 vec![
269 Self::from_array([0.0, 0.0, 0.0]),
270 Self::from_array([0.0, 90.0, 0.0]),
271 Self::from_array([0.0, 180.0, 0.0]),
272 ]
273 }
274
275 pub fn tbp_known_orientations() -> Vec<Self> {
278 vec![
279 Self::from_array([0.0, 0.0, 0.0]), Self::from_array([0.0, 90.0, 0.0]), Self::from_array([0.0, 180.0, 0.0]), Self::from_array([0.0, 270.0, 0.0]), Self::from_array([90.0, 0.0, 0.0]), Self::from_array([-90.0, 0.0, 0.0]), Self::from_array([45.0, 45.0, 0.0]),
288 Self::from_array([45.0, 135.0, 0.0]),
289 Self::from_array([45.0, 225.0, 0.0]),
290 Self::from_array([45.0, 315.0, 0.0]),
291 Self::from_array([-45.0, 45.0, 0.0]),
292 Self::from_array([-45.0, 135.0, 0.0]),
293 Self::from_array([-45.0, 225.0, 0.0]),
294 Self::from_array([-45.0, 315.0, 0.0]),
295 ]
296 }
297
298 pub fn to_quat(&self) -> Quat {
300 Quat::from_euler(
301 EulerRot::XYZ,
302 (self.pitch as f32).to_radians(),
303 (self.yaw as f32).to_radians(),
304 (self.roll as f32).to_radians(),
305 )
306 }
307
308 pub fn to_transform(&self) -> Transform {
310 Transform::from_rotation(self.to_quat())
311 }
312}
313
314impl Default for ObjectRotation {
315 fn default() -> Self {
316 Self::identity()
317 }
318}
319
320#[derive(Clone, Debug)]
323pub struct ViewpointConfig {
324 pub radius: f32,
326 pub yaw_count: usize,
328 pub pitch_angles_deg: Vec<f32>,
330}
331
332impl Default for ViewpointConfig {
333 fn default() -> Self {
334 Self {
335 radius: 0.5,
336 yaw_count: 8,
337 pitch_angles_deg: vec![-30.0, 0.0, 30.0],
340 }
341 }
342}
343
344impl ViewpointConfig {
345 pub fn viewpoint_count(&self) -> usize {
347 self.yaw_count * self.pitch_angles_deg.len()
348 }
349}
350
351#[derive(Clone, Debug, Resource)]
353pub struct SensorConfig {
354 pub viewpoints: ViewpointConfig,
356 pub object_rotations: Vec<ObjectRotation>,
358 pub output_dir: String,
360 pub filename_pattern: String,
362}
363
364impl Default for SensorConfig {
365 fn default() -> Self {
366 Self {
367 viewpoints: ViewpointConfig::default(),
368 object_rotations: vec![ObjectRotation::identity()],
369 output_dir: ".".to_string(),
370 filename_pattern: "capture_{rot}_{view}.png".to_string(),
371 }
372 }
373}
374
375impl SensorConfig {
376 pub fn tbp_benchmark() -> Self {
378 Self {
379 viewpoints: ViewpointConfig::default(),
380 object_rotations: ObjectRotation::tbp_benchmark_rotations(),
381 output_dir: ".".to_string(),
382 filename_pattern: "capture_{rot}_{view}.png".to_string(),
383 }
384 }
385
386 pub fn tbp_full_training() -> Self {
388 Self {
389 viewpoints: ViewpointConfig::default(),
390 object_rotations: ObjectRotation::tbp_known_orientations(),
391 output_dir: ".".to_string(),
392 filename_pattern: "capture_{rot}_{view}.png".to_string(),
393 }
394 }
395
396 pub fn total_captures(&self) -> usize {
398 self.viewpoints.viewpoint_count() * self.object_rotations.len()
399 }
400}
401
402pub fn generate_viewpoints(config: &ViewpointConfig) -> Vec<Transform> {
409 let mut views = Vec::with_capacity(config.viewpoint_count());
410
411 for pitch_deg in &config.pitch_angles_deg {
412 let pitch = pitch_deg.to_radians();
413
414 for i in 0..config.yaw_count {
415 let yaw = (i as f32) * 2.0 * PI / (config.yaw_count as f32);
416
417 let x = config.radius * pitch.cos() * yaw.sin();
422 let y = config.radius * pitch.sin();
423 let z = config.radius * pitch.cos() * yaw.cos();
424
425 let transform = Transform::from_xyz(x, y, z).looking_at(Vec3::ZERO, Vec3::Y);
426 views.push(transform);
427 }
428 }
429 views
430}
431
432#[derive(Component)]
434pub struct CaptureTarget;
435
436#[derive(Component)]
438pub struct CaptureCamera;
439
440#[derive(Clone, Debug, PartialEq)]
448pub struct RenderConfig {
449 pub width: u32,
451 pub height: u32,
453 pub zoom: f32,
456 pub near_plane: f32,
458 pub far_plane: f32,
460 pub lighting: LightingConfig,
462}
463
464#[derive(Clone, Debug, PartialEq)]
468pub struct LightingConfig {
469 pub ambient_brightness: f32,
471 pub key_light_intensity: f32,
473 pub key_light_position: [f32; 3],
475 pub fill_light_intensity: f32,
477 pub fill_light_position: [f32; 3],
479 pub shadows_enabled: bool,
481}
482
483impl Default for LightingConfig {
484 fn default() -> Self {
485 Self {
486 ambient_brightness: 0.3,
487 key_light_intensity: 1500.0,
488 key_light_position: [4.0, 8.0, 4.0],
489 fill_light_intensity: 500.0,
490 fill_light_position: [-4.0, 2.0, -4.0],
491 shadows_enabled: false,
492 }
493 }
494}
495
496impl LightingConfig {
497 pub fn bright() -> Self {
499 Self {
500 ambient_brightness: 0.5,
501 key_light_intensity: 2000.0,
502 key_light_position: [4.0, 8.0, 4.0],
503 fill_light_intensity: 800.0,
504 fill_light_position: [-4.0, 2.0, -4.0],
505 shadows_enabled: false,
506 }
507 }
508
509 pub fn soft() -> Self {
511 Self {
512 ambient_brightness: 0.4,
513 key_light_intensity: 1000.0,
514 key_light_position: [3.0, 6.0, 3.0],
515 fill_light_intensity: 600.0,
516 fill_light_position: [-3.0, 3.0, -3.0],
517 shadows_enabled: false,
518 }
519 }
520
521 pub fn unlit() -> Self {
523 Self {
524 ambient_brightness: 1.0,
525 key_light_intensity: 0.0,
526 key_light_position: [0.0, 0.0, 0.0],
527 fill_light_intensity: 0.0,
528 fill_light_position: [0.0, 0.0, 0.0],
529 shadows_enabled: false,
530 }
531 }
532}
533
534impl Default for RenderConfig {
535 fn default() -> Self {
536 Self::tbp_default()
537 }
538}
539
540impl RenderConfig {
541 pub fn tbp_default() -> Self {
549 Self {
550 width: 64,
551 height: 64,
552 zoom: 4.0,
553 near_plane: 0.01,
554 far_plane: 10.0,
555 lighting: LightingConfig::default(),
556 }
557 }
558
559 pub fn preview() -> Self {
561 Self {
562 width: 256,
563 height: 256,
564 zoom: 1.0,
565 near_plane: 0.01,
566 far_plane: 10.0,
567 lighting: LightingConfig::default(),
568 }
569 }
570
571 pub fn high_res() -> Self {
573 Self {
574 width: 512,
575 height: 512,
576 zoom: 1.0,
577 near_plane: 0.01,
578 far_plane: 10.0,
579 lighting: LightingConfig::default(),
580 }
581 }
582
583 pub fn fov_radians(&self) -> f32 {
590 let base_hfov_rad = 90.0_f32.to_radians();
591 let half_tan = (base_hfov_rad / 2.0).tan() / self.zoom;
592 2.0 * half_tan.atan()
593 }
594
595 pub fn intrinsics(&self) -> CameraIntrinsics {
603 let base_hfov_rad = 90.0_f64.to_radians();
604 let fx_norm = (base_hfov_rad / 2.0).tan() / self.zoom as f64;
606 let fx = (self.width as f64 / 2.0) / fx_norm;
608 let fy = fx; CameraIntrinsics {
611 focal_length: [fx, fy],
612 principal_point: [self.width as f64 / 2.0, self.height as f64 / 2.0],
613 image_size: [self.width, self.height],
614 }
615 }
616}
617
618#[derive(Clone, Debug, PartialEq)]
623pub struct CameraIntrinsics {
624 pub focal_length: [f64; 2],
626 pub principal_point: [f64; 2],
628 pub image_size: [u32; 2],
630}
631
632impl CameraIntrinsics {
633 pub fn project(&self, point: Vec3) -> Option<[f64; 2]> {
635 if point.z <= 0.0 {
636 return None;
637 }
638 let x = (point.x as f64 / point.z as f64) * self.focal_length[0] + self.principal_point[0];
639 let y = (point.y as f64 / point.z as f64) * self.focal_length[1] + self.principal_point[1];
640 Some([x, y])
641 }
642
643 pub fn unproject(&self, pixel: [f64; 2], depth: f64) -> [f64; 3] {
645 let x = (pixel[0] - self.principal_point[0]) / self.focal_length[0] * depth;
646 let y = (pixel[1] - self.principal_point[1]) / self.focal_length[1] * depth;
647 [x, y, depth]
648 }
649}
650
651#[derive(Clone, Debug)]
653pub struct RenderOutput {
654 pub rgba: Vec<u8>,
656 pub depth: Vec<f64>,
660 pub width: u32,
662 pub height: u32,
664 pub intrinsics: CameraIntrinsics,
666 pub camera_transform: Transform,
668 pub object_rotation: ObjectRotation,
670}
671
672impl RenderOutput {
673 pub fn get_rgba(&self, x: u32, y: u32) -> Option<[u8; 4]> {
675 if x >= self.width || y >= self.height {
676 return None;
677 }
678 let idx = ((y * self.width + x) * 4) as usize;
679 Some([
680 self.rgba[idx],
681 self.rgba[idx + 1],
682 self.rgba[idx + 2],
683 self.rgba[idx + 3],
684 ])
685 }
686
687 pub fn get_depth(&self, x: u32, y: u32) -> Option<f64> {
689 if x >= self.width || y >= self.height {
690 return None;
691 }
692 let idx = (y * self.width + x) as usize;
693 Some(self.depth[idx])
694 }
695
696 pub fn get_rgb(&self, x: u32, y: u32) -> Option<[u8; 3]> {
698 self.get_rgba(x, y).map(|rgba| [rgba[0], rgba[1], rgba[2]])
699 }
700
701 pub fn to_rgb_image(&self) -> Vec<Vec<[u8; 3]>> {
703 let mut image = Vec::with_capacity(self.height as usize);
704 for y in 0..self.height {
705 let mut row = Vec::with_capacity(self.width as usize);
706 for x in 0..self.width {
707 row.push(self.get_rgb(x, y).unwrap_or([0, 0, 0]));
708 }
709 image.push(row);
710 }
711 image
712 }
713
714 pub fn to_depth_image(&self) -> Vec<Vec<f64>> {
716 let mut image = Vec::with_capacity(self.height as usize);
717 for y in 0..self.height {
718 let mut row = Vec::with_capacity(self.width as usize);
719 for x in 0..self.width {
720 row.push(self.get_depth(x, y).unwrap_or(0.0));
721 }
722 image.push(row);
723 }
724 image
725 }
726}
727
728#[derive(Debug, Clone)]
730pub enum RenderError {
731 MeshNotFound(String),
733 TextureNotFound(String),
735 FileNotFound { path: String, reason: String },
737 FileWriteFailed { path: String, reason: String },
739 DirectoryCreationFailed { path: String, reason: String },
741 RenderFailed(String),
743 InvalidConfig(String),
745 InvalidInput(String),
747 SerializationError(String),
749 DataParsingError(String),
751 RenderTimeout { duration_secs: u64 },
753}
754
755impl std::fmt::Display for RenderError {
756 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
757 match self {
758 RenderError::MeshNotFound(path) => write!(f, "Mesh not found: {}", path),
759 RenderError::TextureNotFound(path) => write!(f, "Texture not found: {}", path),
760 RenderError::FileNotFound { path, reason } => {
761 write!(f, "File not found at {}: {}", path, reason)
762 }
763 RenderError::FileWriteFailed { path, reason } => {
764 write!(f, "Failed to write file {}: {}", path, reason)
765 }
766 RenderError::DirectoryCreationFailed { path, reason } => {
767 write!(f, "Failed to create directory {}: {}", path, reason)
768 }
769 RenderError::RenderFailed(msg) => write!(f, "Render failed: {}", msg),
770 RenderError::InvalidConfig(msg) => write!(f, "Invalid config: {}", msg),
771 RenderError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
772 RenderError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
773 RenderError::DataParsingError(msg) => write!(f, "Data parsing error: {}", msg),
774 RenderError::RenderTimeout { duration_secs } => {
775 write!(f, "Render timeout after {} seconds", duration_secs)
776 }
777 }
778 }
779}
780
781impl std::error::Error for RenderError {}
782
783pub fn render_to_buffer(
808 object_dir: &Path,
809 camera_transform: &Transform,
810 object_rotation: &ObjectRotation,
811 config: &RenderConfig,
812) -> Result<RenderOutput, RenderError> {
813 render::render_headless(object_dir, camera_transform, object_rotation, config)
815}
816
817pub fn render_all_viewpoints(
830 object_dir: &Path,
831 viewpoint_config: &ViewpointConfig,
832 rotations: &[ObjectRotation],
833 render_config: &RenderConfig,
834) -> Result<Vec<RenderOutput>, RenderError> {
835 let viewpoints = generate_viewpoints(viewpoint_config);
836 let mut outputs = Vec::with_capacity(viewpoints.len() * rotations.len());
837
838 for rotation in rotations {
839 for viewpoint in &viewpoints {
840 let output = render_to_buffer(object_dir, viewpoint, rotation, render_config)?;
841 outputs.push(output);
842 }
843 }
844
845 Ok(outputs)
846}
847
848pub fn render_to_buffer_cached(
916 object_dir: &Path,
917 camera_transform: &Transform,
918 object_rotation: &ObjectRotation,
919 config: &RenderConfig,
920 cache: &mut cache::ModelCache,
921) -> Result<RenderOutput, RenderError> {
922 let mesh_path = object_dir.join("google_16k/textured.obj");
923 let texture_path = object_dir.join("google_16k/texture_map.png");
924
925 cache.cache_scene(mesh_path.clone());
927 cache.cache_texture(texture_path.clone());
928
929 render::render_headless(object_dir, camera_transform, object_rotation, config)
931}
932
933pub fn render_to_files(
950 object_dir: &Path,
951 camera_transform: &Transform,
952 object_rotation: &ObjectRotation,
953 config: &RenderConfig,
954 rgba_path: &Path,
955 depth_path: &Path,
956) -> Result<(), RenderError> {
957 render::render_to_files(
958 object_dir,
959 camera_transform,
960 object_rotation,
961 config,
962 rgba_path,
963 depth_path,
964 )
965}
966
967pub use batch::{
969 BatchRenderConfig, BatchRenderError, BatchRenderOutput, BatchRenderRequest, BatchRenderer,
970 BatchState, RenderStatus,
971};
972
973pub use render::RenderSession;
976
977pub fn create_batch_renderer(config: &BatchRenderConfig) -> Result<BatchRenderer, RenderError> {
995 Ok(BatchRenderer::new(config.clone()))
996}
997
998pub fn queue_render_request(
1023 renderer: &mut BatchRenderer,
1024 request: BatchRenderRequest,
1025) -> Result<(), RenderError> {
1026 renderer
1027 .queue_request(request)
1028 .map_err(|e| RenderError::RenderFailed(e.to_string()))
1029}
1030
1031pub fn render_next_in_batch(
1053 renderer: &mut BatchRenderer,
1054 _timeout_ms: u32,
1055) -> Result<Option<BatchRenderOutput>, RenderError> {
1056 if let Some(request) = renderer.pending_requests.pop_front() {
1057 let output = render_to_buffer(
1058 &request.object_dir,
1059 &request.viewpoint,
1060 &request.object_rotation,
1061 &request.render_config,
1062 )?;
1063 let batch_output = BatchRenderOutput::from_render_output(request, output);
1064 renderer.completed_results.push(batch_output.clone());
1065 renderer.renders_processed += 1;
1066 Ok(Some(batch_output))
1067 } else {
1068 Ok(None)
1069 }
1070}
1071
1072pub fn render_batch(
1091 requests: Vec<BatchRenderRequest>,
1092 config: &BatchRenderConfig,
1093) -> Result<Vec<BatchRenderOutput>, RenderError> {
1094 if requests.is_empty() {
1095 return Ok(Vec::new());
1096 }
1097
1098 if requests.len() > 1 && requests_share_batch_context(&requests) {
1099 let first_request = requests[0].clone();
1100 let viewpoints: Vec<Transform> = requests.iter().map(|request| request.viewpoint).collect();
1101 let outputs = render::render_headless_sequence(
1102 &first_request.object_dir,
1103 &viewpoints,
1104 &first_request.object_rotation,
1105 &first_request.render_config,
1106 )?;
1107
1108 return Ok(requests
1109 .into_iter()
1110 .zip(outputs)
1111 .map(|(request, output)| BatchRenderOutput::from_render_output(request, output))
1112 .collect());
1113 }
1114
1115 let mut renderer = create_batch_renderer(config)?;
1116
1117 for request in requests {
1119 queue_render_request(&mut renderer, request)?;
1120 }
1121
1122 let mut results = Vec::new();
1124 while let Some(output) = render_next_in_batch(&mut renderer, config.frame_timeout_ms)? {
1125 results.push(output);
1126 }
1127
1128 Ok(results)
1129}
1130
1131fn requests_share_batch_context(requests: &[BatchRenderRequest]) -> bool {
1132 let Some(first) = requests.first() else {
1133 return true;
1134 };
1135
1136 requests.iter().all(|request| {
1137 request.object_dir == first.object_dir
1138 && request.object_rotation == first.object_rotation
1139 && request.render_config == first.render_config
1140 })
1141}
1142
1143pub use bevy::prelude::{Quat, Transform, Vec3};
1145
1146#[cfg(test)]
1147mod tests {
1148 use super::*;
1149
1150 #[test]
1151 fn test_object_rotation_identity() {
1152 let rot = ObjectRotation::identity();
1153 assert_eq!(rot.pitch, 0.0);
1154 assert_eq!(rot.yaw, 0.0);
1155 assert_eq!(rot.roll, 0.0);
1156 }
1157
1158 #[test]
1159 fn test_object_rotation_from_array() {
1160 let rot = ObjectRotation::from_array([10.0, 20.0, 30.0]);
1161 assert_eq!(rot.pitch, 10.0);
1162 assert_eq!(rot.yaw, 20.0);
1163 assert_eq!(rot.roll, 30.0);
1164 }
1165
1166 #[test]
1167 fn test_requests_share_batch_context_for_homogeneous_batch() {
1168 let config = RenderConfig::tbp_default();
1169 let request = BatchRenderRequest {
1170 object_dir: "/tmp/ycb/003_cracker_box".into(),
1171 viewpoint: Transform::IDENTITY,
1172 object_rotation: ObjectRotation::identity(),
1173 render_config: config.clone(),
1174 };
1175
1176 assert!(requests_share_batch_context(&[
1177 request.clone(),
1178 BatchRenderRequest {
1179 viewpoint: Transform::from_xyz(1.0, 0.0, 0.0),
1180 ..request
1181 },
1182 ]));
1183 }
1184
1185 #[test]
1186 fn test_requests_share_batch_context_rejects_mixed_objects() {
1187 let config = RenderConfig::tbp_default();
1188 let request = BatchRenderRequest {
1189 object_dir: "/tmp/ycb/003_cracker_box".into(),
1190 viewpoint: Transform::IDENTITY,
1191 object_rotation: ObjectRotation::identity(),
1192 render_config: config.clone(),
1193 };
1194
1195 assert!(!requests_share_batch_context(&[
1196 request.clone(),
1197 BatchRenderRequest {
1198 object_dir: "/tmp/ycb/005_tomato_soup_can".into(),
1199 ..request
1200 },
1201 ]));
1202 }
1203
1204 #[test]
1205 fn test_tbp_benchmark_rotations() {
1206 let rotations = ObjectRotation::tbp_benchmark_rotations();
1207 assert_eq!(rotations.len(), 3);
1208 assert_eq!(rotations[0], ObjectRotation::from_array([0.0, 0.0, 0.0]));
1209 assert_eq!(rotations[1], ObjectRotation::from_array([0.0, 90.0, 0.0]));
1210 assert_eq!(rotations[2], ObjectRotation::from_array([0.0, 180.0, 0.0]));
1211 }
1212
1213 #[test]
1214 fn test_tbp_known_orientations_count() {
1215 let orientations = ObjectRotation::tbp_known_orientations();
1216 assert_eq!(orientations.len(), 14);
1217 }
1218
1219 #[test]
1220 fn test_rotation_to_quat() {
1221 let rot = ObjectRotation::identity();
1222 let quat = rot.to_quat();
1223 assert!((quat.w - 1.0).abs() < 0.001);
1225 assert!(quat.x.abs() < 0.001);
1226 assert!(quat.y.abs() < 0.001);
1227 assert!(quat.z.abs() < 0.001);
1228 }
1229
1230 #[test]
1231 fn test_rotation_90_yaw() {
1232 let rot = ObjectRotation::new(0.0, 90.0, 0.0);
1233 let quat = rot.to_quat();
1234 assert!((quat.w - 0.707).abs() < 0.01);
1236 assert!((quat.y - 0.707).abs() < 0.01);
1237 }
1238
1239 #[test]
1240 fn test_viewpoint_config_default() {
1241 let config = ViewpointConfig::default();
1242 assert_eq!(config.radius, 0.5);
1243 assert_eq!(config.yaw_count, 8);
1244 assert_eq!(config.pitch_angles_deg.len(), 3);
1245 }
1246
1247 #[test]
1248 fn test_viewpoint_count() {
1249 let config = ViewpointConfig::default();
1250 assert_eq!(config.viewpoint_count(), 24); }
1252
1253 #[test]
1254 fn test_generate_viewpoints_count() {
1255 let config = ViewpointConfig::default();
1256 let viewpoints = generate_viewpoints(&config);
1257 assert_eq!(viewpoints.len(), 24);
1258 }
1259
1260 #[test]
1261 fn test_viewpoints_spherical_radius() {
1262 let config = ViewpointConfig::default();
1263 let viewpoints = generate_viewpoints(&config);
1264
1265 for (i, transform) in viewpoints.iter().enumerate() {
1266 let actual_radius = transform.translation.length();
1267 assert!(
1268 (actual_radius - config.radius).abs() < 0.001,
1269 "Viewpoint {} has incorrect radius: {} (expected {})",
1270 i,
1271 actual_radius,
1272 config.radius
1273 );
1274 }
1275 }
1276
1277 #[test]
1278 fn test_viewpoints_looking_at_origin() {
1279 let config = ViewpointConfig::default();
1280 let viewpoints = generate_viewpoints(&config);
1281
1282 for (i, transform) in viewpoints.iter().enumerate() {
1283 let forward = transform.forward();
1284 let to_origin = (Vec3::ZERO - transform.translation).normalize();
1285 let dot = forward.dot(to_origin);
1286 assert!(
1287 dot > 0.99,
1288 "Viewpoint {} not looking at origin, dot product: {}",
1289 i,
1290 dot
1291 );
1292 }
1293 }
1294
1295 #[test]
1296 fn test_sensor_config_default() {
1297 let config = SensorConfig::default();
1298 assert_eq!(config.object_rotations.len(), 1);
1299 assert_eq!(config.total_captures(), 24);
1300 }
1301
1302 #[test]
1303 fn test_sensor_config_tbp_benchmark() {
1304 let config = SensorConfig::tbp_benchmark();
1305 assert_eq!(config.object_rotations.len(), 3);
1306 assert_eq!(config.total_captures(), 72); }
1308
1309 #[test]
1310 fn test_sensor_config_tbp_full() {
1311 let config = SensorConfig::tbp_full_training();
1312 assert_eq!(config.object_rotations.len(), 14);
1313 assert_eq!(config.total_captures(), 336); }
1315
1316 #[test]
1317 fn test_ycb_representative_objects() {
1318 assert_eq!(crate::ycb::REPRESENTATIVE_OBJECTS.len(), 3);
1320 assert!(crate::ycb::REPRESENTATIVE_OBJECTS.contains(&"003_cracker_box"));
1321 }
1322
1323 #[test]
1324 #[allow(deprecated)]
1325 fn test_ycb_ten_objects() {
1326 assert_eq!(crate::ycb::TEN_OBJECTS.len(), 10);
1328 }
1329
1330 #[test]
1331 fn test_ycb_object_mesh_path() {
1332 let path = crate::ycb::object_mesh_path("/tmp/ycb", "003_cracker_box");
1333 assert_eq!(
1334 path,
1335 std::path::Path::new("/tmp/ycb")
1336 .join("003_cracker_box")
1337 .join("google_16k")
1338 .join("textured.obj")
1339 );
1340 }
1341
1342 #[test]
1343 fn test_ycb_object_texture_path() {
1344 let path = crate::ycb::object_texture_path("/tmp/ycb", "003_cracker_box");
1345 assert_eq!(
1346 path,
1347 std::path::Path::new("/tmp/ycb")
1348 .join("003_cracker_box")
1349 .join("google_16k")
1350 .join("texture_map.png")
1351 );
1352 }
1353
1354 #[test]
1359 fn test_render_config_tbp_default() {
1360 let config = RenderConfig::tbp_default();
1361 assert_eq!(config.width, 64);
1363 assert_eq!(config.height, 64);
1364 assert!(config.zoom > 0.0);
1366 assert!(config.near_plane > 0.0);
1368 assert!(config.far_plane > config.near_plane);
1369 }
1370
1371 #[test]
1372 fn test_render_config_preview() {
1373 let config = RenderConfig::preview();
1374 assert_eq!(config.width, 256);
1375 assert_eq!(config.height, 256);
1376 }
1377
1378 #[test]
1379 fn test_render_config_default_is_tbp() {
1380 let default = RenderConfig::default();
1381 let tbp = RenderConfig::tbp_default();
1382 assert_eq!(default.width, tbp.width);
1383 assert_eq!(default.height, tbp.height);
1384 }
1385
1386 #[test]
1387 fn test_render_config_fov() {
1388 let config = RenderConfig::tbp_default();
1389 let fov = config.fov_radians();
1390 assert!(fov > 0.0);
1393 assert!(fov < PI);
1394
1395 let zoomed = RenderConfig {
1397 zoom: config.zoom * 2.0,
1398 ..config
1399 };
1400 assert!(zoomed.fov_radians() < fov);
1401 }
1402
1403 #[test]
1404 fn test_render_config_intrinsics() {
1405 let config = RenderConfig::tbp_default();
1406 let intrinsics = config.intrinsics();
1407
1408 assert_eq!(intrinsics.image_size, [config.width, config.height]);
1410 assert_eq!(
1411 intrinsics.principal_point,
1412 [config.width as f64 / 2.0, config.height as f64 / 2.0]
1413 );
1414 assert_eq!(intrinsics.focal_length[0], intrinsics.focal_length[1]);
1416 assert!(intrinsics.focal_length[0] > 0.0);
1417 }
1418
1419 #[test]
1420 fn test_camera_intrinsics_project() {
1421 let intrinsics = CameraIntrinsics {
1422 focal_length: [100.0, 100.0],
1423 principal_point: [32.0, 32.0],
1424 image_size: [64, 64],
1425 };
1426
1427 let center = intrinsics.project(Vec3::new(0.0, 0.0, 1.0));
1429 assert!(center.is_some());
1430 let [x, y] = center.unwrap();
1431 assert!((x - 32.0).abs() < 0.001);
1432 assert!((y - 32.0).abs() < 0.001);
1433
1434 let behind = intrinsics.project(Vec3::new(0.0, 0.0, -1.0));
1436 assert!(behind.is_none());
1437 }
1438
1439 #[test]
1440 fn test_camera_intrinsics_unproject() {
1441 let intrinsics = CameraIntrinsics {
1442 focal_length: [100.0, 100.0],
1443 principal_point: [32.0, 32.0],
1444 image_size: [64, 64],
1445 };
1446
1447 let point = intrinsics.unproject([32.0, 32.0], 1.0);
1449 assert!((point[0]).abs() < 0.001); assert!((point[1]).abs() < 0.001); assert!((point[2] - 1.0).abs() < 0.001); }
1453
1454 #[test]
1455 fn test_render_output_get_rgba() {
1456 let output = RenderOutput {
1457 rgba: vec![
1458 255, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255,
1459 ],
1460 depth: vec![1.0, 2.0, 3.0, 4.0],
1461 width: 2,
1462 height: 2,
1463 intrinsics: RenderConfig::tbp_default().intrinsics(),
1464 camera_transform: Transform::IDENTITY,
1465 object_rotation: ObjectRotation::identity(),
1466 };
1467
1468 assert_eq!(output.get_rgba(0, 0), Some([255, 0, 0, 255]));
1470 assert_eq!(output.get_rgba(1, 0), Some([0, 255, 0, 255]));
1472 assert_eq!(output.get_rgba(0, 1), Some([0, 0, 255, 255]));
1474 assert_eq!(output.get_rgba(1, 1), Some([255, 255, 255, 255]));
1476 assert_eq!(output.get_rgba(2, 0), None);
1478 }
1479
1480 #[test]
1481 fn test_render_output_get_depth() {
1482 let output = RenderOutput {
1483 rgba: vec![0u8; 16],
1484 depth: vec![1.0, 2.0, 3.0, 4.0],
1485 width: 2,
1486 height: 2,
1487 intrinsics: RenderConfig::tbp_default().intrinsics(),
1488 camera_transform: Transform::IDENTITY,
1489 object_rotation: ObjectRotation::identity(),
1490 };
1491
1492 assert_eq!(output.get_depth(0, 0), Some(1.0));
1493 assert_eq!(output.get_depth(1, 0), Some(2.0));
1494 assert_eq!(output.get_depth(0, 1), Some(3.0));
1495 assert_eq!(output.get_depth(1, 1), Some(4.0));
1496 assert_eq!(output.get_depth(2, 0), None);
1497 }
1498
1499 #[test]
1500 fn test_render_output_to_rgb_image() {
1501 let output = RenderOutput {
1502 rgba: vec![
1503 255, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255,
1504 ],
1505 depth: vec![1.0, 2.0, 3.0, 4.0],
1506 width: 2,
1507 height: 2,
1508 intrinsics: RenderConfig::tbp_default().intrinsics(),
1509 camera_transform: Transform::IDENTITY,
1510 object_rotation: ObjectRotation::identity(),
1511 };
1512
1513 let image = output.to_rgb_image();
1514 assert_eq!(image.len(), 2); assert_eq!(image[0].len(), 2); assert_eq!(image[0][0], [255, 0, 0]); assert_eq!(image[0][1], [0, 255, 0]); assert_eq!(image[1][0], [0, 0, 255]); assert_eq!(image[1][1], [255, 255, 255]); }
1521
1522 #[test]
1523 fn test_render_output_to_depth_image() {
1524 let output = RenderOutput {
1525 rgba: vec![0u8; 16],
1526 depth: vec![1.0, 2.0, 3.0, 4.0],
1527 width: 2,
1528 height: 2,
1529 intrinsics: RenderConfig::tbp_default().intrinsics(),
1530 camera_transform: Transform::IDENTITY,
1531 object_rotation: ObjectRotation::identity(),
1532 };
1533
1534 let depth_image = output.to_depth_image();
1535 assert_eq!(depth_image.len(), 2);
1536 assert_eq!(depth_image[0], vec![1.0, 2.0]);
1537 assert_eq!(depth_image[1], vec![3.0, 4.0]);
1538 }
1539
1540 #[test]
1541 fn test_render_error_display() {
1542 let err = RenderError::MeshNotFound("/path/to/mesh.obj".to_string());
1543 assert!(err.to_string().contains("Mesh not found"));
1544 assert!(err.to_string().contains("/path/to/mesh.obj"));
1545 }
1546
1547 #[test]
1552 fn test_object_rotation_extreme_angles() {
1553 let rot = ObjectRotation::new(450.0, -720.0, 1080.0);
1555 let quat = rot.to_quat();
1556 assert!((quat.length() - 1.0).abs() < 0.001);
1558 }
1559
1560 #[test]
1561 fn test_object_rotation_to_transform() {
1562 let rot = ObjectRotation::new(45.0, 90.0, 0.0);
1563 let transform = rot.to_transform();
1564 assert_eq!(transform.translation, Vec3::ZERO);
1566 assert!(transform.rotation != Quat::IDENTITY);
1568 }
1569
1570 #[test]
1571 fn test_viewpoint_config_single_viewpoint() {
1572 let config = ViewpointConfig {
1573 radius: 1.0,
1574 yaw_count: 1,
1575 pitch_angles_deg: vec![0.0],
1576 };
1577 assert_eq!(config.viewpoint_count(), 1);
1578 let viewpoints = generate_viewpoints(&config);
1579 assert_eq!(viewpoints.len(), 1);
1580 let pos = viewpoints[0].translation;
1582 assert!((pos.x).abs() < 0.001);
1583 assert!((pos.y).abs() < 0.001);
1584 assert!((pos.z - 1.0).abs() < 0.001);
1585 }
1586
1587 #[test]
1588 fn test_viewpoint_radius_scaling() {
1589 let config1 = ViewpointConfig {
1590 radius: 0.5,
1591 yaw_count: 4,
1592 pitch_angles_deg: vec![0.0],
1593 };
1594 let config2 = ViewpointConfig {
1595 radius: 2.0,
1596 yaw_count: 4,
1597 pitch_angles_deg: vec![0.0],
1598 };
1599
1600 let v1 = generate_viewpoints(&config1);
1601 let v2 = generate_viewpoints(&config2);
1602
1603 for (vp1, vp2) in v1.iter().zip(v2.iter()) {
1605 let ratio = vp2.translation.length() / vp1.translation.length();
1606 assert!((ratio - 4.0).abs() < 0.01); }
1608 }
1609
1610 #[test]
1611 fn test_camera_intrinsics_project_at_z_zero() {
1612 let intrinsics = CameraIntrinsics {
1613 focal_length: [100.0, 100.0],
1614 principal_point: [32.0, 32.0],
1615 image_size: [64, 64],
1616 };
1617
1618 let result = intrinsics.project(Vec3::new(1.0, 1.0, 0.0));
1620 assert!(result.is_none());
1621 }
1622
1623 #[test]
1624 fn test_camera_intrinsics_roundtrip() {
1625 let intrinsics = CameraIntrinsics {
1626 focal_length: [100.0, 100.0],
1627 principal_point: [32.0, 32.0],
1628 image_size: [64, 64],
1629 };
1630
1631 let original = Vec3::new(0.5, -0.3, 2.0);
1633 let projected = intrinsics.project(original).unwrap();
1634
1635 let unprojected = intrinsics.unproject(projected, original.z as f64);
1637
1638 assert!((unprojected[0] - original.x as f64).abs() < 0.001); assert!((unprojected[1] - original.y as f64).abs() < 0.001); assert!((unprojected[2] - original.z as f64).abs() < 0.001); }
1643
1644 #[test]
1645 fn test_render_output_empty() {
1646 let output = RenderOutput {
1647 rgba: vec![],
1648 depth: vec![],
1649 width: 0,
1650 height: 0,
1651 intrinsics: RenderConfig::tbp_default().intrinsics(),
1652 camera_transform: Transform::IDENTITY,
1653 object_rotation: ObjectRotation::identity(),
1654 };
1655
1656 assert_eq!(output.get_rgba(0, 0), None);
1658 assert_eq!(output.get_depth(0, 0), None);
1659 assert!(output.to_rgb_image().is_empty());
1660 assert!(output.to_depth_image().is_empty());
1661 }
1662
1663 #[test]
1664 fn test_render_output_1x1() {
1665 let output = RenderOutput {
1666 rgba: vec![128, 64, 32, 255],
1667 depth: vec![0.5],
1668 width: 1,
1669 height: 1,
1670 intrinsics: RenderConfig::tbp_default().intrinsics(),
1671 camera_transform: Transform::IDENTITY,
1672 object_rotation: ObjectRotation::identity(),
1673 };
1674
1675 assert_eq!(output.get_rgba(0, 0), Some([128, 64, 32, 255]));
1676 assert_eq!(output.get_depth(0, 0), Some(0.5));
1677 assert_eq!(output.get_rgb(0, 0), Some([128, 64, 32]));
1678
1679 let rgb_img = output.to_rgb_image();
1680 assert_eq!(rgb_img.len(), 1);
1681 assert_eq!(rgb_img[0].len(), 1);
1682 assert_eq!(rgb_img[0][0], [128, 64, 32]);
1683 }
1684
1685 #[test]
1686 fn test_render_config_high_res() {
1687 let config = RenderConfig::high_res();
1688 assert_eq!(config.width, 512);
1689 assert_eq!(config.height, 512);
1690
1691 let intrinsics = config.intrinsics();
1692 assert_eq!(intrinsics.image_size, [512, 512]);
1693 assert_eq!(intrinsics.principal_point, [256.0, 256.0]);
1694 }
1695
1696 #[test]
1697 fn test_render_config_zoom_affects_fov() {
1698 let base = RenderConfig {
1703 zoom: 2.0,
1704 ..RenderConfig::tbp_default()
1705 };
1706 let doubled = RenderConfig {
1707 zoom: 4.0,
1708 ..RenderConfig::tbp_default()
1709 };
1710
1711 assert!(doubled.fov_radians() < base.fov_radians());
1713
1714 let base_half_tan = (base.fov_radians() / 2.0).tan();
1716 let doubled_half_tan = (doubled.fov_radians() / 2.0).tan();
1717 assert!((base_half_tan / doubled_half_tan - 2.0).abs() < 1e-4);
1718 }
1719
1720 #[test]
1721 fn test_render_config_zoom_affects_intrinsics() {
1722 let a = RenderConfig {
1725 zoom: 2.0,
1726 ..RenderConfig::tbp_default()
1727 };
1728 let b = RenderConfig {
1729 zoom: 4.0,
1730 ..RenderConfig::tbp_default()
1731 };
1732
1733 let fx_a = a.intrinsics().focal_length[0];
1734 let fx_b = b.intrinsics().focal_length[0];
1735
1736 assert!(fx_b > fx_a);
1738
1739 assert!((fx_a / a.zoom as f64 - fx_b / b.zoom as f64).abs() < 1e-9);
1741 }
1742
1743 #[test]
1744 fn test_lighting_config_variants() {
1745 let default = LightingConfig::default();
1746 let bright = LightingConfig::bright();
1747 let soft = LightingConfig::soft();
1748 let unlit = LightingConfig::unlit();
1749
1750 assert!(bright.key_light_intensity > default.key_light_intensity);
1752
1753 assert_eq!(unlit.key_light_intensity, 0.0);
1755 assert_eq!(unlit.fill_light_intensity, 0.0);
1756 assert_eq!(unlit.ambient_brightness, 1.0);
1757
1758 assert!(soft.key_light_intensity < default.key_light_intensity);
1760 }
1761
1762 #[test]
1763 fn test_all_render_error_variants() {
1764 let errors = vec![
1765 RenderError::MeshNotFound("mesh.obj".to_string()),
1766 RenderError::TextureNotFound("texture.png".to_string()),
1767 RenderError::RenderFailed("GPU error".to_string()),
1768 RenderError::InvalidConfig("bad config".to_string()),
1769 ];
1770
1771 for err in errors {
1772 let msg = err.to_string();
1774 assert!(!msg.is_empty());
1775 }
1776 }
1777
1778 #[test]
1779 fn test_tbp_known_orientations_unique() {
1780 let orientations = ObjectRotation::tbp_known_orientations();
1781
1782 let quats: Vec<Quat> = orientations.iter().map(|r| r.to_quat()).collect();
1784
1785 for (i, q1) in quats.iter().enumerate() {
1786 for (j, q2) in quats.iter().enumerate() {
1787 if i != j {
1788 let dot = q1.dot(*q2).abs();
1790 assert!(
1791 dot < 0.999,
1792 "Orientations {} and {} produce same quaternion",
1793 i,
1794 j
1795 );
1796 }
1797 }
1798 }
1799 }
1800}