use std::{
io::{Read, Seek, Write},
mem::ManuallyDrop,
ops::Coroutine,
pin::Pin,
};
use base::FlattenResult;
use crate::{
base::{
self, DataWriteHandle, Flags, GeneratorRead, GeneratorSeek, GeneratorWrite, PkgState,
ReadSeekRequest, ReadSeekWriteRequest, ReadSeekWriteTruncateRequest, Response,
WriteRequest,
},
errors,
util::{ReadSeekWriteExt, WriteExt},
EntryInfo,
};
pub type CreateError = errors::CreateError<std::io::Error>;
pub type ParseError = errors::ParseError<std::io::Error>;
pub type InsertError = errors::InsertError<std::io::Error>;
pub type OpenError = errors::OpenError<std::io::Error>;
pub type RemoveError = errors::RemoveError<std::io::Error>;
pub type RenameError = errors::RenameError<std::io::Error>;
pub type RepackError = errors::RepackError<std::io::Error>;
pub type ReplaceError = errors::ReplaceError<std::io::Error>;
pub trait Truncate {
fn truncate(&mut self, len: u64) -> std::io::Result<()>;
}
impl Truncate for Vec<u8> {
fn truncate(&mut self, len: u64) -> std::io::Result<()> {
self.resize(len as usize, 0);
Ok(())
}
}
impl Truncate for std::fs::File {
fn truncate(&mut self, len: u64) -> std::io::Result<()> {
self.set_len(len)
}
}
impl<T: Truncate> Truncate for std::io::Cursor<T> {
fn truncate(&mut self, len: u64) -> std::io::Result<()> {
self.get_mut().truncate(len)
}
}
impl<T: Truncate> Truncate for &mut T {
fn truncate(&mut self, len: u64) -> std::io::Result<()> {
(*self).truncate(len)
}
}
impl<T: Truncate> Truncate for Box<T> {
fn truncate(&mut self, len: u64) -> std::io::Result<()> {
self.as_mut().truncate(len)
}
}
struct SyncDriver<S> {
storage: S,
}
impl<S: Read + Seek> SyncDriver<S> {
pub fn new(storage: S) -> Self {
Self { storage }
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.storage
}
fn handle_readseek(&mut self, request: ReadSeekRequest) -> std::io::Result<Response> {
Ok(match request {
ReadSeekRequest::Read(count) => {
let mut buf = vec![0; count as usize];
let read = self.storage.read(&mut buf)?;
buf.truncate(read);
Response::Read(buf)
}
ReadSeekRequest::ReadExact(count) => {
let mut buf = vec![0; count as usize];
self.storage.read_exact(&mut buf)?;
Response::Read(buf)
}
ReadSeekRequest::Seek(offset) => Response::Seek(self.storage.seek(offset.into())?),
})
}
pub fn drive_read<R>(
&mut self,
mut coroutine: impl Coroutine<Response, Return = R, Yield = ReadSeekRequest>,
) -> std::io::Result<R> {
let mut response = Response::None;
loop {
use std::ops::CoroutineState;
match unsafe { Pin::new_unchecked(&mut coroutine) }.resume(response) {
CoroutineState::Yielded(request) => response = self.handle_readseek(request)?,
CoroutineState::Complete(result) => break Ok(result),
}
}
}
}
impl<S: Read + Seek + Write> SyncDriver<S> {
fn handle_write(&mut self, request: WriteRequest) -> std::io::Result<Response> {
Ok(match request {
WriteRequest::WriteAll(ptr, count) => {
self.storage
.write_all(unsafe { core::slice::from_raw_parts(ptr, count) })?;
Response::None
}
WriteRequest::Write(ptr, count) => Response::Written(
self.storage
.write(unsafe { core::slice::from_raw_parts(ptr, count) })?,
),
WriteRequest::Copy { from, count, to } => {
self.storage.copy_within(from, count, to)?;
Response::None
}
WriteRequest::WriteRepeated { value, count } => {
self.storage.fill(value, count)?;
Response::None
}
})
}
pub fn drive_write<R>(
&mut self,
mut coroutine: impl Coroutine<Response, Return = R, Yield = ReadSeekWriteRequest>,
) -> std::io::Result<R> {
let mut response = Response::None;
loop {
use std::ops::CoroutineState;
match unsafe { Pin::new_unchecked(&mut coroutine) }.resume(response) {
CoroutineState::Yielded(ReadSeekWriteRequest::ReadSeek(request)) => {
response = self.handle_readseek(request)?
}
CoroutineState::Yielded(ReadSeekWriteRequest::Write(request)) => {
response = self.handle_write(request)?
}
CoroutineState::Complete(result) => break Ok(result),
}
}
}
}
impl<S: Read + Seek + Write + Truncate> SyncDriver<S> {
pub fn drive_truncate<R>(
&mut self,
mut coroutine: impl Coroutine<Response, Return = R, Yield = ReadSeekWriteTruncateRequest>,
) -> std::io::Result<R> {
let mut response = Response::None;
loop {
use std::ops::CoroutineState;
match unsafe { Pin::new_unchecked(&mut coroutine) }.resume(response) {
CoroutineState::Yielded(ReadSeekWriteTruncateRequest::ReadSeek(request)) => {
response = self.handle_readseek(request)?
}
CoroutineState::Yielded(ReadSeekWriteTruncateRequest::Write(request)) => {
response = self.handle_write(request)?
}
CoroutineState::Yielded(ReadSeekWriteTruncateRequest::Truncate(size)) => {
self.storage.truncate(size)?;
response = Response::None;
}
CoroutineState::Complete(result) => break Ok(result),
}
}
}
}
pub struct Pkg<S: Read + Seek> {
driver: SyncDriver<S>,
state: PkgState,
}
pub struct EntryReader<'a, S: Read + Seek> {
driver: &'a mut SyncDriver<S>,
handle: base::ReadHandle,
}
impl<S: Read + Seek> EntryReader<'_, S> {
pub fn is_seekable(&self) -> bool {
self.handle.is_seekable()
}
pub fn is_compressed(&self) -> bool {
self.handle.is_compressed()
}
}
impl<S: Read + Seek> Read for EntryReader<'_, S> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.driver.drive_read(self.handle.read(buf))
}
}
impl<S: Read + Seek> Seek for EntryReader<'_, S> {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
match &mut self.handle {
base::ReadHandle::Raw(handle) => {
Ok(self.driver.drive_read(handle.seek(pos.into())).flatten()?)
}
base::ReadHandle::Deflate(_) => Err(std::io::Error::new(
std::io::ErrorKind::NotSeekable,
"Cannot seek on compressed entry reader",
)),
}
}
}
impl<S: Read + Seek> Pkg<S> {
pub fn inner(&self) -> &S {
&self.driver.storage
}
pub fn contains(&self, path: &str) -> bool {
self.state.contains(path)
}
pub fn paths(&self) -> impl Iterator<Item = &String> {
self.state.paths()
}
pub fn parse(storage: S) -> Result<Self, ParseError> {
let mut driver = SyncDriver::new(storage);
let state = driver.drive_read(base::parse(true)).flatten()?;
Ok(Self { driver, state })
}
pub fn open(&mut self, path: &str) -> Result<EntryReader<S>, OpenError> {
let handle = self
.driver
.drive_read(base::open(&self.state, path))
.flatten()?;
Ok(EntryReader {
driver: &mut self.driver,
handle,
})
}
pub fn metadata(&self, path: &str) -> Option<EntryInfo> {
self.state.index(path).and_then(|idx| self.state.info(idx))
}
}
pub struct EntryWriter<'a, S: Read + Seek + Write> {
driver: &'a mut SyncDriver<S>,
handle: ManuallyDrop<base::WriteHandle<'a>>,
}
impl<S: Read + Seek + Write> EntryWriter<'_, S> {
pub fn is_seekable(&self) -> bool {
self.handle.is_seekable()
}
pub fn is_compressed(&self) -> bool {
self.handle.is_compressed()
}
pub fn finish(mut self) -> std::io::Result<()> {
let handle = unsafe { ManuallyDrop::take(&mut self.handle) };
self.driver.drive_write(handle.finish())?;
self.driver.get_mut().flush()?;
std::mem::forget(self);
Ok(())
}
}
impl<S: Read + Seek + Write> Write for EntryWriter<'_, S> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self.handle.inner_mut() {
DataWriteHandle::Raw(handle) => Ok(self.driver.drive_write(handle.write(buf))?),
DataWriteHandle::Deflate(handle) => Ok(self.driver.drive_write(handle.write(buf))?),
}
}
fn flush(&mut self) -> std::io::Result<()> {
self.driver.drive_write(self.handle.flush())?;
self.driver.get_mut().flush()
}
}
impl<S: Read + Seek + Write> Read for EntryWriter<'_, S> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self.handle.inner_mut() {
DataWriteHandle::Raw(handle) => self.driver.drive_read(handle.read(buf)),
DataWriteHandle::Deflate(_) => Err(std::io::Error::other(
"Cannot read on compressed entry writer",
)),
}
}
}
impl<S: Read + Seek + Write> Seek for EntryWriter<'_, S> {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
match self.handle.inner_mut() {
DataWriteHandle::Raw(handle) => {
Ok(self.driver.drive_read(handle.seek(pos.into())).flatten()?)
}
DataWriteHandle::Deflate(_) => Err(std::io::Error::new(
std::io::ErrorKind::NotSeekable,
"Cannot seek on compressed entry writer",
)),
}
}
}
impl<S: Read + Seek + Write> Drop for EntryWriter<'_, S> {
fn drop(&mut self) {
let handle = unsafe { ManuallyDrop::take(&mut self.handle) };
_ = self.driver.drive_write(handle.finish());
}
}
impl<S: Read + Seek + Write> Pkg<S> {
pub fn create(storage: S) -> Result<Self, CreateError> {
let mut driver = SyncDriver::new(storage);
let state = driver.drive_write(PkgState::create()).flatten()?;
Ok(Self { driver, state })
}
pub fn remove(&mut self, path: &str) -> Result<(), RemoveError> {
self.driver.drive_write(self.state.remove(path)).flatten()
}
pub fn rename(&mut self, src: &str, dst: String) -> Result<(), RenameError> {
self.driver
.drive_write(self.state.rename(src, dst))
.flatten()
}
pub fn replace(&mut self, src: &str, dst: String) -> Result<(), ReplaceError> {
self.driver
.drive_write(self.state.replace(src, dst))
.flatten()
}
pub fn insert(&mut self, path: String, flags: Flags) -> Result<EntryWriter<S>, InsertError> {
let handle = self
.driver
.drive_write(self.state.insert(path, flags))
.flatten()?;
Ok(EntryWriter {
driver: &mut self.driver,
handle: ManuallyDrop::new(handle),
})
}
pub fn flush(&mut self) -> std::io::Result<()> {
self.driver.get_mut().flush()
}
}
impl<S: Read + Seek + Write + Truncate> Pkg<S> {
pub fn repack(&mut self) -> Result<(), RepackError> {
self.driver.drive_truncate(self.state.repack()).flatten()
}
}