use crate::prelude::*;
pub type WriteResult<W> = StdResult<(), <W as Write>::Error>;
pub trait Write {
type Error;
type Seekable: Write<Error = Self::Error, Seekable = Self::Seekable> + Seek;
fn write(&mut self, buf: &[u8]) -> WriteResult<Self>;
fn invalid_input(msg: &'static str) -> Self::Error;
#[inline]
fn make_seekable(&mut self) -> Option<&mut Self::Seekable> {
None
}
}
pub trait Seek: Write {
fn tell(&mut self) -> StdResult<u64, Self::Error>;
fn write_at(&mut self, buf: &[u8], pos: u64) -> WriteResult<Self>;
}
#[derive(Copy, Clone, Debug)]
enum Never {}
#[derive(Debug)]
pub struct NotSeekable<W> {
_phantom: PhantomData<W>,
never: Never,
}
impl<W> Clone for NotSeekable<W> {
fn clone(&self) -> Self {
*self
}
}
impl<W> Copy for NotSeekable<W> {}
impl<W> NotSeekable<W> {
#[inline]
pub fn as_never(&self) -> ! {
match self.never {}
}
}
impl<W: Write> Write for NotSeekable<W> {
type Error = W::Error;
type Seekable = Self;
#[inline]
fn write(&mut self, _: &[u8]) -> WriteResult<Self> {
self.as_never()
}
#[inline]
fn invalid_input(msg: &'static str) -> W::Error {
W::invalid_input(msg)
}
}
impl<W: Write> Seek for NotSeekable<W> {
#[inline]
fn tell(&mut self) -> StdResult<u64, W::Error> {
self.as_never()
}
#[inline]
fn write_at(&mut self, _: &[u8], _: u64) -> WriteResult<Self> {
self.as_never()
}
}
impl<'a, W: Write> Write for &'a mut W {
type Error = W::Error;
type Seekable = W::Seekable;
#[inline]
fn write(&mut self, buf: &[u8]) -> WriteResult<W> {
W::write(self, buf)
}
#[inline]
fn invalid_input(msg: &'static str) -> W::Error {
W::invalid_input(msg)
}
#[inline]
fn make_seekable(&mut self) -> Option<&mut W::Seekable> {
W::make_seekable(self)
}
}
#[cfg(feature = "alloc")]
impl Write for Vec<u8> {
type Error = &'static str;
type Seekable = Self;
#[inline]
fn write(&mut self, buf: &[u8]) -> WriteResult<Self> {
self.extend_from_slice(buf);
Ok(())
}
#[inline]
fn invalid_input(msg: &'static str) -> &'static str {
msg
}
#[inline]
fn make_seekable(&mut self) -> Option<&mut Self> {
Some(self)
}
}
#[cfg(feature = "alloc")]
impl Seek for Vec<u8> {
#[inline]
fn tell(&mut self) -> StdResult<u64, Self::Error> {
Ok(self.len() as u64)
}
#[inline]
fn write_at(&mut self, buf: &[u8], pos: u64) -> WriteResult<Self> {
let out = self
.get_mut(pos as usize..pos as usize + buf.len())
.ok_or("invalid seekback")?;
out.copy_from_slice(buf);
Ok(())
}
}
#[derive(Debug)]
pub struct Cursor<'a> {
buf: &'a mut [u8],
cur: usize,
}
impl<'a> Cursor<'a> {
#[inline]
pub fn new(buffer: &mut [u8]) -> Cursor {
Cursor {
buf: buffer,
cur: 0,
}
}
#[inline]
pub fn from_parts(buffer: &mut [u8], cursor: usize) -> Cursor {
assert!(
cursor <= buffer.len(),
"cursor beyond the end of the buffer"
);
Cursor {
buf: buffer,
cur: cursor,
}
}
#[inline]
pub fn into_parts(self) -> (&'a mut [u8], usize) {
(self.buf, self.cur)
}
#[inline]
pub fn slice(&self) -> &[u8] {
self.buf
}
#[inline]
pub fn slice_mut(&mut self) -> &mut [u8] {
self.buf
}
#[inline]
pub fn cursor(&self) -> usize {
self.cur
}
#[inline]
pub fn written(&self) -> &[u8] {
&self.buf[..self.cur]
}
#[inline]
pub fn unwritten(&self) -> &[u8] {
&self.buf[self.cur..]
}
#[inline]
pub fn split(&self) -> (&[u8], &[u8]) {
self.buf.split_at(self.cur)
}
#[inline]
pub fn written_mut(&mut self) -> &mut [u8] {
&mut self.buf[..self.cur]
}
#[inline]
pub fn unwritten_mut(&mut self) -> &mut [u8] {
&mut self.buf[self.cur..]
}
#[inline]
pub fn split_mut(&mut self) -> (&mut [u8], &mut [u8]) {
self.buf.split_at_mut(self.cur)
}
}
impl<'a> Write for Cursor<'a> {
type Error = CursorError;
type Seekable = Self;
#[inline]
fn write(&mut self, buf: &[u8]) -> WriteResult<Self> {
let up_to = self.cur + buf.len();
if up_to > self.buf.len() {
let space = self.buf.len() - self.cur;
self.buf[self.cur..].copy_from_slice(&buf[..space]);
self.cur = self.buf.len();
Err(CursorError::OutOfSpace)
} else {
self.buf[self.cur..up_to].copy_from_slice(buf);
self.cur += buf.len();
Ok(())
}
}
#[inline]
fn invalid_input(msg: &'static str) -> CursorError {
CursorError::InvalidInput(msg)
}
#[inline]
fn make_seekable(&mut self) -> Option<&mut Self> {
Some(self)
}
}
impl<'a> Seek for Cursor<'a> {
#[inline]
fn tell(&mut self) -> StdResult<u64, Self::Error> {
Ok(self.cur as u64)
}
#[inline]
fn write_at(&mut self, buf: &[u8], pos: u64) -> WriteResult<Self> {
let out = self
.buf
.get_mut(pos as usize..pos as usize + buf.len())
.ok_or(CursorError::OutOfSpace)?;
out.copy_from_slice(buf);
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum CursorError {
OutOfSpace,
InvalidInput(&'static str),
}
impl<'a> Write for &'a mut [u8] {
type Error = CursorError;
type Seekable = NotSeekable<Self>;
#[inline]
fn write(&mut self, buf: &[u8]) -> WriteResult<Self> {
if buf.len() > self.len() {
self.copy_from_slice(&buf[..self.len()]);
*self = &mut [];
Err(CursorError::OutOfSpace)
} else {
self[..buf.len()].copy_from_slice(buf);
let slice = mem::replace(self, &mut []);
*self = &mut slice[buf.len()..];
Ok(())
}
}
#[inline]
fn invalid_input(msg: &'static str) -> CursorError {
CursorError::InvalidInput(msg)
}
}
#[derive(Debug, Clone, Default)]
pub struct IoWrap<T>(pub T);
#[cfg(feature = "std")]
impl<T: io::Write> Write for IoWrap<T> {
type Error = io::Error;
type Seekable = NotSeekable<Self>;
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<()> {
io::Write::write_all(&mut self.0, buf)
}
#[inline]
fn invalid_input(msg: &'static str) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, msg)
}
}
#[derive(Debug, Clone, Default)]
pub struct SeekableWrap<T>(pub T);
#[cfg(feature = "std")]
impl<T: io::Write + io::Seek> Write for SeekableWrap<T> {
type Error = io::Error;
type Seekable = Self;
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<()> {
io::Write::write_all(&mut self.0, buf)
}
#[inline]
fn invalid_input(msg: &'static str) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, msg)
}
#[inline]
fn make_seekable(&mut self) -> Option<&mut Self> {
Some(self)
}
}
#[cfg(feature = "std")]
impl<T: io::Write + io::Seek> Seek for SeekableWrap<T> {
#[inline]
fn tell(&mut self) -> io::Result<u64> {
io::Seek::seek(&mut self.0, io::SeekFrom::Current(0))
}
#[inline]
fn write_at(&mut self, buf: &[u8], pos: u64) -> io::Result<()> {
io::Seek::seek(&mut self.0, io::SeekFrom::Start(pos))?;
self.write(buf)?;
io::Seek::seek(&mut self.0, io::SeekFrom::End(0))?;
Ok(())
}
}
pub(crate) struct WriteCounter(pub u64);
impl Write for WriteCounter {
type Error = &'static str;
type Seekable = NotSeekable<Self>;
#[inline]
fn write(&mut self, buf: &[u8]) -> WriteResult<Self> {
self.0 += buf.len() as u64;
Ok(())
}
#[inline]
fn invalid_input(msg: &'static str) -> &'static str {
msg
}
}