use std::io::{self, Cursor, Read, Write};
use std::thread;
use std::time::Duration;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
const HEADER_LENGTH: usize = 6;
#[derive(Debug)]
pub enum FrameError {
IoError(io::Error),
InvalidChecksum,
InvalidHeaderLength(usize),
UnsupportedVersion,
HandshakeFailure(String),
}
impl std::fmt::Display for FrameError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
FrameError::IoError(err) => f.write_str(&err.to_string()),
FrameError::InvalidChecksum => f.write_str("Invalid checksum in frame header"),
FrameError::InvalidHeaderLength(n) => write!(
f,
"Invalid header length expected {} but was {}",
HEADER_LENGTH, n
),
FrameError::UnsupportedVersion => f.write_str("Unsupported frame version"),
FrameError::HandshakeFailure(msg) => f.write_str(msg),
}
}
}
impl std::error::Error for FrameError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
FrameError::IoError(err) => Some(&*err),
FrameError::InvalidChecksum => None,
FrameError::InvalidHeaderLength(_) => None,
FrameError::UnsupportedVersion => None,
FrameError::HandshakeFailure(_) => None,
}
}
}
impl From<io::Error> for FrameError {
fn from(err: io::Error) -> Self {
FrameError::IoError(err)
}
}
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum FrameVersion {
V1 = 1,
}
impl std::fmt::Display for FrameVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "v{}", *self as u32)
}
}
pub struct Frame {
data: Vec<u8>,
}
impl Frame {
pub fn into_inner(self) -> Vec<u8> {
self.data
}
pub fn read<R: Read>(reader: &mut R) -> Result<Self, FrameError> {
let frame_header = loop {
match FrameHeader::read(reader) {
Err(FrameError::IoError(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(100));
continue;
}
Err(err) => return Err(err),
Ok(header) => break header,
};
};
match frame_header {
FrameHeader::V1 { length } => {
let mut buffer = vec![0; length as usize];
let mut remaining = &mut buffer[..];
while !remaining.is_empty() {
match reader.read(remaining) {
Ok(0) => break,
Ok(n) => {
let tmp = remaining;
remaining = &mut tmp[n..];
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(100));
}
Err(e) => return Err(FrameError::IoError(e)),
}
}
if !remaining.is_empty() {
Err(FrameError::IoError(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Could not receive complete frame",
)))
} else {
Ok(Self { data: buffer })
}
}
}
}
}
pub struct FrameRef<'a> {
version: FrameVersion,
data: &'a [u8],
}
impl<'a> FrameRef<'a> {
pub fn new<'b: 'a>(version: FrameVersion, data: &'b [u8]) -> FrameRef<'a> {
Self { version, data }
}
pub fn write<W: Write>(self, writer: &mut W) -> Result<(), FrameError> {
let frame_header = match self.version {
FrameVersion::V1 => FrameHeader::v1(self.data.len() as u32),
};
loop {
match frame_header.write(writer) {
Err(FrameError::IoError(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(100));
continue;
}
Err(err) => return Err(err),
Ok(_) => break,
}
}
let mut buffer = self.data;
while !buffer.is_empty() {
match writer.write(buffer) {
Ok(0) => {
return Err(FrameError::IoError(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)))
}
Ok(n) => buffer = &buffer[n..],
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(100));
}
Err(e) => return Err(FrameError::IoError(e)),
}
}
writer.flush()?;
Ok(())
}
}
#[derive(Debug, PartialEq)]
enum FrameHeader {
V1 { length: u32 },
}
impl FrameHeader {
fn v1(length: u32) -> Self {
FrameHeader::V1 { length }
}
fn read<R: Read>(reader: &mut R) -> Result<Self, FrameError> {
let version = reader.read_u16::<BigEndian>()?;
match version {
1 => {
let mut buffer = [0u8; HEADER_LENGTH + 1];
let mut cursor = Cursor::new(&mut buffer[..]);
cursor.write_u16::<BigEndian>(1u16)?;
let n = reader.read(&mut cursor.get_mut()[std::mem::size_of::<u16>()..])?;
if n != HEADER_LENGTH + 1 - std::mem::size_of::<u16>() {
return Err(FrameError::InvalidHeaderLength(n));
}
let checksum = compute_checksum(&cursor.get_ref()[..HEADER_LENGTH]);
if checksum != cursor.get_ref()[HEADER_LENGTH] {
return Err(FrameError::InvalidChecksum);
}
Ok(FrameHeader::V1 {
length: cursor.read_u32::<BigEndian>()?,
})
}
_ => Err(FrameError::UnsupportedVersion),
}
}
fn write<W: Write>(&self, writer: &mut W) -> Result<(), FrameError> {
match *self {
FrameHeader::V1 { length } => {
let mut header_bytes = [0u8; HEADER_LENGTH + 1];
let mut cursor = Cursor::new(&mut header_bytes[..]);
cursor.write_u16::<BigEndian>(1)?;
cursor.write_u32::<BigEndian>(length)?;
cursor.get_mut()[HEADER_LENGTH] =
compute_checksum(&cursor.get_ref()[..HEADER_LENGTH]);
writer.write_all(cursor.into_inner())?;
}
}
Ok(())
}
}
fn compute_checksum(buffer: &[u8]) -> u8 {
let mut lrc = 0u16;
for b in buffer {
lrc = (lrc + (*b as u16)) & 0x00ff;
}
lrc = ((lrc ^ 0x00ff) + 1u16) & 0xff;
lrc as u8
}
pub enum FrameNegotiation {
Outbound {
min: FrameVersion,
max: FrameVersion,
},
Inbound { version: FrameVersion },
}
impl FrameNegotiation {
pub fn outbound(min: FrameVersion, max: FrameVersion) -> Self {
FrameNegotiation::Outbound { min, max }
}
pub fn inbound(version: FrameVersion) -> Self {
FrameNegotiation::Inbound { version }
}
pub fn negotiate<S: Read + Write>(self, stream: &mut S) -> Result<FrameVersion, FrameError> {
match self {
FrameNegotiation::Outbound { min, max } => {
stream
.write_u16::<BigEndian>(min as u16)
.map_err(Self::map_io_err)?;
stream
.write_u16::<BigEndian>(max as u16)
.map_err(Self::map_io_err)?;
let frame_version = stream.read_u16::<BigEndian>().map_err(Self::map_io_err)?;
match frame_version {
0 => Err(FrameError::UnsupportedVersion),
1 => Ok(FrameVersion::V1),
_ => Err(FrameError::UnsupportedVersion),
}
}
FrameNegotiation::Inbound { version } => {
let min = stream.read_u16::<BigEndian>().map_err(Self::map_io_err)?;
let max = stream.read_u16::<BigEndian>().map_err(Self::map_io_err)?;
if min > version as u16 || max < version as u16 {
stream.write_u16::<BigEndian>(0).map_err(Self::map_io_err)?;
Err(FrameError::UnsupportedVersion)
} else {
stream
.write_u16::<BigEndian>(version as u16)
.map_err(Self::map_io_err)?;
Ok(version)
}
}
}
}
fn map_io_err(err: io::Error) -> FrameError {
use io::ErrorKind::*;
match err.kind() {
UnexpectedEof | ConnectionReset | ConnectionAborted | BrokenPipe => {
FrameError::HandshakeFailure(
"unable to complete handshake due to closed connection".into(),
)
}
_ => FrameError::IoError(err),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn read_version_1_zero_length() {
let header_bytes = vec![0u8; HEADER_LENGTH + 1];
let mut header_cursor = Cursor::new(header_bytes);
header_cursor
.write_u16::<BigEndian>(1)
.expect("Could not write to cursor");
header_cursor.get_mut()[HEADER_LENGTH] = 0xff;
header_cursor.set_position(0);
let frame_header = FrameHeader::read(&mut header_cursor).expect("Unable to read_header");
assert_eq!(FrameHeader::v1(0), frame_header);
}
#[test]
fn read_version_and_length() {
let header_bytes = vec![0u8; HEADER_LENGTH + 1];
let mut header_cursor = Cursor::new(header_bytes);
header_cursor
.write_u16::<BigEndian>(1)
.expect("Could not write version to cursor");
header_cursor
.write_u32::<BigEndian>(2)
.expect("Could not write length to cursor");
header_cursor.get_mut()[HEADER_LENGTH] = 0xfd;
header_cursor.set_position(0);
let frame_header = FrameHeader::read(&mut header_cursor).expect("Unable to read_header");
assert_eq!(FrameHeader::v1(2), frame_header);
}
#[test]
fn fail_checksum() {
let header_bytes = vec![0u8; HEADER_LENGTH + 1];
let mut header_cursor = Cursor::new(header_bytes);
header_cursor
.write_u16::<BigEndian>(1)
.expect("Could not write version to cursor");
header_cursor
.write_u32::<BigEndian>(2)
.expect("Could not write length to cursor");
header_cursor.set_position(0);
match FrameHeader::read(&mut header_cursor) {
Ok(_) => panic!("Should not have produced a frame header"),
Err(FrameError::InvalidChecksum) => (),
Err(err) => panic!("Produced invalid error: {}", err),
}
}
#[test]
fn standard_write() {
let header_bytes = vec![0u8; HEADER_LENGTH + 1];
let mut header_cursor = Cursor::new(header_bytes);
let frame_header = FrameHeader::v1(3);
frame_header
.write(&mut header_cursor)
.expect("Unable to write frame header");
header_cursor.set_position(0);
assert_eq!(
1,
header_cursor
.read_u16::<BigEndian>()
.expect("Unable to read version")
);
assert_eq!(
3,
header_cursor
.read_u32::<BigEndian>()
.expect("Unable to read length")
);
assert_eq!(0xFC, header_cursor.get_ref()[HEADER_LENGTH]);
}
#[test]
fn round_trip() {
let header_bytes = vec![0u8; HEADER_LENGTH + 1];
let mut header_cursor = Cursor::new(header_bytes);
let frame_header = FrameHeader::v1(100);
frame_header
.write(&mut header_cursor)
.expect("Unable to write frame header");
header_cursor.set_position(0);
let FrameHeader::V1 { length } =
FrameHeader::read(&mut header_cursor).expect("Unable to read header");
assert_eq!(100, length);
}
#[test]
fn basic_outbound_negotiation() {
let (mut tx, mut rx) = stream::byte_stream_pair();
let (done_tx, done_rx) = std::sync::mpsc::channel();
let join_handle = thread::spawn(move || {
let res = FrameNegotiation::inbound(FrameVersion::V1)
.negotiate(&mut rx)
.expect("Should have successfully negotiated");
done_rx.recv().unwrap();
res
});
let version = FrameNegotiation::outbound(FrameVersion::V1, FrameVersion::V1)
.negotiate(&mut tx)
.expect("Unable to negotiate a valid version");
assert_eq!(FrameVersion::V1, version);
done_tx.send(1u8).expect("unable to send stop signal");
let remote_res = join_handle.join().expect("Unable to join thread");
assert_eq!(FrameVersion::V1, remote_res);
}
#[test]
fn unsupported_range() {
let (mut tx, mut rx) = stream::byte_stream_pair();
let (done_tx, done_rx) = std::sync::mpsc::channel();
let join_handle = thread::spawn(move || {
rx.write_u16::<BigEndian>(0)
.expect("Unable to write unsupported version");
done_rx.recv().unwrap();
});
let res = FrameNegotiation::outbound(FrameVersion::V1, FrameVersion::V1).negotiate(&mut tx);
done_tx.send(1u8).expect("Unable to send stop signal");
join_handle.join().expect("Unable to join thread");
match res {
Err(FrameError::UnsupportedVersion) => (),
res => {
panic!("Unexpected result: {:?}", res);
}
}
}
#[test]
fn frame_negotation_eof() {
let (mut tx, rx) = stream::byte_stream_pair();
let (done_tx, done_rx) = std::sync::mpsc::channel();
let join_handle = thread::spawn(move || {
drop(rx);
done_rx.recv().unwrap();
});
let res = FrameNegotiation::inbound(FrameVersion::V1).negotiate(&mut tx);
done_tx.send(1u8).expect("Unable to send stop signal");
join_handle.join().expect("Unable to join thread");
match res {
Err(FrameError::HandshakeFailure(_)) => (),
res => {
panic!("Unexpected result: {:?}", res);
}
}
}
#[test]
fn out_of_range() {
let (mut tx, mut rx) = stream::byte_stream_pair();
let (done_tx, done_rx) = std::sync::mpsc::channel();
let join_handle = thread::spawn(move || {
rx.write_u16::<BigEndian>(3)
.expect("Unable to write min version");
rx.write_u16::<BigEndian>(5)
.expect("Unable to write min version");
let res = rx
.read_u16::<BigEndian>()
.expect("Unable to read negotiated version");
done_rx.recv().unwrap();
res
});
let res = FrameNegotiation::inbound(FrameVersion::V1).negotiate(&mut tx);
done_tx.send(1u8).expect("Unable to send stop signal");
let remote_res = join_handle.join().expect("unable to join thread");
match res {
Err(FrameError::UnsupportedVersion) => (),
res => panic!("Unexpected result: {:?}", res),
}
assert_eq!(0u16, remote_res);
}
#[test]
fn read_frame_v1() {
let input = b"hello";
let mut cursor = Cursor::new(vec![0; 128]);
FrameHeader::v1(input.len() as u32)
.write(&mut cursor)
.expect("Unable to write header");
cursor.write(&input[..]).expect("Unable to write data");
cursor.set_position(0);
let frame = Frame::read(&mut cursor).expect("Unable to read frame");
assert_eq!(input.to_vec(), frame.data);
}
#[test]
fn frame_round_trip() {
let input = b"hello world";
let frame_ref = FrameRef::new(FrameVersion::V1, input);
let mut cursor = Cursor::new(vec![0, 128]);
frame_ref.write(&mut cursor).expect("Unable to write data");
cursor.set_position(0);
let frame = Frame::read(&mut cursor).expect("Unable to read frame");
assert_eq!(input.to_vec(), frame.data);
}
#[cfg(not(target_os = "unix"))]
mod stream {
use std::io::{Error as IoError, Read, Write};
use std::sync::mpsc::{channel, Receiver, Sender};
pub struct InProcByteStream {
outbound: Sender<Vec<u8>>,
inbound: Receiver<Vec<u8>>,
}
pub fn byte_stream_pair() -> (InProcByteStream, InProcByteStream) {
let (left_tx, left_rx) = channel();
let (right_tx, right_rx) = channel();
(
InProcByteStream {
outbound: right_tx,
inbound: left_rx,
},
InProcByteStream {
outbound: left_tx,
inbound: right_rx,
},
)
}
impl Read for InProcByteStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
let bytes_received = self
.inbound
.recv()
.map_err(|e| IoError::new(std::io::ErrorKind::UnexpectedEof, e))?;
buf.copy_from_slice(&bytes_received);
Ok(bytes_received.len())
}
}
impl Write for InProcByteStream {
fn write(&mut self, buf: &[u8]) -> Result<usize, IoError> {
let n = buf.len();
self.outbound
.send(buf.to_vec())
.map_err(|e| IoError::new(std::io::ErrorKind::Interrupted, e))?;
Ok(n)
}
fn flush(&mut self) -> Result<(), IoError> {
Ok(())
}
}
}
#[cfg(target_os = "unix")]
mod stream {
use std::os::unix::net::UnixStream;
pub struct InProcByteStream {
inner: UnixStream,
}
pub fn byte_stream_pair() -> (InProcByteStream, InProcByteStream) {
let (left, right) = UnixStream::pair().expect("Unable to create unix stream");
(
InProcByteStream { inner: left },
InProcByteStream { inner: right },
)
}
impl Read for InProcByteStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
self.inner.read(buf)
}
}
impl Write for InProcByteStream {
fn write(&mut self, buf: &[u8]) -> Result<usize, IoError> {
self.inner.write(buf)
}
}
}
}