use bytes::Bytes;
use futures::stream::{self, StreamExt};
use reqwest::{Client, Response};
use sha2::{Sha256, Digest};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio::sync::Semaphore;
use url::Url;
use crate::types::{
DownloadConfig, DownloadResult, MediaError, MediaResult, MediaType,
};
#[derive(Debug, Clone)]
pub struct MediaDownloader {
client: Client,
config: DownloadConfig,
semaphore: Arc<Semaphore>,
}
impl Default for MediaDownloader {
fn default() -> Self {
Self::new(DownloadConfig::default())
}
}
impl MediaDownloader {
pub fn new(config: DownloadConfig) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.user_agent(&config.user_agent)
.build()
.unwrap_or_default();
let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
Self { client, config, semaphore }
}
pub fn with_client(client: Client, config: DownloadConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
Self { client, config, semaphore }
}
pub async fn download(&self, url: &str) -> MediaResult<DownloadResult> {
let _permit = self.semaphore.acquire().await
.map_err(|e| MediaError::Download(e.to_string()))?;
self.download_with_retry(url).await
}
async fn download_with_retry(&self, url: &str) -> MediaResult<DownloadResult> {
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
if attempt > 0 {
let delay = Duration::from_millis(self.config.retry_delay_ms * (1 << (attempt - 1)));
tokio::time::sleep(delay).await;
}
match self.do_download(url).await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| MediaError::Download("Unknown error".to_string())))
}
async fn do_download(&self, url: &str) -> MediaResult<DownloadResult> {
let response = self.client.get(url)
.send()
.await
.map_err(|e| MediaError::Network(e.to_string()))?;
if !response.status().is_success() {
return Err(MediaError::Http(response.status().as_u16(), response.status().to_string()));
}
let content_type = response.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(';').next().unwrap_or(s).to_string());
let content_length = response.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok());
if let Some(max_size) = self.config.max_file_size {
if let Some(size) = content_length {
if size > max_size {
return Err(MediaError::FileTooLarge(size, max_size));
}
}
}
let bytes = self.download_bytes(response).await?;
if let Some(max_size) = self.config.max_file_size {
if bytes.len() as u64 > max_size {
return Err(MediaError::FileTooLarge(bytes.len() as u64, max_size));
}
}
let hash = compute_sha256(&bytes);
let media_type = detect_media_type(&content_type, url);
let base64 = if self.config.encode_base64 {
use base64::Engine;
Some(base64::engine::general_purpose::STANDARD.encode(&bytes))
} else {
None
};
Ok(DownloadResult {
url: url.to_string(),
bytes,
content_type,
size: content_length.unwrap_or(0),
hash,
media_type,
base64,
})
}
async fn download_bytes(&self, response: Response) -> MediaResult<Bytes> {
response.bytes()
.await
.map_err(|e| MediaError::Download(e.to_string()))
}
pub async fn download_many(&self, urls: &[String]) -> Vec<MediaResult<DownloadResult>> {
stream::iter(urls)
.map(|url| {
let downloader = self.clone();
async move {
downloader.download(url).await
}
})
.buffer_unordered(self.config.max_concurrent)
.collect()
.await
}
pub async fn download_to_file(&self, url: &str, path: &Path) -> MediaResult<DownloadResult> {
let result = self.download(url).await?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.await
.map_err(|e| MediaError::Io(e.to_string()))?;
}
let mut file = fs::File::create(path)
.await
.map_err(|e| MediaError::Io(e.to_string()))?;
file.write_all(&result.bytes)
.await
.map_err(|e| MediaError::Io(e.to_string()))?;
Ok(result)
}
pub async fn download_many_to_dir(
&self,
urls: &[String],
dir: &Path,
) -> Vec<MediaResult<(String, std::path::PathBuf)>> {
stream::iter(urls)
.map(|url| {
let downloader = self.clone();
let dir = dir.to_path_buf();
async move {
let filename = url_to_filename(url);
let path = dir.join(&filename);
downloader.download_to_file(url, &path)
.await
.map(|_| (url.clone(), path))
}
})
.buffer_unordered(self.config.max_concurrent)
.collect()
.await
}
}
pub fn compute_sha256(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
pub fn detect_media_type(content_type: &Option<String>, url: &str) -> MediaType {
if let Some(ct) = content_type {
if ct.starts_with("image/") { return MediaType::Image; }
if ct.starts_with("video/") { return MediaType::Video; }
if ct.starts_with("audio/") { return MediaType::Audio; }
if ct.contains("pdf") { return MediaType::Document; }
if ct.contains("document") || ct.contains("spreadsheet") || ct.contains("presentation") {
return MediaType::Document;
}
}
let url_lower = url.to_lowercase();
if url_lower.ends_with(".jpg") || url_lower.ends_with(".jpeg") ||
url_lower.ends_with(".png") || url_lower.ends_with(".gif") ||
url_lower.ends_with(".webp") || url_lower.ends_with(".svg") ||
url_lower.ends_with(".avif") {
return MediaType::Image;
}
if url_lower.ends_with(".mp4") || url_lower.ends_with(".webm") ||
url_lower.ends_with(".avi") || url_lower.ends_with(".mov") ||
url_lower.ends_with(".mkv") {
return MediaType::Video;
}
if url_lower.ends_with(".mp3") || url_lower.ends_with(".wav") ||
url_lower.ends_with(".ogg") || url_lower.ends_with(".flac") ||
url_lower.ends_with(".aac") {
return MediaType::Audio;
}
if url_lower.ends_with(".pdf") || url_lower.ends_with(".doc") ||
url_lower.ends_with(".docx") || url_lower.ends_with(".xls") ||
url_lower.ends_with(".xlsx") || url_lower.ends_with(".ppt") ||
url_lower.ends_with(".pptx") {
return MediaType::Document;
}
MediaType::Other
}
pub fn url_to_filename(url: &str) -> String {
if let Ok(parsed) = Url::parse(url) {
let path = parsed.path();
let filename = path.rsplit('/').next().unwrap_or("download");
if filename.is_empty() || filename == "/" {
let hash = &compute_sha256(url.as_bytes())[..12];
return format!("download_{}", hash);
}
sanitize_filename(filename)
} else {
let hash = &compute_sha256(url.as_bytes())[..12];
format!("download_{}", hash)
}
}
fn sanitize_filename(name: &str) -> String {
let decoded = urlencoding::decode(name).unwrap_or_else(|_| name.into());
decoded.chars()
.map(|c| {
if c.is_alphanumeric() || c == '.' || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
pub fn is_downloadable(url: &str) -> bool {
let url_lower = url.to_lowercase();
if url_lower.starts_with("data:") {
return false;
}
if url_lower.starts_with("javascript:") {
return false;
}
url_lower.starts_with("http://") || url_lower.starts_with("https://")
}
pub fn estimate_size_from_url(url: &str) -> Option<u64> {
if let Ok(parsed) = Url::parse(url) {
for (key, value) in parsed.query_pairs() {
if key == "size" || key == "s" || key == "bytes" {
if let Ok(size) = value.parse::<u64>() {
return Some(size);
}
}
}
}
None
}
pub async fn download_bytes(url: &str) -> MediaResult<Bytes> {
let downloader = MediaDownloader::default();
let result = downloader.download(url).await?;
Ok(result.bytes)
}
pub async fn download_with_hash(url: &str) -> MediaResult<(Bytes, String)> {
let downloader = MediaDownloader::default();
let result = downloader.download(url).await?;
Ok((result.bytes, result.hash))
}
pub async fn download_to_base64(url: &str) -> MediaResult<String> {
let config = DownloadConfig {
encode_base64: true,
..Default::default()
};
let downloader = MediaDownloader::new(config);
let result = downloader.download(url).await?;
result.base64.ok_or_else(|| MediaError::Download("Base64 encoding failed".to_string()))
}
pub async fn save_to_file(url: &str, path: &Path) -> MediaResult<()> {
let downloader = MediaDownloader::default();
downloader.download_to_file(url, path).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_sha256() {
let data = b"Hello, World!";
let hash = compute_sha256(data);
assert!(!hash.is_empty());
assert_eq!(hash.len(), 64); }
#[test]
fn test_detect_media_type_from_content_type() {
assert_eq!(detect_media_type(&Some("image/png".to_string()), ""), MediaType::Image);
assert_eq!(detect_media_type(&Some("video/mp4".to_string()), ""), MediaType::Video);
assert_eq!(detect_media_type(&Some("audio/mpeg".to_string()), ""), MediaType::Audio);
assert_eq!(detect_media_type(&Some("application/pdf".to_string()), ""), MediaType::Document);
}
#[test]
fn test_detect_media_type_from_url() {
assert_eq!(detect_media_type(&None, "https://example.com/image.png"), MediaType::Image);
assert_eq!(detect_media_type(&None, "https://example.com/video.mp4"), MediaType::Video);
assert_eq!(detect_media_type(&None, "https://example.com/audio.mp3"), MediaType::Audio);
assert_eq!(detect_media_type(&None, "https://example.com/doc.pdf"), MediaType::Document);
assert_eq!(detect_media_type(&None, "https://example.com/unknown"), MediaType::Other);
}
#[test]
fn test_url_to_filename() {
assert_eq!(url_to_filename("https://example.com/images/photo.jpg"), "photo.jpg");
assert_eq!(url_to_filename("https://example.com/file%20name.pdf"), "file_name.pdf");
assert!(url_to_filename("https://example.com/").starts_with("download_"));
}
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("file.txt"), "file.txt");
assert_eq!(sanitize_filename("file name.txt"), "file_name.txt");
assert_eq!(sanitize_filename("file<>:\"/\\|?*.txt"), "file_________.txt");
}
#[test]
fn test_is_downloadable() {
assert!(is_downloadable("https://example.com/file.jpg"));
assert!(is_downloadable("http://example.com/file.pdf"));
assert!(!is_downloadable("data:image/png;base64,abc"));
assert!(!is_downloadable("javascript:void(0)"));
assert!(!is_downloadable("/relative/path"));
}
#[test]
fn test_download_config_default() {
let config = DownloadConfig::default();
assert!(config.max_concurrent > 0);
assert!(config.timeout_secs > 0);
}
#[test]
fn test_downloader_creation() {
let downloader = MediaDownloader::default();
assert!(downloader.config.max_concurrent > 0);
}
#[test]
fn test_downloader_with_config() {
let config = DownloadConfig {
max_concurrent: 10,
timeout_secs: 60,
max_retries: 5,
..Default::default()
};
let downloader = MediaDownloader::new(config.clone());
assert_eq!(downloader.config.max_concurrent, 10);
assert_eq!(downloader.config.timeout_secs, 60);
}
#[test]
fn test_estimate_size_from_url() {
assert_eq!(estimate_size_from_url("https://example.com/file?size=1024"), Some(1024));
assert_eq!(estimate_size_from_url("https://example.com/file"), None);
}
#[tokio::test]
async fn test_download_invalid_url() {
let downloader = MediaDownloader::default();
let result = downloader.download("not-a-valid-url").await;
assert!(result.is_err());
}
#[test]
fn test_media_type_detection_priority() {
assert_eq!(
detect_media_type(&Some("video/mp4".to_string()), "https://example.com/image.png"),
MediaType::Video
);
}
}