remozipsy 0.2.0

Remote Zip Sync - sync remote zip to local fs
Documentation
use bytes::{Buf, BufMut, BytesMut};
use core::fmt::Debug;
use std::{
    sync::{
        Arc,
        atomic::{AtomicUsize, Ordering},
    },
    time::Instant,
};
use tokio::sync::mpsc;
use zip_core::{
    Signature,
    raw::{
        LocalFileHeader, LocalFileHeaderFixed,
        parse::{Parse, ParseExtend},
    },
};

use crate::model::{Error, FileInfo};

use super::{
    FileSystem, RemoteZip,
    compare::Batch,
    remote::{ProcessedRemoteFileInfo, SupportedCompressionMethod},
};

#[expect(type_alias_bounds)]
pub(super) type DownloadResult<R: RemoteZip, F: FileSystem> =
    Result<(), Error<<R as RemoteZip>::Error, <F as FileSystem>::Error>>;
#[expect(type_alias_bounds)]
pub(super) type UnzipResult<R: RemoteZip, F: FileSystem> =
    Result<Option<(u32, String)>, Error<<R as RemoteZip>::Error, <F as FileSystem>::Error>>;

/// A batch consists of a list of files, seperated by junk data, e.g.
/// ```text
/// [file][junk][file][junk][file][junk]
/// ```
/// We MUST NOT have junk data at the beginning of a batch.
pub(super) async fn download_batch<R, F>(
    mut batch: Batch,
    remote: R,
    bytes_downloaded: Arc<AtomicUsize>,
    files_tx: mpsc::UnboundedSender<(BytesMut, ProcessedRemoteFileInfo)>,
) -> DownloadResult<R, F>
where
    R: RemoteZip + 'static,
    F: FileSystem + Clone + Send + 'static,
{
    let batchsize = batch.len();
    let Some(mut next_rfile) = batch.pop_front() else {
        return Ok(());
    };

    let (start_location, end_location) = {
        let last = batch.back().unwrap_or(&next_rfile);
        (next_rfile.start_offset as usize, (last.end_offset_inclusive as usize))
    };

    let before = Instant::now();
    let range = start_location..=end_location;
    let stream = remote.fetch_bytes_stream(range.clone()).await.map_err(Error::Remote)?;
    let elapsed = before.elapsed();
    tracing::trace!(?elapsed, ?range, ?batchsize, "fetched batch metadata from zip");

    use futures_lite::StreamExt;
    let mut stream = Box::pin(stream);

    let mut storage = BytesMut::with_capacity(end_location.saturating_add(1) - start_location);

    // We use a state machine here, because sometimes we have enough bytes to parse
    // the LocalFileHeaderFixed, but don't know it it's enough for the whole file
    #[derive(Debug)]
    enum State {
        None,
        LocalHeaderFixed(LocalFileHeaderFixed),
        LocalHeader(LocalFileHeader, usize),
        Discard(usize),
    }

    let mut state = State::None;

    while let Some(chunk) = stream.try_next().await.map_err(Error::Remote)? {
        bytes_downloaded.fetch_add(chunk.len(), Ordering::SeqCst);
        storage.put(chunk);

        'moredata: loop {
            match state {
                State::Discard(to_be_discarded) => {
                    if storage.len() >= to_be_discarded {
                        let _ = storage.split_to(to_be_discarded);
                        state = State::None;
                    } else {
                        break 'moredata;
                    }
                },
                State::None => {
                    if let Ok(header) = LocalFileHeaderFixed::from_buf(&mut storage) {
                        if !header.is_valid_signature() {
                            return Err(Error::InvalidLocalHeaderSignature {
                                file_name: next_rfile.file_name,
                            });
                        };
                        state = State::LocalHeaderFixed(header);
                    } else {
                        break 'moredata;
                    }
                },
                State::LocalHeaderFixed(header) => {
                    let storage_before = storage.len();
                    match LocalFileHeader::from_buf_fixed(&mut storage, header) {
                        Ok(header) => {
                            state = State::LocalHeader(
                                header,
                                LocalFileHeaderFixed::SIZE_IN_BYTES + storage_before.saturating_sub(storage.len()),
                            );
                        },
                        Err((_, header)) => {
                            state = State::LocalHeaderFixed(header);
                            break 'moredata;
                        },
                    }
                },
                State::LocalHeader(header, bytes_read) => {
                    // now that we finally found the header, verify size
                    if storage.len() >= next_rfile.compressed_size as usize {
                        let data = storage.split_to(next_rfile.compressed_size as usize);
                        let bytes_read = bytes_read + next_rfile.compressed_size as usize;
                        let current_start_offset = next_rfile.start_offset as usize;

                        // this only throws when the statemachine is droped early, then there is noone
                        // to report an error to anyways...
                        let _ = files_tx.send((data, next_rfile));

                        if let Some(next) = batch.pop_front() {
                            next_rfile = next;
                        } else {
                            return Ok(());
                        }

                        // mark junk for cleanup (we cannot drop it here, because we dont know if the
                        // FULL junk is fetched already)
                        let to_be_discarded = (next_rfile.start_offset as usize)
                            .checked_sub(current_start_offset + bytes_read)
                            .ok_or(Error::OverlappingBytesForDifferentFiles)?;
                        if to_be_discarded == 0 {
                            state = State::None;
                        } else {
                            state = State::Discard(to_be_discarded);
                        }
                    } else {
                        state = State::LocalHeader(header, bytes_read);
                        break 'moredata;
                    }
                },
            }
        }
    }

    // when we reach this, the `next_rfile` hasn't been processed. so we got a bug
    Err(Error::InsufficientDownloadRange {
        file_name: next_rfile.file_name,
        storage_size: storage.len() as u64,
        expected_compressed_size: next_rfile.compressed_size as u64,
    })
}

/// Cancel Safety: This function is (obviously) NOT cancel safe. Aside from it
/// owning file bytes, cancelling async file system writing methods may leave a
/// file in a partially written state. Care should be taken to avoid
/// cancellation (which might happen due to e.g. dropping the associated JoinSet
/// without awaiting all tasks first) in order to prevent this.
pub(super) async fn unzip_file<R, F>(
    mut compressed: BytesMut,
    rfile: ProcessedRemoteFileInfo,
    file_system: F,
    runtime_handle: Option<tokio::runtime::Handle>,
    bytes_unzipped: Arc<AtomicUsize>,
) -> UnzipResult<R, F>
where
    R: RemoteZip + 'static,
    F: FileSystem + Clone + Send + 'static,
{
    let file_name = rfile.file_name.clone();
    if compressed.len() != rfile.compressed_size as usize {
        return Err(Error::WrongBytesLength {
            file_name,
            bytes_cnt: compressed.len() as u64,
            expected_compressed_size: rfile.compressed_size as u64,
        });
    }

    let info = FileInfo {
        local_unix_path: file_name.clone(),
        crc32: rfile.crc32,
    };
    let file_system2 = file_system.clone();
    let future = Box::pin(async move { file_system2.prepare_store_file(info).await });
    let prepared = match runtime_handle {
        Some(rt) => rt.spawn(future),
        None => tokio::spawn(future),
    };

    let file_data = match rfile.compression_method {
        SupportedCompressionMethod::Stored => compressed.copy_to_bytes(rfile.compressed_size as usize),
        #[cfg(feature = "deflate")]
        SupportedCompressionMethod::Deflated => {
            use flate2::read::DeflateDecoder;
            use std::io::Read;

            let mut deflate_reader = DeflateDecoder::new(compressed.reader());
            let mut decompressed = Vec::with_capacity(rfile.compressed_size as usize);
            deflate_reader
                .read_to_end(&mut decompressed)
                .map_err(|_| Error::CompressionError)?;
            bytes::Bytes::copy_from_slice(&decompressed)
        },
    };

    let hash = crc32fast::hash(&file_data);
    if hash != rfile.crc32 {
        return Err(Error::InvalidHash {
            remote:     rfile.crc32,
            calculated: hash,
        });
    }

    let prepared = prepared
        .await
        .map_err(|_| Error::JoinError)?
        .map_err(Error::FileSystem)?;

    let data_len = file_data.len();
    file_system
        .store_file(prepared, file_data)
        .await
        .map_err(Error::FileSystem)?;

    bytes_unzipped.fetch_add(data_len, Ordering::SeqCst);
    Ok(None)
}

#[cfg(test)]
mod tests {

    use super::*;
    use crate::RemoteFileInfo;
    use bytes::Bytes;
    use std::ops::RangeInclusive;

    #[test]
    fn test_empty_inputs() {
        const ZIPFILE: &[u8] = include_bytes!("../../../tests/testfiles/example1.zip");

        pub(crate) struct DummyRemoteZip {}
        impl RemoteZip for DummyRemoteZip {
            type Error = ();

            async fn fetch_remote_file_info(&self) -> Result<Vec<RemoteFileInfo>, ()> { Ok(vec![]) }

            async fn fetch_bytes_stream(
                &self,
                range: RangeInclusive<usize>,
            ) -> Result<impl futures_lite::Stream<Item = Result<Bytes, ()>> + std::marker::Send, ()> {
                assert_eq!(range, 68..=507);
                Ok(futures_lite::stream::once(Ok(Bytes::from_static(ZIPFILE).slice(range))))
            }
        }

        // those are actual positions of data in the zip, because we verify the headers
        let batch: Batch = ([
            ProcessedRemoteFileInfo {
                start_offset: 68,
                end_offset_inclusive: 252,
                compressed_size: 185,
                uncompressed_size: 10000,
                compression_method: SupportedCompressionMethod::Deflated,
                crc32: 0,
                file_name: "foo".to_string(),
            },
            ProcessedRemoteFileInfo {
                start_offset: 333,
                end_offset_inclusive: 507,
                compressed_size: 106,
                uncompressed_size: 10000,
                compression_method: SupportedCompressionMethod::Deflated,
                crc32: 0,
                file_name: "bat".to_string(),
            },
        ])
        .into_iter()
        .collect();

        let bytes_downloaded: Arc<AtomicUsize> = Default::default();
        let (downloaded_tx, mut downloaded_rx) = mpsc::unbounded_channel();

        let rt = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .unwrap();

        rt.block_on(download_batch::<DummyRemoteZip, crate::tokio::TokioLocalStorage>(
            batch,
            DummyRemoteZip {},
            bytes_downloaded.clone(),
            downloaded_tx,
        ))
        .unwrap();

        let first_download_result = downloaded_rx
            .blocking_recv()
            .expect("download_batch didn't produce a first output");
        assert_eq!(first_download_result.1.start_offset, 68);
        assert_eq!(first_download_result.0.len(), 185);

        let second_download_result = downloaded_rx
            .blocking_recv()
            .expect("download_batch didn't produce a second output");
        assert_eq!(second_download_result.1.start_offset, 333);
        assert_eq!(second_download_result.0.len(), 106);

        assert!(downloaded_rx.is_empty());

        assert_eq!(bytes_downloaded.load(Ordering::SeqCst), 440);
    }
}