use std::collections::HashMap;
use std::convert::TryFrom;
use std::ops::{Deref, DerefMut, RangeBounds};
use std::u32;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use snafu::{ensure, OptionExt, ResultExt};
use uuid::Uuid;
use crate::errors::{self, DecodeError, EncodeError};
use crate::features::ProtocolVersion;
pub type KeyValues = HashMap<u16, Bytes>;
pub type Annotations = HashMap<String, String>;
pub struct Input {
#[allow(dead_code)]
proto: ProtocolVersion,
bytes: Bytes,
}
pub struct Output<'a> {
#[allow(dead_code)]
proto: &'a ProtocolVersion,
bytes: &'a mut BytesMut,
}
pub(crate) trait Encode {
fn encode(&self, buf: &mut Output) -> Result<(), EncodeError>;
}
pub(crate) trait Decode: Sized {
fn decode(buf: &mut Input) -> Result<Self, DecodeError>;
}
impl Input {
pub fn new(proto: ProtocolVersion, bytes: Bytes) -> Input {
Input { proto, bytes }
}
pub fn proto(&self) -> &ProtocolVersion {
&self.proto
}
pub fn slice(&self, range: impl RangeBounds<usize>) -> Input {
Input {
proto: self.proto.clone(),
bytes: self.bytes.slice(range),
}
}
}
impl Buf for Input {
fn remaining(&self) -> usize {
self.bytes.remaining()
}
fn chunk(&self) -> &[u8] {
self.bytes.chunk()
}
fn advance(&mut self, cnt: usize) {
self.bytes.advance(cnt)
}
fn copy_to_bytes(&mut self, len: usize) -> Bytes {
self.bytes.copy_to_bytes(len)
}
}
impl Deref for Input {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.bytes[..]
}
}
impl Deref for Output<'_> {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.bytes[..]
}
}
impl DerefMut for Output<'_> {
fn deref_mut(&mut self) -> &mut [u8] {
&mut self.bytes[..]
}
}
impl Output<'_> {
pub fn new<'x>(proto: &'x ProtocolVersion, bytes: &'x mut BytesMut) -> Output<'x> {
Output { proto, bytes }
}
pub fn proto(&self) -> &ProtocolVersion {
self.proto
}
pub fn reserve(&mut self, size: usize) {
self.bytes.reserve(size)
}
pub fn extend(&mut self, slice: &[u8]) {
self.bytes.extend(slice)
}
}
unsafe impl BufMut for Output<'_> {
fn remaining_mut(&self) -> usize {
self.bytes.remaining_mut()
}
unsafe fn advance_mut(&mut self, cnt: usize) {
self.bytes.advance_mut(cnt)
}
fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
self.bytes.chunk_mut()
}
}
pub(crate) fn encode<T: Encode>(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> {
buf.reserve(5);
buf.put_u8(code);
let base = buf.len();
buf.put_slice(&[0; 4]);
msg.encode(buf)?;
let size = u32::try_from(buf.len() - base)
.ok()
.context(errors::MessageTooLong)?;
buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]);
Ok(())
}
impl Encode for String {
fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
buf.reserve(2 + self.len());
buf.put_u32(
u32::try_from(self.len())
.ok()
.context(errors::StringTooLong)?,
);
buf.extend(self.as_bytes());
Ok(())
}
}
impl Encode for Bytes {
fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
buf.reserve(2 + self.len());
buf.put_u32(
u32::try_from(self.len())
.ok()
.context(errors::StringTooLong)?,
);
buf.extend(&self[..]);
Ok(())
}
}
impl Decode for String {
fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
ensure!(buf.remaining() >= 4, errors::Underflow);
let len = buf.get_u32() as usize;
ensure!(buf.remaining() >= len, errors::Underflow);
let mut data = vec![0u8; len];
buf.copy_to_slice(&mut data[..]);
String::from_utf8(data)
.map_err(|e| e.utf8_error())
.context(errors::InvalidUtf8)
}
}
impl Decode for Bytes {
fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
ensure!(buf.remaining() >= 4, errors::Underflow);
let len = buf.get_u32() as usize;
ensure!(buf.remaining() >= len, errors::Underflow);
Ok(buf.copy_to_bytes(len))
}
}
impl Decode for Uuid {
fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
ensure!(buf.remaining() >= 16, errors::Underflow);
let mut bytes = [0u8; 16];
buf.copy_to_slice(&mut bytes[..]);
let result = Uuid::from_slice(&bytes).context(errors::InvalidUuid)?;
Ok(result)
}
}
impl Encode for Uuid {
fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
buf.extend(self.as_bytes());
Ok(())
}
}