use crate::{Result, ZiPatchError};
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::path::Path;
pub trait PatchSource {
fn read(&mut self, patch: u32, offset: u64, dst: &mut [u8]) -> Result<()>;
}
#[derive(Debug)]
pub struct FilePatchSource {
files: Vec<File>,
}
impl FilePatchSource {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path)?;
Ok(Self { files: vec![file] })
}
pub fn open_chain<I, P>(paths: I) -> Result<Self>
where
I: IntoIterator<Item = P>,
P: AsRef<Path>,
{
let iter = paths.into_iter();
let mut files = Vec::with_capacity(iter.size_hint().0);
for p in iter {
files.push(File::open(p).map_err(ZiPatchError::Io)?);
}
Ok(Self { files })
}
#[must_use]
pub fn from_file(file: File) -> Self {
Self { files: vec![file] }
}
#[must_use]
pub fn patch_count(&self) -> usize {
self.files.len()
}
}
impl PatchSource for FilePatchSource {
fn read(&mut self, patch: u32, offset: u64, dst: &mut [u8]) -> Result<()> {
let count = self.files.len();
let file = self
.files
.get_mut(patch as usize)
.ok_or(ZiPatchError::PatchIndexOutOfRange { patch, count })?;
file.seek(SeekFrom::Start(offset))?;
match file.read_exact(dst) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
Err(ZiPatchError::PatchSourceTooShort {
offset,
requested: dst.len(),
})
}
Err(e) => Err(ZiPatchError::Io(e)),
}
}
}
#[cfg(any(test, feature = "test-utils"))]
#[derive(Debug, Clone)]
pub struct MemoryPatchSource {
bufs: Vec<std::sync::Arc<[u8]>>,
}
#[cfg(any(test, feature = "test-utils"))]
impl MemoryPatchSource {
#[must_use]
pub fn new(buf: Vec<u8>) -> Self {
Self {
bufs: vec![buf.into()],
}
}
#[must_use]
pub fn from_slice(buf: &[u8]) -> Self {
Self {
bufs: vec![Vec::from(buf).into()],
}
}
#[must_use]
pub fn new_chain(bufs: Vec<Vec<u8>>) -> Self {
Self {
bufs: bufs.into_iter().map(Into::into).collect(),
}
}
#[must_use]
pub fn from_slices(bufs: &[&[u8]]) -> Self {
Self {
bufs: bufs.iter().map(|b| Vec::from(*b).into()).collect(),
}
}
#[must_use]
pub fn patch_count(&self) -> usize {
self.bufs.len()
}
}
#[cfg(any(test, feature = "test-utils"))]
impl PatchSource for MemoryPatchSource {
fn read(&mut self, patch: u32, offset: u64, dst: &mut [u8]) -> Result<()> {
let count = self.bufs.len();
let buf = self
.bufs
.get(patch as usize)
.ok_or(ZiPatchError::PatchIndexOutOfRange { patch, count })?;
let start = usize::try_from(offset).map_err(|_| ZiPatchError::PatchSourceTooShort {
offset,
requested: dst.len(),
})?;
let end = start
.checked_add(dst.len())
.ok_or(ZiPatchError::PatchSourceTooShort {
offset,
requested: dst.len(),
})?;
if end > buf.len() {
return Err(ZiPatchError::PatchSourceTooShort {
offset,
requested: dst.len(),
});
}
dst.copy_from_slice(&buf[start..end]);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_source_round_trips_arbitrary_ranges() {
let bytes: Vec<u8> = (0..=255u8).collect();
let mut src = MemoryPatchSource::new(bytes.clone());
let mut head = [0u8; 16];
src.read(0, 0, &mut head).unwrap();
assert_eq!(&head, &bytes[..16]);
let mut mid = [0u8; 32];
src.read(0, 100, &mut mid).unwrap();
assert_eq!(&mid, &bytes[100..132]);
let mut tail = [0u8; 16];
src.read(0, 240, &mut tail).unwrap();
assert_eq!(&tail, &bytes[240..256]);
let mut empty = [0u8; 0];
src.read(0, 0, &mut empty).unwrap();
src.read(0, 256, &mut empty).unwrap();
}
#[test]
fn memory_source_out_of_range_returns_too_short() {
let mut src = MemoryPatchSource::new(vec![0u8; 16]);
let mut buf = [0u8; 4];
let err = src
.read(0, 15, &mut buf)
.expect_err("read past end must fail");
match err {
ZiPatchError::PatchSourceTooShort { offset, requested } => {
assert_eq!(offset, 15);
assert_eq!(requested, 4);
}
other => panic!("expected PatchSourceTooShort, got {other:?}"),
}
let err = src
.read(0, 1_000_000, &mut buf)
.expect_err("read far past end must fail");
assert!(matches!(err, ZiPatchError::PatchSourceTooShort { .. }));
}
#[test]
fn memory_source_chain_indexes_each_patch() {
let p0: Vec<u8> = (0..16u8).map(|i| 0xA0 | i).collect();
let p1: Vec<u8> = (0..16u8).map(|i| 0xB0 | i).collect();
let mut src = MemoryPatchSource::new_chain(vec![p0.clone(), p1.clone()]);
let mut buf = [0u8; 4];
src.read(0, 0, &mut buf).unwrap();
assert_eq!(&buf, &p0[..4]);
src.read(1, 0, &mut buf).unwrap();
assert_eq!(&buf, &p1[..4]);
src.read(0, 12, &mut buf).unwrap();
assert_eq!(&buf, &p0[12..16]);
}
#[test]
fn memory_source_chain_rejects_out_of_range_patch() {
let mut src = MemoryPatchSource::new_chain(vec![vec![0u8; 16]]);
let mut buf = [0u8; 4];
let err = src
.read(1, 0, &mut buf)
.expect_err("patch 1 must be out of range");
match err {
ZiPatchError::PatchIndexOutOfRange { patch, count } => {
assert_eq!(patch, 1);
assert_eq!(count, 1);
}
other => panic!("expected PatchIndexOutOfRange, got {other:?}"),
}
}
#[test]
fn file_source_round_trips_arbitrary_ranges() {
let bytes: Vec<u8> = (0..=255u8).collect();
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("source.bin");
std::fs::write(&path, &bytes).unwrap();
let mut src = FilePatchSource::open(&path).unwrap();
let mut head = [0u8; 16];
src.read(0, 0, &mut head).unwrap();
assert_eq!(&head, &bytes[..16]);
let mut mid = [0u8; 32];
src.read(0, 100, &mut mid).unwrap();
assert_eq!(&mid, &bytes[100..132]);
}
#[test]
fn file_source_short_returns_too_short() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("source.bin");
std::fs::write(&path, [0u8; 16]).unwrap();
let mut src = FilePatchSource::open(&path).unwrap();
let mut buf = [0u8; 32];
let err = src
.read(0, 0, &mut buf)
.expect_err("read past end must fail");
assert!(matches!(err, ZiPatchError::PatchSourceTooShort { .. }));
}
#[test]
fn file_source_chain_indexes_each_file() {
let tmp = tempfile::tempdir().unwrap();
let p0 = tmp.path().join("p0.bin");
let p1 = tmp.path().join("p1.bin");
std::fs::write(&p0, b"AAAAAAAA").unwrap();
std::fs::write(&p1, b"BBBBBBBB").unwrap();
let mut src = FilePatchSource::open_chain([&p0, &p1]).unwrap();
assert_eq!(src.patch_count(), 2);
let mut buf = [0u8; 4];
src.read(0, 0, &mut buf).unwrap();
assert_eq!(&buf, b"AAAA");
src.read(1, 4, &mut buf).unwrap();
assert_eq!(&buf, b"BBBB");
}
}