use std::io::{self, Read, Error, ErrorKind};
use std::collections::VecDeque;
use value::Value;
#[derive(Debug)]
pub struct FrameReader {
frames: Frames
}
impl FrameReader {
pub fn new(max_frame_size: u32) -> FrameReader {
FrameReader {
frames: Frames::new(max_frame_size)
}
}
pub fn read<T: Read>(&mut self, reader: &mut T) -> io::Result<usize> {
self.frames.read(reader)
}
pub fn iter_mut(&mut self) -> Iter {
Iter {
frames: &mut self.frames
}
}
}
pub struct Iter<'a> {
frames: &'a mut Frames
}
impl<'a> Iterator for Iter<'a> {
type Item = Value;
fn next(&mut self) -> Option<Self::Item> {
self.frames.completed_frames.pop_front()
}
}
#[derive(Debug)]
struct Frames {
max_frame_size: u32,
bytes_read: usize,
current: Vec<u8>,
completed_frames: VecDeque<Value>
}
impl Frames {
pub fn new(max_frame_size: u32) -> Frames {
let mut buf = Vec::with_capacity(max_frame_size as usize);
unsafe { buf.set_len(max_frame_size as usize); }
Frames {
max_frame_size : max_frame_size,
bytes_read : 0,
current : buf,
completed_frames : VecDeque::new()
}
}
fn read<T: Read>(&mut self, reader: &mut T) -> io::Result<usize> {
let mut total_bytes_read = 0;
loop {
match self.do_read(reader) {
Ok(0) => {
if total_bytes_read == 0 {
return Err(Error::new(ErrorKind::UnexpectedEof, "Read 0 bytes"));
}
return Ok(total_bytes_read);
},
Ok(bytes_read) => {
total_bytes_read += bytes_read;
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(total_bytes_read)
}
Err(e) => return Err(e)
}
}
}
fn do_read<T: Read>(&mut self, reader: &mut T) -> io::Result<usize> {
let bytes_read = reader.read(&mut self.current[self.bytes_read..])?;
self.bytes_read += bytes_read;
{
let mut offset = 0;
loop {
let (val, _offset) = Value::decode(&self.current[..self.bytes_read], offset)?;
if _offset == 0 {
break;
}
self.completed_frames.push_back(val);
offset = _offset;
}
if offset > 0 {
let mut k = 0;
let mut i = offset;
while i < self.bytes_read {
self.current[ k ] = self.current[ i ];
i += 1;
k += 1;
}
self.bytes_read = k;
}
}
Ok(bytes_read)
}
}
#[cfg(test)]
mod tests {
use std::thread;
use std::io::Cursor;
use std::io::Write;
use std::net::{TcpListener, TcpStream};
use super::FrameReader;
use super::super::value::Value;
#[test]
fn partial_and_complete_reads() {
let buf1 = String::from("+Hello World\r\n").into_bytes();
let buf2 = String::from("-Error\r\n").into_bytes();
let mut reader = FrameReader::new(64);
let mut data = Cursor::new(&buf1[0..5]);
let bytes_read = reader.read(&mut data).unwrap();
assert_eq!(5, bytes_read);
assert_eq!(None, reader.iter_mut().next());
let mut data = Cursor::new(&buf1[5..]);
let bytes_read = reader.read(&mut data).unwrap();
assert_eq!(9, bytes_read);
let val = reader.iter_mut().next().unwrap();
assert_eq!(Value::Status("Hello World".to_string()), val);
let mut data = Cursor::new(Vec::with_capacity(buf1.len() + buf2.len()));
assert_eq!(buf1.len(), data.write(&buf1).unwrap());
assert_eq!(buf2.len(), data.write(&buf2).unwrap());
data.set_position(0);
let bytes_read = reader.read(&mut data).unwrap();
assert_eq!(buf1.len() + buf2.len(), bytes_read);
let val = reader.iter_mut().next().unwrap();
assert_eq!(Value::Status("Hello World".to_string()), val);
let val = reader.iter_mut().next().unwrap();
assert_eq!(Value::Error("Error".to_string()), val);
}
const IP: &'static str = "127.0.0.1:5003";
#[test]
fn would_block() {
let listener = TcpListener::bind(IP).unwrap();
let h = thread::spawn(move || {
for stream in listener.incoming() {
if let Ok(mut conn) = stream {
conn.set_nonblocking(true).unwrap();
let mut reader = FrameReader::new(512);
let result = reader.read(&mut conn).unwrap();
assert_eq!(result, 0);
return;
}
}
});
let _sock = TcpStream::connect(IP).unwrap();
h.join().unwrap();
}
}