use std::{
cmp,
error::Error,
fmt::{self, Display, Formatter},
io::{self, BufRead, BufReader, ErrorKind, Read, Seek, SeekFrom, Write},
};
use byteorder::{LittleEndian, ReadBytesExt};
use integer_encoding::VarIntReader;
use zstd::Decoder;
use crate::header::{MAGIC, VERSION_MAJOR};
const DEFAULT_BUF_SIZE: usize = 8192;
pub struct Patcher<'a, O, B>
where
O: Read + Seek,
B: BufRead,
{
old: O,
patch: Decoder<'a, B>,
state: PatcherState,
buf: Vec<u8>,
metadata: PatchMetadata,
}
enum PatcherState {
AtNextControl,
Add(usize),
Copy(usize),
}
impl<'a, O, B> Patcher<'a, O, B>
where
O: Read + Seek,
B: BufRead,
{
pub fn with_buffer(old: O, mut patch: B) -> Result<Self, PatchError> {
let metadata = read_header(&mut patch)?;
let patch_decoder = Decoder::with_buffer(patch)?;
Ok(Self {
old,
patch: patch_decoder,
state: PatcherState::AtNextControl,
buf: vec![0; DEFAULT_BUF_SIZE],
metadata,
})
}
pub fn metadata(&self) -> &PatchMetadata {
&self.metadata
}
}
impl<'a, O, P> Patcher<'a, O, BufReader<P>>
where
O: Read + Seek,
P: Read,
{
pub fn new(old: O, mut patch: P) -> Result<Self, PatchError> {
let metadata = read_header(&mut patch)?;
let patch_decoder = Decoder::new(patch)?;
Ok(Self {
old,
patch: patch_decoder,
state: PatcherState::AtNextControl,
buf: vec![0; DEFAULT_BUF_SIZE],
metadata,
})
}
}
impl<'a, O, B> Read for Patcher<'a, O, B>
where
O: Read + Seek,
B: BufRead,
{
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let mut read_total = 0;
while !buf.is_empty() {
let read = match self.state {
PatcherState::AtNextControl => {
match self.patch.read_varint() {
Ok(add_len) => {
self.state = PatcherState::Add(add_len);
0
}
Err(e) => match e.kind() {
ErrorKind::UnexpectedEof => break,
_ => return Err(e),
},
}
}
PatcherState::Add(add_len) => {
let max_read_len = cmp::min(cmp::min(add_len, buf.len()), self.buf.len());
let out = &mut buf[..max_read_len];
self.old.read_exact(out)?;
let diff = &mut self.buf[..max_read_len];
self.patch.read_exact(diff)?;
(0..max_read_len).for_each(|i| out[i] = out[i].wrapping_add(diff[i]));
if add_len == max_read_len {
let copy_len = self.patch.read_varint()?;
self.state = PatcherState::Copy(copy_len);
} else {
self.state = PatcherState::Add(add_len - max_read_len);
}
max_read_len
}
PatcherState::Copy(copy_len) => {
let max_read_len = cmp::min(copy_len, buf.len());
let out = &mut buf[..max_read_len];
self.patch.read_exact(out)?;
if copy_len == max_read_len {
let seek = self.patch.read_varint()?;
self.old.seek(SeekFrom::Current(seek))?;
self.state = PatcherState::AtNextControl;
} else {
self.state = PatcherState::Copy(copy_len - max_read_len);
}
max_read_len
}
};
read_total += read;
buf = &mut buf[read..];
}
Ok(read_total)
}
}
#[derive(Debug)]
pub enum PatchError {
Io(io::Error),
BadMagic(u32),
UnsupportedVersion(u16),
}
impl Display for PatchError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
PatchError::Io(e) => write!(f, "I/O error: {e}"),
PatchError::BadMagic(magic) => {
write!(f, "bad magic: expected {MAGIC:x}, found {magic:x}")
}
PatchError::UnsupportedVersion(version) => {
write!(
f,
"unsupported version: found version {version}.x, \
supported versions are {VERSION_MAJOR}.x",
)
}
}
}
}
impl Error for PatchError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
PatchError::Io(e) => e.source(),
_ => None,
}
}
}
impl From<io::Error> for PatchError {
fn from(value: io::Error) -> Self {
PatchError::Io(value)
}
}
impl From<TryFromValueError> for PatchError {
fn from(value: TryFromValueError) -> Self {
PatchError::UnsupportedVersion(value.0)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct PatchMetadata {
version: PatchVersion,
}
impl PatchMetadata {
fn new(version: PatchVersion) -> Self {
Self { version }
}
pub fn version(&self) -> PatchVersion {
self.version
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct PatchVersion {
major: MajorVersion,
minor: u16,
}
impl PatchVersion {
fn from_values(major: u16, minor: u16) -> Result<Self, TryFromValueError> {
let major = major.try_into()?;
Ok(Self { major, minor })
}
pub fn major(&self) -> u16 {
self.major.into()
}
pub fn minor(&self) -> u16 {
self.minor
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
enum MajorVersion {
One,
}
impl TryFrom<u16> for MajorVersion {
type Error = TryFromValueError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
1 => Ok(MajorVersion::One),
_ => Err(TryFromValueError(value)),
}
}
}
impl From<MajorVersion> for u16 {
fn from(value: MajorVersion) -> Self {
match value {
MajorVersion::One => 1,
}
}
}
#[derive(Debug)]
struct TryFromValueError(u16);
impl Display for TryFromValueError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "version out of supported range")
}
}
impl Error for TryFromValueError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
pub fn read_header<P>(mut patch: &mut P) -> Result<PatchMetadata, PatchError>
where
P: Read + ?Sized,
{
let magic = patch.read_u32::<LittleEndian>()?;
if magic != MAGIC {
return Err(PatchError::BadMagic(magic));
}
let version_major = patch.read_u16::<LittleEndian>()?;
let version_minor = patch.read_u16::<LittleEndian>()?;
let patch_version = PatchVersion::from_values(version_major, version_minor)?;
let data_offset = patch.read_varint()?;
io::copy(&mut patch.take(data_offset), &mut io::sink())?;
Ok(PatchMetadata::new(patch_version))
}
pub fn patch<O, P, W>(old: O, patch: P, new: &mut W) -> Result<u64, PatchError>
where
O: Read + Seek,
P: Read,
W: Write + ?Sized,
{
let mut patcher = Patcher::new(old, patch)?;
Ok(io::copy(&mut patcher, new)?)
}