use std::cmp::min;
use std::fmt;
use std::path::PathBuf;
use std::sync::Arc;
use reqwest::header::{HeaderMap, HeaderValue};
use crate::client::proxy::ProxyConfig;
use crate::download::config::speed_profile::SpeedProfile;
use crate::download::types::ProgressCallback;
use crate::error::{Error, Result};
use crate::model::format::HttpHeaders;
use crate::utils::fs;
use crate::utils::retry::RetryPolicy;
const DEFAULT_PARALLEL_SEGMENTS: usize = 4;
const DEFAULT_SEGMENT_SIZE: usize = 5 * 1024 * 1024; const DEFAULT_RETRY_ATTEMPTS: usize = 3;
pub struct Fetcher {
pub(super) url: String,
pub(super) parallel_segments: usize,
pub(super) segment_size: usize,
pub(super) retry_attempts: usize,
pub(super) retry_policy: RetryPolicy,
pub(super) client: Arc<reqwest::Client>,
pub(super) extra_headers: Option<reqwest::header::HeaderMap>,
pub(super) progress_callback: Option<ProgressCallback>,
pub(super) speed_profile: SpeedProfile,
pub(super) range_constraint: Option<(u64, u64)>,
}
impl fmt::Debug for Fetcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Fetcher")
.field("url", &self.url)
.field("parallel_segments", &self.parallel_segments)
.field("segment_size", &self.segment_size)
.field("retry_attempts", &self.retry_attempts)
.field("speed_profile", &self.speed_profile)
.field("range_constraint", &self.range_constraint)
.field("has_callback", &self.progress_callback.is_some())
.finish()
}
}
impl fmt::Display for Fetcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Fetcher(url={}, segments={}, profile={}, range={:?})",
self.url, self.parallel_segments, self.speed_profile, self.range_constraint
)
}
}
pub(super) struct PartsGuard {
path: PathBuf,
keep: bool,
}
impl PartsGuard {
pub(super) fn new(path: PathBuf) -> Self {
Self { path, keep: false }
}
pub(super) fn commit(&mut self) {
self.keep = true;
}
}
impl Drop for PartsGuard {
fn drop(&mut self) {
if !self.keep {
let _ = std::fs::remove_file(&self.path);
}
}
}
impl Fetcher {
pub fn new(url: impl AsRef<str>, proxy: Option<&ProxyConfig>, http_headers: Option<HttpHeaders>) -> Result<Self> {
tracing::debug!(
url = %url.as_ref(),
has_proxy = proxy.is_some(),
has_headers = http_headers.is_some(),
"⚙️ Creating fetcher"
);
let (user_agent, default_headers) = match &http_headers {
Some(headers) => (Some(headers.user_agent.clone()), Some(headers.to_header_map())),
None => (None, None),
};
let client = crate::utils::http::build_http_client(crate::utils::http::HttpClientConfig {
proxy,
user_agent,
default_headers,
http2_adaptive_window: true,
..Default::default()
})?;
Ok(Self::with_client(url, client))
}
pub fn with_client(url: impl AsRef<str>, client: Arc<reqwest::Client>) -> Self {
tracing::debug!(
url = %url.as_ref(),
"⚙️ Creating fetcher with custom client"
);
Self {
url: url.as_ref().to_string(),
parallel_segments: DEFAULT_PARALLEL_SEGMENTS,
segment_size: DEFAULT_SEGMENT_SIZE,
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
retry_policy: RetryPolicy::default(),
client,
extra_headers: None,
progress_callback: None,
speed_profile: SpeedProfile::default(),
range_constraint: None,
}
}
pub fn with_client_and_headers(
url: impl AsRef<str>,
client: Arc<reqwest::Client>,
headers: crate::model::format::HttpHeaders,
) -> Self {
let mut header_map = headers.to_header_map();
if let Ok(ua) = reqwest::header::HeaderValue::from_str(&headers.user_agent) {
header_map.insert(reqwest::header::USER_AGENT, ua);
}
let mut fetcher = Self::with_client(url, client);
fetcher.extra_headers = Some(header_map);
fetcher
}
pub fn with_parallel_segments(mut self, segments: usize) -> Self {
tracing::debug!(
segments = segments,
url = %self.url,
"⚙️ Configuring parallel segments for fetcher"
);
self.parallel_segments = segments;
self
}
pub fn with_segment_size(mut self, size: usize) -> Self {
self.segment_size = size;
self
}
pub fn with_retry_attempts(mut self, attempts: usize) -> Self {
self.retry_attempts = attempts;
self
}
pub fn with_progress_callback<F>(mut self, callback: F) -> Self
where
F: Fn(u64, u64) + Send + Sync + 'static,
{
self.progress_callback = Some(Arc::new(callback));
self
}
pub fn with_speed_profile(mut self, profile: SpeedProfile) -> Self {
self.speed_profile = profile;
self
}
pub fn with_range(mut self, start: u64, end: u64) -> Self {
self.range_constraint = Some((start, end));
self
}
pub async fn fetch_json(&self, auth_token: Option<String>) -> Result<serde_json::Value> {
let response = self.fetch_internal(auth_token).await?;
let json = response.json().await?;
Ok(json)
}
pub async fn fetch_text(&self, auth_token: Option<String>) -> Result<String> {
let response = self.fetch_internal(auth_token).await?;
let text = response.text().await?;
Ok(text)
}
async fn fetch_internal(&self, auth_token: Option<String>) -> Result<reqwest::Response> {
tracing::debug!(
url = %self.url,
has_token = auth_token.is_some(),
"📥 Fetching data"
);
let mut headers = HeaderMap::new();
if let Some(auth_token) = auth_token {
let value = HeaderValue::from_str(&format!("Bearer {}", auth_token)).map_err(|e| Error::InvalidHeader {
header: "Authorization".to_string(),
reason: e.to_string(),
})?;
headers.insert(reqwest::header::AUTHORIZATION, value);
}
let response = self
.client
.get(&self.url)
.headers(headers)
.send()
.await?
.error_for_status()?;
Ok(response)
}
pub async fn fetch_asset(&self, destination: impl Into<PathBuf>) -> Result<()> {
let destination: PathBuf = destination.into();
if let Some((start, end)) = self.range_constraint {
return self.fetch_asset_range(destination, start, end).await;
}
tracing::debug!(
url = %self.url,
destination = ?destination,
parallel_segments = self.parallel_segments,
segment_size = self.segment_size,
"📥 Fetching asset to file"
);
fs::create_parent_dir(&destination).await?;
if let Some(parent) = destination.parent()
&& !parent.exists()
{
tokio::fs::create_dir_all(parent).await?;
}
let file_exists = destination.as_path().exists();
let file_size = if file_exists {
match tokio::fs::metadata(&destination).await {
Ok(metadata) => Some(metadata.len()),
Err(_) => None,
}
} else {
None
};
let (supports_ranges, content_length) = self.probe_range_support().await?;
if !supports_ranges {
return self.fetch_asset_simple(destination).await;
}
let Some(content_length) = content_length else {
return self.fetch_asset_simple(destination).await;
};
if file_size.is_some_and(|size| size == content_length) {
tracing::debug!(
destination = ?destination,
size = content_length,
"✅ File already exists with correct size, skipping download"
);
return Ok(());
}
let file = Arc::new(self.open_download_file(&destination, file_size, content_length).await?);
let segment_size = self.segment_size as u64;
let ranges: Vec<(u64, u64)> = (0..content_length.div_ceil(segment_size))
.map(|i| {
let start = i * segment_size;
let end = min(start + segment_size - 1, content_length - 1);
(start, end)
})
.collect();
self.run_parallel_segments(file, file_exists, ranges, 0, content_length, &destination)
.await
}
pub(crate) async fn fetch_asset_range(
&self,
destination: impl Into<PathBuf>,
byte_start: u64,
byte_end: u64,
) -> Result<()> {
let destination: PathBuf = destination.into();
let range_len = byte_end - byte_start + 1;
tracing::debug!(
url = %self.url,
destination = ?destination,
byte_start,
byte_end,
range_len,
parallel_segments = self.parallel_segments,
segment_size = self.segment_size,
"📥 Fetching asset range to file"
);
fs::create_parent_dir(&destination).await?;
if let Some(parent) = destination.parent()
&& !parent.exists()
{
tokio::fs::create_dir_all(parent).await?;
}
let file_exists = destination.as_path().exists();
let file_size = if file_exists {
tokio::fs::metadata(&destination).await.ok().map(|m| m.len())
} else {
None
};
if file_size.is_some_and(|size| size == range_len) {
tracing::debug!(
destination = ?destination,
size = range_len,
"✅ Range already downloaded with correct size, skipping"
);
return Ok(());
}
let file = Arc::new(self.open_download_file(&destination, file_size, range_len).await?);
let segment_size = self.segment_size as u64;
let ranges: Vec<(u64, u64)> = (0..range_len.div_ceil(segment_size))
.map(|i| {
let seg_start = byte_start + i * segment_size;
let seg_end = min(seg_start + segment_size - 1, byte_end);
(seg_start, seg_end)
})
.collect();
self.run_parallel_segments(file, file_exists, ranges, byte_start, range_len, &destination)
.await?;
tracing::debug!(
byte_start,
byte_end,
destination = ?destination,
"✅ Asset range downloaded"
);
Ok(())
}
}