use std::{path::Path, sync::Arc};
use fs_err::tokio as tokio_fs;
use futures_util::stream::TryStreamExt;
use rattler_conda_types::package::CondaArchiveType;
use rattler_digest::Sha256Hash;
use reqwest::Response;
use tokio::io::BufReader;
use tokio_util::{either::Either, io::StreamReader};
use tracing;
use url::Url;
use zip::result::ZipError;
use crate::{DownloadReporter, ExtractError, ExtractResult};
const DATA_DESCRIPTOR_ERROR_MESSAGE: &str = "The file length is not available in the local header";
fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
response
.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
}
async fn get_reader(
url: Url,
client: reqwest_middleware::ClientWithMiddleware,
expected_sha256: Option<Sha256Hash>,
reporter: Option<Arc<dyn DownloadReporter>>,
) -> Result<impl tokio::io::AsyncRead, ExtractError> {
if let Some(reporter) = &reporter {
reporter.on_download_start();
}
if url.scheme() == "file" {
let file =
tokio_fs::File::open(url.to_file_path().expect("Could not convert to file path"))
.await
.map_err(ExtractError::IoError)?;
Ok(Either::Left(BufReader::new(file)))
} else {
let mut request = client.get(url.clone());
if let Some(sha256) = expected_sha256 {
request = request.header("X-Expected-Sha256", format!("{sha256:x}"));
}
let response = request
.send()
.await
.and_then(error_for_status)
.map_err(ExtractError::ReqwestError)?;
let total_bytes = response.content_length();
let mut bytes_received = Box::new(0);
let byte_stream = response.bytes_stream().inspect_ok(move |frame| {
*bytes_received += frame.len() as u64;
if let Some(reporter) = &reporter {
reporter.on_download_progress(*bytes_received, total_bytes);
}
});
Ok(Either::Right(StreamReader::new(byte_stream.map_err(
|err| {
if err.is_body() {
std::io::Error::new(std::io::ErrorKind::Interrupted, err)
} else if err.is_decode() {
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
} else {
std::io::Error::other(err)
}
},
))))
}
}
pub async fn extract_tar_bz2(
client: reqwest_middleware::ClientWithMiddleware,
url: Url,
destination: &Path,
expected_sha256: Option<Sha256Hash>,
reporter: Option<Arc<dyn DownloadReporter>>,
) -> Result<ExtractResult, ExtractError> {
let reader = get_reader(url.clone(), client, expected_sha256, reporter.clone()).await?;
let result = crate::tokio::async_read::extract_tar_bz2(reader, destination).await?;
if let Some(reporter) = &reporter {
reporter.on_download_complete();
}
Ok(result)
}
pub async fn extract_conda(
client: reqwest_middleware::ClientWithMiddleware,
url: Url,
destination: &Path,
expected_sha256: Option<Sha256Hash>,
reporter: Option<Arc<dyn DownloadReporter>>,
) -> Result<ExtractResult, ExtractError> {
let reader = get_reader(
url.clone(),
client.clone(),
expected_sha256,
reporter.clone(),
)
.await?;
match crate::tokio::async_read::extract_conda(reader, destination).await {
Ok(result) => {
if let Some(reporter) = &reporter {
reporter.on_download_complete();
}
Ok(result)
}
Err(ExtractError::ZipError(ZipError::UnsupportedArchive(zip_error)))
if (zip_error.contains(DATA_DESCRIPTOR_ERROR_MESSAGE)) =>
{
tracing::warn!(
"Failed to stream decompress conda package from '{}' due to the presence of zip data descriptors. Falling back to non streaming decompression",
url
);
if let Some(reporter) = &reporter {
reporter.on_download_complete();
}
let new_reader =
get_reader(url.clone(), client, expected_sha256, reporter.clone()).await?;
match crate::tokio::async_read::extract_conda_via_buffering(new_reader, destination)
.await
{
Ok(result) => {
if let Some(reporter) = &reporter {
reporter.on_download_complete();
}
Ok(result)
}
Err(e) => Err(e),
}
}
Err(e) => Err(e),
}
}
pub async fn extract(
client: reqwest_middleware::ClientWithMiddleware,
url: Url,
destination: &Path,
expected_sha256: Option<Sha256Hash>,
reporter: Option<Arc<dyn DownloadReporter>>,
) -> Result<ExtractResult, ExtractError> {
match CondaArchiveType::try_from(Path::new(url.path()))
.ok_or(ExtractError::UnsupportedArchiveType)?
{
CondaArchiveType::TarBz2 => {
extract_tar_bz2(client, url, destination, expected_sha256, reporter).await
}
CondaArchiveType::Conda => {
extract_conda(client, url, destination, expected_sha256, reporter).await
}
}
}