use bevy::prelude::*;
use serde::{Deserialize, Serialize};
use std::f32::consts::PI;
use std::path::{Path, PathBuf};
mod render;
pub mod batch;
pub mod benchmark;
pub mod backend;
pub mod cache;
pub mod fixtures;
pub const RENDERER_POLICY_VERSION: &str = "tbp-targeting-v1";
pub use ycbust::{
self, DownloadOptions, Subset as YcbSubset, GOOGLE_16K_MESH_RELATIVE, REPRESENTATIVE_OBJECTS,
TBP_SIMILAR_OBJECTS, TBP_STANDARD_OBJECTS,
};
pub mod ycb {
pub use ycbust::{
download_ycb, DownloadOptions, Subset, REPRESENTATIVE_OBJECTS, TBP_SIMILAR_OBJECTS,
TBP_STANDARD_OBJECTS,
};
use std::path::Path;
pub async fn download_models<P: AsRef<Path>>(
output_dir: P,
subset: Subset,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
download_ycb(subset, output_dir.as_ref(), DownloadOptions::default()).await?;
Ok(())
}
pub async fn download_models_with_options<P: AsRef<Path>>(
output_dir: P,
subset: Subset,
options: DownloadOptions,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
download_ycb(subset, output_dir.as_ref(), options).await?;
Ok(())
}
pub async fn download_objects<P: AsRef<Path>>(
output_dir: P,
object_ids: &[&str],
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
ycbust::download_objects(object_ids, output_dir.as_ref(), DownloadOptions::default())
.await?;
Ok(())
}
pub fn missing_objects<P: AsRef<Path>>(output_dir: P, object_ids: &[&str]) -> Vec<String> {
ycbust::validate_objects(output_dir.as_ref(), object_ids)
.into_iter()
.filter(|validation| !validation.is_complete())
.map(|validation| validation.name)
.collect()
}
pub fn objects_exist<P: AsRef<Path>>(output_dir: P, object_ids: &[&str]) -> bool {
missing_objects(output_dir, object_ids).is_empty()
}
pub fn models_exist<P: AsRef<Path>>(output_dir: P) -> bool {
objects_exist(output_dir, REPRESENTATIVE_OBJECTS)
}
pub fn object_mesh_path<P: AsRef<Path>>(output_dir: P, object_id: &str) -> std::path::PathBuf {
ycbust::object_mesh_path(output_dir.as_ref(), object_id)
}
pub fn object_texture_path<P: AsRef<Path>>(
output_dir: P,
object_id: &str,
) -> std::path::PathBuf {
ycbust::object_texture_path(output_dir.as_ref(), object_id)
}
}
pub fn initialize() {
use std::sync::atomic::{AtomicBool, Ordering};
static INITIALIZED: AtomicBool = AtomicBool::new(false);
if !INITIALIZED.swap(true, Ordering::SeqCst) {
let config = backend::BackendConfig::new();
config.apply_env();
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ObjectRotation {
pub pitch: f64,
pub yaw: f64,
pub roll: f64,
}
impl ObjectRotation {
pub fn new(pitch: f64, yaw: f64, roll: f64) -> Self {
Self { pitch, yaw, roll }
}
pub fn from_array(arr: [f64; 3]) -> Self {
Self {
pitch: arr[0],
yaw: arr[1],
roll: arr[2],
}
}
pub fn identity() -> Self {
Self::new(0.0, 0.0, 0.0)
}
pub fn tbp_benchmark_rotations() -> Vec<Self> {
vec![
Self::from_array([0.0, 0.0, 0.0]),
Self::from_array([0.0, 90.0, 0.0]),
Self::from_array([0.0, 180.0, 0.0]),
]
}
pub fn tbp_known_orientations() -> Vec<Self> {
vec![
Self::from_array([0.0, 0.0, 0.0]), Self::from_array([0.0, 90.0, 0.0]), Self::from_array([0.0, 180.0, 0.0]), Self::from_array([0.0, 270.0, 0.0]), Self::from_array([90.0, 0.0, 0.0]), Self::from_array([-90.0, 0.0, 0.0]), Self::from_array([45.0, 45.0, 0.0]),
Self::from_array([45.0, 135.0, 0.0]),
Self::from_array([45.0, 225.0, 0.0]),
Self::from_array([45.0, 315.0, 0.0]),
Self::from_array([-45.0, 45.0, 0.0]),
Self::from_array([-45.0, 135.0, 0.0]),
Self::from_array([-45.0, 225.0, 0.0]),
Self::from_array([-45.0, 315.0, 0.0]),
]
}
pub fn to_quat(&self) -> Quat {
Quat::from_euler(
EulerRot::XYZ,
(self.pitch as f32).to_radians(),
(self.yaw as f32).to_radians(),
(self.roll as f32).to_radians(),
)
}
pub fn to_transform(&self) -> Transform {
Transform::from_rotation(self.to_quat())
}
pub fn to_transform_with_translation_scale(&self, translation: Vec3, scale: Vec3) -> Transform {
Transform {
translation,
rotation: self.to_quat(),
scale,
}
}
}
impl Default for ObjectRotation {
fn default() -> Self {
Self::identity()
}
}
#[derive(Clone, Debug)]
pub struct ViewpointConfig {
pub radius: f32,
pub yaw_count: usize,
pub pitch_angles_deg: Vec<f32>,
}
impl Default for ViewpointConfig {
fn default() -> Self {
Self {
radius: 0.5,
yaw_count: 8,
pitch_angles_deg: vec![-30.0, 0.0, 30.0],
}
}
}
impl ViewpointConfig {
pub fn viewpoint_count(&self) -> usize {
self.yaw_count * self.pitch_angles_deg.len()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MeshBounds {
pub min: Vec3,
pub max: Vec3,
pub center: Vec3,
pub vertex_count: usize,
}
impl MeshBounds {
pub fn extents(&self) -> Vec3 {
self.max - self.min
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "policy", content = "target", rename_all = "snake_case")]
pub enum TargetingPolicy {
Origin,
MeshCenter,
ExplicitTarget([f32; 3]),
}
impl TargetingPolicy {
pub fn label(&self) -> &'static str {
match self {
TargetingPolicy::Origin => "origin",
TargetingPolicy::MeshCenter => "mesh-center",
TargetingPolicy::ExplicitTarget(_) => "explicit-target",
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct TargetedViewpoints {
pub policy: TargetingPolicy,
pub target_point: Vec3,
pub mesh_bounds: Option<MeshBounds>,
pub viewpoints: Vec<Transform>,
}
#[derive(Clone, Debug, Resource)]
pub struct SensorConfig {
pub viewpoints: ViewpointConfig,
pub object_rotations: Vec<ObjectRotation>,
pub output_dir: String,
pub filename_pattern: String,
}
impl Default for SensorConfig {
fn default() -> Self {
Self {
viewpoints: ViewpointConfig::default(),
object_rotations: vec![ObjectRotation::identity()],
output_dir: ".".to_string(),
filename_pattern: "capture_{rot}_{view}.png".to_string(),
}
}
}
impl SensorConfig {
pub fn tbp_benchmark() -> Self {
Self {
viewpoints: ViewpointConfig::default(),
object_rotations: ObjectRotation::tbp_benchmark_rotations(),
output_dir: ".".to_string(),
filename_pattern: "capture_{rot}_{view}.png".to_string(),
}
}
pub fn tbp_full_training() -> Self {
Self {
viewpoints: ViewpointConfig::default(),
object_rotations: ObjectRotation::tbp_known_orientations(),
output_dir: ".".to_string(),
filename_pattern: "capture_{rot}_{view}.png".to_string(),
}
}
pub fn total_captures(&self) -> usize {
self.viewpoints.viewpoint_count() * self.object_rotations.len()
}
}
pub fn generate_viewpoints(config: &ViewpointConfig) -> Vec<Transform> {
generate_viewpoints_around_target(config, Vec3::ZERO)
}
pub fn generate_viewpoints_around_target(config: &ViewpointConfig, target: Vec3) -> Vec<Transform> {
let mut views = Vec::with_capacity(config.viewpoint_count());
for pitch_deg in &config.pitch_angles_deg {
let pitch = pitch_deg.to_radians();
for i in 0..config.yaw_count {
let yaw = (i as f32) * 2.0 * PI / (config.yaw_count as f32);
let x = config.radius * pitch.cos() * yaw.sin();
let y = config.radius * pitch.sin();
let z = config.radius * pitch.cos() * yaw.cos();
let translation = target + Vec3::new(x, y, z);
let transform = Transform::from_translation(translation).looking_at(target, Vec3::Y);
views.push(transform);
}
}
views
}
pub fn rotated_mesh_center(mesh_center: Vec3, object_rotation: &ObjectRotation) -> Vec3 {
object_rotation.to_quat() * mesh_center
}
pub fn generate_object_centered_viewpoints(
config: &ViewpointConfig,
mesh_center: Vec3,
object_rotation: &ObjectRotation,
) -> Vec<Transform> {
generate_viewpoints_around_target(config, rotated_mesh_center(mesh_center, object_rotation))
}
pub fn load_mesh_bounds(mesh_path: &Path) -> Result<MeshBounds, RenderError> {
if !mesh_path.exists() {
return Err(RenderError::MeshNotFound(mesh_path.display().to_string()));
}
let (models, _) = tobj::load_obj(
mesh_path,
&tobj::LoadOptions {
triangulate: false,
single_index: true,
..Default::default()
},
)
.map_err(|err| {
RenderError::DataParsingError(format!(
"Failed to parse OBJ mesh {}: {}",
mesh_path.display(),
err
))
})?;
let mut min = Vec3::splat(f32::INFINITY);
let mut max = Vec3::splat(f32::NEG_INFINITY);
let mut vertex_count = 0usize;
for model in models {
for vertex in model.mesh.positions.chunks_exact(3) {
let point = Vec3::new(vertex[0], vertex[1], vertex[2]);
min = min.min(point);
max = max.max(point);
vertex_count += 1;
}
}
if vertex_count == 0 {
return Err(RenderError::DataParsingError(format!(
"OBJ mesh {} contains no vertices",
mesh_path.display()
)));
}
Ok(MeshBounds {
min,
max,
center: (min + max) * 0.5,
vertex_count,
})
}
pub fn load_ycb_mesh_bounds(object_dir: &Path) -> Result<MeshBounds, RenderError> {
load_mesh_bounds(&object_dir.join(GOOGLE_16K_MESH_RELATIVE))
}
pub fn generate_ycb_object_centered_viewpoints(
object_dir: &Path,
config: &ViewpointConfig,
object_rotation: &ObjectRotation,
) -> Result<Vec<Transform>, RenderError> {
let bounds = load_ycb_mesh_bounds(object_dir)?;
Ok(generate_object_centered_viewpoints(
config,
bounds.center,
object_rotation,
))
}
pub fn generate_targeted_viewpoints(
object_dir: &Path,
config: &ViewpointConfig,
object_rotation: &ObjectRotation,
policy: &TargetingPolicy,
) -> Result<TargetedViewpoints, RenderError> {
match policy {
TargetingPolicy::Origin => Ok(TargetedViewpoints {
policy: policy.clone(),
target_point: Vec3::ZERO,
mesh_bounds: None,
viewpoints: generate_viewpoints(config),
}),
TargetingPolicy::MeshCenter => {
let bounds = load_ycb_mesh_bounds(object_dir)?;
let target_point = rotated_mesh_center(bounds.center, object_rotation);
Ok(TargetedViewpoints {
policy: policy.clone(),
target_point,
mesh_bounds: Some(bounds),
viewpoints: generate_viewpoints_around_target(config, target_point),
})
}
TargetingPolicy::ExplicitTarget(target) => {
let target_point = Vec3::from_array(*target);
Ok(TargetedViewpoints {
policy: policy.clone(),
target_point,
mesh_bounds: None,
viewpoints: generate_viewpoints_around_target(config, target_point),
})
}
}
}
#[derive(Component)]
pub struct CaptureTarget;
#[derive(Component)]
pub struct CaptureCamera;
#[derive(Clone, Debug, PartialEq)]
pub struct RenderConfig {
pub width: u32,
pub height: u32,
pub zoom: f32,
pub near_plane: f32,
pub far_plane: f32,
pub lighting: LightingConfig,
}
#[derive(Clone, Debug, PartialEq)]
pub struct LightingConfig {
pub ambient_brightness: f32,
pub key_light_intensity: f32,
pub key_light_position: [f32; 3],
pub fill_light_intensity: f32,
pub fill_light_position: [f32; 3],
pub shadows_enabled: bool,
}
impl Default for LightingConfig {
fn default() -> Self {
Self {
ambient_brightness: 0.3,
key_light_intensity: 1500.0,
key_light_position: [4.0, 8.0, 4.0],
fill_light_intensity: 500.0,
fill_light_position: [-4.0, 2.0, -4.0],
shadows_enabled: false,
}
}
}
impl LightingConfig {
pub fn bright() -> Self {
Self {
ambient_brightness: 0.5,
key_light_intensity: 2000.0,
key_light_position: [4.0, 8.0, 4.0],
fill_light_intensity: 800.0,
fill_light_position: [-4.0, 2.0, -4.0],
shadows_enabled: false,
}
}
pub fn soft() -> Self {
Self {
ambient_brightness: 0.4,
key_light_intensity: 1000.0,
key_light_position: [3.0, 6.0, 3.0],
fill_light_intensity: 600.0,
fill_light_position: [-3.0, 3.0, -3.0],
shadows_enabled: false,
}
}
pub fn unlit() -> Self {
Self {
ambient_brightness: 1.0,
key_light_intensity: 0.0,
key_light_position: [0.0, 0.0, 0.0],
fill_light_intensity: 0.0,
fill_light_position: [0.0, 0.0, 0.0],
shadows_enabled: false,
}
}
}
impl Default for RenderConfig {
fn default() -> Self {
Self::tbp_default()
}
}
impl RenderConfig {
pub fn tbp_default() -> Self {
Self {
width: 64,
height: 64,
zoom: 4.0,
near_plane: 0.01,
far_plane: 10.0,
lighting: LightingConfig::default(),
}
}
pub fn preview() -> Self {
Self {
width: 256,
height: 256,
zoom: 1.0,
near_plane: 0.01,
far_plane: 10.0,
lighting: LightingConfig::default(),
}
}
pub fn high_res() -> Self {
Self {
width: 512,
height: 512,
zoom: 1.0,
near_plane: 0.01,
far_plane: 10.0,
lighting: LightingConfig::default(),
}
}
pub fn fov_radians(&self) -> f32 {
let base_hfov_rad = 90.0_f32.to_radians();
let half_tan = (base_hfov_rad / 2.0).tan() / self.zoom;
2.0 * half_tan.atan()
}
pub fn intrinsics(&self) -> CameraIntrinsics {
self.intrinsics_for_size(self.width, self.height)
}
pub fn intrinsics_for_size(&self, width: u32, height: u32) -> CameraIntrinsics {
let base_hfov_rad = 90.0_f64.to_radians();
let fx_norm = (base_hfov_rad / 2.0).tan() / self.zoom as f64;
let fx = (width as f64 / 2.0) / fx_norm;
let fy = fx;
CameraIntrinsics {
focal_length: [fx, fy],
principal_point: [width as f64 / 2.0, height as f64 / 2.0],
image_size: [width, height],
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct CameraIntrinsics {
pub focal_length: [f64; 2],
pub principal_point: [f64; 2],
pub image_size: [u32; 2],
}
impl CameraIntrinsics {
pub fn project(&self, point: Vec3) -> Option<[f64; 2]> {
if point.z <= 0.0 {
return None;
}
let x = (point.x as f64 / point.z as f64) * self.focal_length[0] + self.principal_point[0];
let y = (point.y as f64 / point.z as f64) * self.focal_length[1] + self.principal_point[1];
Some([x, y])
}
pub fn unproject(&self, pixel: [f64; 2], depth: f64) -> [f64; 3] {
let x = (pixel[0] - self.principal_point[0]) / self.focal_length[0] * depth;
let y = (pixel[1] - self.principal_point[1]) / self.focal_length[1] * depth;
[x, y, depth]
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct RenderHealth {
pub center_pixel: Option<[u32; 2]>,
pub center_depth: Option<f64>,
pub center_foreground: bool,
pub foreground_pixel_count: usize,
pub foreground_coverage: f64,
pub center_5x5_foreground_count: usize,
pub nearest_foreground_pixel: Option<[u32; 2]>,
pub nearest_foreground_depth: Option<f64>,
pub nearest_foreground_distance_px: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct RenderOutput {
pub rgba: Vec<u8>,
pub depth: Vec<f64>,
pub width: u32,
pub height: u32,
pub intrinsics: CameraIntrinsics,
pub camera_transform: Transform,
pub object_rotation: ObjectRotation,
pub object_translation: Vec3,
pub object_scale: Vec3,
pub target_point: Vec3,
pub targeting_policy: TargetingPolicy,
}
pub(crate) fn semantic_3d_from_depth(
depth: &[f64],
width: u32,
height: u32,
intrinsics: &CameraIntrinsics,
camera_transform: Transform,
object_semantic_id: u32,
far_plane: f64,
) -> Vec<[f64; 4]> {
let total_pixels = (width as usize).saturating_mul(height as usize);
let mut rows = Vec::with_capacity(total_pixels);
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) as usize;
let Some(&pixel_depth) = depth.get(idx) else {
rows.push([0.0, 0.0, 0.0, 0.0]);
continue;
};
let Some(world) = pixel_surface_point_world_from_parts(
pixel_depth,
[x, y],
intrinsics,
camera_transform,
far_plane,
) else {
rows.push([0.0, 0.0, 0.0, 0.0]);
continue;
};
rows.push([world[0], world[1], world[2], object_semantic_id as f64]);
}
}
rows
}
fn pixel_surface_point_world_from_parts(
depth: f64,
pixel: [u32; 2],
intrinsics: &CameraIntrinsics,
camera_transform: Transform,
far_plane: f64,
) -> Option<[f64; 3]> {
if !RenderOutput::is_foreground_depth(depth, far_plane) {
return None;
}
let fx = intrinsics.focal_length[0];
let fy = intrinsics.focal_length[1];
if !fx.is_finite() || !fy.is_finite() || fx.abs() <= f64::EPSILON || fy.abs() <= f64::EPSILON {
return None;
}
let [x, y] = pixel;
let camera_x = (x as f64 - intrinsics.principal_point[0]) / fx * depth;
let camera_y = -((y as f64 - intrinsics.principal_point[1]) / fy * depth);
let point = Vec3::new(camera_x as f32, camera_y as f32, -depth as f32);
let world = camera_transform.translation + camera_transform.rotation * point;
Some([world.x as f64, world.y as f64, world.z as f64])
}
impl RenderOutput {
pub const TBP_FAR_PLANE_METERS: f64 = 10.0;
pub fn with_targeting(mut self, target_point: Vec3, targeting_policy: TargetingPolicy) -> Self {
self.target_point = target_point;
self.targeting_policy = targeting_policy;
self
}
pub fn with_object_transform(mut self, object_translation: Vec3, object_scale: Vec3) -> Self {
self.object_translation = object_translation;
self.object_scale = object_scale;
self
}
pub fn get_rgba(&self, x: u32, y: u32) -> Option<[u8; 4]> {
if x >= self.width || y >= self.height {
return None;
}
let idx = ((y * self.width + x) * 4) as usize;
Some([
self.rgba[idx],
self.rgba[idx + 1],
self.rgba[idx + 2],
self.rgba[idx + 3],
])
}
pub fn get_depth(&self, x: u32, y: u32) -> Option<f64> {
if x >= self.width || y >= self.height {
return None;
}
let idx = (y * self.width + x) as usize;
Some(self.depth[idx])
}
pub fn get_rgb(&self, x: u32, y: u32) -> Option<[u8; 3]> {
self.get_rgba(x, y).map(|rgba| [rgba[0], rgba[1], rgba[2]])
}
pub fn center_pixel(&self) -> Option<[u32; 2]> {
if self.width == 0 || self.height == 0 {
return None;
}
let x = self.intrinsics.principal_point[0]
.round()
.clamp(0.0, (self.width - 1) as f64) as u32;
let y = self.intrinsics.principal_point[1]
.round()
.clamp(0.0, (self.height - 1) as f64) as u32;
Some([x, y])
}
pub fn center_pixel_raw_depth(&self) -> Option<f64> {
let [x, y] = self.center_pixel()?;
self.get_depth(x, y)
}
pub fn center_pixel_depth(&self) -> Option<f64> {
self.center_pixel_depth_with_far_plane(Self::TBP_FAR_PLANE_METERS)
}
pub fn center_pixel_depth_with_far_plane(&self, far_plane: f64) -> Option<f64> {
self.center_pixel_raw_depth()
.filter(|depth| Self::is_foreground_depth(*depth, far_plane))
}
pub fn is_foreground_depth(depth: f64, far_plane: f64) -> bool {
depth.is_finite() && depth > 0.0 && far_plane.is_finite() && depth < far_plane * 0.999
}
pub fn health(&self) -> RenderHealth {
self.health_with_far_plane(Self::TBP_FAR_PLANE_METERS)
}
pub fn health_with_far_plane(&self, far_plane: f64) -> RenderHealth {
let center_pixel = self.center_pixel();
let center_depth = self.center_pixel_raw_depth();
let center_foreground = center_depth
.map(|depth| Self::is_foreground_depth(depth, far_plane))
.unwrap_or(false);
let total_pixels = (self.width as usize).saturating_mul(self.height as usize);
let mut foreground_pixel_count = 0usize;
let mut center_5x5_foreground_count = 0usize;
let mut nearest_foreground_pixel = None;
let mut nearest_foreground_depth = None;
let mut nearest_foreground_distance_px = None;
for y in 0..self.height {
for x in 0..self.width {
let Some(depth) = self.get_depth(x, y) else {
continue;
};
if !Self::is_foreground_depth(depth, far_plane) {
continue;
}
foreground_pixel_count += 1;
if let Some([cx, cy]) = center_pixel {
let dx = x as i64 - cx as i64;
let dy = y as i64 - cy as i64;
if dx.abs() <= 2 && dy.abs() <= 2 {
center_5x5_foreground_count += 1;
}
let distance = ((dx * dx + dy * dy) as f64).sqrt();
if nearest_foreground_distance_px
.map(|current| distance < current)
.unwrap_or(true)
{
nearest_foreground_pixel = Some([x, y]);
nearest_foreground_depth = Some(depth);
nearest_foreground_distance_px = Some(distance);
}
}
}
}
RenderHealth {
center_pixel,
center_depth,
center_foreground,
foreground_pixel_count,
foreground_coverage: if total_pixels > 0 {
foreground_pixel_count as f64 / total_pixels as f64
} else {
0.0
},
center_5x5_foreground_count,
nearest_foreground_pixel,
nearest_foreground_depth,
nearest_foreground_distance_px,
}
}
pub fn camera_to_world_point(&self, camera_point: [f64; 3]) -> [f64; 3] {
let point = Vec3::new(
camera_point[0] as f32,
camera_point[1] as f32,
camera_point[2] as f32,
);
let rotated = self.camera_transform.rotation * point;
let translated = self.camera_transform.translation + rotated;
[
translated.x as f64,
translated.y as f64,
translated.z as f64,
]
}
pub fn world_to_camera_point(&self, world_point: [f64; 3]) -> [f64; 3] {
let point = Vec3::new(
world_point[0] as f32,
world_point[1] as f32,
world_point[2] as f32,
);
let relative = point - self.camera_transform.translation;
let camera_point = self.camera_transform.rotation.inverse() * relative;
[
camera_point.x as f64,
camera_point.y as f64,
camera_point.z as f64,
]
}
pub fn center_surface_point_world(&self) -> Option<[f64; 3]> {
self.center_surface_point_world_with_far_plane(Self::TBP_FAR_PLANE_METERS)
}
pub fn center_surface_point_world_with_far_plane(&self, far_plane: f64) -> Option<[f64; 3]> {
let [x, y] = self.center_pixel()?;
self.pixel_surface_point_world_with_far_plane([x, y], far_plane)
}
pub fn pixel_surface_point_world(&self, pixel: [u32; 2]) -> Option<[f64; 3]> {
self.pixel_surface_point_world_with_far_plane(pixel, Self::TBP_FAR_PLANE_METERS)
}
pub fn pixel_surface_point_world_with_far_plane(
&self,
pixel: [u32; 2],
far_plane: f64,
) -> Option<[f64; 3]> {
let [x, y] = pixel;
let depth = self.get_depth(x, y)?;
pixel_surface_point_world_from_parts(
depth,
pixel,
&self.intrinsics,
self.camera_transform,
far_plane,
)
}
pub fn semantic_3d(&self, object_semantic_id: u32) -> Vec<[f64; 4]> {
self.semantic_3d_with_far_plane(object_semantic_id, Self::TBP_FAR_PLANE_METERS)
}
pub fn semantic_3d_with_far_plane(
&self,
object_semantic_id: u32,
far_plane: f64,
) -> Vec<[f64; 4]> {
semantic_3d_from_depth(
&self.depth,
self.width,
self.height,
&self.intrinsics,
self.camera_transform,
object_semantic_id,
far_plane,
)
}
pub fn to_rgb_image(&self) -> Vec<Vec<[u8; 3]>> {
let mut image = Vec::with_capacity(self.height as usize);
for y in 0..self.height {
let mut row = Vec::with_capacity(self.width as usize);
for x in 0..self.width {
row.push(self.get_rgb(x, y).unwrap_or([0, 0, 0]));
}
image.push(row);
}
image
}
pub fn to_depth_image(&self) -> Vec<Vec<f64>> {
let mut image = Vec::with_capacity(self.height as usize);
for y in 0..self.height {
let mut row = Vec::with_capacity(self.width as usize);
for x in 0..self.width {
row.push(self.get_depth(x, y).unwrap_or(0.0));
}
image.push(row);
}
image
}
}
#[derive(Debug, Clone)]
pub enum RenderError {
MeshNotFound(String),
TextureNotFound(String),
FileNotFound { path: String, reason: String },
FileWriteFailed { path: String, reason: String },
DirectoryCreationFailed { path: String, reason: String },
RenderFailed(String),
InvalidConfig(String),
InvalidInput(String),
SerializationError(String),
DataParsingError(String),
RenderTimeout { duration_secs: u64 },
}
impl std::fmt::Display for RenderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RenderError::MeshNotFound(path) => write!(f, "Mesh not found: {}", path),
RenderError::TextureNotFound(path) => write!(f, "Texture not found: {}", path),
RenderError::FileNotFound { path, reason } => {
write!(f, "File not found at {}: {}", path, reason)
}
RenderError::FileWriteFailed { path, reason } => {
write!(f, "Failed to write file {}: {}", path, reason)
}
RenderError::DirectoryCreationFailed { path, reason } => {
write!(f, "Failed to create directory {}: {}", path, reason)
}
RenderError::RenderFailed(msg) => write!(f, "Render failed: {}", msg),
RenderError::InvalidConfig(msg) => write!(f, "Invalid config: {}", msg),
RenderError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
RenderError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
RenderError::DataParsingError(msg) => write!(f, "Data parsing error: {}", msg),
RenderError::RenderTimeout { duration_secs } => {
write!(f, "Render timeout after {} seconds", duration_secs)
}
}
}
}
impl std::error::Error for RenderError {}
pub fn render_to_buffer(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
config: &RenderConfig,
) -> Result<RenderOutput, RenderError> {
render::render_headless(
object_dir,
camera_transform,
object_rotation,
Vec3::ZERO,
Vec3::ONE,
config,
)
}
pub fn render_to_buffer_with_object_transform(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
object_translation: Vec3,
object_scale: Vec3,
config: &RenderConfig,
) -> Result<RenderOutput, RenderError> {
render::render_headless(
object_dir,
camera_transform,
object_rotation,
object_translation,
object_scale,
config,
)
}
pub fn render_to_buffer_with_target(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
config: &RenderConfig,
target_point: Vec3,
targeting_policy: TargetingPolicy,
) -> Result<RenderOutput, RenderError> {
render_to_buffer(object_dir, camera_transform, object_rotation, config)
.map(|output| output.with_targeting(target_point, targeting_policy))
}
#[allow(clippy::too_many_arguments)]
pub fn render_to_buffer_with_target_and_object_transform(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
object_translation: Vec3,
object_scale: Vec3,
config: &RenderConfig,
target_point: Vec3,
targeting_policy: TargetingPolicy,
) -> Result<RenderOutput, RenderError> {
render_to_buffer_with_object_transform(
object_dir,
camera_transform,
object_rotation,
object_translation,
object_scale,
config,
)
.map(|output| output.with_targeting(target_point, targeting_policy))
}
pub fn render_all_viewpoints(
object_dir: &Path,
viewpoint_config: &ViewpointConfig,
rotations: &[ObjectRotation],
render_config: &RenderConfig,
) -> Result<Vec<RenderOutput>, RenderError> {
let viewpoints = generate_viewpoints(viewpoint_config);
let mut outputs = Vec::with_capacity(viewpoints.len() * rotations.len());
for rotation in rotations {
for viewpoint in &viewpoints {
let output = render_to_buffer(object_dir, viewpoint, rotation, render_config)?;
outputs.push(output);
}
}
Ok(outputs)
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CenterHitValidationReport {
pub object_id: String,
pub object_dir: String,
pub target_policy: TargetingPolicy,
pub rotations: Vec<CenterHitRotationReport>,
}
impl CenterHitValidationReport {
pub fn is_valid(&self) -> bool {
self.rotations
.iter()
.all(|rotation| rotation.center_hits > 0)
}
pub fn zero_hit_rotations(&self) -> Vec<usize> {
self.rotations
.iter()
.filter(|rotation| rotation.center_hits == 0)
.map(|rotation| rotation.rotation_index)
.collect()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CenterHitRotationReport {
pub rotation_index: usize,
pub rotation_euler: [f64; 3],
pub target_point: [f32; 3],
pub mesh_bounds: Option<MeshBoundsMetadata>,
pub total_viewpoints: usize,
pub center_hits: usize,
pub center_misses: usize,
pub misses: Vec<CenterHitMiss>,
}
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub struct MeshBoundsMetadata {
pub min: [f32; 3],
pub max: [f32; 3],
pub center: [f32; 3],
pub vertex_count: usize,
}
impl From<MeshBounds> for MeshBoundsMetadata {
fn from(bounds: MeshBounds) -> Self {
Self {
min: bounds.min.to_array(),
max: bounds.max.to_array(),
center: bounds.center.to_array(),
vertex_count: bounds.vertex_count,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CenterHitMiss {
pub viewpoint_index: usize,
pub camera_position: [f32; 3],
pub camera_rotation_xyzw: [f32; 4],
pub health: RenderHealth,
}
pub fn validate_center_hits(
object_id: impl Into<String>,
object_dir: &Path,
viewpoint_config: &ViewpointConfig,
rotations: &[ObjectRotation],
render_config: &RenderConfig,
target_policy: &TargetingPolicy,
) -> Result<CenterHitValidationReport, RenderError> {
let object_id = object_id.into();
let mut rotation_reports = Vec::with_capacity(rotations.len());
for (rotation_index, rotation) in rotations.iter().enumerate() {
let targeted =
generate_targeted_viewpoints(object_dir, viewpoint_config, rotation, target_policy)?;
let requests: Vec<batch::BatchRenderRequest> = targeted
.viewpoints
.iter()
.map(|viewpoint| batch::BatchRenderRequest {
object_dir: PathBuf::from(object_dir),
viewpoint: *viewpoint,
object_rotation: rotation.clone(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
render_config: render_config.clone(),
target_point: targeted.target_point,
targeting_policy: target_policy.clone(),
})
.collect();
let outputs = render_batch(requests, &batch::BatchRenderConfig::default())
.map_err(|error| RenderError::RenderFailed(error.to_string()))?;
let mut center_hits = 0usize;
let mut misses = Vec::new();
for (viewpoint_index, output) in outputs.iter().enumerate() {
if output.status != batch::RenderStatus::Success {
return Err(RenderError::RenderFailed(format!(
"Render failed for {} rotation {} viewpoint {}: {:?}",
object_id, rotation_index, viewpoint_index, output.error_message
)));
}
if output.health.center_foreground {
center_hits += 1;
} else {
let t = output.request.viewpoint.translation;
let q = output.request.viewpoint.rotation;
misses.push(CenterHitMiss {
viewpoint_index,
camera_position: [t.x, t.y, t.z],
camera_rotation_xyzw: [q.x, q.y, q.z, q.w],
health: output.health.clone(),
});
}
}
rotation_reports.push(CenterHitRotationReport {
rotation_index,
rotation_euler: [rotation.pitch, rotation.yaw, rotation.roll],
target_point: targeted.target_point.to_array(),
mesh_bounds: targeted.mesh_bounds.map(MeshBoundsMetadata::from),
total_viewpoints: outputs.len(),
center_hits,
center_misses: outputs.len().saturating_sub(center_hits),
misses,
});
}
Ok(CenterHitValidationReport {
object_id,
object_dir: object_dir.display().to_string(),
target_policy: target_policy.clone(),
rotations: rotation_reports,
})
}
pub fn render_to_buffer_cached(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
config: &RenderConfig,
cache: &mut cache::ModelCache,
) -> Result<RenderOutput, RenderError> {
render_to_buffer_cached_with_object_transform(
object_dir,
camera_transform,
object_rotation,
Vec3::ZERO,
Vec3::ONE,
config,
cache,
)
}
pub fn render_to_buffer_cached_with_object_transform(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
object_translation: Vec3,
object_scale: Vec3,
config: &RenderConfig,
cache: &mut cache::ModelCache,
) -> Result<RenderOutput, RenderError> {
let mesh_path = object_dir.join("google_16k/textured.obj");
let texture_path = object_dir.join("google_16k/texture_map.png");
cache.cache_scene(mesh_path.clone());
cache.cache_texture(texture_path.clone());
render::render_headless(
object_dir,
camera_transform,
object_rotation,
object_translation,
object_scale,
config,
)
}
pub fn render_to_files(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
config: &RenderConfig,
rgba_path: &Path,
depth_path: &Path,
) -> Result<(), RenderError> {
render_to_files_with_object_transform(
object_dir,
camera_transform,
object_rotation,
Vec3::ZERO,
Vec3::ONE,
config,
rgba_path,
depth_path,
)
}
#[allow(clippy::too_many_arguments)]
pub fn render_to_files_with_object_transform(
object_dir: &Path,
camera_transform: &Transform,
object_rotation: &ObjectRotation,
object_translation: Vec3,
object_scale: Vec3,
config: &RenderConfig,
rgba_path: &Path,
depth_path: &Path,
) -> Result<(), RenderError> {
render::render_to_files(
object_dir,
camera_transform,
object_rotation,
object_translation,
object_scale,
config,
rgba_path,
depth_path,
)
}
pub use batch::{
BatchRenderConfig, BatchRenderError, BatchRenderOutput, BatchRenderRequest, BatchRenderer,
BatchState, RenderStatus,
};
pub use render::RenderSession;
pub use render::PersistentRenderer;
pub fn create_batch_renderer(config: &BatchRenderConfig) -> Result<BatchRenderer, RenderError> {
Ok(BatchRenderer::new(config.clone()))
}
pub fn queue_render_request(
renderer: &mut BatchRenderer,
request: BatchRenderRequest,
) -> Result<(), RenderError> {
renderer
.queue_request(request)
.map_err(|e| RenderError::RenderFailed(e.to_string()))
}
pub fn render_next_in_batch(
renderer: &mut BatchRenderer,
_timeout_ms: u32,
) -> Result<Option<BatchRenderOutput>, RenderError> {
if let Some(request) = renderer.pending_requests.pop_front() {
let output = render_to_buffer_with_object_transform(
&request.object_dir,
&request.viewpoint,
&request.object_rotation,
request.object_translation,
request.object_scale,
&request.render_config,
)?;
let batch_output = BatchRenderOutput::from_render_output(request, output);
renderer.completed_results.push(batch_output.clone());
renderer.renders_processed += 1;
Ok(Some(batch_output))
} else {
Ok(None)
}
}
pub fn render_batch(
requests: Vec<BatchRenderRequest>,
config: &BatchRenderConfig,
) -> Result<Vec<BatchRenderOutput>, RenderError> {
if requests.is_empty() {
return Ok(Vec::new());
}
if requests.len() > 1 && requests_share_batch_context(&requests) {
let first_request = requests[0].clone();
let viewpoints: Vec<Transform> = requests.iter().map(|request| request.viewpoint).collect();
let outputs = render::render_headless_sequence(
&first_request.object_dir,
&viewpoints,
&first_request.object_rotation,
first_request.object_translation,
first_request.object_scale,
&first_request.render_config,
)?;
return Ok(requests
.into_iter()
.zip(outputs)
.map(|(request, output)| BatchRenderOutput::from_render_output(request, output))
.collect());
}
let mut renderer = create_batch_renderer(config)?;
for request in requests {
queue_render_request(&mut renderer, request)?;
}
let mut results = Vec::new();
while let Some(output) = render_next_in_batch(&mut renderer, config.frame_timeout_ms)? {
results.push(output);
}
Ok(results)
}
fn requests_share_batch_context(requests: &[BatchRenderRequest]) -> bool {
let Some(first) = requests.first() else {
return true;
};
requests.iter().all(|request| {
request.object_dir == first.object_dir
&& request.object_rotation == first.object_rotation
&& request.object_translation == first.object_translation
&& request.object_scale == first.object_scale
&& request.render_config == first.render_config
})
}
pub use bevy::prelude::{Quat, Transform, Vec3};
#[cfg(test)]
mod tests {
use super::*;
fn assert_vec3_close(actual: Vec3, expected: Vec3) {
assert!(
(actual - expected).length() < 1e-5,
"expected {:?}, got {:?}",
expected,
actual
);
}
fn assert_point_close(actual: [f64; 3], expected: [f64; 3]) {
for axis in 0..3 {
assert!(
(actual[axis] - expected[axis]).abs() < 1e-5,
"axis {} expected {:?}, got {:?}",
axis,
expected,
actual
);
}
}
fn render_output_for_depth(
width: u32,
height: u32,
depth: Vec<f64>,
intrinsics: CameraIntrinsics,
camera_transform: Transform,
) -> RenderOutput {
RenderOutput {
rgba: vec![0u8; (width * height * 4) as usize],
depth,
width,
height,
intrinsics,
camera_transform,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
}
}
#[test]
fn test_object_rotation_identity() {
let rot = ObjectRotation::identity();
assert_eq!(rot.pitch, 0.0);
assert_eq!(rot.yaw, 0.0);
assert_eq!(rot.roll, 0.0);
}
#[test]
fn test_object_rotation_from_array() {
let rot = ObjectRotation::from_array([10.0, 20.0, 30.0]);
assert_eq!(rot.pitch, 10.0);
assert_eq!(rot.yaw, 20.0);
assert_eq!(rot.roll, 30.0);
}
#[test]
fn test_requests_share_batch_context_for_homogeneous_batch() {
let config = RenderConfig::tbp_default();
let request = BatchRenderRequest {
object_dir: "/tmp/ycb/003_cracker_box".into(),
viewpoint: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
render_config: config.clone(),
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert!(requests_share_batch_context(&[
request.clone(),
BatchRenderRequest {
viewpoint: Transform::from_xyz(1.0, 0.0, 0.0),
..request
},
]));
}
#[test]
fn test_requests_share_batch_context_rejects_mixed_objects() {
let config = RenderConfig::tbp_default();
let request = BatchRenderRequest {
object_dir: "/tmp/ycb/003_cracker_box".into(),
viewpoint: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
render_config: config.clone(),
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert!(!requests_share_batch_context(&[
request.clone(),
BatchRenderRequest {
object_dir: "/tmp/ycb/005_tomato_soup_can".into(),
..request
},
]));
}
#[test]
fn test_requests_share_batch_context_rejects_mixed_object_translation() {
let config = RenderConfig::tbp_default();
let request = BatchRenderRequest {
object_dir: "/tmp/ycb/003_cracker_box".into(),
viewpoint: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
render_config: config.clone(),
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert!(!requests_share_batch_context(&[
request.clone(),
BatchRenderRequest {
object_translation: Vec3::new(0.1, 0.0, 0.0),
..request
},
]));
}
#[test]
fn test_requests_share_batch_context_rejects_mixed_object_scale() {
let config = RenderConfig::tbp_default();
let request = BatchRenderRequest {
object_dir: "/tmp/ycb/003_cracker_box".into(),
viewpoint: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
render_config: config.clone(),
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert!(!requests_share_batch_context(&[
request.clone(),
BatchRenderRequest {
object_scale: Vec3::splat(1.25),
..request
},
]));
}
#[test]
fn test_tbp_benchmark_rotations() {
let rotations = ObjectRotation::tbp_benchmark_rotations();
assert_eq!(rotations.len(), 3);
assert_eq!(rotations[0], ObjectRotation::from_array([0.0, 0.0, 0.0]));
assert_eq!(rotations[1], ObjectRotation::from_array([0.0, 90.0, 0.0]));
assert_eq!(rotations[2], ObjectRotation::from_array([0.0, 180.0, 0.0]));
}
#[test]
fn test_tbp_known_orientations_count() {
let orientations = ObjectRotation::tbp_known_orientations();
assert_eq!(orientations.len(), 14);
}
#[test]
fn test_rotation_to_quat() {
let rot = ObjectRotation::identity();
let quat = rot.to_quat();
assert!((quat.w - 1.0).abs() < 0.001);
assert!(quat.x.abs() < 0.001);
assert!(quat.y.abs() < 0.001);
assert!(quat.z.abs() < 0.001);
}
#[test]
fn test_rotation_90_yaw() {
let rot = ObjectRotation::new(0.0, 90.0, 0.0);
let quat = rot.to_quat();
assert!((quat.w - 0.707).abs() < 0.01);
assert!((quat.y - 0.707).abs() < 0.01);
}
#[test]
fn test_viewpoint_config_default() {
let config = ViewpointConfig::default();
assert_eq!(config.radius, 0.5);
assert_eq!(config.yaw_count, 8);
assert_eq!(config.pitch_angles_deg.len(), 3);
}
#[test]
fn test_viewpoint_count() {
let config = ViewpointConfig::default();
assert_eq!(config.viewpoint_count(), 24); }
#[test]
fn test_generate_viewpoints_count() {
let config = ViewpointConfig::default();
let viewpoints = generate_viewpoints(&config);
assert_eq!(viewpoints.len(), 24);
}
#[test]
fn test_viewpoints_spherical_radius() {
let config = ViewpointConfig::default();
let viewpoints = generate_viewpoints(&config);
for (i, transform) in viewpoints.iter().enumerate() {
let actual_radius = transform.translation.length();
assert!(
(actual_radius - config.radius).abs() < 0.001,
"Viewpoint {} has incorrect radius: {} (expected {})",
i,
actual_radius,
config.radius
);
}
}
#[test]
fn test_viewpoints_looking_at_origin() {
let config = ViewpointConfig::default();
let viewpoints = generate_viewpoints(&config);
for (i, transform) in viewpoints.iter().enumerate() {
let forward = transform.forward();
let to_origin = (Vec3::ZERO - transform.translation).normalize();
let dot = forward.dot(to_origin);
assert!(
dot > 0.99,
"Viewpoint {} not looking at origin, dot product: {}",
i,
dot
);
}
}
#[test]
fn test_generate_viewpoints_around_target_preserves_orbit() {
let config = ViewpointConfig {
radius: 2.0,
yaw_count: 4,
pitch_angles_deg: vec![0.0],
};
let target = Vec3::new(1.0, -0.5, 0.25);
let viewpoints = generate_viewpoints_around_target(&config, target);
assert_eq!(viewpoints.len(), 4);
for (i, transform) in viewpoints.iter().enumerate() {
let offset = transform.translation - target;
assert!(
(offset.length() - config.radius).abs() < 1e-5,
"viewpoint {} has radius {}, expected {}",
i,
offset.length(),
config.radius
);
let forward = transform.forward();
let to_target = (target - transform.translation).normalize();
assert!(
forward.dot(to_target) > 0.99,
"viewpoint {} is not looking at target",
i
);
}
}
#[test]
fn test_generate_viewpoints_keeps_origin_targeting() {
let config = ViewpointConfig {
radius: 1.0,
yaw_count: 1,
pitch_angles_deg: vec![0.0],
};
let origin_view = generate_viewpoints(&config)[0];
let explicit_origin_view = generate_viewpoints_around_target(&config, Vec3::ZERO)[0];
assert_vec3_close(origin_view.translation, explicit_origin_view.translation);
let forward = origin_view.forward();
let to_origin = (Vec3::ZERO - origin_view.translation).normalize();
assert!(forward.dot(to_origin) > 0.99);
}
#[test]
fn test_object_centered_viewpoints_apply_yaw_rotation_to_target() {
let config = ViewpointConfig {
radius: 1.0,
yaw_count: 1,
pitch_angles_deg: vec![0.0],
};
let mesh_center = Vec3::new(0.25, 0.0, 0.0);
let rotation = ObjectRotation::new(0.0, 90.0, 0.0);
let target = rotated_mesh_center(mesh_center, &rotation);
assert!(target.distance(mesh_center) > 0.1);
let origin_view = generate_viewpoints(&config)[0];
let centered_view = generate_object_centered_viewpoints(&config, mesh_center, &rotation)[0];
assert_vec3_close(centered_view.translation, origin_view.translation + target);
let forward = centered_view.forward();
let to_target = (target - centered_view.translation).normalize();
assert!(forward.dot(to_target) > 0.99);
}
#[test]
fn test_load_ycb_mesh_bounds_from_standard_obj_path() {
let dir = tempfile::tempdir().unwrap();
let mesh_dir = dir.path().join("google_16k");
std::fs::create_dir_all(&mesh_dir).unwrap();
std::fs::write(
mesh_dir.join("textured.obj"),
"v -1.0 -2.0 -3.0\nv 3.0 4.0 5.0\nv 1.0 0.0 2.0\nf 1 2 3\n",
)
.unwrap();
let bounds = load_ycb_mesh_bounds(dir.path()).unwrap();
assert_eq!(bounds.vertex_count, 3);
assert_vec3_close(bounds.min, Vec3::new(-1.0, -2.0, -3.0));
assert_vec3_close(bounds.max, Vec3::new(3.0, 4.0, 5.0));
assert_vec3_close(bounds.center, Vec3::new(1.0, 1.0, 1.0));
assert_vec3_close(bounds.extents(), Vec3::new(4.0, 6.0, 8.0));
}
#[test]
fn test_targeting_policy_serializes_stable_label() {
assert_eq!(TargetingPolicy::Origin.label(), "origin");
assert_eq!(TargetingPolicy::MeshCenter.label(), "mesh-center");
let json = serde_json::to_string(&TargetingPolicy::MeshCenter).unwrap();
assert!(json.contains("mesh_center"));
let loaded: TargetingPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(loaded, TargetingPolicy::MeshCenter);
}
#[test]
fn test_render_output_with_targeting_overrides_origin_default() {
let target_point = Vec3::new(0.1, 0.2, -0.3);
let output = render_output_for_depth(
1,
1,
vec![1.0],
RenderConfig::tbp_default().intrinsics(),
Transform::IDENTITY,
)
.with_targeting(target_point, TargetingPolicy::MeshCenter);
assert_eq!(output.target_point, target_point);
assert_eq!(output.targeting_policy, TargetingPolicy::MeshCenter);
}
#[test]
fn test_center_hit_validation_report_detects_zero_hit_rotation() {
let report = CenterHitValidationReport {
object_id: "test_object".to_string(),
object_dir: "/tmp/ycb/test_object".to_string(),
target_policy: TargetingPolicy::MeshCenter,
rotations: vec![
CenterHitRotationReport {
rotation_index: 0,
rotation_euler: [0.0, 0.0, 0.0],
target_point: [0.0, 0.0, 0.0],
mesh_bounds: None,
total_viewpoints: 24,
center_hits: 1,
center_misses: 23,
misses: Vec::new(),
},
CenterHitRotationReport {
rotation_index: 1,
rotation_euler: [0.0, 90.0, 0.0],
target_point: [0.1, 0.0, 0.0],
mesh_bounds: None,
total_viewpoints: 24,
center_hits: 0,
center_misses: 24,
misses: Vec::new(),
},
],
};
assert!(!report.is_valid());
assert_eq!(report.zero_hit_rotations(), vec![1]);
}
#[test]
fn test_sensor_config_default() {
let config = SensorConfig::default();
assert_eq!(config.object_rotations.len(), 1);
assert_eq!(config.total_captures(), 24);
}
#[test]
fn test_sensor_config_tbp_benchmark() {
let config = SensorConfig::tbp_benchmark();
assert_eq!(config.object_rotations.len(), 3);
assert_eq!(config.total_captures(), 72); }
#[test]
fn test_sensor_config_tbp_full() {
let config = SensorConfig::tbp_full_training();
assert_eq!(config.object_rotations.len(), 14);
assert_eq!(config.total_captures(), 336); }
#[test]
fn test_ycb_representative_objects() {
assert_eq!(crate::ycb::REPRESENTATIVE_OBJECTS.len(), 3);
assert!(crate::ycb::REPRESENTATIVE_OBJECTS.contains(&"003_cracker_box"));
}
#[test]
fn test_ycb_tbp_standard_objects() {
assert_eq!(crate::ycb::TBP_STANDARD_OBJECTS.len(), 10);
assert!(crate::ycb::TBP_STANDARD_OBJECTS.contains(&"025_mug"));
}
#[test]
fn test_ycb_tbp_similar_objects() {
assert_eq!(crate::ycb::TBP_SIMILAR_OBJECTS.len(), 10);
assert!(crate::ycb::TBP_SIMILAR_OBJECTS.contains(&"003_cracker_box"));
}
#[test]
fn test_ycb_object_mesh_path() {
let path = crate::ycb::object_mesh_path("/tmp/ycb", "003_cracker_box");
assert_eq!(
path,
std::path::Path::new("/tmp/ycb")
.join("003_cracker_box")
.join("google_16k")
.join("textured.obj")
);
}
#[test]
fn test_ycb_object_texture_path() {
let path = crate::ycb::object_texture_path("/tmp/ycb", "003_cracker_box");
assert_eq!(
path,
std::path::Path::new("/tmp/ycb")
.join("003_cracker_box")
.join("google_16k")
.join("texture_map.png")
);
}
#[test]
fn test_render_config_tbp_default() {
let config = RenderConfig::tbp_default();
assert_eq!(config.width, 64);
assert_eq!(config.height, 64);
assert!(config.zoom > 0.0);
assert!(config.near_plane > 0.0);
assert!(config.far_plane > config.near_plane);
}
#[test]
fn test_render_config_preview() {
let config = RenderConfig::preview();
assert_eq!(config.width, 256);
assert_eq!(config.height, 256);
}
#[test]
fn test_render_config_default_is_tbp() {
let default = RenderConfig::default();
let tbp = RenderConfig::tbp_default();
assert_eq!(default.width, tbp.width);
assert_eq!(default.height, tbp.height);
}
#[test]
fn test_render_config_fov() {
let config = RenderConfig::tbp_default();
let fov = config.fov_radians();
assert!(fov > 0.0);
assert!(fov < PI);
let zoomed = RenderConfig {
zoom: config.zoom * 2.0,
..config
};
assert!(zoomed.fov_radians() < fov);
}
#[test]
fn test_render_config_intrinsics() {
let config = RenderConfig::tbp_default();
let intrinsics = config.intrinsics();
assert_eq!(intrinsics.image_size, [config.width, config.height]);
assert_eq!(
intrinsics.principal_point,
[config.width as f64 / 2.0, config.height as f64 / 2.0]
);
assert_eq!(intrinsics.focal_length[0], intrinsics.focal_length[1]);
assert!(intrinsics.focal_length[0] > 0.0);
}
#[test]
fn test_render_config_intrinsics_for_size_uses_tbp_zoom_formula() {
let config = RenderConfig {
width: 64,
height: 64,
zoom: 4.0,
..RenderConfig::tbp_default()
};
let intrinsics = config.intrinsics_for_size(64, 64);
assert!((intrinsics.focal_length[0] - 128.0).abs() < 1e-9);
assert!((intrinsics.focal_length[1] - 128.0).abs() < 1e-9);
assert_ne!(intrinsics.focal_length[0], 64.0 * config.zoom as f64);
assert_eq!(intrinsics.principal_point, [32.0, 32.0]);
assert_eq!(intrinsics.image_size, [64, 64]);
}
#[test]
fn test_render_config_intrinsics_for_size_tracks_actual_readback_size() {
let config = RenderConfig {
width: 64,
height: 64,
zoom: 4.0,
..RenderConfig::tbp_default()
};
let intrinsics = config.intrinsics_for_size(128, 96);
assert!((intrinsics.focal_length[0] - 256.0).abs() < 1e-9);
assert!((intrinsics.focal_length[1] - 256.0).abs() < 1e-9);
assert_eq!(intrinsics.principal_point, [64.0, 48.0]);
assert_eq!(intrinsics.image_size, [128, 96]);
}
#[test]
fn test_camera_intrinsics_project() {
let intrinsics = CameraIntrinsics {
focal_length: [100.0, 100.0],
principal_point: [32.0, 32.0],
image_size: [64, 64],
};
let center = intrinsics.project(Vec3::new(0.0, 0.0, 1.0));
assert!(center.is_some());
let [x, y] = center.unwrap();
assert!((x - 32.0).abs() < 0.001);
assert!((y - 32.0).abs() < 0.001);
let behind = intrinsics.project(Vec3::new(0.0, 0.0, -1.0));
assert!(behind.is_none());
}
#[test]
fn test_camera_intrinsics_unproject() {
let intrinsics = CameraIntrinsics {
focal_length: [100.0, 100.0],
principal_point: [32.0, 32.0],
image_size: [64, 64],
};
let point = intrinsics.unproject([32.0, 32.0], 1.0);
assert!((point[0]).abs() < 0.001); assert!((point[1]).abs() < 0.001); assert!((point[2] - 1.0).abs() < 0.001); }
#[test]
fn test_render_output_get_rgba() {
let output = RenderOutput {
rgba: vec![
255, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255,
],
depth: vec![1.0, 2.0, 3.0, 4.0],
width: 2,
height: 2,
intrinsics: RenderConfig::tbp_default().intrinsics(),
camera_transform: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert_eq!(output.get_rgba(0, 0), Some([255, 0, 0, 255]));
assert_eq!(output.get_rgba(1, 0), Some([0, 255, 0, 255]));
assert_eq!(output.get_rgba(0, 1), Some([0, 0, 255, 255]));
assert_eq!(output.get_rgba(1, 1), Some([255, 255, 255, 255]));
assert_eq!(output.get_rgba(2, 0), None);
}
#[test]
fn test_render_output_get_depth() {
let output = RenderOutput {
rgba: vec![0u8; 16],
depth: vec![1.0, 2.0, 3.0, 4.0],
width: 2,
height: 2,
intrinsics: RenderConfig::tbp_default().intrinsics(),
camera_transform: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert_eq!(output.get_depth(0, 0), Some(1.0));
assert_eq!(output.get_depth(1, 0), Some(2.0));
assert_eq!(output.get_depth(0, 1), Some(3.0));
assert_eq!(output.get_depth(1, 1), Some(4.0));
assert_eq!(output.get_depth(2, 0), None);
}
#[test]
fn test_render_output_to_rgb_image() {
let output = RenderOutput {
rgba: vec![
255, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255,
],
depth: vec![1.0, 2.0, 3.0, 4.0],
width: 2,
height: 2,
intrinsics: RenderConfig::tbp_default().intrinsics(),
camera_transform: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
let image = output.to_rgb_image();
assert_eq!(image.len(), 2); assert_eq!(image[0].len(), 2); assert_eq!(image[0][0], [255, 0, 0]); assert_eq!(image[0][1], [0, 255, 0]); assert_eq!(image[1][0], [0, 0, 255]); assert_eq!(image[1][1], [255, 255, 255]); }
#[test]
fn test_render_output_to_depth_image() {
let output = RenderOutput {
rgba: vec![0u8; 16],
depth: vec![1.0, 2.0, 3.0, 4.0],
width: 2,
height: 2,
intrinsics: RenderConfig::tbp_default().intrinsics(),
camera_transform: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
let depth_image = output.to_depth_image();
assert_eq!(depth_image.len(), 2);
assert_eq!(depth_image[0], vec![1.0, 2.0]);
assert_eq!(depth_image[1], vec![3.0, 4.0]);
}
#[test]
fn test_render_output_semantic_3d_marks_foreground_and_background() {
let output = render_output_for_depth(
2,
2,
vec![0.25, 10.0, 0.5, f64::INFINITY],
CameraIntrinsics {
focal_length: [1.0, 1.0],
principal_point: [0.0, 0.0],
image_size: [2, 2],
},
Transform::IDENTITY,
);
let semantic = output.semantic_3d(42);
assert_eq!(semantic.len(), 4);
assert_eq!(semantic[0][3], 42.0);
assert_eq!(semantic[1], [0.0, 0.0, 0.0, 0.0]);
assert_eq!(semantic[2][3], 42.0);
assert_eq!(semantic[3], [0.0, 0.0, 0.0, 0.0]);
assert_point_close(
[semantic[0][0], semantic[0][1], semantic[0][2]],
[0.0, 0.0, -0.25],
);
assert_point_close(
[semantic[2][0], semantic[2][1], semantic[2][2]],
[0.0, -0.5, -0.5],
);
}
#[test]
fn test_render_output_semantic_3d_matches_pixel_surface_points() {
let output = render_output_for_depth(
3,
3,
vec![10.0, 10.0, 2.0, 10.0, 0.25, 10.0, 10.0, 10.0, 10.0],
CameraIntrinsics {
focal_length: [1.0, 1.0],
principal_point: [1.0, 1.0],
image_size: [3, 3],
},
Transform::IDENTITY,
);
let semantic = output.semantic_3d(3);
let top_right = output
.pixel_surface_point_world([2, 0])
.expect("foreground point");
let center = output
.pixel_surface_point_world([1, 1])
.expect("foreground point");
assert_point_close([semantic[2][0], semantic[2][1], semantic[2][2]], top_right);
assert_eq!(semantic[2][3], 3.0);
assert_point_close([semantic[4][0], semantic[4][1], semantic[4][2]], center);
assert_eq!(semantic[4][3], 3.0);
}
#[test]
fn test_render_health_center_hit() {
let mut depth = vec![10.0; 7 * 7];
depth[3 * 7 + 3] = 0.25;
depth[6 * 7 + 6] = 0.5;
let output = render_output_for_depth(
7,
7,
depth,
CameraIntrinsics {
focal_length: [10.0, 10.0],
principal_point: [3.0, 3.0],
image_size: [7, 7],
},
Transform::IDENTITY,
);
let health = output.health();
assert_eq!(health.center_pixel, Some([3, 3]));
assert_eq!(health.center_depth, Some(0.25));
assert!(health.center_foreground);
assert_eq!(health.foreground_pixel_count, 2);
assert!((health.foreground_coverage - 2.0 / 49.0).abs() < 1e-12);
assert_eq!(health.center_5x5_foreground_count, 1);
assert_eq!(health.nearest_foreground_pixel, Some([3, 3]));
assert_eq!(health.nearest_foreground_depth, Some(0.25));
assert_eq!(health.nearest_foreground_distance_px, Some(0.0));
}
#[test]
fn test_render_health_far_center_uses_nearest_foreground() {
let mut depth = vec![10.0; 7 * 7];
depth[3 * 7 + 1] = 0.5;
let output = render_output_for_depth(
7,
7,
depth,
CameraIntrinsics {
focal_length: [10.0, 10.0],
principal_point: [3.0, 3.0],
image_size: [7, 7],
},
Transform::IDENTITY,
);
let health = output.health();
assert_eq!(health.center_pixel, Some([3, 3]));
assert_eq!(health.center_depth, Some(10.0));
assert!(!health.center_foreground);
assert_eq!(health.foreground_pixel_count, 1);
assert_eq!(health.center_5x5_foreground_count, 1);
assert_eq!(health.nearest_foreground_pixel, Some([1, 3]));
assert_eq!(health.nearest_foreground_depth, Some(0.5));
assert_eq!(health.nearest_foreground_distance_px, Some(2.0));
}
#[test]
fn test_center_surface_point_world_uses_bevy_camera_forward() {
let mut depth = vec![10.0; 3 * 3];
depth[3 + 1] = 0.25;
let output = render_output_for_depth(
3,
3,
depth,
CameraIntrinsics {
focal_length: [1.0, 1.0],
principal_point: [1.0, 1.0],
image_size: [3, 3],
},
Transform::IDENTITY,
);
assert_eq!(output.center_pixel_depth(), Some(0.25));
assert_point_close(
output.center_surface_point_world().expect("surface point"),
[0.0, 0.0, -0.25],
);
}
#[test]
fn test_pixel_surface_point_world_maps_image_y_down_to_camera_y_up() {
let mut depth = vec![10.0; 3 * 3];
depth[2] = 2.0;
let output = render_output_for_depth(
3,
3,
depth,
CameraIntrinsics {
focal_length: [1.0, 1.0],
principal_point: [1.0, 1.0],
image_size: [3, 3],
},
Transform::IDENTITY,
);
assert_point_close(
output
.pixel_surface_point_world([2, 0])
.expect("surface point"),
[2.0, 2.0, -2.0],
);
}
#[test]
fn test_camera_world_point_helpers_roundtrip() {
let output = render_output_for_depth(
1,
1,
vec![0.25],
CameraIntrinsics {
focal_length: [1.0, 1.0],
principal_point: [0.0, 0.0],
image_size: [1, 1],
},
Transform::from_xyz(0.0, 0.0, 1.0).looking_at(Vec3::ZERO, Vec3::Y),
);
assert_point_close(
output.center_surface_point_world().expect("surface point"),
[0.0, 0.0, 0.75],
);
let world_point = [0.1, -0.2, 0.7];
let camera_point = output.world_to_camera_point(world_point);
assert_point_close(output.camera_to_world_point(camera_point), world_point);
}
#[test]
fn test_render_error_display() {
let err = RenderError::MeshNotFound("/path/to/mesh.obj".to_string());
assert!(err.to_string().contains("Mesh not found"));
assert!(err.to_string().contains("/path/to/mesh.obj"));
}
#[test]
fn test_object_rotation_extreme_angles() {
let rot = ObjectRotation::new(450.0, -720.0, 1080.0);
let quat = rot.to_quat();
assert!((quat.length() - 1.0).abs() < 0.001);
}
#[test]
fn test_object_rotation_to_transform() {
let rot = ObjectRotation::new(45.0, 90.0, 0.0);
let transform = rot.to_transform();
assert_eq!(transform.translation, Vec3::ZERO);
assert!(transform.rotation != Quat::IDENTITY);
}
#[test]
fn test_object_rotation_to_transform_with_translation_scale() {
let rot = ObjectRotation::new(0.0, 90.0, 0.0);
let translation = Vec3::new(0.25, -0.5, 1.25);
let scale = Vec3::new(1.0, 1.5, 0.75);
let transform = rot.to_transform_with_translation_scale(translation, scale);
assert_eq!(transform.translation, translation);
assert_eq!(transform.scale, scale);
assert_eq!(transform.rotation, rot.to_quat());
}
#[test]
fn test_viewpoint_config_single_viewpoint() {
let config = ViewpointConfig {
radius: 1.0,
yaw_count: 1,
pitch_angles_deg: vec![0.0],
};
assert_eq!(config.viewpoint_count(), 1);
let viewpoints = generate_viewpoints(&config);
assert_eq!(viewpoints.len(), 1);
let pos = viewpoints[0].translation;
assert!((pos.x).abs() < 0.001);
assert!((pos.y).abs() < 0.001);
assert!((pos.z - 1.0).abs() < 0.001);
}
#[test]
fn test_viewpoint_radius_scaling() {
let config1 = ViewpointConfig {
radius: 0.5,
yaw_count: 4,
pitch_angles_deg: vec![0.0],
};
let config2 = ViewpointConfig {
radius: 2.0,
yaw_count: 4,
pitch_angles_deg: vec![0.0],
};
let v1 = generate_viewpoints(&config1);
let v2 = generate_viewpoints(&config2);
for (vp1, vp2) in v1.iter().zip(v2.iter()) {
let ratio = vp2.translation.length() / vp1.translation.length();
assert!((ratio - 4.0).abs() < 0.01); }
}
#[test]
fn test_camera_intrinsics_project_at_z_zero() {
let intrinsics = CameraIntrinsics {
focal_length: [100.0, 100.0],
principal_point: [32.0, 32.0],
image_size: [64, 64],
};
let result = intrinsics.project(Vec3::new(1.0, 1.0, 0.0));
assert!(result.is_none());
}
#[test]
fn test_camera_intrinsics_roundtrip() {
let intrinsics = CameraIntrinsics {
focal_length: [100.0, 100.0],
principal_point: [32.0, 32.0],
image_size: [64, 64],
};
let original = Vec3::new(0.5, -0.3, 2.0);
let projected = intrinsics.project(original).unwrap();
let unprojected = intrinsics.unproject(projected, original.z as f64);
assert!((unprojected[0] - original.x as f64).abs() < 0.001); assert!((unprojected[1] - original.y as f64).abs() < 0.001); assert!((unprojected[2] - original.z as f64).abs() < 0.001); }
#[test]
fn test_render_output_empty() {
let output = RenderOutput {
rgba: vec![],
depth: vec![],
width: 0,
height: 0,
intrinsics: RenderConfig::tbp_default().intrinsics(),
camera_transform: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert_eq!(output.get_rgba(0, 0), None);
assert_eq!(output.get_depth(0, 0), None);
assert!(output.to_rgb_image().is_empty());
assert!(output.to_depth_image().is_empty());
}
#[test]
fn test_render_output_1x1() {
let output = RenderOutput {
rgba: vec![128, 64, 32, 255],
depth: vec![0.5],
width: 1,
height: 1,
intrinsics: RenderConfig::tbp_default().intrinsics(),
camera_transform: Transform::IDENTITY,
object_rotation: ObjectRotation::identity(),
object_translation: Vec3::ZERO,
object_scale: Vec3::ONE,
target_point: Vec3::ZERO,
targeting_policy: TargetingPolicy::Origin,
};
assert_eq!(output.get_rgba(0, 0), Some([128, 64, 32, 255]));
assert_eq!(output.get_depth(0, 0), Some(0.5));
assert_eq!(output.get_rgb(0, 0), Some([128, 64, 32]));
let rgb_img = output.to_rgb_image();
assert_eq!(rgb_img.len(), 1);
assert_eq!(rgb_img[0].len(), 1);
assert_eq!(rgb_img[0][0], [128, 64, 32]);
}
#[test]
fn test_render_config_high_res() {
let config = RenderConfig::high_res();
assert_eq!(config.width, 512);
assert_eq!(config.height, 512);
let intrinsics = config.intrinsics();
assert_eq!(intrinsics.image_size, [512, 512]);
assert_eq!(intrinsics.principal_point, [256.0, 256.0]);
}
#[test]
fn test_render_config_zoom_affects_fov() {
let base = RenderConfig {
zoom: 2.0,
..RenderConfig::tbp_default()
};
let doubled = RenderConfig {
zoom: 4.0,
..RenderConfig::tbp_default()
};
assert!(doubled.fov_radians() < base.fov_radians());
let base_half_tan = (base.fov_radians() / 2.0).tan();
let doubled_half_tan = (doubled.fov_radians() / 2.0).tan();
assert!((base_half_tan / doubled_half_tan - 2.0).abs() < 1e-4);
}
#[test]
fn test_render_config_zoom_affects_intrinsics() {
let a = RenderConfig {
zoom: 2.0,
..RenderConfig::tbp_default()
};
let b = RenderConfig {
zoom: 4.0,
..RenderConfig::tbp_default()
};
let fx_a = a.intrinsics().focal_length[0];
let fx_b = b.intrinsics().focal_length[0];
assert!(fx_b > fx_a);
assert!((fx_a / a.zoom as f64 - fx_b / b.zoom as f64).abs() < 1e-9);
}
#[test]
fn test_lighting_config_variants() {
let default = LightingConfig::default();
let bright = LightingConfig::bright();
let soft = LightingConfig::soft();
let unlit = LightingConfig::unlit();
assert!(bright.key_light_intensity > default.key_light_intensity);
assert_eq!(unlit.key_light_intensity, 0.0);
assert_eq!(unlit.fill_light_intensity, 0.0);
assert_eq!(unlit.ambient_brightness, 1.0);
assert!(soft.key_light_intensity < default.key_light_intensity);
}
#[test]
fn test_all_render_error_variants() {
let errors = vec![
RenderError::MeshNotFound("mesh.obj".to_string()),
RenderError::TextureNotFound("texture.png".to_string()),
RenderError::RenderFailed("GPU error".to_string()),
RenderError::InvalidConfig("bad config".to_string()),
];
for err in errors {
let msg = err.to_string();
assert!(!msg.is_empty());
}
}
#[test]
fn test_tbp_known_orientations_unique() {
let orientations = ObjectRotation::tbp_known_orientations();
let quats: Vec<Quat> = orientations.iter().map(|r| r.to_quat()).collect();
for (i, q1) in quats.iter().enumerate() {
for (j, q2) in quats.iter().enumerate() {
if i != j {
let dot = q1.dot(*q2).abs();
assert!(
dot < 0.999,
"Orientations {} and {} produce same quaternion",
i,
j
);
}
}
}
}
}