use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use futures::future::join_all;
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::sync::Semaphore;
use tokio::time::sleep;
use crate::core::PdbStructure;
use crate::parser::{parse_mmcif_string, parse_pdb_string};
use super::{DownloadError, FileFormat, RCSB_DOWNLOAD_URL};
#[derive(Debug, Clone)]
pub struct AsyncDownloadOptions {
pub max_concurrent: usize,
pub rate_limit_ms: u64,
pub timeout_secs: u64,
pub retries: usize,
}
impl Default for AsyncDownloadOptions {
fn default() -> Self {
Self {
max_concurrent: 5,
rate_limit_ms: 100,
timeout_secs: 30,
retries: 2,
}
}
}
impl AsyncDownloadOptions {
pub fn new() -> Self {
Self::default()
}
pub fn conservative() -> Self {
Self {
max_concurrent: 2,
rate_limit_ms: 500,
timeout_secs: 60,
retries: 3,
}
}
pub fn fast() -> Self {
Self {
max_concurrent: 20,
rate_limit_ms: 25,
timeout_secs: 30,
retries: 1,
}
}
pub fn with_max_concurrent(mut self, max_concurrent: usize) -> Self {
self.max_concurrent = max_concurrent;
self
}
pub fn with_rate_limit_ms(mut self, rate_limit_ms: u64) -> Self {
self.rate_limit_ms = rate_limit_ms;
self
}
pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = timeout_secs;
self
}
pub fn with_retries(mut self, retries: usize) -> Self {
self.retries = retries;
self
}
}
fn build_download_url(pdb_id: &str, format: FileFormat) -> String {
let pdb_id_upper = pdb_id.to_uppercase();
match format {
FileFormat::Pdb => format!("{}/{}.pdb", RCSB_DOWNLOAD_URL, pdb_id_upper),
FileFormat::Cif => format!("{}/{}.cif", RCSB_DOWNLOAD_URL, pdb_id_upper),
}
}
pub async fn download_pdb_string_async(
pdb_id: &str,
format: FileFormat,
) -> Result<String, DownloadError> {
let url = build_download_url(pdb_id, format);
let client = reqwest::Client::new();
let response = client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(DownloadError::NotFound(pdb_id.to_string()));
}
if !response.status().is_success() {
return Err(DownloadError::RequestFailed(format!(
"HTTP {}: {}",
response.status(),
response.text().await.unwrap_or_default()
)));
}
Ok(response.text().await?)
}
pub async fn download_structure_async(
pdb_id: &str,
format: FileFormat,
) -> Result<PdbStructure, DownloadError> {
let content = download_pdb_string_async(pdb_id, format).await?;
let structure = match format {
FileFormat::Pdb => parse_pdb_string(&content)?,
FileFormat::Cif => parse_mmcif_string(&content)?,
};
Ok(structure)
}
pub async fn download_to_file_async<P: AsRef<Path>>(
pdb_id: &str,
path: P,
format: FileFormat,
) -> Result<(), DownloadError> {
let content = download_pdb_string_async(pdb_id, format).await?;
let mut file = File::create(path).await?;
file.write_all(content.as_bytes()).await?;
Ok(())
}
async fn download_with_retry(
client: &reqwest::Client,
pdb_id: &str,
format: FileFormat,
retries: usize,
timeout: Duration,
) -> Result<PdbStructure, DownloadError> {
let url = build_download_url(pdb_id, format);
let mut last_error = DownloadError::RequestFailed("No attempts made".to_string());
for attempt in 0..=retries {
if attempt > 0 {
let backoff = Duration::from_secs(1 << (attempt - 1));
sleep(backoff).await;
}
let result = async {
let response = client.get(&url).timeout(timeout).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(DownloadError::NotFound(pdb_id.to_string()));
}
if !response.status().is_success() {
return Err(DownloadError::RequestFailed(format!(
"HTTP {}: {}",
response.status(),
response.text().await.unwrap_or_default()
)));
}
let content = response.text().await?;
let structure = match format {
FileFormat::Pdb => parse_pdb_string(&content)?,
FileFormat::Cif => parse_mmcif_string(&content)?,
};
Ok(structure)
}
.await;
match result {
Ok(structure) => return Ok(structure),
Err(e) => {
if matches!(e, DownloadError::NotFound(_)) {
return Err(e);
}
last_error = e;
}
}
}
Err(last_error)
}
pub async fn download_multiple_async(
pdb_ids: &[&str],
format: FileFormat,
options: Option<AsyncDownloadOptions>,
) -> Vec<(String, Result<PdbStructure, DownloadError>)> {
let options = options.unwrap_or_default();
let semaphore = Arc::new(Semaphore::new(options.max_concurrent));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(options.timeout_secs))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let client = Arc::new(client);
let rate_limit = Duration::from_millis(options.rate_limit_ms);
let timeout = Duration::from_secs(options.timeout_secs);
let retries = options.retries;
let tasks: Vec<_> = pdb_ids
.iter()
.enumerate()
.map(|(idx, &pdb_id)| {
let semaphore = Arc::clone(&semaphore);
let client = Arc::clone(&client);
let pdb_id = pdb_id.to_string();
async move {
let _permit = semaphore.acquire().await.unwrap();
if idx > 0 {
sleep(rate_limit).await;
}
let result = download_with_retry(&client, &pdb_id, format, retries, timeout).await;
(pdb_id, result)
}
})
.collect();
join_all(tasks).await
}
pub async fn download_multiple_to_files_async<P: AsRef<Path> + Sync>(
pdb_ids: &[&str],
output_dir: P,
format: FileFormat,
options: Option<AsyncDownloadOptions>,
) -> Vec<(String, Result<(), DownloadError>)> {
let options = options.unwrap_or_default();
let semaphore = Arc::new(Semaphore::new(options.max_concurrent));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(options.timeout_secs))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let client = Arc::new(client);
let rate_limit = Duration::from_millis(options.rate_limit_ms);
let timeout = Duration::from_secs(options.timeout_secs);
let retries = options.retries;
let output_dir = output_dir.as_ref().to_path_buf();
let tasks: Vec<_> = pdb_ids
.iter()
.enumerate()
.map(|(idx, &pdb_id)| {
let semaphore = Arc::clone(&semaphore);
let client = Arc::clone(&client);
let pdb_id = pdb_id.to_string();
let output_dir = output_dir.clone();
async move {
let _permit = semaphore.acquire().await.unwrap();
if idx > 0 {
sleep(rate_limit).await;
}
let url = build_download_url(&pdb_id, format);
let filename = format!("{}.{}", pdb_id.to_uppercase(), format.extension());
let path = output_dir.join(filename);
let result: Result<(), DownloadError> = async {
let mut last_error =
DownloadError::RequestFailed("No attempts made".to_string());
for attempt in 0..=retries {
if attempt > 0 {
let backoff = Duration::from_secs(1 << (attempt - 1));
sleep(backoff).await;
}
let download_result = async {
let response = client.get(&url).timeout(timeout).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(DownloadError::NotFound(pdb_id.clone()));
}
if !response.status().is_success() {
return Err(DownloadError::RequestFailed(format!(
"HTTP {}: {}",
response.status(),
response.text().await.unwrap_or_default()
)));
}
let content = response.text().await?;
let mut file = File::create(&path).await?;
file.write_all(content.as_bytes()).await?;
Ok(())
}
.await;
match download_result {
Ok(()) => return Ok(()),
Err(e) => {
if matches!(e, DownloadError::NotFound(_)) {
return Err(e);
}
last_error = e;
}
}
}
Err(last_error)
}
.await;
(pdb_id, result)
}
})
.collect();
join_all(tasks).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_async_options_default() {
let options = AsyncDownloadOptions::default();
assert_eq!(options.max_concurrent, 5);
assert_eq!(options.rate_limit_ms, 100);
assert_eq!(options.timeout_secs, 30);
assert_eq!(options.retries, 2);
}
#[test]
fn test_async_options_conservative() {
let options = AsyncDownloadOptions::conservative();
assert_eq!(options.max_concurrent, 2);
assert_eq!(options.rate_limit_ms, 500);
assert_eq!(options.timeout_secs, 60);
assert_eq!(options.retries, 3);
}
#[test]
fn test_async_options_fast() {
let options = AsyncDownloadOptions::fast();
assert_eq!(options.max_concurrent, 20);
assert_eq!(options.rate_limit_ms, 25);
assert_eq!(options.timeout_secs, 30);
assert_eq!(options.retries, 1);
}
#[test]
fn test_async_options_builder() {
let options = AsyncDownloadOptions::new()
.with_max_concurrent(10)
.with_rate_limit_ms(50)
.with_timeout_secs(45)
.with_retries(5);
assert_eq!(options.max_concurrent, 10);
assert_eq!(options.rate_limit_ms, 50);
assert_eq!(options.timeout_secs, 45);
assert_eq!(options.retries, 5);
}
#[test]
fn test_build_download_url_pdb() {
let url = build_download_url("1ubq", FileFormat::Pdb);
assert_eq!(url, "https://files.rcsb.org/download/1UBQ.pdb");
}
#[test]
fn test_build_download_url_cif() {
let url = build_download_url("8hm2", FileFormat::Cif);
assert_eq!(url, "https://files.rcsb.org/download/8HM2.cif");
}
}