mkups 0.1.0

Toolkit for creating, applying, and inspecting .ups patches
Documentation
//! Functions for applying UPS patches. These functions do not verify the checksums included in the
//! patch.
use std::io::prelude::*;
use nom::Finish;

use crate::{
    data::*,
    parse,
};

/// Errors that can occur when applying a patch.
#[derive(Debug)]
pub enum ApplyError {
    /// I/O error.
    Io(std::io::Error),
    /// Malformed patch file
    Parse(String),
    /// [Direction](crate::apply::Direction) was set to Auto, but the input file length does not
    /// match one of the lengths in the UPS header
    DirectionMustBeSpecified,
}

impl From<std::io::Error> for ApplyError {
    fn from(err: std::io::Error) -> Self {
        ApplyError::Io(err)
    }
}

impl From<parse::Error<'_>> for ApplyError {
    fn from(err: parse::Error) -> Self {
        ApplyError::Parse(parse::format_error(err))
    }
}

/// In which direction the patch should be applied.
#[derive(Debug)]
pub enum Direction {
    /// Auto-detect direction. In most cases, this is what you want.
    Auto,
    /// Forward-apply the patch.
    Forward,
    /// Reverse-apply the patch.
    Reverse,
}

/// Patches the input file (which will not be modified) using an UPS patch, and writes the result
/// to the specified writer. Checksums are not verified! The UPS patch must include both the header
/// and the trailer.
pub fn patch_file(input: &mut std::fs::File, writer: &mut impl Write, patch: &[u8], dir: Direction) -> Result<(), ApplyError> {
    let input_size = input.metadata()?.len();
    let (patch, header) = parse::header(patch).finish()?;
    let output_size = get_output_len(input_size.try_into().unwrap(), &header, dir)
        .ok_or(ApplyError::DirectionMustBeSpecified)?;
    let end = patch.len() - Trailer::LENGTH;
    apply_hunks(input, writer, &patch[0..end], output_size)
}

/// Calculates the output length after application of a patch.
pub fn get_output_len(input_size: usize, header: &Header, dir: Direction) -> Option<usize> {
    match dir {
        Direction::Forward => Some(header.dst_len),
        Direction::Reverse => Some(header.src_len),
        Direction::Auto => {
            if header.src_len == header.dst_len ||
                input_size == header.src_len {
                    Some(header.dst_len)
            } else if input_size == header.dst_len {
                Some(header.src_len)
            } else {
                None
            }
        }
    }
}

/// Read data form reader, and write it to writer while appying an UPS patch. The passed `hunks`
/// slice *must not* contain the UPS header and trailer!
pub fn apply_hunks(reader: &mut impl Read,
                   writer: &mut impl Write,
                   hunks: &[u8],
                   output_size: usize) -> Result<(), ApplyError> {
    let mut hunks = hunks;
    let mut written = 0usize;
    let mut writer = TruncatingWrite::new(writer, output_size);
    while !hunks.is_empty() {
        let hunk: Hunk;
        (hunks, hunk) = parse::hunk(hunks).finish()?;
        written += apply_hunk(reader, &mut writer, &hunk)?;
        written += copy(reader, &mut writer, &mut [0], 1)?;
    }

    if written < output_size {
        let mut buf_vec = vec![0; 4096];
        let buf = buf_vec.as_mut_slice();
        copy(reader, &mut writer, buf, output_size - written)?;
    }
    std::io::copy(reader, &mut writer)?;
    Ok(())
}

/// Low level method for applying a single hunk. You probably want to use
/// [apply_hunks](apply_hunks) or [patch_file](patch_file) instead.
pub fn apply_hunk(reader: &mut impl Read, writer: &mut impl Write, hunk: &Hunk) -> Result<usize, ApplyError> {
    let mut buf_vec = vec![0; 4096];
    let buf = buf_vec.as_mut_slice();
    let mut written = 0usize;

    // skip
    written += copy(reader, writer, buf, hunk.skip)?;

    // patch
    let mut patch = hunk.xor.len();
    let mut done = 0;
    while patch > 0 {
        let n = read_nonzero(reader, buf, patch)?;
        #[allow(clippy::needless_range_loop)]
        for i in 0..n {
            buf[i] ^= hunk.xor[i+done];
        }
        writer.write_all(&buf[0..n])?;
        patch -= n;
        done += n;
    }
    written += done;
    Ok(written)
}

fn copy(reader: &mut impl Read, writer: &mut impl Write, buf: &mut [u8], skip: usize) -> Result<usize, std::io::Error> {
    let mut skip = skip;
    while skip > 0 {
        let n = read_nonzero(reader, buf, skip)?;
        writer.write_all(&buf[0..n])?;
        skip -= n;
    }
    Ok(skip)
}

fn read_nonzero(reader: &mut impl Read, buf: &mut [u8], n: usize) -> Result<usize, std::io::Error> {
    let n = n.min(buf.len());
    let buf = &mut buf[0..n];
    let read = reader.read(buf)?;
    if read == 0 {
        buf.fill(0);
        return Ok(n);
    }
    Ok(read)
}

struct TruncatingWrite<W: Write> {
    w: W,
    remaining: usize,
}

impl<W: Write> TruncatingWrite<W> {
    pub fn new(w: W, remaining: usize) -> Self {
        TruncatingWrite { w, remaining }
    }

    #[allow(dead_code)]
    pub fn into_inner(self) -> W {
        self.w
    }
}

impl<W: Write> Write for TruncatingWrite<W> {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        let hidden = if buf.len() > self.remaining { buf.len() - self.remaining } else { 0 };
        let to_write = buf.len() - hidden;
        let written = self.w.write(&buf[0..to_write])?;
        self.remaining -= written;
        Ok(written + hidden)
    }

    fn flush(&mut self) -> std::io::Result<()> {
        self.w.flush()
    }
}

#[test]
pub fn test_get_output_len() {
    assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 10 }, Direction::Auto), Some(10));
    assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 10 }, Direction::Forward), Some(10));
    assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 10 }, Direction::Reverse), Some(10));
    assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 20 }, Direction::Auto), Some(20));
    assert_eq!(get_output_len(10, &Header { src_len: 10, dst_len: 20 }, Direction::Forward), Some(20));
    assert_eq!(get_output_len(20, &Header { src_len: 10, dst_len: 20 }, Direction::Auto), Some(10));
    assert_eq!(get_output_len(20, &Header { src_len: 10, dst_len: 20 }, Direction::Reverse), Some(10));
    assert_eq!(get_output_len(15, &Header { src_len: 10, dst_len: 20 }, Direction::Auto), None);
}

#[test]
pub fn test_truncating_write() {
    let buf = Vec::new();
    let cursor = std::io::Cursor::new(buf);
    let mut write = TruncatingWrite::new(cursor, 16);
    assert_eq!(write.write(b"Hello, World!").ok(), Some(13));
    assert_eq!(write.write(b"Foo bar").ok(), Some(7));
    assert_eq!(write.write(b"baz").ok(), Some(3));
    assert_eq!(write.into_inner().into_inner(), b"Hello, World!Foo");
}

#[test]
pub fn test_apply_hunk_inmem() {
    test(&b"foo"[..], Hunk { skip: 1, xor: &b" "[..] }, b"fO");
    test(&b"f"[..], Hunk { skip: 1, xor: &b"ooBar"[..] }, b"fooBar");

    fn test(input: &[u8], hunk: Hunk, expected: &[u8]) {
        assert_eq!(apply_hunk_inmem(input, &hunk).ok(), Some(expected.to_vec()));
    }
}

#[cfg(test)]
fn apply_hunk_inmem(buf: &[u8], hunk: &Hunk) -> Result<Vec<u8>, ApplyError> {
    let mut rd_cursor = std::io::Cursor::new(buf);
    let mut wr_cursor = std::io::Cursor::new(Vec::new());
    apply_hunk(&mut rd_cursor, &mut wr_cursor, hunk)?;
    Ok(wr_cursor.into_inner())
}