use rand::seq::SliceRandom;
use rand::SeedableRng;
use std::collections::HashSet;
use std::fs;
use std::path::PathBuf;
pub trait HttpClient: Send + Sync {
fn get(&self, url: &str) -> Result<String, DownloadError>;
fn get_bytes(&self, url: &str) -> Result<Vec<u8>, DownloadError>;
}
#[derive(Debug)]
pub enum DownloadError {
RequestError(reqwest::Error),
IoError(std::io::Error),
NoEntriesFound,
DownloadFailed(String),
JsonError(serde_json::Error),
}
impl std::fmt::Display for DownloadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DownloadError::RequestError(e) => write!(f, "Request error: {}", e),
DownloadError::IoError(e) => write!(f, "IO error: {}", e),
DownloadError::NoEntriesFound => write!(f, "No entries found"),
DownloadError::DownloadFailed(msg) => write!(f, "Download failed: {}", msg),
DownloadError::JsonError(e) => write!(f, "JSON error: {}", e),
}
}
}
impl std::error::Error for DownloadError {}
impl From<reqwest::Error> for DownloadError {
fn from(err: reqwest::Error) -> Self {
DownloadError::RequestError(err)
}
}
impl From<std::io::Error> for DownloadError {
fn from(err: std::io::Error) -> Self {
DownloadError::IoError(err)
}
}
impl From<serde_json::Error> for DownloadError {
fn from(err: serde_json::Error) -> Self {
DownloadError::JsonError(err)
}
}
impl From<String> for DownloadError {
fn from(err: String) -> Self {
DownloadError::DownloadFailed(err)
}
}
pub struct ReqwestClient;
#[cfg(test)]
impl HttpClient for ustar_test_utils::MockHttpClient {
fn get(&self, url: &str) -> Result<String, DownloadError> {
self.get(url).map_err(DownloadError::DownloadFailed)
}
fn get_bytes(&self, url: &str) -> Result<Vec<u8>, DownloadError> {
self.get_bytes(url).map_err(DownloadError::DownloadFailed)
}
}
impl HttpClient for ReqwestClient {
fn get(&self, url: &str) -> Result<String, DownloadError> {
let rt = tokio::runtime::Runtime::new().map_err(|e| {
DownloadError::DownloadFailed(format!("Failed to create runtime: {}", e))
})?;
rt.block_on(async {
let response = reqwest::get(url).await?;
if !response.status().is_success() {
return Err(DownloadError::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
Ok(response.text().await?)
})
}
fn get_bytes(&self, url: &str) -> Result<Vec<u8>, DownloadError> {
let rt = tokio::runtime::Runtime::new().map_err(|e| {
DownloadError::DownloadFailed(format!("Failed to create runtime: {}", e))
})?;
rt.block_on(async {
let response = reqwest::get(url).await?;
if !response.status().is_success() {
return Err(DownloadError::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
Ok(response.bytes().await?.to_vec())
})
}
}
#[derive(clap::Parser, Debug)]
pub struct CommonDownloaderCli {
#[arg(default_value_t = 50, value_name = "COUNT")]
pub count: usize,
#[arg(short, long, default_value = "tests/test_data")]
pub output_dir: String,
#[arg(long)]
pub verbose: bool,
#[arg(long)]
pub list: bool,
#[arg(long, default_value_t = 42)]
pub seed: u64,
}
pub struct DownloaderConfig {
pub output_dir: PathBuf,
pub verbose: bool,
pub file_extension: String,
}
impl DownloaderConfig {
pub fn new() -> Self {
Self {
output_dir: PathBuf::from("."),
verbose: true,
file_extension: "cif".to_string(),
}
}
pub fn output_dir<P: Into<PathBuf>>(mut self, dir: P) -> Self {
self.output_dir = dir.into();
self
}
pub fn verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn file_extension<S: Into<String>>(mut self, ext: S) -> Self {
self.file_extension = ext.into();
self
}
}
pub trait DataSource {
fn get_available_entries(&self) -> Result<Vec<String>, DownloadError>;
fn download_entry(
&self,
entry_id: &str,
output_path: &PathBuf,
) -> Result<PathBuf, DownloadError>;
}
pub struct GenericDownloader<T: DataSource> {
config: DownloaderConfig,
data_source: T,
}
impl<T: DataSource> GenericDownloader<T> {
pub fn new(config: DownloaderConfig, data_source: T) -> Self {
Self {
config,
data_source,
}
}
pub fn download_unique_random_batch(
&self,
count: usize,
seed: u64,
) -> Result<Vec<(String, PathBuf)>, DownloadError> {
let mut entries = self.data_source.get_available_entries()?;
if entries.is_empty() {
return Err(DownloadError::NoEntriesFound);
}
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
entries.shuffle(&mut rng);
let mut results: Vec<(String, PathBuf)> = Vec::new();
let mut tried = HashSet::new();
for entry_id in entries.into_iter() {
if results.len() >= count {
break;
}
if !tried.insert(entry_id.clone()) {
continue;
}
let filename = format!("{}.{}", entry_id, self.config.file_extension);
let filepath = self.config.output_dir.join(&filename);
if filepath.exists() {
if self.config.verbose {
println!("Already exists, skipping: {}", filepath.display());
}
continue;
}
match self.data_source.download_entry(&entry_id, &filepath) {
Ok(path) => results.push((entry_id, path)),
Err(e) => {
if self.config.verbose {
eprintln!("Failed to download {}: {}", entry_id, e);
}
}
}
}
Ok(results)
}
pub fn list_files(&self) -> Result<(), DownloadError> {
let entries = self.data_source.get_available_entries()?;
let mut downloaded = HashSet::new();
if let Ok(dir_entries) = fs::read_dir(&self.config.output_dir) {
for entry in dir_entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(&self.config.file_extension) {
if let Some(stem) =
name.strip_suffix(&format!(".{}", self.config.file_extension))
{
downloaded.insert(stem.to_lowercase());
}
}
}
}
}
if self.config.verbose {
println!("[VERBOSE] Total available files: {}", entries.len());
} else {
println!("Total available files: {}", entries.len());
}
let mut downloaded_count = 0;
for entry_id in &entries {
let is_downloaded = downloaded.contains(&entry_id.to_lowercase());
if is_downloaded {
downloaded_count += 1;
}
let status = if is_downloaded { " [downloaded]" } else { "" };
if self.config.verbose {
println!("[VERBOSE] {}{}", entry_id, status);
} else {
println!("{}{}", entry_id, status);
}
}
if self.config.verbose {
println!(
"[VERBOSE] Total downloaded: {} / {}",
downloaded_count,
entries.len()
);
} else {
println!(
"\nTotal downloaded: {} / {}",
downloaded_count,
entries.len()
);
}
Ok(())
}
}