#![allow(unused)]
use crate::constants::DEFAULT_USER_AGENT;
use anyhow::{Result, anyhow};
use reqwest::{Client, header};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
#[cfg(feature = "hickory_dns")]
use std::sync::{Arc, OnceLock};
#[cfg(feature = "hickory_dns")]
static GLOBAL_DNS_RESOLVER: OnceLock<
Arc<hickory_resolver::Resolver<hickory_resolver::name_server::TokioConnectionProvider>>,
> = OnceLock::new();
#[cfg(feature = "hickory_dns")]
async fn get_or_init_dns_resolver()
-> Result<Arc<hickory_resolver::Resolver<hickory_resolver::name_server::TokioConnectionProvider>>> {
use hickory_proto::xfer::Protocol;
use hickory_resolver::Resolver;
use hickory_resolver::config::*;
use hickory_resolver::name_server::TokioConnectionProvider;
if let Some(resolver) = GLOBAL_DNS_RESOLVER.get() {
return Ok(resolver.clone());
}
let config = if let Ok(custom_dns) = std::env::var("PAYLOAD_DUMPER_CUSTOM_DNS") {
let dns_ips: Result<Vec<_>> = custom_dns
.split(',')
.map(|s| s.trim().parse::<std::net::IpAddr>())
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| anyhow!("Invalid DNS IP in PAYLOAD_DUMPER_CUSTOM_DNS: {}", e));
let dns_ips = dns_ips?;
if dns_ips.is_empty() {
return Err(anyhow!("PAYLOAD_DUMPER_CUSTOM_DNS is empty"));
}
let mut config = ResolverConfig::new();
for ip in dns_ips {
let socket_addr = std::net::SocketAddr::new(ip, 53);
config.add_name_server(NameServerConfig {
socket_addr,
protocol: Protocol::Udp,
tls_dns_name: None,
http_endpoint: None,
trust_negative_responses: true,
bind_addr: None,
});
}
config
} else {
ResolverConfig::cloudflare()
};
let resolver =
Resolver::builder_with_config(config, TokioConnectionProvider::default()).build();
let resolver = Arc::new(resolver);
Ok(GLOBAL_DNS_RESOLVER.get_or_init(|| resolver.clone()).clone())
}
async fn create_http_client(user_agent: Option<&str>, cookies: Option<&str>) -> Result<Client> {
let mut headers = header::HeaderMap::new();
let ua = user_agent.unwrap_or(DEFAULT_USER_AGENT);
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_str(ua)
.map_err(|e| anyhow!("Invalid user agent string: {}", e))?,
);
if let Some(cookie_str) = cookies {
if !cookie_str.is_empty() {
headers.insert(
header::COOKIE,
header::HeaderValue::from_str(cookie_str)
.map_err(|e| anyhow!("Invalid cookie string: {}", e))?,
);
}
}
headers.insert(
header::ACCEPT_ENCODING,
header::HeaderValue::from_static("gzip, deflate, br"),
);
headers.insert(header::ACCEPT, header::HeaderValue::from_static("*/*"));
headers.insert(
header::CONNECTION,
header::HeaderValue::from_static("keep-alive"),
);
headers.insert(
header::CACHE_CONTROL,
header::HeaderValue::from_static("no-transform"),
);
let mut client_builder = Client::builder()
.timeout(Duration::from_secs(600))
.connect_timeout(Duration::from_secs(30))
.pool_max_idle_per_host(10)
.pool_idle_timeout(Duration::from_secs(90))
.tcp_keepalive(Some(Duration::from_secs(30)))
.http2_keep_alive_interval(Some(Duration::from_secs(30)))
.http2_adaptive_window(true)
.default_headers(headers)
.redirect(reqwest::redirect::Policy::limited(10));
#[cfg(feature = "hickory_dns")]
{
use hickory_resolver::name_server::TokioConnectionProvider;
use reqwest::dns::{Name, Resolve, Resolving};
use std::net::SocketAddr;
struct CustomDnsResolver {
resolver: Arc<hickory_resolver::Resolver<TokioConnectionProvider>>,
}
impl Resolve for CustomDnsResolver {
fn resolve(&self, name: Name) -> Resolving {
let resolver = self.resolver.clone();
Box::pin(async move {
let name_str = name.as_str();
let lookup = resolver
.lookup_ip(name_str)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let addrs: Box<dyn Iterator<Item = SocketAddr> + Send> =
Box::new(lookup.into_iter().map(|ip| SocketAddr::new(ip, 0)));
Ok(addrs)
})
}
}
let resolver = get_or_init_dns_resolver()
.await
.map_err(|e| anyhow!("Failed to create DNS resolver: {}", e))?;
client_builder = client_builder.dns_resolver(Arc::new(CustomDnsResolver { resolver }));
}
client_builder
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))
}
pub struct HttpReader {
pub client: Client,
pub url: String,
pub content_length: u64,
}
impl HttpReader {
pub async fn new(url: String, user_agent: Option<&str>, cookies: Option<&str>) -> Result<Self> {
let client = create_http_client(user_agent, cookies).await?;
url::Url::parse(&url).map_err(|e| anyhow!("Invalid URL: {}", e))?;
let mut retry_count = 0;
const MAX_RETRIES: u32 = 3;
let mut last_error = None;
while retry_count < MAX_RETRIES {
match client.head(&url).send().await {
Ok(response) => {
if !response.status().is_success() {
return Err(anyhow!("Failed to access URL: {}", response.status()));
}
let supports_ranges = response
.headers()
.get(header::ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|v| v == "bytes")
.unwrap_or(false);
let content_length = response
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.ok_or_else(|| anyhow!("Could not determine content length"))?;
if content_length == 0 {
return Err(anyhow!("File size is 0"));
}
return Ok(Self {
client,
url,
content_length,
});
}
Err(e) => {
last_error = Some(e);
retry_count += 1;
if retry_count < MAX_RETRIES {
tokio::time::sleep(Duration::from_secs(2u64.pow(retry_count))).await;
}
}
}
}
Err(anyhow!(
"Failed to connect after {} retries. Last error: {}",
MAX_RETRIES,
last_error.unwrap()
))
}
pub async fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<()> {
if offset >= self.content_length {
return Err(anyhow!(
"Offset {} exceeds content length {}",
offset,
self.content_length
));
}
let remaining = self.content_length - offset;
let to_read = std::cmp::min(buf.len() as u64, remaining) as usize;
if to_read == 0 {
return Ok(());
}
let end = offset + to_read as u64 - 1;
let range_header = format!("bytes={}-{}", offset, end);
let mut retry_count = 0;
const MAX_RETRIES: u32 = 3;
let mut last_error = None;
while retry_count < MAX_RETRIES {
match self
.client
.get(&self.url)
.header(header::RANGE, &range_header)
.send()
.await
{
Ok(response) => {
let status = response.status();
if !status.is_success() && status.as_u16() != 206 {
return Err(anyhow!("Range request failed: {}", status));
}
let bytes = response.bytes().await?;
if bytes.len() != to_read {
return Err(anyhow!(
"Server returned incorrect bytes: expected {}, got {}",
to_read,
bytes.len()
));
}
buf[..to_read].copy_from_slice(&bytes);
return Ok(());
}
Err(e) => {
last_error = Some(e);
retry_count += 1;
if retry_count < MAX_RETRIES {
tokio::time::sleep(Duration::from_secs(2u64.pow(retry_count))).await;
}
}
}
}
Err(anyhow!(
"Failed to read after {} retries. Last error: {}",
MAX_RETRIES,
last_error.unwrap()
))
}
}
#[async_trait::async_trait]
impl crate::zip::zip_io::ZipIO for HttpReader {
async fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<()> {
self.read_at(offset, buf).await
}
async fn size(&self) -> Result<u64> {
Ok(self.content_length)
}
}