use super::{ExtractError, ExtractResult};
use std::io::{Seek, SeekFrom, copy};
use std::mem::ManuallyDrop;
use std::{
ffi::OsStr,
io::Read,
path::{Component, Path, PathBuf},
};
use tempfile::SpooledTempFile;
use zip::read::{ZipArchive, ZipFile, read_zipfile_from_stream};
const SAFE_MTIME_FLOOR: u64 = 315_532_800;
pub fn stream_tar_bz2(reader: impl Read) -> tar::Archive<impl Read + Sized> {
tar::Archive::new(bzip2::read::BzDecoder::new(reader))
}
pub(crate) fn stream_tar_zst(
reader: impl Read,
) -> Result<tar::Archive<impl Read + Sized>, ExtractError> {
Ok(tar::Archive::new(zstd::stream::read::Decoder::new(reader)?))
}
pub fn extract_tar_bz2(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
process_with_hashing(reader, |reader| {
let mut archive = stream_tar_bz2(reader);
unpack_tar_archive_sync(&mut archive, destination)?;
Ok(())
})
}
pub fn extract_conda_via_streaming(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
process_with_hashing(reader, |reader| {
while let Some(file) = read_zipfile_from_stream(reader)? {
extract_zipfile(file, destination)?;
}
Ok(())
})
}
pub fn extract_conda_via_buffering(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
if destination.exists() {
std::fs::remove_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
}
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
process_with_hashing(reader, |reader| {
let mut temp_file = SpooledTempFile::new(5 * 1024 * 1024);
copy(reader, &mut temp_file)?;
temp_file.seek(SeekFrom::Start(0))?;
let mut archive = ZipArchive::new(temp_file)?;
for i in 0..archive.len() {
let file = archive.by_index(i)?;
extract_zipfile(file, destination)?;
}
Ok(())
})
}
fn extract_zipfile<R: std::io::Read>(
zip_file: ZipFile<'_, R>,
destination: &Path,
) -> Result<(), ExtractError> {
let mut file = ManuallyDrop::new(zip_file);
if file
.mangled_name()
.file_name()
.map(OsStr::to_string_lossy)
.is_some_and(|file_name| file_name.ends_with(".tar.zst"))
{
let mut archive = stream_tar_zst(&mut *file)?;
unpack_tar_archive_sync(&mut archive, destination)?;
} else {
std::io::copy(&mut *file, &mut std::io::sink())?;
}
let _ = ManuallyDrop::into_inner(file);
Ok(())
}
fn unpack_tar_archive_sync<R: Read>(
archive: &mut tar::Archive<R>,
destination: &Path,
) -> Result<(), ExtractError> {
archive.set_preserve_mtime(false);
for entry in archive.entries().map_err(ExtractError::IoError)? {
let mut entry = entry.map_err(ExtractError::IoError)?;
let mtime = entry.header().mtime().unwrap_or(0);
let is_symlink = entry.header().entry_type().is_symlink();
let entry_path = entry.path().map_err(ExtractError::IoError)?.into_owned();
let unpacked = entry
.unpack_in(destination)
.map_err(ExtractError::IoError)?;
if unpacked && let Some(full_path) = unpacked_destination_path(destination, &entry_path) {
set_mtime_safe(&full_path, mtime, is_symlink);
}
}
Ok(())
}
fn unpacked_destination_path(destination: &Path, entry_path: &Path) -> Option<PathBuf> {
let mut full_path = destination.to_path_buf();
for component in entry_path.components() {
match component {
Component::Prefix(_) | Component::RootDir | Component::CurDir => {}
Component::ParentDir => return None,
Component::Normal(part) => full_path.push(part),
}
}
if full_path == destination {
return None;
}
Some(full_path)
}
fn set_mtime_safe(path: &Path, mtime: u64, is_symlink: bool) {
let clamped = std::cmp::max(mtime, SAFE_MTIME_FLOOR);
let file_time = filetime::FileTime::from_unix_time(clamped as i64, 0);
let result = if is_symlink {
filetime::set_symlink_file_times(path, file_time, file_time)
} else {
filetime::set_file_mtime(path, file_time)
};
if let Err(e) = result {
tracing::warn!(
"Failed to set mtime for '{}': {}. \
The target filesystem may not support this timestamp. \
This does not affect package integrity.",
path.display(),
e
);
}
}
pub(crate) struct SizeCountingReader<R> {
inner: R,
size: u64,
}
impl<R> SizeCountingReader<R> {
pub(crate) fn new(inner: R) -> Self {
Self { inner, size: 0 }
}
pub(crate) fn finalize(self) -> (R, u64) {
(self.inner, self.size)
}
}
impl<R: Read> Read for SizeCountingReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes_read = self.inner.read(buf)?;
self.size += bytes_read as u64;
Ok(bytes_read)
}
}
impl<R: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for SizeCountingReader<R> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let previously_filled = buf.filled().len();
let this = self.as_mut().get_mut();
let reader = std::pin::Pin::new(&mut this.inner);
match reader.poll_read(cx, buf) {
std::task::Poll::Ready(Ok(())) => {
let bytes_read = buf.filled().len() - previously_filled;
this.size += bytes_read as u64;
std::task::Poll::Ready(Ok(()))
}
other => other,
}
}
}
fn process_with_hashing<E, R, F>(reader: R, processor: F) -> Result<ExtractResult, E>
where
R: Read,
E: From<std::io::Error>,
F: FnOnce(
&mut SizeCountingReader<
&mut rattler_digest::HashingReader<
rattler_digest::HashingReader<R, rattler_digest::Sha256>,
rattler_digest::Md5,
>,
>,
) -> Result<(), E>,
{
let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader);
let mut md5_reader =
rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader);
let mut size_reader = SizeCountingReader::new(&mut md5_reader);
processor(&mut size_reader)?;
std::io::copy(&mut size_reader, &mut std::io::sink())?;
let (_, total_size) = size_reader.finalize();
let (sha256_reader, md5) = md5_reader.finalize();
let (_, sha256) = sha256_reader.finalize();
Ok(ExtractResult {
sha256,
md5,
total_size,
})
}