use std::io::prelude::*;
use nom::Finish;
use crate::{
data::*,
parse,
};
#[derive(Debug)]
pub enum ApplyError {
Io(std::io::Error),
Parse(String),
DirectionMustBeSpecified,
}
impl From<std::io::Error> for ApplyError {
fn from(err: std::io::Error) -> Self {
ApplyError::Io(err)
}
}
impl From<parse::Error<'_>> for ApplyError {
fn from(err: parse::Error) -> Self {
ApplyError::Parse(parse::format_error(err))
}
}
#[derive(Debug)]
pub enum Direction {
Auto,
Forward,
Reverse,
}
pub fn patch_file(input: &mut std::fs::File, writer: &mut impl Write, patch: &[u8], dir: Direction) -> Result<(), ApplyError> {
let input_size = input.metadata()?.len();
let (patch, header) = parse::header(patch).finish()?;
let output_size = get_output_len(input_size.try_into().unwrap(), &header, dir)
.ok_or(ApplyError::DirectionMustBeSpecified)?;
let end = patch.len() - Trailer::LENGTH;
apply_hunks(input, writer, &patch[0..end], output_size)
}
pub fn get_output_len(input_size: usize, header: &Header, dir: Direction) -> Option<usize> {
match dir {
Direction::Forward => Some(header.dst_len),
Direction::Reverse => Some(header.src_len),
Direction::Auto => {
if header.src_len == header.dst_len ||
input_size == header.src_len {
Some(header.dst_len)
} else if input_size == header.dst_len {
Some(header.src_len)
} else {
None
}
}
}
}
pub fn apply_hunks(reader: &mut impl Read,
writer: &mut impl Write,
hunks: &[u8],
output_size: usize) -> Result<(), ApplyError> {
let mut hunks = hunks;
let mut written = 0usize;
let mut writer = TruncatingWrite::new(writer, output_size);
while !hunks.is_empty() {
let hunk: Hunk;
(hunks, hunk) = parse::hunk(hunks).finish()?;
written += apply_hunk(reader, &mut writer, &hunk)?;
written += copy(reader, &mut writer, &mut [0], 1)?;
}
if written < output_size {
let mut buf_vec = vec![0; 4096];
let buf = buf_vec.as_mut_slice();
copy(reader, &mut writer, buf, output_size - written)?;
}
std::io::copy(reader, &mut writer)?;
Ok(())
}
pub fn apply_hunk(reader: &mut impl Read, writer: &mut impl Write, hunk: &Hunk) -> Result<usize, ApplyError> {
let mut buf_vec = vec![0; 4096];
let buf = buf_vec.as_mut_slice();
let mut written = 0usize;
written += copy(reader, writer, buf, hunk.skip)?;
let mut patch = hunk.xor.len();
let mut done = 0;
while patch > 0 {
let n = read_nonzero(reader, buf, patch)?;
#[allow(clippy::needless_range_loop)]
for i in 0..n {
buf[i] ^= hunk.xor[i+done];
}
writer.write_all(&buf[0..n])?;
patch -= n;
done += n;
}
written += done;
Ok(written)
}
fn copy(reader: &mut impl Read, writer: &mut impl Write, buf: &mut [u8], skip: usize) -> Result<usize, std::io::Error> {
let mut skip = skip;
while skip > 0 {
let n = read_nonzero(reader, buf, skip)?;
writer.write_all(&buf[0..n])?;
skip -= n;
}
Ok(skip)
}
fn read_nonzero(reader: &mut impl Read, buf: &mut [u8], n: usize) -> Result<usize, std::io::Error> {
let n = n.min(buf.len());
let buf = &mut buf[0..n];
let read = reader.read(buf)?;
if read == 0 {
buf.fill(0);
return Ok(n);
}
Ok(read)
}
struct TruncatingWrite<W: Write> {
w: W,
remaining: usize,
}
impl<W: Write> TruncatingWrite<W> {
pub fn new(w: W, remaining: usize) -> Self {
TruncatingWrite { w, remaining }
}
#[allow(dead_code)]
pub fn into_inner(self) -> W {
self.w
}
}
impl<W: Write> Write for TruncatingWrite<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let hidden = if buf.len() > self.remaining { buf.len() - self.remaining } else { 0 };
let to_write = buf.len() - hidden;
let written = self.w.write(&buf[0..to_write])?;
self.remaining -= written;
Ok(written + hidden)
}
fn flush(&mut self) -> std::io::Result<()> {
self.w.flush()
}
}
#[test]
pub fn test_get_output_len() {
assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 10 }, Direction::Auto), Some(10));
assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 10 }, Direction::Forward), Some(10));
assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 10 }, Direction::Reverse), Some(10));
assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 20 }, Direction::Auto), Some(20));
assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 20 }, Direction::Forward), Some(20));
assert_eq!(get_output_len(20, &Header { src_len: 10, dst_len: 20 }, Direction::Auto), Some(10));
assert_eq!(get_output_len(20, &Header { src_len: 10, dst_len: 20 }, Direction::Reverse), Some(10));
assert_eq!(get_output_len(15, &Header { src_len: 10, dst_len: 20 }, Direction::Auto), None);
}
#[test]
pub fn test_truncating_write() {
let buf = Vec::new();
let cursor = std::io::Cursor::new(buf);
let mut write = TruncatingWrite::new(cursor, 16);
assert_eq!(write.write(b"Hello, World!").ok(), Some(13));
assert_eq!(write.write(b"Foo bar").ok(), Some(7));
assert_eq!(write.write(b"baz").ok(), Some(3));
assert_eq!(write.into_inner().into_inner(), b"Hello, World!Foo");
}
#[test]
pub fn test_apply_hunk_inmem() {
test(&b"foo"[..], Hunk { skip: 1, xor: &b" "[..] }, b"fO");
test(&b"f"[..], Hunk { skip: 1, xor: &b"ooBar"[..] }, b"fooBar");
fn test(input: &[u8], hunk: Hunk, expected: &[u8]) {
assert_eq!(apply_hunk_inmem(input, &hunk).ok(), Some(expected.to_vec()));
}
}
#[cfg(test)]
fn apply_hunk_inmem(buf: &[u8], hunk: &Hunk) -> Result<Vec<u8>, ApplyError> {
let mut rd_cursor = std::io::Cursor::new(buf);
let mut wr_cursor = std::io::Cursor::new(Vec::new());
apply_hunk(&mut rd_cursor, &mut wr_cursor, hunk)?;
Ok(wr_cursor.into_inner())
}