use std::io::{Error, Result};
pub(crate) const MAX_PAYLOAD: u32 = 0b00011111111111111111111111111111;
pub(crate) const PROTO_VER_2: ProtoHeader =
ProtoHeader::Version(2, [b't', b'x', b'5']);
#[derive(Debug, PartialEq)]
pub(crate) enum ProtoHeader {
Version(u8, [u8; 3]),
CompleteMessage(u32),
MultipartMessage(u32),
PermitRequest(u32),
PermitGrant(u32),
}
impl ProtoHeader {
pub fn decode(a: u8, b: u8, c: u8, d: u8) -> Result<Self> {
use bit_field::BitField;
let r = u32::from_be_bytes([a, b, c, d]);
match r.get_bits(29..) {
2 => Ok(Self::Version(
r.get_bits(24..29) as u8,
[
r.get_bits(16..24) as u8,
r.get_bits(8..16) as u8,
r.get_bits(0..8) as u8,
],
)),
3 => Ok(Self::CompleteMessage(r.get_bits(..29))),
4 => Ok(Self::MultipartMessage(r.get_bits(..29))),
5 => Ok(Self::PermitRequest(r.get_bits(..29))),
6 => Ok(Self::PermitGrant(r.get_bits(..29))),
_ => Err(Error::other("ReservedHeaderBits")),
}
}
pub fn encode(&self) -> Result<(u8, u8, u8, u8)> {
use bit_field::BitField;
let mut out: u32 = 0;
match self {
Self::Version(v, [a, b, c]) => {
if *v > 31 {
return Err(Error::other("VersionOverflow"));
}
out.set_bits(29.., 2);
out.set_bits(24..29, *v as u32);
out.set_bits(16..24, *a as u32);
out.set_bits(8..16, *b as u32);
out.set_bits(0..8, *c as u32);
}
Self::CompleteMessage(s) => {
if *s > MAX_PAYLOAD {
return Err(Error::other("SizeOverflow"));
}
out.set_bits(29.., 3);
out.set_bits(..29, *s);
}
Self::MultipartMessage(s) => {
if *s > MAX_PAYLOAD {
return Err(Error::other("SizeOverflow"));
}
out.set_bits(29.., 4);
out.set_bits(..29, *s);
}
Self::PermitRequest(s) => {
if *s > MAX_PAYLOAD {
return Err(Error::other("SizeOverflow"));
}
out.set_bits(29.., 5);
out.set_bits(..29, *s);
}
Self::PermitGrant(s) => {
if *s > MAX_PAYLOAD {
return Err(Error::other("SizeOverflow"));
}
out.set_bits(29.., 6);
out.set_bits(..29, *s);
}
}
let out = out.to_be_bytes();
Ok((out[0], out[1], out[2], out[3]))
}
}
pub(crate) enum ProtoEncodeResult {
NeedPermit {
permit_req: Vec<u8>,
msg_payload: Vec<Vec<u8>>,
},
OneMessage(Vec<u8>),
}
pub(crate) fn proto_encode(data: &[u8]) -> Result<ProtoEncodeResult> {
const MAX: usize = (16 * 1024) - 4;
let len = data.len();
if len > MAX_PAYLOAD as usize {
return Err(Error::other("PayloadSizeOverflow"));
}
if len <= MAX {
let (a, b, c, d) = ProtoHeader::CompleteMessage(len as u32).encode()?;
let mut buf = Vec::with_capacity(len + 4);
buf.extend_from_slice(&[a, b, c, d]);
buf.extend_from_slice(data);
Ok(ProtoEncodeResult::OneMessage(buf))
} else {
let (a, b, c, d) = ProtoHeader::PermitRequest(len as u32).encode()?;
let permit_req = vec![a, b, c, d];
let mut msg_payload = Vec::new();
let mut cur = 0;
while len - cur > 0 {
let amt = std::cmp::min((16 * 1024) - 4, len - cur);
let (a, b, c, d) =
ProtoHeader::MultipartMessage(amt as u32).encode()?;
let mut buf = Vec::with_capacity(amt + 4);
buf.extend_from_slice(&[a, b, c, d]);
buf.extend_from_slice(&data[cur..cur + amt]);
msg_payload.push(buf);
cur += amt;
}
Ok(ProtoEncodeResult::NeedPermit {
permit_req,
msg_payload,
})
}
}
#[derive(Debug, PartialEq)]
pub(crate) enum ProtoDecodeResult {
Idle,
Message(Vec<u8>),
RemotePermitRequest(u32),
RemotePermitGrant(u32),
}
#[derive(Clone, Copy, PartialEq)]
enum DecodeState {
NeedVersion,
Ready,
RemoteAwaitingPermit(u32),
ReceiveChunked,
}
pub(crate) struct ProtoDecoder {
state: DecodeState,
want_size: usize,
incoming: Vec<u8>,
want_remote_grant: bool,
did_error: bool,
grant_permit: Option<tokio::sync::OwnedSemaphorePermit>,
grant_notify: Option<tokio::sync::oneshot::Sender<()>>,
}
impl Default for ProtoDecoder {
fn default() -> Self {
Self {
state: DecodeState::NeedVersion,
want_size: 0,
incoming: Vec::new(),
want_remote_grant: false,
did_error: false,
grant_permit: None,
grant_notify: None,
}
}
}
impl ProtoDecoder {
pub fn sent_remote_permit_grant(
&mut self,
grant_permit: tokio::sync::OwnedSemaphorePermit,
) -> Result<()> {
self.check_err()?;
if let DecodeState::RemoteAwaitingPermit(permit_len) = self.state {
self.state = DecodeState::ReceiveChunked;
self.want_size = permit_len as usize;
self.incoming.reserve(self.want_size);
self.grant_permit = Some(grant_permit);
Ok(())
} else {
self.did_error = true;
Err(Error::other("InvalidStateToSendPermit"))
}
}
pub fn sent_remote_permit_request(
&mut self,
grant_notify: Option<tokio::sync::oneshot::Sender<()>>,
) -> Result<()> {
self.check_err()?;
if self.want_remote_grant {
self.did_error = true;
Err(Error::other("InvalidDuplicatePermitRequest"))
} else {
self.want_remote_grant = true;
self.grant_notify = grant_notify;
Ok(())
}
}
pub fn decode(&mut self, chunk: &[u8]) -> Result<ProtoDecodeResult> {
self.check_err()?;
match self.priv_decode(chunk) {
Ok(r) => Ok(r),
Err(err) => {
self.did_error = true;
Err(err)
}
}
}
fn check_err(&self) -> Result<()> {
if self.did_error {
Err(Error::other("FnCallOnErroredDecoder"))
} else {
Ok(())
}
}
fn priv_decode(&mut self, chunk: &[u8]) -> Result<ProtoDecodeResult> {
let len = chunk.len();
if len < 4 {
return Err(Error::other("InvalidHeaderLen"));
}
match ProtoHeader::decode(chunk[0], chunk[1], chunk[2], chunk[3])? {
ProtoHeader::Version(v, [a, b, c]) => {
if v != 2 || a != b't' || b != b'x' || c != b'5' {
return Err(Error::other(format!(
"invalid version v = {v}, tag = {}",
String::from_utf8_lossy(&[a, b, c][..]),
)));
}
if self.state == DecodeState::NeedVersion {
self.state = DecodeState::Ready;
Ok(ProtoDecodeResult::Idle)
} else {
Err(Error::other("RecvUnexpectedVersionMessage"))
}
}
ProtoHeader::CompleteMessage(msg_len) => {
if self.state == DecodeState::Ready {
if msg_len as usize != len - 4 {
return Err(Error::other("InvalidCompleteMessageLen"));
}
Ok(ProtoDecodeResult::Message(chunk[4..].to_vec()))
} else {
Err(Error::other("RecvUnexpectedCompleteMessage"))
}
}
ProtoHeader::MultipartMessage(msg_len) => {
if self.state == DecodeState::ReceiveChunked {
if msg_len as usize != len - 4 || msg_len == 0 {
return Err(Error::other("InvalidMultipartMessageLen"));
}
if msg_len as usize + self.incoming.len() > self.want_size {
return Err(Error::other("ChunkTooLarge"));
}
self.incoming.extend_from_slice(&chunk[4..]);
if self.incoming.len() == self.want_size {
drop(self.grant_permit.take());
self.state = DecodeState::Ready;
Ok(ProtoDecodeResult::Message(std::mem::take(
&mut self.incoming,
)))
} else {
Ok(ProtoDecodeResult::Idle)
}
} else {
Err(Error::other("RecvUnexpectedMultipartMessage"))
}
}
ProtoHeader::PermitRequest(permit_len) => {
if self.state == DecodeState::Ready {
self.state = DecodeState::RemoteAwaitingPermit(permit_len);
Ok(ProtoDecodeResult::RemotePermitRequest(permit_len))
} else {
Err(Error::other("RecvUnexpectedPermitRequest"))
}
}
ProtoHeader::PermitGrant(permit_len) => {
if self.want_remote_grant {
self.want_remote_grant = false;
if let Some(grant_notify) = self.grant_notify.take() {
let _ = grant_notify.send(());
}
Ok(ProtoDecodeResult::RemotePermitGrant(permit_len))
} else {
Err(Error::other("RecvUnexpectedPermitGrant"))
}
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn proto_header_encode_decode() {
fn check(hdr: ProtoHeader) {
let (a, b, c, d) = hdr.encode().unwrap();
let res = ProtoHeader::decode(a, b, c, d).unwrap();
assert_eq!(hdr, res);
}
for v in 0..32 {
check(ProtoHeader::Version(v, [b't', b'x', b'5']));
}
for v in &[0, 42, 0b00011111111111111111111111111111] {
check(ProtoHeader::CompleteMessage(*v));
check(ProtoHeader::MultipartMessage(*v));
check(ProtoHeader::PermitRequest(*v));
check(ProtoHeader::PermitGrant(*v));
}
}
#[test]
fn proto_header_overflow() {
assert!(ProtoHeader::Version(0b00100000, [b't', b'x', b'5'])
.encode()
.is_err());
assert!(ProtoHeader::CompleteMessage(u32::MAX).encode().is_err());
assert!(ProtoHeader::MultipartMessage(u32::MAX).encode().is_err());
assert!(ProtoHeader::PermitRequest(u32::MAX).encode().is_err());
assert!(ProtoHeader::PermitGrant(u32::MAX).encode().is_err());
}
#[test]
fn proto_header_version_2() {
const PROTO_VERSION_2: &[u8; 4] = &[
0b01000010, b't', b'x', b'5',
];
let (a, b, c, d) = PROTO_VER_2.encode().unwrap();
assert_eq!(PROTO_VERSION_2[0], a);
assert_eq!(PROTO_VERSION_2[1], b);
assert_eq!(PROTO_VERSION_2[2], c);
assert_eq!(PROTO_VERSION_2[3], d);
}
#[test]
fn proto_decode_complete_msg() {
let mut dec = ProtoDecoder::default();
let (a, b, c, d) = PROTO_VER_2.encode().unwrap();
assert_eq!(ProtoDecodeResult::Idle, dec.decode(&[a, b, c, d]).unwrap(),);
match proto_encode(b"hello").unwrap() {
ProtoEncodeResult::OneMessage(buf) => {
match dec.decode(&buf).unwrap() {
ProtoDecodeResult::Message(msg) => {
assert_eq!(b"hello", msg.as_slice());
}
_ => panic!(),
}
}
_ => panic!(),
}
}
#[test]
fn proto_decode_chunked_msg() {
use rand::Rng;
let mut dec = ProtoDecoder::default();
let (a, b, c, d) = PROTO_VER_2.encode().unwrap();
assert_eq!(ProtoDecodeResult::Idle, dec.decode(&[a, b, c, d]).unwrap(),);
let mut msg = vec![0; 15 * 1024 * 1024];
rand::rng().fill(&mut msg[..]);
match proto_encode(&msg).unwrap() {
ProtoEncodeResult::NeedPermit {
permit_req,
mut msg_payload,
} => {
match dec.decode(&permit_req).unwrap() {
ProtoDecodeResult::RemotePermitRequest(permit_len) => {
assert_eq!(15 * 1024 * 1024, permit_len);
}
_ => panic!(),
}
dec.sent_remote_permit_grant(
std::sync::Arc::new(tokio::sync::Semaphore::new(1))
.try_acquire_owned()
.unwrap(),
)
.unwrap();
while msg_payload.len() > 1 {
assert_eq!(
ProtoDecodeResult::Idle,
dec.decode(&msg_payload.remove(0)).unwrap(),
)
}
match dec.decode(&msg_payload.remove(0)).unwrap() {
ProtoDecodeResult::Message(msg_res) => {
assert_eq!(msg, msg_res);
}
_ => panic!(),
}
}
_ => panic!(),
}
}
#[test]
fn proto_decode_bad_version() {
let mut dec = ProtoDecoder::default();
assert!(dec.decode(b"hello").is_err());
}
#[test]
fn proto_decode_no_duplicate_permit_requests() {
let mut dec = ProtoDecoder::default();
let (a, b, c, d) = PROTO_VER_2.encode().unwrap();
assert_eq!(ProtoDecodeResult::Idle, dec.decode(&[a, b, c, d]).unwrap(),);
dec.sent_remote_permit_request(None).unwrap();
assert!(dec.sent_remote_permit_request(None).is_err());
}
#[test]
fn proto_decode_grant_during_multipart() {
use rand::Rng;
let mut dec = ProtoDecoder::default();
let (a, b, c, d) = PROTO_VER_2.encode().unwrap();
assert_eq!(ProtoDecodeResult::Idle, dec.decode(&[a, b, c, d]).unwrap(),);
dec.sent_remote_permit_request(None).unwrap();
let mut msg = vec![0; 17 * 1024];
rand::rng().fill(&mut msg[..]);
match proto_encode(&msg).unwrap() {
ProtoEncodeResult::NeedPermit {
permit_req,
mut msg_payload,
} => {
match dec.decode(&permit_req).unwrap() {
ProtoDecodeResult::RemotePermitRequest(permit_len) => {
assert_eq!(17 * 1024, permit_len);
}
_ => panic!(),
}
dec.sent_remote_permit_grant(
std::sync::Arc::new(tokio::sync::Semaphore::new(1))
.try_acquire_owned()
.unwrap(),
)
.unwrap();
assert_eq!(2, msg_payload.len());
assert_eq!(
ProtoDecodeResult::Idle,
dec.decode(&msg_payload.remove(0)).unwrap(),
);
let (a, b, c, d) =
ProtoHeader::PermitGrant(18 * 1024).encode().unwrap();
assert_eq!(
ProtoDecodeResult::RemotePermitGrant(18 * 1024),
dec.decode(&[a, b, c, d]).unwrap(),
);
match dec.decode(&msg_payload.remove(0)).unwrap() {
ProtoDecodeResult::Message(msg_res) => {
assert_eq!(msg, msg_res);
}
_ => panic!(),
}
}
_ => panic!(),
}
}
}