#[cfg(feature = "s3")]
pub mod s3;
use anyhow::{Context, Result};
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use serde::Deserialize;
use std::fs::{self, File};
use std::io::Write;
use std::path::Path;
pub const BASE_URL: &str = "https://ycb-benchmarks.s3.amazonaws.com/data/";
pub const OBJECTS_URL: &str = "https://ycb-benchmarks.s3.amazonaws.com/data/objects.json";
pub const REPRESENTATIVE_OBJECTS: &[&str] =
&["003_cracker_box", "004_sugar_box", "005_tomato_soup_can"];
#[deprecated(
since = "0.3.0",
note = "Use TBP_STANDARD_OBJECTS or TBP_SIMILAR_OBJECTS instead"
)]
pub const TEN_OBJECTS: &[&str] = &[
"003_cracker_box",
"004_sugar_box",
"005_tomato_soup_can",
"006_mustard_bottle",
"007_tuna_fish_can",
"008_pudding_box",
"009_gelatin_box",
"010_potted_meat_can",
"011_banana",
"019_pitcher_base",
];
pub const TBP_STANDARD_OBJECTS: &[&str] = &[
"025_mug",
"024_bowl",
"010_potted_meat_can",
"031_spoon",
"012_strawberry",
"006_mustard_bottle",
"062_dice",
"058_golf_ball",
"073-c_lego_duplo",
"011_banana",
];
pub const TBP_SIMILAR_OBJECTS: &[&str] = &[
"003_cracker_box",
"004_sugar_box",
"009_gelatin_box",
"021_bleach_cleanser",
"036_wood_block",
"039_key",
"040_large_marker",
"051_large_clamp",
"052_extra_large_clamp",
"061_foam_brick",
];
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default)]
pub enum Subset {
#[default]
Representative,
#[allow(deprecated)]
Ten,
TbpStandard,
TbpSimilar,
All,
}
#[derive(Clone, Debug)]
pub struct DownloadOptions {
pub overwrite: bool,
pub full: bool,
pub show_progress: bool,
pub delete_archives: bool,
}
impl Default for DownloadOptions {
fn default() -> Self {
Self {
overwrite: false,
full: false,
show_progress: true,
delete_archives: true,
}
}
}
#[derive(Deserialize, Debug)]
struct ObjectsResponse {
objects: Vec<String>,
}
async fn selected_objects_for_subset(subset: Subset, client: &Client) -> Result<Vec<String>> {
match get_subset_objects(subset) {
Some(objects) => Ok(objects),
None => fetch_objects(client).await,
}
}
fn download_file_types(full: bool) -> &'static [&'static str] {
if full {
&["berkeley_processed", "google_16k"]
} else {
&["google_16k"]
}
}
fn local_artifact_exists(output_dir: &Path, object: &str, file_type: &str) -> bool {
match file_type {
"google_16k" => object_mesh_path(output_dir, object).exists(),
_ => false,
}
}
pub async fn fetch_objects(client: &Client) -> Result<Vec<String>> {
let response = client
.get(OBJECTS_URL)
.send()
.await
.with_context(|| format!("Failed to fetch objects list from {}", OBJECTS_URL))?;
let response = response
.error_for_status()
.with_context(|| format!("YCB objects endpoint returned an error for {}", OBJECTS_URL))?;
let objects_response: ObjectsResponse = response
.json()
.await
.context("Failed to parse objects JSON")?;
Ok(objects_response.objects)
}
pub fn get_tgz_url(object: &str, file_type: &str) -> String {
if file_type == "berkeley_rgbd" || file_type == "berkeley_rgb_highres" {
format!(
"{}berkeley/{}/{}_{}.tgz",
BASE_URL, object, object, file_type
)
} else if file_type == "berkeley_processed" {
format!(
"{}berkeley/{}/{}_berkeley_meshes.tgz",
BASE_URL, object, object
)
} else {
format!("{}google/{}_{}.tgz", BASE_URL, object, file_type)
}
}
pub async fn download_file(
client: &Client,
url: &str,
dest_path: &Path,
show_progress: bool,
) -> Result<()> {
let res = client
.get(url)
.send()
.await
.with_context(|| format!("Failed to send request to {}", url))?;
let res = res
.error_for_status()
.with_context(|| format!("YCB download failed for {}", url))?;
let total_size = res.content_length().unwrap_or(0);
let filename = dest_path
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
let pb = if show_progress {
let pb = ProgressBar::new(total_size);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.expect("Invalid progress bar template - this is a bug")
.progress_chars("#>-"),
);
pb.set_message(format!("Downloading {}", filename));
Some(pb)
} else {
None
};
let mut file = File::create(dest_path).context("Failed to create file")?;
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item.context("Error while downloading chunk")?;
file.write_all(&chunk)
.context("Error while writing to file")?;
if let Some(ref pb) = pb {
pb.inc(chunk.len() as u64);
}
}
if let Some(pb) = pb {
pb.finish_with_message("Done");
}
Ok(())
}
pub fn extract_tgz(tgz_path: &Path, output_dir: &Path, delete_archive: bool) -> Result<()> {
let tar_gz = File::open(tgz_path)?;
let tar = flate2::read::GzDecoder::new(tar_gz);
let mut archive = tar::Archive::new(tar);
for entry in archive
.entries()
.context("Failed to read archive entries")?
{
let mut entry = entry.context("Failed to read archive entry")?;
let path = entry
.path()
.context("Failed to get entry path")?
.to_path_buf();
if path
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
anyhow::bail!(
"Archive contains invalid path with '..': {}",
path.display()
);
}
let dest = output_dir.join(&path);
let canonical_output = output_dir
.canonicalize()
.unwrap_or_else(|_| output_dir.to_path_buf());
if let Ok(canonical_dest) = dest.canonicalize() {
if !canonical_dest.starts_with(&canonical_output) {
anyhow::bail!(
"Archive tries to write outside output directory: {}",
dest.display()
);
}
}
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("Failed to create directory: {}", parent.display()))?;
}
entry
.unpack(&dest)
.with_context(|| format!("Failed to extract: {}", path.display()))?;
}
if delete_archive {
fs::remove_file(tgz_path)?;
}
Ok(())
}
pub async fn url_exists(client: &Client, url: &str) -> Result<bool> {
let response = client
.head(url)
.send()
.await
.context("Failed to check URL")?;
Ok(response.status().is_success())
}
pub async fn download_ycb(
subset: Subset,
output_dir: &Path,
options: DownloadOptions,
) -> Result<()> {
let client = Client::new();
let selected_objects = selected_objects_for_subset(subset, &client).await?;
fs::create_dir_all(output_dir).context("Failed to create output directory")?;
let file_types = download_file_types(options.full);
for object in &selected_objects {
for &file_type in file_types {
let filename = format!("{}_{}.tgz", object, file_type);
let dest_path = output_dir.join(&filename);
if !options.overwrite
&& (dest_path.exists() || local_artifact_exists(output_dir, object, file_type))
{
continue;
}
let url = get_tgz_url(object, file_type);
if !url_exists(&client, &url).await? {
continue;
}
download_file(&client, &url, &dest_path, options.show_progress).await?;
extract_tgz(&dest_path, output_dir, options.delete_archives)?;
}
}
Ok(())
}
pub fn get_subset_objects(subset: Subset) -> Option<Vec<String>> {
#[allow(deprecated)]
match subset {
Subset::Representative => Some(
REPRESENTATIVE_OBJECTS
.iter()
.map(|s| s.to_string())
.collect(),
),
Subset::Ten => Some(TEN_OBJECTS.iter().map(|s| s.to_string()).collect()),
Subset::TbpStandard => Some(TBP_STANDARD_OBJECTS.iter().map(|s| s.to_string()).collect()),
Subset::TbpSimilar => Some(TBP_SIMILAR_OBJECTS.iter().map(|s| s.to_string()).collect()),
Subset::All => None,
}
}
pub fn object_mesh_path(ycb_dir: &Path, object: &str) -> std::path::PathBuf {
ycb_dir.join(object).join("google_16k").join("textured.obj")
}
pub fn object_texture_path(ycb_dir: &Path, object: &str) -> std::path::PathBuf {
ycb_dir
.join(object)
.join("google_16k")
.join("texture_map.png")
}
#[derive(Debug, Clone)]
pub struct ObjectValidation {
pub name: String,
pub mesh_present: bool,
pub texture_present: bool,
}
impl ObjectValidation {
pub fn is_complete(&self) -> bool {
self.mesh_present && self.texture_present
}
}
pub fn validate_objects(ycb_dir: &Path, objects: &[&str]) -> Vec<ObjectValidation> {
objects
.iter()
.map(|name| ObjectValidation {
name: name.to_string(),
mesh_present: object_mesh_path(ycb_dir, name).exists(),
texture_present: object_texture_path(ycb_dir, name).exists(),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_tgz_url_google_16k() {
let url = get_tgz_url("003_cracker_box", "google_16k");
assert_eq!(
url,
"https://ycb-benchmarks.s3.amazonaws.com/data/google/003_cracker_box_google_16k.tgz"
);
}
#[test]
fn test_get_tgz_url_berkeley_processed() {
let url = get_tgz_url("003_cracker_box", "berkeley_processed");
assert_eq!(
url,
"https://ycb-benchmarks.s3.amazonaws.com/data/berkeley/003_cracker_box/003_cracker_box_berkeley_meshes.tgz"
);
}
#[test]
fn test_get_tgz_url_berkeley_rgbd() {
let url = get_tgz_url("003_cracker_box", "berkeley_rgbd");
assert_eq!(
url,
"https://ycb-benchmarks.s3.amazonaws.com/data/berkeley/003_cracker_box/003_cracker_box_berkeley_rgbd.tgz"
);
}
#[test]
fn test_get_tgz_url_berkeley_rgb_highres() {
let url = get_tgz_url("003_cracker_box", "berkeley_rgb_highres");
assert_eq!(
url,
"https://ycb-benchmarks.s3.amazonaws.com/data/berkeley/003_cracker_box/003_cracker_box_berkeley_rgb_highres.tgz"
);
}
#[test]
fn test_get_tgz_url_different_objects() {
let url1 = get_tgz_url("004_sugar_box", "google_16k");
assert!(url1.contains("004_sugar_box"));
let url2 = get_tgz_url("005_tomato_soup_can", "google_16k");
assert!(url2.contains("005_tomato_soup_can"));
}
#[test]
fn test_subset_default() {
let subset = Subset::default();
assert_eq!(subset, Subset::Representative);
}
#[test]
fn test_download_options_default() {
let options = DownloadOptions::default();
assert!(!options.overwrite);
assert!(!options.full);
assert!(options.show_progress);
assert!(options.delete_archives);
}
#[test]
fn test_get_subset_objects_representative() {
let objects = get_subset_objects(Subset::Representative);
assert_eq!(objects.unwrap().len(), 3);
}
#[test]
fn test_get_subset_objects_ten() {
let objects = get_subset_objects(Subset::Ten);
assert_eq!(objects.unwrap().len(), 10);
}
#[test]
fn test_get_subset_objects_tbp_standard() {
let objects = get_subset_objects(Subset::TbpStandard);
assert_eq!(objects.unwrap().len(), 10);
}
#[test]
fn test_get_subset_objects_tbp_similar() {
let objects = get_subset_objects(Subset::TbpSimilar);
assert_eq!(objects.unwrap().len(), 10);
}
#[test]
fn test_get_subset_objects_all() {
let objects = get_subset_objects(Subset::All);
assert!(objects.is_none());
}
#[test]
fn test_local_artifact_exists_for_google_16k_mesh() {
let dir = tempfile::tempdir().unwrap();
let mesh_path = object_mesh_path(dir.path(), "003_cracker_box");
fs::create_dir_all(mesh_path.parent().unwrap()).unwrap();
File::create(&mesh_path).unwrap();
assert!(local_artifact_exists(
dir.path(),
"003_cracker_box",
"google_16k"
));
assert!(!local_artifact_exists(
dir.path(),
"003_cracker_box",
"berkeley_processed"
));
}
}