use std::convert::{From, TryFrom};
use std::fmt;
use std::fmt::{Display, Formatter};
use crate::transport::{TReadTransport, TWriteTransport};
use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
#[cfg(test)]
macro_rules! assert_eq_written_bytes {
($o_prot:ident, $expected_bytes:ident) => {{
assert_eq!($o_prot.transport.write_bytes(), &$expected_bytes);
}};
}
#[cfg(test)]
macro_rules! copy_write_buffer_to_read_buffer {
($o_prot:ident) => {{
$o_prot.transport.copy_write_buffer_to_read_buffer();
}};
}
#[cfg(test)]
macro_rules! set_readable_bytes {
($i_prot:ident, $bytes:expr) => {
$i_prot.transport.set_readable_bytes($bytes);
};
}
mod binary;
mod compact;
mod multiplexed;
mod stored;
pub use self::binary::{
TBinaryInputProtocol, TBinaryInputProtocolFactory, TBinaryOutputProtocol,
TBinaryOutputProtocolFactory,
};
pub use self::compact::{
TCompactInputProtocol, TCompactInputProtocolFactory, TCompactOutputProtocol,
TCompactOutputProtocolFactory,
};
pub use self::multiplexed::TMultiplexedOutputProtocol;
pub use self::stored::TStoredInputProtocol;
pub trait TSerializable: Sized {
fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> crate::Result<Self>;
fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> crate::Result<()>;
}
const MAXIMUM_SKIP_DEPTH: i8 = 64;
pub trait TInputProtocol {
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier>;
fn read_message_end(&mut self) -> crate::Result<()>;
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>>;
fn read_struct_end(&mut self) -> crate::Result<()>;
fn read_field_begin(&mut self) -> crate::Result<TFieldIdentifier>;
fn read_field_end(&mut self) -> crate::Result<()>;
fn read_bool(&mut self) -> crate::Result<bool>;
fn read_bytes(&mut self) -> crate::Result<Vec<u8>>;
fn read_i8(&mut self) -> crate::Result<i8>;
fn read_i16(&mut self) -> crate::Result<i16>;
fn read_i32(&mut self) -> crate::Result<i32>;
fn read_i64(&mut self) -> crate::Result<i64>;
fn read_double(&mut self) -> crate::Result<f64>;
fn read_uuid(&mut self) -> crate::Result<uuid::Uuid>;
fn read_string(&mut self) -> crate::Result<String>;
fn read_list_begin(&mut self) -> crate::Result<TListIdentifier>;
fn read_list_end(&mut self) -> crate::Result<()>;
fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier>;
fn read_set_end(&mut self) -> crate::Result<()>;
fn read_map_begin(&mut self) -> crate::Result<TMapIdentifier>;
fn read_map_end(&mut self) -> crate::Result<()>;
fn skip(&mut self, field_type: TType) -> crate::Result<()> {
self.skip_till_depth(field_type, MAXIMUM_SKIP_DEPTH)
}
fn skip_till_depth(&mut self, field_type: TType, depth: i8) -> crate::Result<()> {
if depth == 0 {
return Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::DepthLimit,
message: format!("cannot parse past {:?}", field_type),
}));
}
match field_type {
TType::Bool => self.read_bool().map(|_| ()),
TType::I08 => self.read_i8().map(|_| ()),
TType::I16 => self.read_i16().map(|_| ()),
TType::I32 => self.read_i32().map(|_| ()),
TType::I64 => self.read_i64().map(|_| ()),
TType::Double => self.read_double().map(|_| ()),
TType::String => self.read_bytes().map(|_| ()),
TType::Uuid => self.read_uuid().map(|_| ()),
TType::Struct => {
self.read_struct_begin()?;
loop {
let field_ident = self.read_field_begin()?;
if field_ident.field_type == TType::Stop {
break;
}
self.skip_till_depth(field_ident.field_type, depth - 1)?;
}
self.read_struct_end()
}
TType::List => {
let list_ident = self.read_list_begin()?;
for _ in 0..list_ident.size {
self.skip_till_depth(list_ident.element_type, depth - 1)?;
}
self.read_list_end()
}
TType::Set => {
let set_ident = self.read_set_begin()?;
for _ in 0..set_ident.size {
self.skip_till_depth(set_ident.element_type, depth - 1)?;
}
self.read_set_end()
}
TType::Map => {
let map_ident = self.read_map_begin()?;
for _ in 0..map_ident.size {
let key_type = map_ident
.key_type
.expect("non-zero sized map should contain key type");
let val_type = map_ident
.value_type
.expect("non-zero sized map should contain value type");
self.skip_till_depth(key_type, depth - 1)?;
self.skip_till_depth(val_type, depth - 1)?;
}
self.read_map_end()
}
u => Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::Unknown,
message: format!("cannot skip field type {:?}", &u),
})),
}
}
fn read_byte(&mut self) -> crate::Result<u8>;
fn min_serialized_size(&self, field_type: TType) -> usize {
self::compact::compact_protocol_min_serialized_size(field_type)
}
}
pub trait TOutputProtocol {
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> crate::Result<()>;
fn write_message_end(&mut self) -> crate::Result<()>;
fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> crate::Result<()>;
fn write_struct_end(&mut self) -> crate::Result<()>;
fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> crate::Result<()>;
fn write_field_end(&mut self) -> crate::Result<()>;
fn write_field_stop(&mut self) -> crate::Result<()>;
fn write_bool(&mut self, b: bool) -> crate::Result<()>;
fn write_bytes(&mut self, b: &[u8]) -> crate::Result<()>;
fn write_i8(&mut self, i: i8) -> crate::Result<()>;
fn write_i16(&mut self, i: i16) -> crate::Result<()>;
fn write_i32(&mut self, i: i32) -> crate::Result<()>;
fn write_i64(&mut self, i: i64) -> crate::Result<()>;
fn write_double(&mut self, d: f64) -> crate::Result<()>;
fn write_uuid(&mut self, uuid: &uuid::Uuid) -> crate::Result<()>;
fn write_string(&mut self, s: &str) -> crate::Result<()>;
fn write_list_begin(&mut self, identifier: &TListIdentifier) -> crate::Result<()>;
fn write_list_end(&mut self) -> crate::Result<()>;
fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> crate::Result<()>;
fn write_set_end(&mut self) -> crate::Result<()>;
fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> crate::Result<()>;
fn write_map_end(&mut self) -> crate::Result<()>;
fn flush(&mut self) -> crate::Result<()>;
fn write_byte(&mut self, b: u8) -> crate::Result<()>; }
impl<P> TInputProtocol for Box<P>
where
P: TInputProtocol + ?Sized,
{
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
(**self).read_message_begin()
}
fn read_message_end(&mut self) -> crate::Result<()> {
(**self).read_message_end()
}
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
(**self).read_struct_begin()
}
fn read_struct_end(&mut self) -> crate::Result<()> {
(**self).read_struct_end()
}
fn read_field_begin(&mut self) -> crate::Result<TFieldIdentifier> {
(**self).read_field_begin()
}
fn read_field_end(&mut self) -> crate::Result<()> {
(**self).read_field_end()
}
fn read_bool(&mut self) -> crate::Result<bool> {
(**self).read_bool()
}
fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
(**self).read_bytes()
}
fn read_i8(&mut self) -> crate::Result<i8> {
(**self).read_i8()
}
fn read_i16(&mut self) -> crate::Result<i16> {
(**self).read_i16()
}
fn read_i32(&mut self) -> crate::Result<i32> {
(**self).read_i32()
}
fn read_i64(&mut self) -> crate::Result<i64> {
(**self).read_i64()
}
fn read_double(&mut self) -> crate::Result<f64> {
(**self).read_double()
}
fn read_uuid(&mut self) -> crate::Result<uuid::Uuid> {
(**self).read_uuid()
}
fn read_string(&mut self) -> crate::Result<String> {
(**self).read_string()
}
fn read_list_begin(&mut self) -> crate::Result<TListIdentifier> {
(**self).read_list_begin()
}
fn read_list_end(&mut self) -> crate::Result<()> {
(**self).read_list_end()
}
fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier> {
(**self).read_set_begin()
}
fn read_set_end(&mut self) -> crate::Result<()> {
(**self).read_set_end()
}
fn read_map_begin(&mut self) -> crate::Result<TMapIdentifier> {
(**self).read_map_begin()
}
fn read_map_end(&mut self) -> crate::Result<()> {
(**self).read_map_end()
}
fn read_byte(&mut self) -> crate::Result<u8> {
(**self).read_byte()
}
fn min_serialized_size(&self, field_type: TType) -> usize {
(**self).min_serialized_size(field_type)
}
}
impl<P> TOutputProtocol for Box<P>
where
P: TOutputProtocol + ?Sized,
{
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> crate::Result<()> {
(**self).write_message_begin(identifier)
}
fn write_message_end(&mut self) -> crate::Result<()> {
(**self).write_message_end()
}
fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> crate::Result<()> {
(**self).write_struct_begin(identifier)
}
fn write_struct_end(&mut self) -> crate::Result<()> {
(**self).write_struct_end()
}
fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> crate::Result<()> {
(**self).write_field_begin(identifier)
}
fn write_field_end(&mut self) -> crate::Result<()> {
(**self).write_field_end()
}
fn write_field_stop(&mut self) -> crate::Result<()> {
(**self).write_field_stop()
}
fn write_bool(&mut self, b: bool) -> crate::Result<()> {
(**self).write_bool(b)
}
fn write_bytes(&mut self, b: &[u8]) -> crate::Result<()> {
(**self).write_bytes(b)
}
fn write_i8(&mut self, i: i8) -> crate::Result<()> {
(**self).write_i8(i)
}
fn write_i16(&mut self, i: i16) -> crate::Result<()> {
(**self).write_i16(i)
}
fn write_i32(&mut self, i: i32) -> crate::Result<()> {
(**self).write_i32(i)
}
fn write_i64(&mut self, i: i64) -> crate::Result<()> {
(**self).write_i64(i)
}
fn write_double(&mut self, d: f64) -> crate::Result<()> {
(**self).write_double(d)
}
fn write_uuid(&mut self, uuid: &uuid::Uuid) -> crate::Result<()> {
(**self).write_uuid(uuid)
}
fn write_string(&mut self, s: &str) -> crate::Result<()> {
(**self).write_string(s)
}
fn write_list_begin(&mut self, identifier: &TListIdentifier) -> crate::Result<()> {
(**self).write_list_begin(identifier)
}
fn write_list_end(&mut self) -> crate::Result<()> {
(**self).write_list_end()
}
fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> crate::Result<()> {
(**self).write_set_begin(identifier)
}
fn write_set_end(&mut self) -> crate::Result<()> {
(**self).write_set_end()
}
fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> crate::Result<()> {
(**self).write_map_begin(identifier)
}
fn write_map_end(&mut self) -> crate::Result<()> {
(**self).write_map_end()
}
fn flush(&mut self) -> crate::Result<()> {
(**self).flush()
}
fn write_byte(&mut self, b: u8) -> crate::Result<()> {
(**self).write_byte(b)
}
}
pub trait TInputProtocolFactory {
fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send>;
}
impl<T> TInputProtocolFactory for Box<T>
where
T: TInputProtocolFactory + ?Sized,
{
fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send> {
(**self).create(transport)
}
}
pub trait TOutputProtocolFactory {
fn create(&self, transport: Box<dyn TWriteTransport + Send>)
-> Box<dyn TOutputProtocol + Send>;
}
impl<T> TOutputProtocolFactory for Box<T>
where
T: TOutputProtocolFactory + ?Sized,
{
fn create(
&self,
transport: Box<dyn TWriteTransport + Send>,
) -> Box<dyn TOutputProtocol + Send> {
(**self).create(transport)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TMessageIdentifier {
pub name: String,
pub message_type: TMessageType,
pub sequence_number: i32,
}
impl TMessageIdentifier {
pub fn new<S: Into<String>>(
name: S,
message_type: TMessageType,
sequence_number: i32,
) -> TMessageIdentifier {
TMessageIdentifier {
name: name.into(),
message_type,
sequence_number,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TStructIdentifier {
pub name: String,
}
impl TStructIdentifier {
pub fn new<S: Into<String>>(name: S) -> TStructIdentifier {
TStructIdentifier { name: name.into() }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TFieldIdentifier {
pub name: Option<String>,
pub field_type: TType,
pub id: Option<i16>,
}
impl TFieldIdentifier {
pub fn new<N, S, I>(name: N, field_type: TType, id: I) -> TFieldIdentifier
where
N: Into<Option<S>>,
S: Into<String>,
I: Into<Option<i16>>,
{
TFieldIdentifier {
name: name.into().map(|n| n.into()),
field_type,
id: id.into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TListIdentifier {
pub element_type: TType,
pub size: i32,
}
impl TListIdentifier {
pub fn new(element_type: TType, size: i32) -> TListIdentifier {
TListIdentifier { element_type, size }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TSetIdentifier {
pub element_type: TType,
pub size: i32,
}
impl TSetIdentifier {
pub fn new(element_type: TType, size: i32) -> TSetIdentifier {
TSetIdentifier { element_type, size }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TMapIdentifier {
pub key_type: Option<TType>,
pub value_type: Option<TType>,
pub size: i32,
}
impl TMapIdentifier {
pub fn new<K, V>(key_type: K, value_type: V, size: i32) -> TMapIdentifier
where
K: Into<Option<TType>>,
V: Into<Option<TType>>,
{
TMapIdentifier {
key_type: key_type.into(),
value_type: value_type.into(),
size,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TMessageType {
Call,
Reply,
Exception,
OneWay,
}
impl Display for TMessageType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self {
TMessageType::Call => write!(f, "Call"),
TMessageType::Reply => write!(f, "Reply"),
TMessageType::Exception => write!(f, "Exception"),
TMessageType::OneWay => write!(f, "OneWay"),
}
}
}
impl From<TMessageType> for u8 {
fn from(message_type: TMessageType) -> Self {
match message_type {
TMessageType::Call => 0x01,
TMessageType::Reply => 0x02,
TMessageType::Exception => 0x03,
TMessageType::OneWay => 0x04,
}
}
}
impl TryFrom<u8> for TMessageType {
type Error = crate::Error;
fn try_from(b: u8) -> Result<Self, Self::Error> {
match b {
0x01 => Ok(TMessageType::Call),
0x02 => Ok(TMessageType::Reply),
0x03 => Ok(TMessageType::Exception),
0x04 => Ok(TMessageType::OneWay),
unkn => Err(crate::Error::Protocol(ProtocolError {
kind: ProtocolErrorKind::InvalidData,
message: format!("cannot convert {} to TMessageType", unkn),
})),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TType {
Stop,
Void,
Bool,
I08,
Double,
I16,
I32,
I64,
String,
Utf7,
Struct,
Map,
Set,
List,
Uuid,
}
impl Display for TType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self {
TType::Stop => write!(f, "STOP"),
TType::Void => write!(f, "void"),
TType::Bool => write!(f, "bool"),
TType::I08 => write!(f, "i08"),
TType::Double => write!(f, "double"),
TType::I16 => write!(f, "i16"),
TType::I32 => write!(f, "i32"),
TType::I64 => write!(f, "i64"),
TType::String => write!(f, "string"),
TType::Utf7 => write!(f, "UTF7"),
TType::Struct => write!(f, "struct"),
TType::Map => write!(f, "map"),
TType::Set => write!(f, "set"),
TType::List => write!(f, "list"),
TType::Uuid => write!(f, "UUID"),
}
}
}
pub fn verify_expected_sequence_number(expected: i32, actual: i32) -> crate::Result<()> {
if expected == actual {
Ok(())
} else {
Err(crate::Error::Application(crate::ApplicationError {
kind: crate::ApplicationErrorKind::BadSequenceId,
message: format!("expected {} got {}", expected, actual),
}))
}
}
pub fn verify_expected_service_call(expected: &str, actual: &str) -> crate::Result<()> {
if expected == actual {
Ok(())
} else {
Err(crate::Error::Application(crate::ApplicationError {
kind: crate::ApplicationErrorKind::WrongMethodName,
message: format!("expected {} got {}", expected, actual),
}))
}
}
pub fn verify_expected_message_type(
expected: TMessageType,
actual: TMessageType,
) -> crate::Result<()> {
if expected == actual {
Ok(())
} else {
Err(crate::Error::Application(crate::ApplicationError {
kind: crate::ApplicationErrorKind::InvalidMessageType,
message: format!("expected {} got {}", expected, actual),
}))
}
}
pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> crate::Result<()> {
match *field {
Some(_) => Ok(()),
None => Err(crate::Error::Protocol(crate::ProtocolError {
kind: crate::ProtocolErrorKind::Unknown,
message: format!("missing required field {}", field_name),
})),
}
}
pub(crate) fn check_container_size(
config: &TConfiguration,
container_size: i32,
element_size: usize,
) -> crate::Result<()> {
if container_size < 0 {
return Err(crate::Error::Protocol(ProtocolError::new(
ProtocolErrorKind::NegativeSize,
format!("Negative container size: {}", container_size),
)));
}
let size_as_usize = container_size as usize;
if let Some(max_size) = config.max_container_size() {
if size_as_usize > max_size {
return Err(crate::Error::Protocol(ProtocolError::new(
ProtocolErrorKind::SizeLimit,
format!(
"Container size {} exceeds maximum allowed size of {}",
container_size, max_size
),
)));
}
}
if let Some(min_bytes_needed) = size_as_usize.checked_mul(element_size) {
if let Some(max_message_size) = config.max_message_size() {
if min_bytes_needed > max_message_size {
return Err(crate::Error::Protocol(ProtocolError::new(
ProtocolErrorKind::SizeLimit,
format!(
"Container would require {} bytes, exceeding message size limit of {}",
min_bytes_needed, max_message_size
),
)));
}
}
Ok(())
} else {
Err(crate::Error::Protocol(ProtocolError::new(
ProtocolErrorKind::SizeLimit,
format!(
"Container size {} with element size {} bytes would result in overflow",
container_size, element_size
),
)))
}
}
pub fn field_id(field_ident: &TFieldIdentifier) -> crate::Result<i16> {
field_ident.id.ok_or_else(|| {
crate::Error::Protocol(crate::ProtocolError {
kind: crate::ProtocolErrorKind::Unknown,
message: format!("missing field id in {:?}", field_ident),
})
})
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
use crate::transport::{TReadTransport, TWriteTransport};
#[test]
fn must_create_usable_input_protocol_from_concrete_input_protocol() {
let r: Box<dyn TReadTransport> = Box::new(Cursor::new([0, 1, 2]));
let mut t = TCompactInputProtocol::new(r);
takes_input_protocol(&mut t)
}
#[test]
fn must_create_usable_input_protocol_from_boxed_input() {
let r: Box<dyn TReadTransport> = Box::new(Cursor::new([0, 1, 2]));
let mut t: Box<dyn TInputProtocol> = Box::new(TCompactInputProtocol::new(r));
takes_input_protocol(&mut t)
}
#[test]
fn must_create_usable_output_protocol_from_concrete_output_protocol() {
let w: Box<dyn TWriteTransport> = Box::new(vec![0u8; 10]);
let mut t = TCompactOutputProtocol::new(w);
takes_output_protocol(&mut t)
}
#[test]
fn must_create_usable_output_protocol_from_boxed_output() {
let w: Box<dyn TWriteTransport> = Box::new(vec![0u8; 10]);
let mut t: Box<dyn TOutputProtocol> = Box::new(TCompactOutputProtocol::new(w));
takes_output_protocol(&mut t)
}
fn takes_input_protocol<R>(t: &mut R)
where
R: TInputProtocol,
{
t.read_byte().unwrap();
}
fn takes_output_protocol<W>(t: &mut W)
where
W: TOutputProtocol,
{
t.flush().unwrap();
}
fn build_struct_with_unknown_binary_field(payload: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.push(0x0A); buf.extend_from_slice(&1_i16.to_be_bytes());
buf.extend_from_slice(&42_i64.to_be_bytes());
buf.push(0x0B); buf.extend_from_slice(&99_i16.to_be_bytes());
buf.extend_from_slice(&(payload.len() as i32).to_be_bytes());
buf.extend_from_slice(payload);
buf.push(0x00); buf
}
fn read_struct_skipping_unknown(data: &[u8]) -> crate::Result<i64> {
let cursor = Cursor::new(data.to_vec());
let mut proto = TBinaryInputProtocol::new(cursor, true);
proto.read_struct_begin()?;
let mut known_value: Option<i64> = None;
loop {
let field = proto.read_field_begin()?;
if field.field_type == TType::Stop {
break;
}
match field.id {
Some(1) if field.field_type == TType::I64 => {
known_value = Some(proto.read_i64()?);
}
_ => {
proto.skip(field.field_type)?;
}
}
proto.read_field_end()?;
}
proto.read_struct_end()?;
known_value.ok_or_else(|| {
crate::Error::Protocol(crate::ProtocolError {
kind: crate::ProtocolErrorKind::InvalidData,
message: "missing known field".to_string(),
})
})
}
#[test]
fn must_skip_binary_field_with_non_utf8_bytes() {
let non_utf8: Vec<u8> = vec![
0x04, 0xFF, 0xFE, 0x80, 0x90, 0xAB, 0xCD, 0xEF,
0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE,
];
assert!(String::from_utf8(non_utf8.clone()).is_err());
let data = build_struct_with_unknown_binary_field(&non_utf8);
let result = read_struct_skipping_unknown(&data);
assert!(result.is_ok(), "skip() failed on non-UTF-8 binary: {:?}", result.err());
assert_eq!(result.unwrap(), 42);
}
#[test]
fn must_skip_valid_utf8_string_field() {
let data = build_struct_with_unknown_binary_field(b"hello world");
assert_eq!(read_struct_skipping_unknown(&data).unwrap(), 42);
}
#[test]
fn must_skip_empty_binary_field() {
let data = build_struct_with_unknown_binary_field(&[]);
assert_eq!(read_struct_skipping_unknown(&data).unwrap(), 42);
}
}