use std::{
fmt::Debug,
io::{Cursor, Read, Result, Write},
sync::Arc,
};
use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
use chrono::Duration;
use crate::{
sync::Mutex,
types::{constants, status_codes::StatusCode},
};
pub type EncodingResult<T> = std::result::Result<T, StatusCode>;
#[derive(Debug)]
pub struct DepthLock {
depth_gauge: Arc<Mutex<DepthGauge>>,
}
impl Drop for DepthLock {
fn drop(&mut self) {
let mut dg = trace_lock!(self.depth_gauge);
if dg.current_depth > 0 {
dg.current_depth -= 1;
}
}
}
impl DepthLock {
pub fn obtain(
depth_gauge: Arc<Mutex<DepthGauge>>,
) -> core::result::Result<DepthLock, StatusCode> {
let mut dg = trace_lock!(depth_gauge);
if dg.current_depth >= dg.max_depth {
warn!("Decoding in stream aborted due maximum recursion depth being reached");
Err(StatusCode::BadDecodingError)
} else {
dg.current_depth += 1;
drop(dg);
Ok(Self { depth_gauge })
}
}
}
#[derive(Debug)]
pub struct DepthGauge {
pub(crate) max_depth: usize,
pub(crate) current_depth: usize,
}
impl Default for DepthGauge {
fn default() -> Self {
Self {
max_depth: constants::MAX_DECODING_DEPTH,
current_depth: 0,
}
}
}
impl DepthGauge {
pub fn minimal() -> Self {
Self {
max_depth: 1,
..Default::default()
}
}
pub fn max_depth(&self) -> usize {
self.max_depth
}
pub fn current_depth(&self) -> usize {
self.current_depth
}
}
#[derive(Clone, Debug)]
pub struct DecodingOptions {
pub client_offset: Duration,
pub max_message_size: usize,
pub max_chunk_count: usize,
pub max_string_length: usize,
pub max_byte_string_length: usize,
pub max_array_length: usize,
pub decoding_depth_gauge: Arc<Mutex<DepthGauge>>,
}
impl Default for DecodingOptions {
fn default() -> Self {
DecodingOptions {
client_offset: Duration::zero(),
max_message_size: constants::MAX_MESSAGE_SIZE,
max_chunk_count: constants::MAX_CHUNK_COUNT,
max_string_length: constants::MAX_STRING_LENGTH,
max_byte_string_length: constants::MAX_BYTE_STRING_LENGTH,
max_array_length: constants::MAX_ARRAY_LENGTH,
decoding_depth_gauge: Arc::new(Mutex::new(DepthGauge::default())),
}
}
}
impl DecodingOptions {
pub fn minimal() -> Self {
DecodingOptions {
max_string_length: 8192,
max_byte_string_length: 8192,
max_array_length: 8192,
decoding_depth_gauge: Arc::new(Mutex::new(DepthGauge::minimal())),
..Default::default()
}
}
#[cfg(test)]
pub fn test() -> Self {
Self::default()
}
pub fn depth_lock(&self) -> core::result::Result<DepthLock, StatusCode> {
DepthLock::obtain(self.decoding_depth_gauge.clone())
}
}
pub trait BinaryEncoder<T> {
fn byte_len(&self) -> usize;
fn encode<S: Write>(&self, stream: &mut S) -> EncodingResult<usize>;
fn decode<S: Read>(stream: &mut S, decoding_options: &DecodingOptions) -> EncodingResult<T>;
fn encode_to_vec(&self) -> Vec<u8> {
let mut buffer = Cursor::new(Vec::with_capacity(self.byte_len()));
let _ = self.encode(&mut buffer);
buffer.into_inner()
}
}
pub fn process_encode_io_result(result: Result<usize>) -> EncodingResult<usize> {
result.map_err(|err| {
trace!("Encoding error - {:?}", err);
StatusCode::BadEncodingError
})
}
pub fn process_decode_io_result<T>(result: Result<T>) -> EncodingResult<T>
where
T: Debug,
{
result.map_err(|err| {
trace!("Decoding error - {:?}", err);
StatusCode::BadDecodingError
})
}
pub fn byte_len_array<T: BinaryEncoder<T>>(values: &Option<Vec<T>>) -> usize {
let mut size = 4;
if let Some(ref values) = values {
size += values.iter().map(|v| v.byte_len()).sum::<usize>();
}
size
}
pub fn write_array<S: Write, T: BinaryEncoder<T>>(
stream: &mut S,
values: &Option<Vec<T>>,
) -> EncodingResult<usize> {
let mut size = 0;
if let Some(ref values) = values {
size += write_i32(stream, values.len() as i32)?;
for value in values.iter() {
size += value.encode(stream)?;
}
} else {
size += write_i32(stream, -1)?;
}
Ok(size)
}
pub fn read_array<S: Read, T: BinaryEncoder<T>>(
stream: &mut S,
decoding_options: &DecodingOptions,
) -> EncodingResult<Option<Vec<T>>> {
let len = read_i32(stream)?;
if len == -1 {
Ok(None)
} else if len < -1 {
error!("Array length is negative value and invalid");
Err(StatusCode::BadDecodingError)
} else if len as usize > decoding_options.max_array_length {
error!(
"Array length {} exceeds decoding limit {}",
len, decoding_options.max_array_length
);
Err(StatusCode::BadDecodingError)
} else {
let mut values: Vec<T> = Vec::with_capacity(len as usize);
for _ in 0..len {
values.push(T::decode(stream, decoding_options)?);
}
Ok(Some(values))
}
}
pub fn write_bytes(stream: &mut dyn Write, value: u8, count: usize) -> EncodingResult<usize> {
for _ in 0..count {
let _ = stream
.write_u8(value)
.map_err(|_| StatusCode::BadEncodingError)?;
}
Ok(count)
}
pub fn write_u8<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<u8>,
{
let buf: [u8; 1] = [value.into()];
process_encode_io_result(stream.write(&buf))
}
pub fn write_i16<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<i16>,
{
let mut buf = [0u8; 2];
LittleEndian::write_i16(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_u16<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<u16>,
{
let mut buf = [0u8; 2];
LittleEndian::write_u16(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_i32<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<i32>,
{
let mut buf = [0u8; 4];
LittleEndian::write_i32(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_u32<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<u32>,
{
let mut buf = [0u8; 4];
LittleEndian::write_u32(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_i64<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<i64>,
{
let mut buf = [0u8; 8];
LittleEndian::write_i64(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_u64<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<u64>,
{
let mut buf = [0u8; 8];
LittleEndian::write_u64(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_f32<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<f32>,
{
let mut buf = [0u8; 4];
LittleEndian::write_f32(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn write_f64<T>(stream: &mut dyn Write, value: T) -> EncodingResult<usize>
where
T: Into<f64>,
{
let mut buf = [0u8; 8];
LittleEndian::write_f64(&mut buf, value.into());
process_encode_io_result(stream.write(&buf))
}
pub fn read_bytes(stream: &mut dyn Read, buf: &mut [u8]) -> EncodingResult<usize> {
let result = stream.read_exact(buf);
process_decode_io_result(result)?;
Ok(buf.len())
}
pub fn read_u8(stream: &mut dyn Read) -> EncodingResult<u8> {
let mut buf = [0u8];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(buf[0])
}
pub fn read_i16(stream: &mut dyn Read) -> EncodingResult<i16> {
let mut buf = [0u8; 2];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_i16(&buf))
}
pub fn read_u16(stream: &mut dyn Read) -> EncodingResult<u16> {
let mut buf = [0u8; 2];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_u16(&buf))
}
pub fn read_i32(stream: &mut dyn Read) -> EncodingResult<i32> {
let mut buf = [0u8; 4];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_i32(&buf))
}
pub fn read_u32(stream: &mut dyn Read) -> EncodingResult<u32> {
let mut buf = [0u8; 4];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_u32(&buf))
}
pub fn read_i64(stream: &mut dyn Read) -> EncodingResult<i64> {
let mut buf = [0u8; 8];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_i64(&buf))
}
pub fn read_u64(stream: &mut dyn Read) -> EncodingResult<u64> {
let mut buf = [0u8; 8];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_u64(&buf))
}
pub fn read_f32(stream: &mut dyn Read) -> EncodingResult<f32> {
let mut buf = [0u8; 4];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_f32(&buf))
}
pub fn read_f64(stream: &mut dyn Read) -> EncodingResult<f64> {
let mut buf = [0u8; 8];
let result = stream.read_exact(&mut buf);
process_decode_io_result(result)?;
Ok(LittleEndian::read_f64(&buf))
}