pub mod binary;
pub mod error;
pub mod rw_ext;
use std::sync::Arc;
use bytes::{Buf, BufMut};
pub use error::*;
use tokio::io::AsyncRead;
pub use self::binary::TAsyncBinaryProtocol;
const MAXIMUM_SKIP_DEPTH: i8 = 64;
lazy_static::lazy_static! {
pub static ref VOID_IDENT: TStructIdentifier = TStructIdentifier { name: "void" };
}
#[async_trait::async_trait]
pub trait Message: Sized + Send {
fn encode<T: TOutputProtocol>(&self, protocol: &mut T) -> Result<(), Error>;
fn decode<T: TInputProtocol>(protocol: &mut T) -> Result<Self, Error>;
async fn decode_async<R>(protocol: &mut TAsyncBinaryProtocol<R>) -> Result<Self, Error>
where
R: AsyncRead + Unpin + Send;
}
#[async_trait::async_trait]
pub trait EntryMessage: Sized + Send {
fn encode<T: TOutputProtocol>(&self, protocol: &mut T) -> Result<(), Error>;
fn decode<T: TInputProtocol>(
protocol: &mut T,
msg_ident: &TMessageIdentifier,
) -> Result<Self, Error>;
async fn decode_async<R>(
protocol: &mut TAsyncBinaryProtocol<R>,
msg_ident: &TMessageIdentifier,
) -> Result<Self, Error>
where
R: AsyncRead + Unpin + Send;
}
pub trait TInputProtocol {
type Buf: Buf;
fn read_message_begin(&mut self) -> Result<TMessageIdentifier, Error>;
fn read_message_end(&mut self) -> Result<(), Error>;
fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>, Error>;
fn read_struct_end(&mut self) -> Result<(), Error>;
fn read_field_begin(&mut self) -> Result<TFieldIdentifier, Error>;
fn read_field_end(&mut self) -> Result<(), Error>;
fn read_bool(&mut self) -> Result<bool, Error>;
fn read_bytes(&mut self) -> Result<Vec<u8>, Error>;
fn read_i8(&mut self) -> Result<i8, Error>;
fn read_i16(&mut self) -> Result<i16, Error>;
fn read_i32(&mut self) -> Result<i32, Error>;
fn read_i64(&mut self) -> Result<i64, Error>;
fn read_double(&mut self) -> Result<f64, Error>;
fn read_string(&mut self) -> Result<String, Error>;
fn read_list_begin(&mut self) -> Result<TListIdentifier, Error>;
fn read_list_end(&mut self) -> Result<(), Error>;
fn read_set_begin(&mut self) -> Result<TSetIdentifier, Error>;
fn read_set_end(&mut self) -> Result<(), Error>;
fn read_map_begin(&mut self) -> Result<TMapIdentifier, Error>;
fn read_map_end(&mut self) -> Result<(), Error>;
fn skip(&mut self, field_type: TType) -> Result<(), Error> {
self.skip_till_depth(field_type, MAXIMUM_SKIP_DEPTH)
}
fn skip_till_depth(&mut self, field_type: TType, depth: i8) -> Result<(), Error> {
if depth == 0 {
return Err(new_protocol_error(
ProtocolErrorKind::DepthLimit,
format!("cannot parse past {:?}", field_type),
));
}
match field_type {
TType::Bool => self.read_bool().map(|_| ()),
TType::I08 => self.read_i8().map(|_| ()),
TType::I16 => self.read_i16().map(|_| ()),
TType::I32 => self.read_i32().map(|_| ()),
TType::I64 => self.read_i64().map(|_| ()),
TType::Double => self.read_double().map(|_| ()),
TType::String => self.read_string().map(|_| ()),
TType::Struct => {
self.read_struct_begin()?;
loop {
let field_ident = self.read_field_begin()?;
if field_ident.field_type == TType::Stop {
break;
}
self.skip_till_depth(field_ident.field_type, depth - 1)?;
}
self.read_struct_end()
}
TType::List => {
let list_ident = self.read_list_begin()?;
for _ in 0..list_ident.size {
self.skip_till_depth(list_ident.element_type, depth - 1)?;
}
self.read_list_end()
}
TType::Set => {
let set_ident = self.read_set_begin()?;
for _ in 0..set_ident.size {
self.skip_till_depth(set_ident.element_type, depth - 1)?;
}
self.read_set_end()
}
TType::Map => {
let map_ident = self.read_map_begin()?;
for _ in 0..map_ident.size {
let key_type = map_ident.key_type;
let val_type = map_ident.value_type;
self.skip_till_depth(key_type, depth - 1)?;
self.skip_till_depth(val_type, depth - 1)?;
}
self.read_map_end()
}
u => Err(new_protocol_error(
ProtocolErrorKind::DepthLimit,
format!("cannot skip field type {:?}", &u),
)),
}
}
fn read_byte(&mut self) -> Result<u8, Error>;
fn buf_mut(&mut self) -> &mut Self::Buf;
}
pub trait TLengthProtocol {
fn write_message_begin_len(&self, identifier: &TMessageIdentifier) -> usize;
fn write_message_end_len(&self) -> usize;
fn write_struct_begin_len(&self, identifier: &TStructIdentifier) -> usize;
fn write_struct_end_len(&self) -> usize;
fn write_field_begin_len(&self, identifier: &TFieldIdentifier) -> usize;
fn write_field_end_len(&self) -> usize;
fn write_field_stop_len(&self) -> usize;
fn write_bool_len(&self, b: bool) -> usize;
fn write_bytes_len(&self, b: &[u8]) -> usize;
fn write_byte_len(&self, b: u8) -> usize;
fn write_i8_len(&self, i: i8) -> usize;
fn write_i16_len(&self, i: i16) -> usize;
fn write_i32_len(&self, i: i32) -> usize;
fn write_i64_len(&self, i: i64) -> usize;
fn write_double_len(&self, d: f64) -> usize;
fn write_string_len(&self, s: &str) -> usize;
fn write_list_begin_len(&self, identifier: &TListIdentifier) -> usize;
fn write_list_end_len(&self) -> usize;
fn write_set_begin_len(&self, identifier: &TSetIdentifier) -> usize;
fn write_set_end_len(&self) -> usize;
fn write_map_begin_len(&self, identifier: &TMapIdentifier) -> usize;
fn write_map_end_len(&self) -> usize;
}
pub trait TOutputProtocol: TLengthProtocol {
type Buf: BufMut;
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<(), Error>;
fn write_message_end(&mut self) -> Result<(), Error>;
fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> Result<(), Error>;
fn write_struct_end(&mut self) -> Result<(), Error>;
fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> Result<(), Error>;
fn write_field_end(&mut self) -> Result<(), Error>;
fn write_field_stop(&mut self) -> Result<(), Error>;
fn write_bool(&mut self, b: bool) -> Result<(), Error>;
fn write_bytes(&mut self, b: &[u8]) -> Result<(), Error>;
fn write_byte(&mut self, b: u8) -> Result<(), Error>;
fn write_i8(&mut self, i: i8) -> Result<(), Error>;
fn write_i16(&mut self, i: i16) -> Result<(), Error>;
fn write_i32(&mut self, i: i32) -> Result<(), Error>;
fn write_i64(&mut self, i: i64) -> Result<(), Error>;
fn write_double(&mut self, d: f64) -> Result<(), Error>;
fn write_string(&mut self, s: &str) -> Result<(), Error>;
fn write_list_begin(&mut self, identifier: &TListIdentifier) -> Result<(), Error>;
fn write_list_end(&mut self) -> Result<(), Error>;
fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> Result<(), Error>;
fn write_set_end(&mut self) -> Result<(), Error>;
fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> Result<(), Error>;
fn write_map_end(&mut self) -> Result<(), Error>;
fn flush(&mut self) -> Result<(), Error>;
fn reserve(&mut self, size: usize);
fn buf_mut(&mut self) -> &mut Self::Buf;
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TStructIdentifier {
pub name: &'static str,
}
impl TStructIdentifier {
pub fn new(name: &'static str) -> TStructIdentifier {
TStructIdentifier { name }
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum TType {
Stop = 0,
Void = 1,
Bool = 2,
I08 = 3,
Double = 4,
I16 = 6,
I32 = 8,
I64 = 10,
String = 11,
Struct = 12,
Map = 13,
Set = 14,
List = 15,
Utf8 = 16,
Utf16 = 17,
}
impl From<TType> for u8 {
fn from(ttype: TType) -> Self {
ttype as u8
}
}
impl TryFrom<u8> for TType {
type Error = Error;
#[inline]
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(TType::Stop),
1 => Ok(TType::Void),
2 => Ok(TType::Bool),
3 => Ok(TType::I08),
4 => Ok(TType::Double),
6 => Ok(TType::I16),
8 => Ok(TType::I32),
10 => Ok(TType::I64),
11 => Ok(TType::String),
12 => Ok(TType::Struct),
13 => Ok(TType::Map),
14 => Ok(TType::Set),
15 => Ok(TType::List),
16 => Ok(TType::Utf8),
17 => Ok(TType::Utf16),
_ => Err(new_protocol_error(
ProtocolErrorKind::InvalidData,
format!("invalid ttype {}", value),
)),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum TMessageType {
Call = 1,
Reply = 2,
Exception = 3,
OneWay = 4,
}
impl TryFrom<u8> for TMessageType {
type Error = Error;
#[inline]
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(TMessageType::Call),
2 => Ok(TMessageType::Reply),
3 => Ok(TMessageType::Exception),
4 => Ok(TMessageType::OneWay),
_ => Err(new_protocol_error(
ProtocolErrorKind::InvalidData,
format!("invalid tmessage type {}", value),
)),
}
}
}
impl From<TMessageType> for u8 {
fn from(t: TMessageType) -> Self {
t as u8
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TMessageIdentifier {
pub name: smol_str::SmolStr,
pub message_type: TMessageType,
pub sequence_number: i32,
}
impl TMessageIdentifier {
pub fn new(
name: smol_str::SmolStr,
message_type: TMessageType,
sequence_number: i32,
) -> TMessageIdentifier {
TMessageIdentifier {
name,
message_type,
sequence_number,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TListIdentifier {
pub element_type: TType,
pub size: usize,
}
impl TListIdentifier {
pub fn new(element_type: TType, size: usize) -> TListIdentifier {
TListIdentifier { element_type, size }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TSetIdentifier {
pub element_type: TType,
pub size: usize,
}
impl TSetIdentifier {
pub fn new(element_type: TType, size: usize) -> TSetIdentifier {
TSetIdentifier { element_type, size }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TFieldIdentifier {
pub name: Option<&'static str>,
pub field_type: TType,
pub id: Option<i16>,
}
impl TFieldIdentifier {
pub fn new<N, I>(name: N, field_type: TType, id: I) -> TFieldIdentifier
where
N: Into<Option<&'static str>>,
I: Into<Option<i16>>,
{
TFieldIdentifier {
name: name.into(),
field_type,
id: id.into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TMapIdentifier {
pub key_type: TType,
pub value_type: TType,
pub size: usize,
}
impl TMapIdentifier {
pub fn new<K, V>(key_type: K, value_type: V, size: usize) -> TMapIdentifier
where
K: Into<TType>,
V: Into<TType>,
{
TMapIdentifier {
key_type: key_type.into(),
value_type: value_type.into(),
size,
}
}
}
pub trait Size {
fn size<T: TLengthProtocol>(&self, protocol: &T) -> usize;
}
#[async_trait::async_trait]
impl<Message> EntryMessage for Arc<Message>
where
Message: EntryMessage + Sync,
{
fn encode<T: TOutputProtocol>(&self, protocol: &mut T) -> Result<(), Error> {
(**self).encode(protocol)
}
fn decode<T: TInputProtocol>(
protocol: &mut T,
msg_ident: &TMessageIdentifier,
) -> Result<Self, Error> {
Message::decode(protocol, msg_ident).map(Arc::new)
}
async fn decode_async<R>(
protocol: &mut TAsyncBinaryProtocol<R>,
msg_ident: &TMessageIdentifier,
) -> Result<Self, Error>
where
R: AsyncRead + Unpin + Send,
{
Message::decode_async(protocol, msg_ident)
.await
.map(Arc::new)
}
}
impl<Message> Size for Arc<Message>
where
Message: Size + Sync,
{
fn size<T: TLengthProtocol>(&self, protocol: &T) -> usize {
(**self).size(protocol)
}
}