use std::{error, fmt, mem, ops};
use std::cmp::min;
use std::convert::Infallible;
use bytes::Bytes;
use super::error::{ContentError, DecodeError};
pub trait Source {
type Error: error::Error;
fn pos(&self) -> Pos;
fn request(&mut self, len: usize) -> Result<usize, Self::Error>;
fn slice(&self) -> &[u8];
fn bytes(&self, start: usize, end: usize) -> Bytes;
fn advance(&mut self, len: usize);
fn skip(&mut self, len: usize) -> Result<usize, Self::Error> {
let res = min(self.request(len)?, len);
self.advance(res);
Ok(res)
}
fn take_u8(&mut self) -> Result<u8, DecodeError<Self::Error>> {
if self.request(1)? < 1 {
return Err(self.content_err("unexpected end of data"))
}
let res = self.slice()[0];
self.advance(1);
Ok(res)
}
fn take_opt_u8(&mut self) -> Result<Option<u8>, Self::Error> {
if self.request(1)? < 1 {
return Ok(None)
}
let res = self.slice()[0];
self.advance(1);
Ok(Some(res))
}
fn content_err(
&self, err: impl Into<ContentError>
) -> DecodeError<Self::Error> {
DecodeError::content(err.into(), self.pos())
}
}
impl<T: Source> Source for &'_ mut T {
type Error = T::Error;
fn request(&mut self, len: usize) -> Result<usize, Self::Error> {
Source::request(*self, len)
}
fn advance(&mut self, len: usize) {
Source::advance(*self, len)
}
fn slice(&self) -> &[u8] {
Source::slice(*self)
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
Source::bytes(*self, start, end)
}
fn pos(&self) -> Pos {
Source::pos(*self)
}
}
pub trait IntoSource {
type Source: Source;
fn into_source(self) -> Self::Source;
}
impl<T: Source> IntoSource for T {
type Source = Self;
fn into_source(self) -> Self::Source {
self
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Pos(usize);
impl From<usize> for Pos {
fn from(pos: usize) -> Pos {
Pos(pos)
}
}
impl ops::Add for Pos {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Pos(self.0 + rhs.0)
}
}
impl fmt::Display for Pos {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Clone, Debug)]
pub struct BytesSource {
data: Bytes,
pos: usize,
offset: Pos,
}
impl BytesSource {
pub fn new(data: Bytes) -> Self {
BytesSource { data, pos: 0, offset: 0.into() }
}
pub fn with_offset(data: Bytes, offset: Pos) -> Self {
BytesSource { data, pos: 0, offset }
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn split_to(&mut self, len: usize) -> Bytes {
let res = self.data.split_to(len);
self.pos += len;
res
}
pub fn into_bytes(self) -> Bytes {
self.data
}
}
impl Source for BytesSource {
type Error = Infallible;
fn pos(&self) -> Pos {
self.offset + self.pos.into()
}
fn request(&mut self, _len: usize) -> Result<usize, Self::Error> {
Ok(self.data.len())
}
fn slice(&self) -> &[u8] {
self.data.as_ref()
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
self.data.slice(start..end)
}
fn advance(&mut self, len: usize) {
assert!(len <= self.data.len());
bytes::Buf::advance(&mut self.data, len);
self.pos += len;
}
}
impl IntoSource for Bytes {
type Source = BytesSource;
fn into_source(self) -> Self::Source {
BytesSource::new(self)
}
}
#[derive(Clone, Copy, Debug)]
pub struct SliceSource<'a> {
data: &'a [u8],
pos: usize
}
impl<'a> SliceSource<'a> {
pub fn new(data: &'a [u8]) -> Self {
SliceSource { data, pos: 0 }
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn split_to(&mut self, len: usize) -> &'a [u8] {
let (left, right) = self.data.split_at(len);
self.data = right;
self.pos += len;
left
}
}
impl Source for SliceSource<'_> {
type Error = Infallible;
fn pos(&self) -> Pos {
self.pos.into()
}
fn request(&mut self, _len: usize) -> Result<usize, Self::Error> {
Ok(self.data.len())
}
fn advance(&mut self, len: usize) {
assert!(len <= self.data.len());
self.data = &self.data[len..];
self.pos += len;
}
fn slice(&self) -> &[u8] {
self.data
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
Bytes::copy_from_slice(&self.data[start..end])
}
}
impl<'a> IntoSource for &'a [u8] {
type Source = SliceSource<'a>;
fn into_source(self) -> Self::Source {
SliceSource::new(self)
}
}
#[derive(Clone, Debug)]
pub struct LimitedSource<S> {
source: S,
limit: Option<usize>,
}
impl<S> LimitedSource<S> {
pub fn new(source: S) -> Self {
LimitedSource {
source,
limit: None
}
}
pub fn unwrap(self) -> S {
self.source
}
pub fn limit(&self) -> Option<usize> {
self.limit
}
pub fn limit_further(&mut self, limit: Option<usize>) -> Option<usize> {
if let Some(cur) = self.limit {
match limit {
Some(limit) => assert!(limit <= cur),
None => panic!("relimiting to unlimited"),
}
}
mem::replace(&mut self.limit, limit)
}
pub fn set_limit(&mut self, limit: Option<usize>) {
self.limit = limit
}
}
impl<S: Source> LimitedSource<S> {
pub fn skip_all(&mut self) -> Result<(), DecodeError<S::Error>> {
let limit = self.limit.unwrap();
if self.request(limit)? < limit {
return Err(self.content_err("unexpected end of data"))
}
self.advance(limit);
Ok(())
}
pub fn take_all(&mut self) -> Result<Bytes, DecodeError<S::Error>> {
let limit = self.limit.unwrap();
if self.request(limit)? < limit {
return Err(self.content_err("unexpected end of data"))
}
let res = self.bytes(0, limit);
self.advance(limit);
Ok(res)
}
pub fn exhausted(&mut self) -> Result<(), DecodeError<S::Error>> {
match self.limit {
Some(0) => Ok(()),
Some(_limit) => Err(self.content_err("trailing data")),
None => {
if self.source.request(1)? == 0 {
Ok(())
}
else {
Err(self.content_err("trailing data"))
}
}
}
}
}
impl<S: Source> Source for LimitedSource<S> {
type Error = S::Error;
fn pos(&self) -> Pos {
self.source.pos()
}
fn request(&mut self, len: usize) -> Result<usize, Self::Error> {
if let Some(limit) = self.limit {
Ok(min(limit, self.source.request(min(limit, len))?))
}
else {
self.source.request(len)
}
}
fn advance(&mut self, len: usize) {
if let Some(limit) = self.limit {
assert!(
len <= limit,
"advanced past end of limit"
);
self.limit = Some(limit - len);
}
self.source.advance(len)
}
fn slice(&self) -> &[u8] {
let res = self.source.slice();
if let Some(limit) = self.limit {
if res.len() > limit {
return &res[..limit]
}
}
res
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
if let Some(limit) = self.limit {
assert!(start <= limit);
assert!(end <= limit);
}
self.source.bytes(start, end)
}
}
pub struct CaptureSource<'a, S: 'a> {
source: &'a mut S,
len: usize,
pos: usize,
}
impl<'a, S: Source> CaptureSource<'a, S> {
pub fn new(source: &'a mut S) -> Self {
CaptureSource {
source,
len: 0,
pos: 0,
}
}
pub fn into_bytes(self) -> Bytes {
let res = self.source.bytes(0, self.pos);
self.skip();
res
}
pub fn skip(self) {
self.source.advance(self.pos)
}
}
impl<'a, S: Source + 'a> Source for CaptureSource<'a, S> {
type Error = S::Error;
fn pos(&self) -> Pos {
self.source.pos() + self.pos.into()
}
fn request(&mut self, len: usize) -> Result<usize, Self::Error> {
self.len = self.source.request(self.pos + len)?;
Ok(self.len - self.pos)
}
fn slice(&self) -> &[u8] {
&self.source.slice()[self.pos..]
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
let start = start + self.pos;
let end = end + self.pos;
assert!(
self.len >= start,
"start past the end of data"
);
assert!(
self.len >= end,
"end past the end of data"
);
self.source.bytes(start, end)
}
fn advance(&mut self, len: usize) {
assert!(
self.len >= self.pos + len,
"advanced past the end of data"
);
self.pos += len;
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn take_u8() {
let mut source = b"123".into_source();
assert_eq!(source.take_u8().unwrap(), b'1');
assert_eq!(source.take_u8().unwrap(), b'2');
assert_eq!(source.take_u8().unwrap(), b'3');
assert!(source.take_u8().is_err())
}
#[test]
fn take_opt_u8() {
let mut source = b"123".into_source();
assert_eq!(source.take_opt_u8().unwrap(), Some(b'1'));
assert_eq!(source.take_opt_u8().unwrap(), Some(b'2'));
assert_eq!(source.take_opt_u8().unwrap(), Some(b'3'));
assert_eq!(source.take_opt_u8().unwrap(), None);
}
#[test]
fn bytes_impl() {
let mut bytes = Bytes::from_static(b"1234567890").into_source();
assert!(bytes.request(4).unwrap() >= 4);
assert!(&Source::slice(&bytes)[..4] == b"1234");
assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34"));
Source::advance(&mut bytes, 4);
assert!(bytes.request(4).unwrap() >= 4);
assert!(&Source::slice(&bytes)[..4] == b"5678");
Source::advance(&mut bytes, 4);
assert_eq!(bytes.request(4).unwrap(), 2);
assert!(&Source::slice(&bytes) == b"90");
bytes.advance(2);
assert_eq!(bytes.request(4).unwrap(), 0);
}
#[test]
fn slice_impl() {
let mut bytes = b"1234567890".into_source();
assert!(bytes.request(4).unwrap() >= 4);
assert!(&bytes.slice()[..4] == b"1234");
assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34"));
bytes.advance(4);
assert!(bytes.request(4).unwrap() >= 4);
assert!(&bytes.slice()[..4] == b"5678");
bytes.advance(4);
assert_eq!(bytes.request(4).unwrap(), 2);
assert!(&bytes.slice() == b"90");
bytes.advance(2);
assert_eq!(bytes.request(4).unwrap(), 0);
}
#[test]
fn limited_source() {
let mut the_source = LimitedSource::new(
b"12345678".into_source()
);
the_source.set_limit(Some(4));
let mut source = the_source.clone();
assert!(source.exhausted().is_err());
assert_eq!(source.request(6).unwrap(), 4);
source.advance(2);
assert!(source.exhausted().is_err());
assert_eq!(source.request(6).unwrap(), 2);
source.advance(2);
source.exhausted().unwrap();
assert_eq!(source.request(6).unwrap(), 0);
let mut source = the_source.clone();
source.skip_all().unwrap();
let source = source.unwrap();
assert_eq!(source.slice(), b"5678");
let mut source = the_source.clone();
assert_eq!(source.take_all().unwrap(), Bytes::from_static(b"1234"));
source.exhausted().unwrap();
let source = source.unwrap();
assert_eq!(source.slice(), b"5678");
}
#[test]
#[should_panic]
fn limited_source_far_advance() {
let mut source = LimitedSource::new(
b"12345678".into_source()
);
source.set_limit(Some(4));
assert_eq!(source.request(6).unwrap(), 4);
source.advance(4);
assert_eq!(source.request(6).unwrap(), 0);
source.advance(6); }
#[test]
#[should_panic]
fn limit_further() {
let mut source = LimitedSource::new(b"12345".into_source());
source.set_limit(Some(4));
source.limit_further(Some(5)); }
#[test]
fn capture_source() {
let mut source = b"1234567890".into_source();
{
let mut capture = CaptureSource::new(&mut source);
assert_eq!(capture.request(4).unwrap(), 10);
capture.advance(4);
assert_eq!(capture.into_bytes(), Bytes::from_static(b"1234"));
}
assert_eq!(source.data, b"567890");
let mut source = b"1234567890".into_source();
{
let mut capture = CaptureSource::new(&mut source);
assert_eq!(capture.request(4).unwrap(), 10);
capture.advance(4);
capture.skip();
}
assert_eq!(source.data, b"567890");
}
}