1use bevy::prelude::*;
31use std::f32::consts::PI;
32
33pub use ycbust::{self, DownloadOptions, Subset as YcbSubset, REPRESENTATIVE_OBJECTS, TEN_OBJECTS};
35
36pub mod ycb {
38 pub use ycbust::{download_ycb, DownloadOptions, Subset, REPRESENTATIVE_OBJECTS, TEN_OBJECTS};
39
40 use std::path::Path;
41
42 pub async fn download_models<P: AsRef<Path>>(
55 output_dir: P,
56 subset: Subset,
57 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
58 let options = DownloadOptions {
59 overwrite: false,
60 full: false,
61 show_progress: true,
62 delete_archives: true,
63 };
64 download_ycb(subset, output_dir.as_ref(), options).await?;
65 Ok(())
66 }
67
68 pub async fn download_models_with_options<P: AsRef<Path>>(
70 output_dir: P,
71 subset: Subset,
72 options: DownloadOptions,
73 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
74 download_ycb(subset, output_dir.as_ref(), options).await?;
75 Ok(())
76 }
77
78 pub fn models_exist<P: AsRef<Path>>(output_dir: P) -> bool {
80 let path = output_dir.as_ref();
81 path.join("003_cracker_box/google_16k/textured.obj")
83 .exists()
84 }
85
86 pub fn object_mesh_path<P: AsRef<Path>>(output_dir: P, object_id: &str) -> std::path::PathBuf {
88 output_dir
89 .as_ref()
90 .join(object_id)
91 .join("google_16k")
92 .join("textured.obj")
93 }
94
95 pub fn object_texture_path<P: AsRef<Path>>(
97 output_dir: P,
98 object_id: &str,
99 ) -> std::path::PathBuf {
100 output_dir
101 .as_ref()
102 .join(object_id)
103 .join("google_16k")
104 .join("texture_map.png")
105 }
106}
107
108#[derive(Clone, Debug, PartialEq)]
111pub struct ObjectRotation {
112 pub pitch: f32,
114 pub yaw: f32,
116 pub roll: f32,
118}
119
120impl ObjectRotation {
121 pub fn new(pitch: f32, yaw: f32, roll: f32) -> Self {
123 Self { pitch, yaw, roll }
124 }
125
126 pub fn from_array(arr: [f32; 3]) -> Self {
128 Self {
129 pitch: arr[0],
130 yaw: arr[1],
131 roll: arr[2],
132 }
133 }
134
135 pub fn identity() -> Self {
137 Self::new(0.0, 0.0, 0.0)
138 }
139
140 pub fn tbp_benchmark_rotations() -> Vec<Self> {
143 vec![
144 Self::from_array([0.0, 0.0, 0.0]),
145 Self::from_array([0.0, 90.0, 0.0]),
146 Self::from_array([0.0, 180.0, 0.0]),
147 ]
148 }
149
150 pub fn tbp_known_orientations() -> Vec<Self> {
153 vec![
154 Self::from_array([0.0, 0.0, 0.0]), Self::from_array([0.0, 90.0, 0.0]), Self::from_array([0.0, 180.0, 0.0]), Self::from_array([0.0, 270.0, 0.0]), Self::from_array([90.0, 0.0, 0.0]), Self::from_array([-90.0, 0.0, 0.0]), Self::from_array([45.0, 45.0, 0.0]),
163 Self::from_array([45.0, 135.0, 0.0]),
164 Self::from_array([45.0, 225.0, 0.0]),
165 Self::from_array([45.0, 315.0, 0.0]),
166 Self::from_array([-45.0, 45.0, 0.0]),
167 Self::from_array([-45.0, 135.0, 0.0]),
168 Self::from_array([-45.0, 225.0, 0.0]),
169 Self::from_array([-45.0, 315.0, 0.0]),
170 ]
171 }
172
173 pub fn to_quat(&self) -> Quat {
175 Quat::from_euler(
176 EulerRot::XYZ,
177 self.pitch.to_radians(),
178 self.yaw.to_radians(),
179 self.roll.to_radians(),
180 )
181 }
182
183 pub fn to_transform(&self) -> Transform {
185 Transform::from_rotation(self.to_quat())
186 }
187}
188
189impl Default for ObjectRotation {
190 fn default() -> Self {
191 Self::identity()
192 }
193}
194
195#[derive(Clone, Debug)]
198pub struct ViewpointConfig {
199 pub radius: f32,
201 pub yaw_count: usize,
203 pub pitch_angles_deg: Vec<f32>,
205}
206
207impl Default for ViewpointConfig {
208 fn default() -> Self {
209 Self {
210 radius: 0.5,
211 yaw_count: 8,
212 pitch_angles_deg: vec![-30.0, 0.0, 30.0],
215 }
216 }
217}
218
219impl ViewpointConfig {
220 pub fn viewpoint_count(&self) -> usize {
222 self.yaw_count * self.pitch_angles_deg.len()
223 }
224}
225
226#[derive(Clone, Debug, Resource)]
228pub struct SensorConfig {
229 pub viewpoints: ViewpointConfig,
231 pub object_rotations: Vec<ObjectRotation>,
233 pub output_dir: String,
235 pub filename_pattern: String,
237}
238
239impl Default for SensorConfig {
240 fn default() -> Self {
241 Self {
242 viewpoints: ViewpointConfig::default(),
243 object_rotations: vec![ObjectRotation::identity()],
244 output_dir: ".".to_string(),
245 filename_pattern: "capture_{rot}_{view}.png".to_string(),
246 }
247 }
248}
249
250impl SensorConfig {
251 pub fn tbp_benchmark() -> Self {
253 Self {
254 viewpoints: ViewpointConfig::default(),
255 object_rotations: ObjectRotation::tbp_benchmark_rotations(),
256 output_dir: ".".to_string(),
257 filename_pattern: "capture_{rot}_{view}.png".to_string(),
258 }
259 }
260
261 pub fn tbp_full_training() -> Self {
263 Self {
264 viewpoints: ViewpointConfig::default(),
265 object_rotations: ObjectRotation::tbp_known_orientations(),
266 output_dir: ".".to_string(),
267 filename_pattern: "capture_{rot}_{view}.png".to_string(),
268 }
269 }
270
271 pub fn total_captures(&self) -> usize {
273 self.viewpoints.viewpoint_count() * self.object_rotations.len()
274 }
275}
276
277pub fn generate_viewpoints(config: &ViewpointConfig) -> Vec<Transform> {
284 let mut views = Vec::with_capacity(config.viewpoint_count());
285
286 for pitch_deg in &config.pitch_angles_deg {
287 let pitch = pitch_deg.to_radians();
288
289 for i in 0..config.yaw_count {
290 let yaw = (i as f32) * 2.0 * PI / (config.yaw_count as f32);
291
292 let x = config.radius * pitch.cos() * yaw.sin();
297 let y = config.radius * pitch.sin();
298 let z = config.radius * pitch.cos() * yaw.cos();
299
300 let transform = Transform::from_xyz(x, y, z).looking_at(Vec3::ZERO, Vec3::Y);
301 views.push(transform);
302 }
303 }
304 views
305}
306
307#[derive(Component)]
309pub struct CaptureTarget;
310
311#[derive(Component)]
313pub struct CaptureCamera;
314
315pub use bevy::prelude::{Quat, Transform, Vec3};
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_object_rotation_identity() {
324 let rot = ObjectRotation::identity();
325 assert_eq!(rot.pitch, 0.0);
326 assert_eq!(rot.yaw, 0.0);
327 assert_eq!(rot.roll, 0.0);
328 }
329
330 #[test]
331 fn test_object_rotation_from_array() {
332 let rot = ObjectRotation::from_array([10.0, 20.0, 30.0]);
333 assert_eq!(rot.pitch, 10.0);
334 assert_eq!(rot.yaw, 20.0);
335 assert_eq!(rot.roll, 30.0);
336 }
337
338 #[test]
339 fn test_tbp_benchmark_rotations() {
340 let rotations = ObjectRotation::tbp_benchmark_rotations();
341 assert_eq!(rotations.len(), 3);
342 assert_eq!(rotations[0], ObjectRotation::from_array([0.0, 0.0, 0.0]));
343 assert_eq!(rotations[1], ObjectRotation::from_array([0.0, 90.0, 0.0]));
344 assert_eq!(rotations[2], ObjectRotation::from_array([0.0, 180.0, 0.0]));
345 }
346
347 #[test]
348 fn test_tbp_known_orientations_count() {
349 let orientations = ObjectRotation::tbp_known_orientations();
350 assert_eq!(orientations.len(), 14);
351 }
352
353 #[test]
354 fn test_rotation_to_quat() {
355 let rot = ObjectRotation::identity();
356 let quat = rot.to_quat();
357 assert!((quat.w - 1.0).abs() < 0.001);
359 assert!(quat.x.abs() < 0.001);
360 assert!(quat.y.abs() < 0.001);
361 assert!(quat.z.abs() < 0.001);
362 }
363
364 #[test]
365 fn test_rotation_90_yaw() {
366 let rot = ObjectRotation::new(0.0, 90.0, 0.0);
367 let quat = rot.to_quat();
368 assert!((quat.w - 0.707).abs() < 0.01);
370 assert!((quat.y - 0.707).abs() < 0.01);
371 }
372
373 #[test]
374 fn test_viewpoint_config_default() {
375 let config = ViewpointConfig::default();
376 assert_eq!(config.radius, 0.5);
377 assert_eq!(config.yaw_count, 8);
378 assert_eq!(config.pitch_angles_deg.len(), 3);
379 }
380
381 #[test]
382 fn test_viewpoint_count() {
383 let config = ViewpointConfig::default();
384 assert_eq!(config.viewpoint_count(), 24); }
386
387 #[test]
388 fn test_generate_viewpoints_count() {
389 let config = ViewpointConfig::default();
390 let viewpoints = generate_viewpoints(&config);
391 assert_eq!(viewpoints.len(), 24);
392 }
393
394 #[test]
395 fn test_viewpoints_spherical_radius() {
396 let config = ViewpointConfig::default();
397 let viewpoints = generate_viewpoints(&config);
398
399 for (i, transform) in viewpoints.iter().enumerate() {
400 let actual_radius = transform.translation.length();
401 assert!(
402 (actual_radius - config.radius).abs() < 0.001,
403 "Viewpoint {} has incorrect radius: {} (expected {})",
404 i,
405 actual_radius,
406 config.radius
407 );
408 }
409 }
410
411 #[test]
412 fn test_viewpoints_looking_at_origin() {
413 let config = ViewpointConfig::default();
414 let viewpoints = generate_viewpoints(&config);
415
416 for (i, transform) in viewpoints.iter().enumerate() {
417 let forward = transform.forward();
418 let to_origin = (Vec3::ZERO - transform.translation).normalize();
419 let dot = forward.dot(to_origin);
420 assert!(
421 dot > 0.99,
422 "Viewpoint {} not looking at origin, dot product: {}",
423 i,
424 dot
425 );
426 }
427 }
428
429 #[test]
430 fn test_sensor_config_default() {
431 let config = SensorConfig::default();
432 assert_eq!(config.object_rotations.len(), 1);
433 assert_eq!(config.total_captures(), 24);
434 }
435
436 #[test]
437 fn test_sensor_config_tbp_benchmark() {
438 let config = SensorConfig::tbp_benchmark();
439 assert_eq!(config.object_rotations.len(), 3);
440 assert_eq!(config.total_captures(), 72); }
442
443 #[test]
444 fn test_sensor_config_tbp_full() {
445 let config = SensorConfig::tbp_full_training();
446 assert_eq!(config.object_rotations.len(), 14);
447 assert_eq!(config.total_captures(), 336); }
449
450 #[test]
451 fn test_ycb_representative_objects() {
452 assert_eq!(crate::ycb::REPRESENTATIVE_OBJECTS.len(), 3);
454 assert!(crate::ycb::REPRESENTATIVE_OBJECTS.contains(&"003_cracker_box"));
455 }
456
457 #[test]
458 fn test_ycb_ten_objects() {
459 assert_eq!(crate::ycb::TEN_OBJECTS.len(), 10);
461 }
462
463 #[test]
464 fn test_ycb_object_mesh_path() {
465 let path = crate::ycb::object_mesh_path("/tmp/ycb", "003_cracker_box");
466 assert_eq!(
467 path.to_string_lossy(),
468 "/tmp/ycb/003_cracker_box/google_16k/textured.obj"
469 );
470 }
471
472 #[test]
473 fn test_ycb_object_texture_path() {
474 let path = crate::ycb::object_texture_path("/tmp/ycb", "003_cracker_box");
475 assert_eq!(
476 path.to_string_lossy(),
477 "/tmp/ycb/003_cracker_box/google_16k/texture_map.png"
478 );
479 }
480}