use crate::errors::*;
use amq_protocol::frame::{parse_frame, AMQPFrame};
use amq_protocol::types::parsing::parse_long_uint;
use bytes::Buf;
use input_buffer::{InputBuffer, MIN_READ};
use log::trace;
use snafu::ResultExt;
use std::io;
use std::marker::PhantomData;
pub struct FrameBuffer(Inner<AmqpFrameKind>);
impl FrameBuffer {
pub fn new() -> FrameBuffer {
FrameBuffer(Inner::new())
}
pub fn read_from<S, F>(&mut self, stream: &mut S, handler: F) -> Result<usize>
where
S: io::Read,
F: FnMut(AMQPFrame) -> Result<()>,
{
self.0.read_from(stream, handler)
}
}
trait FrameKind {
type Frame;
fn parse_size(buf: &[u8]) -> Option<usize>;
fn parse_frame(buf: &[u8]) -> Result<Self::Frame>;
}
enum AmqpFrameKind {}
impl AmqpFrameKind {
const AMQP_FRAME_SIZE_POS: std::ops::Range<usize> = 3..7;
}
impl FrameKind for AmqpFrameKind {
type Frame = AMQPFrame;
fn parse_size(buf: &[u8]) -> Option<usize> {
if buf.len() < Self::AMQP_FRAME_SIZE_POS.end {
None
} else {
let (_, size) = parse_long_uint(&buf[Self::AMQP_FRAME_SIZE_POS]).unwrap();
Some(size as usize + 8)
}
}
fn parse_frame(buf: &[u8]) -> Result<AMQPFrame> {
if let Ok((rest, frame)) = parse_frame(buf) {
if rest.is_empty() {
return Ok(frame);
}
}
MalformedFrameSnafu.fail()
}
}
struct Inner<Kind: FrameKind> {
buf: InputBuffer,
phantom: PhantomData<Kind>,
}
impl<Kind: FrameKind> Inner<Kind> {
fn new() -> Inner<Kind> {
Inner {
buf: InputBuffer::new(),
phantom: PhantomData,
}
}
fn read_from<S, F>(&mut self, stream: &mut S, mut handler: F) -> Result<usize>
where
S: io::Read,
F: FnMut(Kind::Frame) -> Result<()>,
{
let mut bytes_read = 0;
loop {
let bytes = self.buf.chunk();
let frame_size = Kind::parse_size(bytes);
let mut reserve = MIN_READ;
if let Some(frame_size) = frame_size {
if bytes.len() >= frame_size {
let frame = Kind::parse_frame(&bytes[..frame_size])?;
handler(frame)?;
self.buf.advance(frame_size);
continue;
} else {
reserve = usize::max(MIN_READ, frame_size);
}
}
match self.buf.prepare_reserve(reserve).read_from(stream) {
Ok(0) => return UnexpectedSocketCloseSnafu.fail(),
Ok(n) => {
trace!("read {} bytes", n);
bytes_read += n;
}
Err(err) => match err.kind() {
io::ErrorKind::WouldBlock => return Ok(bytes_read),
_ => return Err(err).context(IoErrorReadingSocketSnafu),
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::{FrameKind, Inner, Result};
use crate::errors::*;
use mockstream::FailingMockStream;
use std::io::{self, Cursor, Read};
struct FakeFrameKind {}
impl FrameKind for FakeFrameKind {
type Frame = Vec<u8>;
fn parse_size(buf: &[u8]) -> Option<usize> {
if buf.len() >= 2 {
Some(buf[1] as usize)
} else {
None
}
}
fn parse_frame(buf: &[u8]) -> Result<Self::Frame> {
assert!(buf.len() == buf[1] as usize);
if buf.len() == 6 && &buf[2..] == b"fail" {
MalformedFrameSnafu.fail()
} else {
Ok(Vec::from(buf))
}
}
}
fn make_buffer() -> Inner<FakeFrameKind> {
Inner::new()
}
fn would_block() -> FailingMockStream {
FailingMockStream::new(io::ErrorKind::WouldBlock, "", 1)
}
#[test]
fn full_frame_available() {
let frame0 = b"a\x04aa";
let mut c = Cursor::new(frame0).chain(would_block());
let mut got = None;
let mut buf = make_buffer();
let n = buf
.read_from(&mut c, |f| {
got = Some(f);
Ok(())
})
.unwrap();
assert_eq!(n, 4);
assert_eq!(got, Some(Vec::from(&frame0[..])));
}
#[test]
fn two_full_frames_available() {
let frame0 = b"a\x04aa";
let frame1 = b"b\x04bb";
let mut c = Cursor::new(frame0)
.chain(Cursor::new(frame1))
.chain(would_block());
let mut got = Vec::new();
let mut buf = make_buffer();
let n = buf.read_from(&mut c, |f| {
got.push(f);
Ok(())
}).unwrap();
assert_eq!(n, 8);
assert_eq!(got, vec![Vec::from(&frame0[..]), Vec::from(&frame1[..])]);
}
#[test]
fn partial_first_frame() {
let mut c = Cursor::new(b"a\x04")
.chain(would_block())
.chain(Cursor::new(b"aa"))
.chain(would_block());
let mut got = None;
let mut buf = make_buffer();
let n = buf
.read_from(&mut c, |f| {
got = Some(f);
Ok(())
})
.unwrap();
assert_eq!(n, 2);
assert!(got.is_none());
let n = buf
.read_from(&mut c, |f| {
got = Some(f);
Ok(())
})
.unwrap();
assert_eq!(n, 2);
assert_eq!(got, Some(b"a\x04aa".to_vec()));
}
#[test]
fn split_frames() {
let mut c = Cursor::new(b"a\x04")
.chain(would_block())
.chain(Cursor::new(b"aab\x04b"))
.chain(would_block())
.chain(Cursor::new(b"bc\x04"))
.chain(would_block());
let mut got = Vec::new();
let mut buf = make_buffer();
let n = buf.read_from(&mut c, |f| {
got.push(f);
Ok(())
}).unwrap();
assert_eq!(n, 2);
assert!(got.is_empty());
let n = buf.read_from(&mut c, |f| {
got.push(f);
Ok(())
}).unwrap();
assert_eq!(n, 5);
assert_eq!(got, vec![b"a\x04aa".to_vec()]);
let n = buf.read_from(&mut c, |f| {
got.push(f);
Ok(())
}).unwrap();
assert_eq!(n, 3);
assert_eq!(
got,
vec![
b"a\x04aa".to_vec(),
b"b\x04bb".to_vec()
]
);
}
#[test]
fn parse_fail() {
let mut c = Cursor::new(b"x\x06fail").chain(would_block());
let mut buf = make_buffer();
let res = buf.read_from(&mut c, |_| panic!("should not be called"));
assert!(res.is_err());
match res.unwrap_err() {
Error::MalformedFrame => (),
err => panic!("unexpected error {}", err),
}
}
#[test]
fn callback_fail() {
let mut c = Cursor::new(b"a\x04aa").chain(would_block());
let mut buf = make_buffer();
let res = buf.read_from(&mut c, |_| __NonexhaustiveSnafu.fail());
assert!(res.is_err());
match res.unwrap_err() {
Error::__Nonexhaustive => (),
err => panic!("unexpected error {}", err),
}
}
#[test]
fn eof_fail() {
let mut c = Cursor::new(b"a\x04a");
let mut buf = make_buffer();
let res = buf.read_from(&mut c, |_| panic!("should not be called"));
assert!(res.is_err());
match res.unwrap_err() {
Error::UnexpectedSocketClose => (),
err => panic!("unexpected error {}", err),
}
}
#[test]
fn io_fail() {
let mut c = Cursor::new(b"a\x04a").chain(FailingMockStream::new(
io::ErrorKind::ConnectionReset,
"",
1,
));
let mut buf = make_buffer();
let res = buf.read_from(&mut c, |_| panic!("should not be called"));
assert!(res.is_err());
match res.unwrap_err() {
Error::IoErrorReadingSocket { .. } => (),
err => panic!("unexpected error {}", err),
}
}
}