use std::ffi::OsStr;
use std::io::{self, Read, Write};
use std::os::unix::ffi::OsStrExt;
use std::path::{Component, Path};
use blake3::{Hash, Hasher};
use pkgar_core::{Entry, PackageSrc};
use crate::{Error, ErrorKind, ResultExt};
pub trait EntryExt {
fn check_path(&self) -> Result<&Path, Error>;
fn verify(&self, blake3: Hash, size: u64) -> Result<(), Error>;
}
impl EntryExt for Entry {
fn check_path(&self) -> Result<&Path, Error> {
let path = Path::new(OsStr::from_bytes(self.path_bytes()));
for component in path.components() {
match component {
Component::Normal(_) => {}
invalid => {
let bad_component: &Path = invalid.as_ref();
return Err(Error::from_kind(ErrorKind::InvalidPathComponent(
bad_component.to_path_buf(),
)))
.chain_err(|| ErrorKind::Entry(*self));
}
}
}
Ok(&path)
}
fn verify(&self, blake3: Hash, size: u64) -> Result<(), Error> {
if size != self.size() {
Err(Error::from_kind(ErrorKind::LengthMismatch(
size,
self.size(),
)))
.chain_err(|| ErrorKind::Entry(*self))
} else if blake3 != self.blake3() {
Err(pkgar_core::Error::InvalidBlake3.into())
} else {
Ok(())
}
}
}
pub trait PackageSrcExt
where
Self: PackageSrc + Sized,
{
fn path(&self) -> &Path;
fn entry_reader(&mut self, entry: Entry) -> EntryReader<'_, Self> {
EntryReader {
src: self,
entry,
pos: 0,
}
}
}
pub struct EntryReader<'a, Src>
where
Src: PackageSrc,
{
src: &'a mut Src,
entry: Entry,
pos: usize,
}
impl<Src, E> Read for EntryReader<'_, Src>
where
Src: PackageSrc<Err = E>,
E: From<pkgar_core::Error> + std::error::Error,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let count = self
.src
.read_entry(self.entry, self.pos, buf)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string()))?;
self.pos += count;
Ok(count)
}
}
pub(crate) fn copy_and_hash<R: Read, W: Write>(
mut read: R,
mut write: W,
buf: &mut [u8],
) -> Result<(u64, Hash), io::Error> {
let mut hasher = Hasher::new();
let mut written = 0;
loop {
let count = read.read(buf)?;
if count == 0 {
break;
}
written += count as u64;
hasher.update_with_join::<blake3::join::RayonJoin>(&buf[..count]);
write.write_all(&buf[..count])?;
}
Ok((written, hasher.finalize()))
}