use anyhow::{anyhow, bail, Context, Result};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use futures_util::{AsyncReadExt, StreamExt, TryStreamExt};
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use s3::creds::Credentials;
use s3::error::S3Error;
use s3::region::Region;
use s3::Bucket;
use crate::{fetch_objects, get_tgz_url, url_exists, DownloadOptions, Subset};
use crate::{REPRESENTATIVE_OBJECTS, TEN_OBJECTS};
#[derive(Clone, Debug)]
pub struct S3Destination {
pub bucket: String,
pub prefix: String,
pub region: String,
}
impl S3Destination {
pub fn from_url(url: &str) -> Result<Self> {
let url = url.trim();
if !url.starts_with("s3://") {
bail!("S3 URL must start with 's3://', got: {}", url);
}
let path = &url[5..];
if path.is_empty() {
bail!("S3 URL must include a bucket name");
}
let (bucket, prefix) = match path.find('/') {
Some(idx) => {
let bucket = &path[..idx];
let mut prefix = path[idx + 1..].to_string();
if !prefix.is_empty() && !prefix.ends_with('/') {
prefix.push('/');
}
(bucket.to_string(), prefix)
}
None => (path.to_string(), String::new()),
};
if bucket.is_empty() {
bail!("S3 URL must include a bucket name");
}
Ok(Self {
bucket,
prefix,
region: "us-east-1".to_string(),
})
}
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region = region.into();
self
}
pub fn full_path(&self, path: &str) -> String {
format!("{}{}", self.prefix, path)
}
pub fn to_url(&self) -> String {
format!("s3://{}/{}", self.bucket, self.prefix)
}
}
#[derive(Clone, Debug, Default)]
pub struct S3UploadStats {
pub files_uploaded: usize,
pub files_skipped: usize,
pub bytes_uploaded: u64,
}
pub async fn check_aws_credentials(profile: Option<&str>) -> Result<String> {
let creds = get_credentials(profile)?;
let access_key = creds
.access_key
.as_ref()
.ok_or_else(|| anyhow!("No AWS access key found"))?;
let secret_key = creds
.secret_key
.as_ref()
.ok_or_else(|| anyhow!("No AWS secret key found"))?;
if access_key.is_empty() {
bail!("AWS access key is empty");
}
if secret_key.is_empty() {
bail!("AWS secret key is empty");
}
let masked = if access_key.len() > 8 {
format!(
"{}...{}",
&access_key[..4],
&access_key[access_key.len() - 4..]
)
} else {
"****".to_string()
};
Ok(format!("AWS credentials loaded (access key: {})", masked))
}
fn get_credentials(profile: Option<&str>) -> Result<Credentials> {
if std::env::var("AWS_ACCESS_KEY_ID").is_ok() {
return Credentials::from_env()
.map_err(|e| anyhow!("Failed to load AWS credentials from environment: {}", e));
}
let profile_name = profile
.map(|s| s.to_string())
.or_else(|| std::env::var("AWS_PROFILE").ok())
.unwrap_or_else(|| "default".to_string());
Credentials::from_profile(Some(&profile_name)).map_err(|e| {
anyhow!(
"Failed to load AWS credentials for profile '{}': {}",
profile_name,
e
)
})
}
async fn create_bucket(dest: &S3Destination, profile: Option<&str>) -> Result<Box<Bucket>> {
let creds = get_credentials(profile)?;
let region = Region::Custom {
region: dest.region.clone(),
endpoint: format!("https://s3.{}.amazonaws.com", dest.region),
};
let bucket = Bucket::new(&dest.bucket, region, creds)
.map_err(|e| anyhow!("Failed to create S3 bucket handle: {}", e))?
.with_path_style();
Ok(bucket)
}
async fn object_exists(bucket: &Bucket, path: &str) -> Result<bool> {
match bucket.head_object(path).await {
Ok(_) => Ok(true),
Err(S3Error::HttpFailWithBody(404, _)) => Ok(false),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("404")
|| err_str.contains("Not Found")
|| err_str.contains("NoSuchKey")
{
Ok(false)
} else {
Err(anyhow!("Failed to check if object exists: {}", e))
}
}
}
}
pub async fn download_ycb_to_s3(
subset: Subset,
dest: S3Destination,
options: DownloadOptions,
profile: Option<&str>,
) -> Result<S3UploadStats> {
let http_client = Client::new();
let s3_bucket = create_bucket(&dest, profile).await?;
let mut stats = S3UploadStats::default();
let selected_objects: Vec<String> = match subset {
Subset::Representative => REPRESENTATIVE_OBJECTS
.iter()
.map(|s| s.to_string())
.collect(),
Subset::Ten => TEN_OBJECTS.iter().map(|s| s.to_string()).collect(),
Subset::All => fetch_objects(&http_client).await?,
};
let file_types = if options.full {
vec!["berkeley_processed", "google_16k"]
} else {
vec!["google_16k"]
};
println!(
"Streaming {} objects to {}",
selected_objects.len(),
dest.to_url()
);
for object in &selected_objects {
for file_type in &file_types {
let url = get_tgz_url(object, file_type);
if !url_exists(&http_client, &url).await? {
if options.show_progress {
println!("Skipping {} ({}): not found on source", object, file_type);
}
continue;
}
let result = stream_tgz_to_s3(
&http_client,
&url,
&s3_bucket,
&dest.prefix,
object,
file_type,
&options,
)
.await;
match result {
Ok((uploaded, skipped, bytes)) => {
stats.files_uploaded += uploaded;
stats.files_skipped += skipped;
stats.bytes_uploaded += bytes;
}
Err(e) => {
eprintln!("Error processing {} ({}): {}", object, file_type, e);
}
}
}
}
Ok(stats)
}
fn sanitize_tar_path(path: &std::path::Path) -> Option<String> {
if path
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
return None;
}
if path.is_absolute() {
return None;
}
let path_str = path.to_string_lossy();
let normalized = path_str.replace('\\', "/");
if normalized.is_empty() || normalized.starts_with('/') {
return None;
}
Some(normalized)
}
async fn stream_tgz_to_s3(
client: &Client,
url: &str,
bucket: &Bucket,
prefix: &str,
object: &str,
file_type: &str,
options: &DownloadOptions,
) -> Result<(usize, usize, u64)> {
let mut uploaded = 0usize;
let mut skipped = 0usize;
let mut bytes = 0u64;
let response = client
.get(url)
.send()
.await
.context("Failed to start download")?;
if !response.status().is_success() {
bail!("HTTP request failed with status: {}", response.status());
}
let pb = if options.show_progress {
let pb = ProgressBar::new_spinner();
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} [{elapsed_precise}] {msg}")
.expect("Invalid progress bar template"),
);
pb.set_message(format!("{} ({}) - extracting...", object, file_type));
Some(pb)
} else {
None
};
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
let stream_reader = byte_stream.into_async_read();
let buf_reader = futures_util::io::BufReader::new(stream_reader);
let decoder = GzipDecoder::new(buf_reader);
let archive = Archive::new(decoder);
let mut entries = archive.entries().context("Failed to read tar entries")?;
while let Some(entry_result) = entries.next().await {
let mut entry = entry_result.context("Failed to read tar entry")?;
let path = entry
.path()
.context("Failed to get entry path")?
.to_path_buf();
if entry.header().entry_type().is_dir() {
continue;
}
let std_path = std::path::Path::new(path.as_os_str());
let sanitized_path = match sanitize_tar_path(std_path) {
Some(p) => p,
None => {
eprintln!(
"Warning: Skipping invalid/unsafe path in archive: {}",
path.display()
);
continue;
}
};
let s3_path = format!("{}{}", prefix, sanitized_path);
if !options.overwrite {
match object_exists(bucket, &s3_path).await {
Ok(true) => {
skipped += 1;
continue;
}
Ok(false) => {} Err(e) => {
eprintln!("Warning: Failed to check if {} exists: {}", s3_path, e);
}
}
}
let mut content = Vec::new();
entry
.read_to_end(&mut content)
.await
.context("Failed to read tar entry content")?;
let content_len = content.len() as u64;
let content_type = guess_content_type(&sanitized_path);
bucket
.put_object_with_content_type(&s3_path, &content, content_type)
.await
.map_err(|e| anyhow!("Failed to upload {}: {}", s3_path, e))?;
uploaded += 1;
bytes += content_len;
if let Some(ref pb) = pb {
pb.set_message(format!(
"{} ({}) - {} files uploaded",
object, file_type, uploaded
));
}
}
if let Some(pb) = pb {
pb.finish_with_message(format!("{} ({}) - {} files", object, file_type, uploaded));
}
Ok((uploaded, skipped, bytes))
}
fn guess_content_type(path: &str) -> &'static str {
if path.ends_with(".obj") {
"model/obj"
} else if path.ends_with(".mtl") {
"model/mtl"
} else if path.ends_with(".png") {
"image/png"
} else if path.ends_with(".jpg") || path.ends_with(".jpeg") {
"image/jpeg"
} else if path.ends_with(".ply") {
"application/ply"
} else if path.ends_with(".json") {
"application/json"
} else if path.ends_with(".txt") {
"text/plain"
} else {
"application/octet-stream"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_s3_destination_from_url_basic() {
let dest = S3Destination::from_url("s3://my-bucket/prefix/path/").unwrap();
assert_eq!(dest.bucket, "my-bucket");
assert_eq!(dest.prefix, "prefix/path/");
}
#[test]
fn test_s3_destination_from_url_no_trailing_slash() {
let dest = S3Destination::from_url("s3://my-bucket/prefix/path").unwrap();
assert_eq!(dest.bucket, "my-bucket");
assert_eq!(dest.prefix, "prefix/path/");
}
#[test]
fn test_s3_destination_from_url_bucket_only() {
let dest = S3Destination::from_url("s3://my-bucket").unwrap();
assert_eq!(dest.bucket, "my-bucket");
assert_eq!(dest.prefix, "");
}
#[test]
fn test_s3_destination_from_url_bucket_with_slash() {
let dest = S3Destination::from_url("s3://my-bucket/").unwrap();
assert_eq!(dest.bucket, "my-bucket");
assert_eq!(dest.prefix, "");
}
#[test]
fn test_s3_destination_from_url_invalid() {
assert!(S3Destination::from_url("http://example.com").is_err());
assert!(S3Destination::from_url("s3://").is_err());
assert!(S3Destination::from_url("/local/path").is_err());
}
#[test]
fn test_s3_destination_full_path() {
let dest = S3Destination::from_url("s3://my-bucket/ycb/").unwrap();
assert_eq!(
dest.full_path("003_cracker_box/google_16k/textured.obj"),
"ycb/003_cracker_box/google_16k/textured.obj"
);
}
#[test]
fn test_s3_destination_to_url() {
let dest = S3Destination::from_url("s3://my-bucket/prefix/").unwrap();
assert_eq!(dest.to_url(), "s3://my-bucket/prefix/");
}
#[test]
fn test_guess_content_type() {
assert_eq!(guess_content_type("model.obj"), "model/obj");
assert_eq!(guess_content_type("texture.png"), "image/png");
assert_eq!(guess_content_type("data.json"), "application/json");
assert_eq!(
guess_content_type("unknown.xyz"),
"application/octet-stream"
);
}
#[test]
fn test_sanitize_tar_path_valid() {
use std::path::Path;
assert_eq!(
sanitize_tar_path(Path::new("foo/bar/file.txt")),
Some("foo/bar/file.txt".to_string())
);
assert_eq!(
sanitize_tar_path(Path::new("file.obj")),
Some("file.obj".to_string())
);
}
#[test]
fn test_sanitize_tar_path_traversal() {
use std::path::Path;
assert_eq!(sanitize_tar_path(Path::new("../etc/passwd")), None);
assert_eq!(sanitize_tar_path(Path::new("foo/../bar")), None);
assert_eq!(sanitize_tar_path(Path::new("foo/bar/../../baz")), None);
}
#[test]
fn test_sanitize_tar_path_absolute() {
use std::path::Path;
assert_eq!(sanitize_tar_path(Path::new("/etc/passwd")), None);
}
#[test]
fn test_sanitize_tar_path_empty() {
use std::path::Path;
assert_eq!(sanitize_tar_path(Path::new("")), None);
}
#[cfg(windows)]
#[test]
fn test_sanitize_tar_path_windows_separators() {
use std::path::Path;
let result = sanitize_tar_path(Path::new("foo\\bar\\file.txt"));
assert_eq!(result, Some("foo/bar/file.txt".to_string()));
}
}