use std::mem;
use std::cmp::min;
use bytes::Bytes;
use super::error::Error;
pub trait Source {
type Err: From<Error>;
fn request(&mut self, len: usize) -> Result<usize, Self::Err>;
fn advance(&mut self, len: usize) -> Result<(), Self::Err>;
fn slice(&self) -> &[u8];
fn bytes(&self, start: usize, end: usize) -> Bytes;
fn take_u8(&mut self) -> Result<u8, Self::Err> {
if self.request(1)? < 1 {
xerr!(return Err(Error::Malformed.into()))
}
let res = self.slice()[0];
self.advance(1)?;
Ok(res)
}
fn take_opt_u8(&mut self) -> Result<Option<u8>, Self::Err> {
if self.request(1)? < 1 {
return Ok(None)
}
let res = self.slice()[0];
self.advance(1)?;
Ok(Some(res))
}
}
impl Source for Bytes {
type Err = Error;
fn request(&mut self, _len: usize) -> Result<usize, Self::Err> {
Ok(self.len())
}
fn advance(&mut self, len: usize) -> Result<(), Self::Err> {
if len > self.len() {
Err(Error::Malformed)
}
else {
bytes::Buf::advance(self, len);
Ok(())
}
}
fn slice(&self) -> &[u8] {
self.as_ref()
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
self.slice(start..end)
}
}
impl<'a> Source for &'a [u8] {
type Err = Error;
fn request(&mut self, _len: usize) -> Result<usize, Self::Err> {
Ok(self.len())
}
fn advance(&mut self, len: usize) -> Result<(), Self::Err> {
if len > self.len() {
Err(Error::Malformed)
}
else {
*self = &self[len..];
Ok(())
}
}
fn slice(&self) -> &[u8] {
self
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
Bytes::copy_from_slice(&self[start..end])
}
}
impl<'a, T: Source> Source for &'a mut T {
type Err = T::Err;
fn request(&mut self, len: usize) -> Result<usize, Self::Err> {
Source::request(*self, len)
}
fn advance(&mut self, len: usize) -> Result<(), Self::Err> {
Source::advance(*self, len)
}
fn slice(&self) -> &[u8] {
Source::slice(*self)
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
Source::bytes(*self, start, end)
}
}
#[derive(Clone, Debug)]
pub struct LimitedSource<S> {
source: S,
limit: Option<usize>,
}
impl<S> LimitedSource<S> {
pub fn new(source: S) -> Self {
LimitedSource {
source,
limit: None
}
}
pub fn unwrap(self) -> S {
self.source
}
pub fn limit(&self) -> Option<usize> {
self.limit
}
pub fn limit_further(&mut self, limit: Option<usize>) -> Option<usize> {
if let Some(cur) = self.limit {
match limit {
Some(limit) => assert!(limit <= cur),
None => panic!("relimiting to unlimited"),
}
}
mem::replace(&mut self.limit, limit)
}
pub fn set_limit(&mut self, limit: Option<usize>) {
self.limit = limit
}
}
impl<S: Source> LimitedSource<S> {
pub fn skip_all(&mut self) -> Result<(), S::Err> {
let limit = self.limit.unwrap();
self.advance(limit)
}
pub fn take_all(&mut self) -> Result<Bytes, S::Err> {
let limit = self.limit.unwrap();
if self.request(limit)? < limit {
return Err(Error::Malformed.into())
}
let res = self.bytes(0, limit);
self.advance(limit)?;
Ok(res)
}
pub fn exhausted(&mut self) -> Result<(), S::Err> {
match self.limit {
Some(0) => Ok(()),
Some(_limit) => {
xerr!(Err(Error::Malformed.into()))
}
None => {
if self.source.request(1)? == 0 {
Ok(())
}
else {
xerr!(Err(Error::Malformed.into()))
}
}
}
}
}
impl<S: Source> Source for LimitedSource<S> {
type Err = S::Err;
fn request(&mut self, len: usize) -> Result<usize, Self::Err> {
if let Some(limit) = self.limit {
Ok(min(limit, self.source.request(min(limit, len))?))
}
else {
self.source.request(len)
}
}
fn advance(&mut self, len: usize) -> Result<(), Self::Err> {
if let Some(limit) = self.limit {
if len > limit {
xerr!(return Err(Error::Malformed.into()))
}
self.limit = Some(limit - len);
}
self.source.advance(len)
}
fn slice(&self) -> &[u8] {
let res = self.source.slice();
if let Some(limit) = self.limit {
if res.len() > limit {
return &res[..limit]
}
}
res
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
if let Some(limit) = self.limit {
assert!(start <= limit);
assert!(end <= limit);
}
self.source.bytes(start, end)
}
}
pub struct CaptureSource<'a, S: 'a> {
source: &'a mut S,
pos: usize,
}
impl<'a, S: Source> CaptureSource<'a, S> {
pub fn new(source: &'a mut S) -> Self {
CaptureSource {
source,
pos: 0
}
}
pub fn into_bytes(self) -> Bytes {
let res = self.source.bytes(0, self.pos);
self.skip();
res
}
pub fn skip(self) {
assert!(
!self.source.advance(self.pos).is_err(),
"failed to advance capture source"
);
}
}
impl<'a, S: Source + 'a> Source for CaptureSource<'a, S> {
type Err = S::Err;
fn request(&mut self, len: usize) -> Result<usize, Self::Err> {
self.source.request(self.pos + len).map(|res| res - self.pos)
}
fn advance(&mut self, len: usize) -> Result<(), Self::Err> {
if self.request(len)? < len {
return Err(Error::Malformed.into())
}
self.pos += len;
Ok(())
}
fn slice(&self) -> &[u8] {
&self.source.slice()[self.pos..]
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
self.source.bytes(start + self.pos, end + self.pos)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn take_u8() {
let mut source = &b"123"[..];
assert_eq!(source.take_u8(), Ok(b'1'));
assert_eq!(source.take_u8(), Ok(b'2'));
assert_eq!(source.take_u8(), Ok(b'3'));
assert_eq!(source.take_u8(), Err(Error::Malformed));
}
#[test]
fn take_opt_u8() {
let mut source = &b"123"[..];
assert_eq!(source.take_opt_u8(), Ok(Some(b'1')));
assert_eq!(source.take_opt_u8(), Ok(Some(b'2')));
assert_eq!(source.take_opt_u8(), Ok(Some(b'3')));
assert_eq!(source.take_opt_u8(), Ok(None));
}
#[test]
fn bytes_impl() {
let mut bytes = Bytes::from_static(b"1234567890");
assert!(bytes.request(4).unwrap() >= 4);
assert!(&Source::slice(&bytes)[..4] == b"1234");
assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34"));
Source::advance(&mut bytes, 4).unwrap();
assert!(bytes.request(4).unwrap() >= 4);
assert!(&Source::slice(&bytes)[..4] == b"5678");
Source::advance(&mut bytes, 4).unwrap();
assert_eq!(bytes.request(4).unwrap(), 2);
assert!(&Source::slice(&bytes)[..] == b"90");
assert_eq!(
Source::advance(&mut bytes, 4).unwrap_err(),
Error::Malformed
);
}
#[test]
fn slice_impl() {
let mut bytes = &b"1234567890"[..];
assert!(bytes.request(4).unwrap() >= 4);
assert!(&Source::slice(&bytes)[..4] == b"1234");
assert_eq!(bytes.bytes(2, 4), Bytes::from_static(b"34"));
Source::advance(&mut bytes, 4).unwrap();
assert!(bytes.request(4).unwrap() >= 4);
assert!(&Source::slice(&bytes)[..4] == b"5678");
Source::advance(&mut bytes, 4).unwrap();
assert_eq!(bytes.request(4).unwrap(), 2);
assert!(&Source::slice(&bytes)[..] == b"90");
assert_eq!(
Source::advance(&mut bytes, 4).unwrap_err(),
Error::Malformed
);
}
#[test]
fn limited_source() {
let mut the_source = LimitedSource::new(&b"12345678"[..]);
the_source.set_limit(Some(4));
let mut source = the_source.clone();
assert_eq!(source.exhausted(), Err(Error::Malformed));
assert_eq!(source.request(6).unwrap(), 4);
source.advance(2).unwrap();
assert_eq!(source.exhausted(), Err(Error::Malformed));
assert_eq!(source.request(6).unwrap(), 2);
source.advance(2).unwrap();
source.exhausted().unwrap();
assert_eq!(source.request(6).unwrap(), 0);
source.advance(0).unwrap();
assert_eq!(source.advance(2).unwrap_err(), Error::Malformed);
let mut source = the_source.clone();
assert_eq!(source.advance(5).unwrap_err(), Error::Malformed);
let mut source = the_source.clone();
source.skip_all().unwrap();
let source = source.unwrap();
assert_eq!(source.slice(), b"5678");
let mut source = the_source.clone();
assert_eq!(source.take_all().unwrap(), Bytes::from_static(b"1234"));
source.exhausted().unwrap();
let source = source.unwrap();
assert_eq!(source.slice(), b"5678");
}
#[test]
#[should_panic]
fn limit_further() {
let mut source = LimitedSource::new(&b"12345");
source.set_limit(Some(4));
source.limit_further(Some(5));
}
#[test]
fn capture_source() {
let mut source = &b"1234567890"[..];
{
let mut capture = CaptureSource::new(&mut source);
capture.advance(4).unwrap();
assert_eq!(capture.into_bytes(), Bytes::from_static(b"1234"));
}
assert_eq!(source, b"567890");
let mut source = &b"1234567890"[..];
{
let mut capture = CaptureSource::new(&mut source);
capture.advance(4).unwrap();
capture.skip();
}
assert_eq!(source, b"567890");
}
}