use bytes::{Buf, BufMut, Bytes};
use thiserror::Error;
use tokio_util::codec::{Decoder, Encoder};
const WIRE_ID: u8 = 0x01;
#[derive(Debug, Error)]
pub enum Error {
#[error("IO error: {0:?}")]
Io(#[from] std::io::Error),
#[error("Invalid wire ID: {0}")]
WireId(u8),
#[error("Rejected")]
Rejected,
}
pub struct Codec {
state: State,
}
impl Codec {
pub fn new_client() -> Self {
Self { state: State::Ack }
}
pub fn new_server() -> Self {
Self { state: State::AuthReceive }
}
}
#[derive(Debug, Clone)]
enum State {
AuthReceive,
Ack,
}
#[derive(Debug, Clone)]
pub enum Message {
Auth(Bytes),
Ack,
Reject,
}
impl Decoder for Codec {
type Item = Message;
type Error = Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.state {
State::AuthReceive => {
if src.is_empty() {
return Ok(None);
}
let wire_id = u8::from_be_bytes([src[0]]);
if wire_id != WIRE_ID {
return Err(Error::WireId(wire_id));
}
if src.len() < 4 {
return Ok(None);
}
let id_size = u32::from_be_bytes([src[1], src[2], src[3], src[4]]);
if src.len() < id_size as usize {
return Ok(None);
}
src.advance(1);
src.advance(4);
let id = src.split_to(id_size as usize).freeze();
self.state = State::Ack;
Ok(Some(Message::Auth(id)))
}
State::Ack => {
if src.len() < 2 {
return Ok(None);
}
let wire_id = u8::from_be_bytes([src[0]]);
if wire_id != WIRE_ID {
return Err(Error::WireId(wire_id));
}
src.advance(1);
let ack = src.get_u8();
if ack == 0 {
return Err(Error::Rejected);
}
Ok(Some(Message::Ack))
}
}
}
}
impl Encoder<Message> for Codec {
type Error = std::io::Error;
fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
match item {
Message::Auth(id) => {
self.state = State::Ack;
dst.reserve(1 + 4 + id.len());
dst.put_u8(WIRE_ID);
dst.put_u32(id.len() as u32);
dst.put(id);
}
Message::Ack => {
dst.reserve(1 + 1);
dst.put_u8(WIRE_ID);
dst.put_u8(1);
}
Message::Reject => {
dst.reserve(1 + 1);
dst.put_u8(WIRE_ID);
dst.put_u8(0);
}
}
Ok(())
}
}