use bytes::{BufMut, Bytes, BytesMut};
use prost::Message;
use crate::protocol::types::{ChunkHeader, Envelope, PackageChunk, RpcError};
const CHUNK_HEADER_SIZE: usize = 16;
const MAX_ARGUMENTS_COUNT: usize = 16;
const MAX_ARGUMENT_SIZE: usize = 16 * 1024 * 1024;
const MAX_FUNCTION_NAME_SIZE: usize = u16::MAX as usize;
#[derive(Default, Clone)]
pub struct ProtobufCodec;
impl ProtobufCodec {
pub fn encode<T: Message>(&self, value: &T) -> Result<Bytes, RpcError> {
let mut buf = BytesMut::with_capacity(value.encoded_len());
value.encode(&mut buf).map_err(|_| RpcError::Encode)?;
Ok(buf.freeze())
}
pub fn decode<T: Message + Default>(&self, bytes: &[u8]) -> Result<T, RpcError> {
T::decode(bytes).map_err(|_| RpcError::Decode)
}
}
#[derive(Default, Clone)]
pub struct PackageChunkCodec;
impl PackageChunkCodec {
pub fn encode(&self, value: PackageChunk) -> Result<Bytes, RpcError> {
let header = value.header();
let mut bytes = BytesMut::with_capacity(CHUNK_HEADER_SIZE + header.payload_len() as usize);
bytes.put_u64(header.call_id());
bytes.put_u16(header.index());
bytes.put_u16(header.total());
bytes.put_u32(header.payload_len());
bytes.extend_from_slice(value.payload());
Ok(bytes.freeze())
}
pub fn decode(&self, bytes: &[u8]) -> Result<PackageChunk, RpcError> {
if bytes.len() < CHUNK_HEADER_SIZE {
return Err(RpcError::ChunkHeaderSizeConstraintViolation);
}
let len = bytes[12..16]
.try_into()
.map(u32::from_le_bytes)
.map_err(|_| RpcError::Decode)?;
if bytes.len() < CHUNK_HEADER_SIZE + len as usize {
return Err(RpcError::ChunkHeaderSizeConstraintViolation);
}
let call_id = bytes[..8]
.try_into()
.map(u64::from_le_bytes)
.map_err(|_| RpcError::Decode)?;
let index = bytes[8..10]
.try_into()
.map(u16::from_le_bytes)
.map_err(|_| RpcError::Decode)?;
let total = bytes[10..12]
.try_into()
.map(u16::from_le_bytes)
.map_err(|_| RpcError::Decode)?;
let header = ChunkHeader::new(call_id, index, total, len);
let payload_start = CHUNK_HEADER_SIZE;
let payload_end = payload_start + len as usize;
let payload = Bytes::copy_from_slice(&bytes[payload_start..payload_end + len as usize]);
Ok(PackageChunk::new(header, payload))
}
}
#[derive(Default, Clone)]
pub struct EnvelopeCodec;
impl EnvelopeCodec {
pub fn encode(&self, value: Envelope) -> Result<Bytes, RpcError> {
let fn_name = value.fn_name();
let args = value.parameters();
if fn_name.len() > MAX_FUNCTION_NAME_SIZE {
return Err(RpcError::MaxFunctionNameConstraintViolation);
}
if args.len() > MAX_ARGUMENTS_COUNT {
return Err(RpcError::MaxArgumentsConstraintViolation);
}
for arg in args {
if arg.len() > MAX_ARGUMENT_SIZE {
return Err(RpcError::MaxArgumentSizeConstraintViolation);
}
}
let mut capacity = 2 + fn_name.len() + 2;
for arg in args {
capacity += 8 + arg.len();
}
let mut buf = BytesMut::with_capacity(capacity);
buf.put_u16(fn_name.len() as u16);
buf.extend_from_slice(fn_name);
buf.put_u16(args.len() as u16);
for arg in args {
buf.put_u64(arg.len() as u64);
buf.extend_from_slice(arg);
}
Ok(buf.freeze())
}
pub fn decode(&self, bytes: &[u8]) -> Result<Envelope, RpcError> {
let mut cursor = 0;
if bytes.len() < 2 {
return Err(RpcError::Decode);
}
let fn_len = bytes[cursor..cursor + 2]
.try_into()
.map(u16::from_le_bytes)
.map_err(|_| RpcError::Decode)? as usize;
if fn_len > MAX_FUNCTION_NAME_SIZE {
return Err(RpcError::MaxFunctionNameConstraintViolation);
}
cursor += 2;
if bytes.len() < fn_len + cursor {
return Err(RpcError::Decode);
}
let fn_name = Bytes::copy_from_slice(&bytes[cursor..cursor + fn_len]);
cursor += fn_len;
if bytes.len() < cursor + 2 {
return Err(RpcError::Decode);
}
let arg_count = bytes[cursor..cursor + 2]
.try_into()
.map(u16::from_le_bytes)
.map_err(|_| RpcError::Decode)? as usize;
cursor += 2;
if arg_count > MAX_ARGUMENTS_COUNT {
return Err(RpcError::MaxArgumentsConstraintViolation);
}
let mut parameters = Vec::with_capacity(arg_count);
for _ in 0..arg_count {
if bytes.len() < cursor + 8 {
return Err(RpcError::Decode);
}
let arg_len = bytes[cursor..cursor + 8]
.try_into()
.map(u64::from_le_bytes)
.map_err(|_| RpcError::Decode)? as usize;
cursor += 8;
if arg_len > MAX_ARGUMENT_SIZE {
return Err(RpcError::MaxArgumentSizeConstraintViolation);
}
if bytes.len() < cursor + arg_len {
return Err(RpcError::Decode);
}
let arg = &bytes[cursor..cursor + arg_len];
cursor += arg_len;
parameters.push(Bytes::copy_from_slice(arg));
}
if cursor != bytes.len() {
return Err(RpcError::GarbageBytes);
}
let envelope = Envelope::new(fn_name, parameters);
Ok(envelope)
}
}