use futures::StreamExt;
use futures::TryStreamExt;
use once_cell::sync::Lazy;
use reqwest::{Client, ClientBuilder};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use butterfly_common::{Error, Result};
use crate::core::source::{DownloadSource, SourceConfig};
use crate::core::stream::{create_http_stream, DownloadOptions, DownloadStream, OverwriteBehavior};
const MAX_RETRY_ATTEMPTS: u32 = 3;
const BASE_RETRY_DELAY_MS: u64 = 1000;
static GLOBAL_CLIENT: Lazy<Client> = Lazy::new(|| {
ClientBuilder::new()
.tcp_keepalive(Duration::from_secs(60))
.pool_idle_timeout(Duration::from_secs(90))
.pool_max_idle_per_host(20)
.timeout(Duration::from_secs(30)) .connect_timeout(Duration::from_secs(10)) .user_agent(format!("butterfly-dl/{}", env!("BUTTERFLY_VERSION")))
.build()
.expect("Failed to create HTTP client")
});
async fn retry_on_network_error<F, Fut, T>(operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut attempt = 0;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(Error::NetworkError(msg)) if attempt < MAX_RETRY_ATTEMPTS => {
attempt += 1;
let delay = BASE_RETRY_DELAY_MS * (1 << (attempt - 1)); eprintln!("⚠️ Network error (attempt {attempt}): {msg}. Retrying in {delay}ms...");
tokio::time::sleep(Duration::from_millis(delay)).await;
}
Err(e) => return Err(e), }
}
}
async fn check_overwrite_permission(file_path: &str, behavior: &OverwriteBehavior) -> Result<bool> {
if !std::path::Path::new(file_path).exists() {
return Ok(true); }
match behavior {
OverwriteBehavior::Force => {
eprintln!("⚠️ Overwriting existing file: {file_path}");
Ok(true)
}
OverwriteBehavior::NeverOverwrite => Err(Error::IoError(std::io::Error::new(
std::io::ErrorKind::AlreadyExists,
format!("File already exists: {file_path} (use --force to overwrite)"),
))),
OverwriteBehavior::Prompt => {
eprintln!("⚠️ File already exists: {file_path}");
eprint!("Overwrite? [y/N]: ");
use std::io::Write;
std::io::stderr().flush().map_err(Error::IoError)?;
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.map_err(Error::IoError)?;
let response = input.trim().to_lowercase();
match response.as_str() {
"y" | "yes" => {
eprintln!("✅ Overwriting file");
Ok(true)
}
_ => {
eprintln!("❌ Download cancelled");
Err(Error::IoError(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"Download cancelled by user",
)))
}
}
}
}
}
pub struct Downloader {
config: SourceConfig,
}
impl Default for Downloader {
fn default() -> Self {
Self::new()
}
}
impl Downloader {
pub fn new() -> Self {
Self {
config: SourceConfig::default(),
}
}
pub fn with_config(config: SourceConfig) -> Self {
Self { config }
}
pub async fn download_to_file(
&self,
source: &str,
file_path: &str,
options: &DownloadOptions,
) -> Result<()> {
check_overwrite_permission(file_path, &options.overwrite).await?;
let download_source = crate::core::source::resolve_source(source, &self.config)?;
match download_source {
DownloadSource::Http { url } => {
self.download_http_to_file(&url, file_path, options).await
}
}
}
pub async fn download_stream(
&self,
source: &str,
options: &DownloadOptions,
) -> Result<(DownloadStream, u64)> {
let download_source = crate::core::source::resolve_source(source, &self.config)?;
match download_source {
DownloadSource::Http { url } => self.create_http_stream(&url, options).await,
}
}
async fn download_http_to_file(
&self,
url: &str,
file_path: &str,
options: &DownloadOptions,
) -> Result<()> {
let client = &*GLOBAL_CLIENT;
let (total_size, supports_ranges) = retry_on_network_error(|| async {
let head_response = client.head(url).send().await?;
if !head_response.status().is_success() {
return Err(create_helpful_http_error(url, head_response.status()));
}
let total_size = head_response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.ok_or_else(|| Error::HttpError("Could not determine file size".to_string()))?;
let supports_ranges = head_response
.headers()
.get("accept-ranges")
.is_some_and(|v| v.to_str().unwrap_or("") == "bytes");
Ok((total_size, supports_ranges))
})
.await?;
let file = create_optimized_file(file_path, Some(total_size)).await?;
let optimal_connections =
calculate_optimal_connections(total_size, options.max_connections);
if !supports_ranges || optimal_connections == 1 {
self.download_single_resilient(
client,
url,
Box::new(file),
total_size,
supports_ranges,
options,
)
.await
} else {
self.download_http_parallel_resilient(client, url, Box::new(file), total_size, options)
.await
}
}
async fn create_http_stream(
&self,
url: &str,
_options: &DownloadOptions,
) -> Result<(DownloadStream, u64)> {
let client = &*GLOBAL_CLIENT;
let head_response = client.head(url).send().await?;
if !head_response.status().is_success() {
return Err(create_helpful_http_error(url, head_response.status()));
}
let total_size = head_response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let response = client.get(url).send().await?;
if !response.status().is_success() {
let status = response.status();
return Err(Error::HttpError(format!("Failed to download: {status}")));
}
let stream = create_http_stream(response);
Ok((stream, total_size))
}
async fn download_single_resilient(
&self,
client: &Client,
url: &str,
mut writer: Box<dyn AsyncWrite + Send + Unpin>,
total_size: u64,
supports_ranges: bool,
options: &DownloadOptions,
) -> Result<()> {
let mut downloaded = 0u64;
while downloaded < total_size {
let result = if downloaded == 0 {
retry_on_network_error(|| async {
let response = client.get(url).send().await?;
let stream = create_http_stream(response);
Ok(stream)
})
.await
} else if supports_ranges {
retry_on_network_error(|| async {
let range_header = format!("bytes={downloaded}-");
let response = client.get(url).header("Range", range_header).send().await?;
let stream = create_http_stream(response);
Ok(stream)
})
.await
} else {
return Err(Error::NetworkError(
"Cannot resume download - server doesn't support ranges".to_string(),
));
};
match result {
Ok(stream) => {
match self
.stream_to_writer_resilient(
stream,
&mut writer,
total_size,
&mut downloaded,
options,
)
.await
{
Ok(()) => break, Err(Error::NetworkError(_)) => {
eprintln!("⚠️ Stream interrupted at {downloaded} bytes, resuming...");
continue; }
Err(e) => return Err(e), }
}
Err(e) => return Err(e),
}
}
writer.flush().await?;
Ok(())
}
async fn stream_to_writer_resilient(
&self,
mut stream: DownloadStream,
writer: &mut Box<dyn AsyncWrite + Send + Unpin>,
total_size: u64,
downloaded: &mut u64,
options: &DownloadOptions,
) -> Result<()> {
let mut buffer = vec![0u8; options.buffer_size];
loop {
let bytes_read = stream
.read(&mut buffer)
.await
.map_err(|e| Error::NetworkError(format!("Stream read error: {e}")))?;
if bytes_read == 0 {
break;
}
writer.write_all(&buffer[..bytes_read]).await?;
*downloaded += bytes_read as u64;
if let Some(ref progress) = options.progress {
progress(*downloaded, total_size);
}
}
Ok(())
}
async fn download_http_parallel_resilient(
&self,
client: &Client,
url: &str,
mut writer: Box<dyn AsyncWrite + Send + Unpin>,
total_size: u64,
options: &DownloadOptions,
) -> Result<()> {
let connections = calculate_optimal_connections(total_size, options.max_connections);
let chunk_size = total_size / connections as u64;
let ranges: Vec<(u64, u64)> = (0..connections)
.map(|i| {
let start = i as u64 * chunk_size;
let end = if i == connections - 1 {
total_size - 1
} else {
start + chunk_size - 1
};
(start, end)
})
.collect();
let downloaded_bytes = Arc::new(AtomicU64::new(0));
let progress_handle = if let Some(progress_fn) = options.progress.clone() {
let downloaded_clone = Arc::clone(&downloaded_bytes);
Some(tokio::spawn(async move {
while downloaded_clone.load(Ordering::Relaxed) < total_size {
let current = downloaded_clone.load(Ordering::Relaxed);
progress_fn(current, total_size);
tokio::time::sleep(Duration::from_millis(100)).await;
}
progress_fn(total_size, total_size);
}))
} else {
None
};
let mut ring_buffer: Vec<Option<Vec<u8>>> = vec![None; ranges.len()];
let mut next_chunk_to_write = 0;
let stream = futures::stream::iter(ranges.into_iter().enumerate())
.map(|(idx, (start, end))| {
let client = client.clone();
let url = url.to_string();
let downloaded_bytes = Arc::clone(&downloaded_bytes);
async move {
retry_on_network_error(|| async {
let range_header = format!("bytes={start}-{end}");
let response = client
.get(&url)
.header("Range", range_header)
.send()
.await?;
if !response.status().is_success() && response.status().as_u16() != 206 {
let status = response.status();
return Err(Error::HttpError(format!(
"Range request failed: {status}"
)));
}
let mut chunk_data = Vec::new();
let mut stream = response.bytes_stream();
while let Some(bytes_chunk) = stream.try_next().await? {
chunk_data.extend_from_slice(&bytes_chunk);
downloaded_bytes.fetch_add(bytes_chunk.len() as u64, Ordering::Relaxed);
}
Ok::<(usize, Vec<u8>), Error>((idx, chunk_data))
})
.await
}
})
.buffer_unordered(connections);
tokio::pin!(stream);
while let Some(result) = stream.next().await {
let (idx, data) = result?;
ring_buffer[idx] = Some(data);
while next_chunk_to_write < ring_buffer.len()
&& ring_buffer[next_chunk_to_write].is_some()
{
if let Some(chunk) = ring_buffer[next_chunk_to_write].take() {
writer.write_all(&chunk).await?;
}
next_chunk_to_write += 1;
}
}
writer.flush().await?;
if let Some(handle) = progress_handle {
handle.abort();
}
Ok(())
}
}
fn calculate_optimal_connections(file_size: u64, max_connections: usize) -> usize {
let cpu_count = num_cpus::get();
let base_connections = match file_size {
size if size <= 1024 * 1024 => 1, size if size <= 10 * 1024 * 1024 => 2, size if size <= 100 * 1024 * 1024 => 4, size if size <= 512 * 1024 * 1024 => 8, size if size <= 1024 * 1024 * 1024 => 12, _ => 16, };
std::cmp::min(
base_connections,
std::cmp::min(max_connections, cpu_count * 2),
)
}
async fn create_optimized_file(path: &str, _size_hint: Option<u64>) -> Result<tokio::fs::File> {
#[cfg(unix)]
{
#[cfg(target_os = "linux")]
{
use std::os::unix::fs::OpenOptionsExt;
if let Some(size) = _size_hint {
if size > 1024 * 1024 * 1024 {
match std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.custom_flags(libc::O_DIRECT)
.open(path)
{
Ok(file) => {
return Ok(tokio::fs::File::from_std(file));
}
Err(_) => {
}
}
}
}
}
}
tokio::fs::File::create(path).await.map_err(Into::into)
}
fn create_helpful_http_error(url: &str, status: reqwest::StatusCode) -> Error {
let mut message = format!("Failed to get file info: {status}");
if status == reqwest::StatusCode::NOT_FOUND {
let source = if url.contains("planet.openstreetmap.org") {
Some("planet".to_string())
} else if url.contains("download.geofabrik.de") {
url.split("download.geofabrik.de/")
.nth(1)
.and_then(|after_domain| after_domain.strip_suffix("-latest.osm.pbf"))
.map(|s| s.to_string())
} else {
None
};
if let Some(source) = source {
if let Some(suggestion) = butterfly_common::error::suggest_correction(&source) {
message = format!("Source '{source}' not found. Did you mean '{suggestion}'?");
} else {
message = format!(
"Source '{source}' not found. Check the URL or try common sources like: planet, europe, asia"
);
}
} else {
message = format!(
"Source not found ({status}): {url}. Check the URL or try common sources like: planet, europe, asia"
);
}
}
Error::HttpError(message)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tempfile::NamedTempFile;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn test_calculate_optimal_connections() {
let cpu_count = num_cpus::get();
assert_eq!(calculate_optimal_connections(512 * 1024, 16), 1); assert_eq!(calculate_optimal_connections(5 * 1024 * 1024, 16), 2); assert_eq!(calculate_optimal_connections(50 * 1024 * 1024, 16), 4); assert_eq!(
calculate_optimal_connections(200 * 1024 * 1024, 16),
std::cmp::min(8, cpu_count * 2)
); assert_eq!(
calculate_optimal_connections(2 * 1024 * 1024 * 1024, 16),
std::cmp::min(16, cpu_count * 2)
); }
#[test]
fn test_calculate_optimal_connections_with_limit() {
let result = calculate_optimal_connections(2 * 1024 * 1024 * 1024, 8);
let cpu_count = num_cpus::get();
let expected = std::cmp::min(8, cpu_count * 2); assert_eq!(result, expected);
}
#[test]
fn test_calculate_optimal_connections_small_files() {
assert_eq!(calculate_optimal_connections(100 * 1024, 16), 1); assert_eq!(calculate_optimal_connections(500 * 1024, 16), 1); assert_eq!(calculate_optimal_connections(1024 * 1024, 16), 1); assert_eq!(calculate_optimal_connections(1024 * 1024 + 1, 16), 2); }
#[tokio::test]
async fn test_resilient_download_with_network_failure() {
let mock_server = MockServer::start().await;
let test_data = b"A".repeat(1024);
let total_size = test_data.len() as u64;
let head_call_count = Arc::new(AtomicUsize::new(0));
let get_call_count = Arc::new(AtomicUsize::new(0));
let head_count_clone = Arc::clone(&head_call_count);
Mock::given(method("HEAD"))
.and(path("/test-file.pbf"))
.respond_with(move |_: &wiremock::Request| {
head_count_clone.fetch_add(1, Ordering::SeqCst);
ResponseTemplate::new(200)
.insert_header("content-length", total_size.to_string().as_str())
.insert_header("accept-ranges", "bytes")
})
.mount(&mock_server)
.await;
let get_count_clone = Arc::clone(&get_call_count);
let test_data_clone = test_data.clone();
Mock::given(method("GET"))
.and(path("/test-file.pbf"))
.respond_with(move |_req: &wiremock::Request| {
let call_num = get_count_clone.fetch_add(1, Ordering::SeqCst) + 1;
println!("GET call #{call_num}");
println!("Call {call_num} - returning full data");
ResponseTemplate::new(200)
.insert_header("content-length", total_size.to_string().as_str())
.set_body_raw(test_data_clone.clone(), "application/octet-stream")
})
.mount(&mock_server)
.await;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_str().unwrap();
let downloader = Downloader::new();
let options = DownloadOptions::default();
let base_uri = mock_server.uri();
let url = format!("{base_uri}/test-file.pbf");
let result = downloader
.download_http_to_file(&url, file_path, &options)
.await;
assert!(result.is_ok(), "Download should succeed: {result:?}");
let downloaded_data = std::fs::read(file_path).unwrap();
assert_eq!(
downloaded_data, test_data,
"Downloaded file should match original data"
);
let head_calls = head_call_count.load(Ordering::SeqCst);
let get_calls = get_call_count.load(Ordering::SeqCst);
println!("HEAD calls: {head_calls}, GET calls: {get_calls}");
assert_eq!(head_calls, 1, "Should have made 1 HEAD request");
assert_eq!(get_calls, 1, "Should have made 1 GET request");
println!("✅ Basic download test passed! Made {head_calls} HEAD and {get_calls} GET calls");
}
#[tokio::test]
async fn test_retry_exponential_backoff() {
use std::time::Instant;
let start_time = Instant::now();
let call_count = Arc::new(AtomicUsize::new(0));
let result = retry_on_network_error(|| {
let count_clone = Arc::clone(&call_count);
async move {
let call_num = count_clone.fetch_add(1, Ordering::SeqCst) + 1;
if call_num <= 2 {
Err(Error::NetworkError("Simulated network failure".to_string()))
} else {
Ok("success")
}
}
})
.await;
let elapsed = start_time.elapsed();
let calls = call_count.load(Ordering::SeqCst);
assert!(result.is_ok());
assert_eq!(calls, 3);
assert!(
elapsed >= Duration::from_secs(3),
"Should implement exponential backoff delays"
);
println!("✅ Exponential backoff test passed! {calls} calls in {elapsed:?}");
}
#[tokio::test]
async fn test_overwrite_behavior_force() {
use crate::core::stream::OverwriteBehavior;
use tempfile::NamedTempFile;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_str().unwrap();
std::fs::write(file_path, "existing content").unwrap();
assert!(std::path::Path::new(file_path).exists());
let result = check_overwrite_permission(file_path, &OverwriteBehavior::Force).await;
assert!(result.is_ok(), "Force overwrite should succeed");
assert!(result.unwrap(), "Force overwrite should return true");
println!("✅ Force overwrite test passed!");
}
#[tokio::test]
async fn test_overwrite_behavior_never() {
use crate::core::stream::OverwriteBehavior;
use tempfile::NamedTempFile;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_str().unwrap();
std::fs::write(file_path, "existing content").unwrap();
assert!(std::path::Path::new(file_path).exists());
let result =
check_overwrite_permission(file_path, &OverwriteBehavior::NeverOverwrite).await;
assert!(
result.is_err(),
"Never overwrite should fail when file exists"
);
let error = result.unwrap_err();
match error {
Error::IoError(io_err) => {
assert_eq!(io_err.kind(), std::io::ErrorKind::AlreadyExists);
assert!(io_err.to_string().contains("use --force to overwrite"));
}
_ => panic!("Expected IoError with AlreadyExists kind"),
}
println!("✅ Never overwrite test passed!");
}
#[tokio::test]
async fn test_overwrite_behavior_new_file() {
use crate::core::stream::OverwriteBehavior;
use tempfile::tempdir;
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("nonexistent.pbf");
let file_path_str = file_path.to_str().unwrap();
assert!(!std::path::Path::new(file_path_str).exists());
for behavior in [
OverwriteBehavior::Force,
OverwriteBehavior::NeverOverwrite,
OverwriteBehavior::Prompt,
] {
let result = check_overwrite_permission(file_path_str, &behavior).await;
assert!(
result.is_ok(),
"All behaviors should succeed for non-existent file"
);
assert!(
result.unwrap(),
"All behaviors should return true for non-existent file"
);
}
println!("✅ New file test passed!");
}
}