use std::io;
use bytes::{BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::types::{
encoding::{BinaryEncoder, DecodingOptions},
status_code::StatusCode,
};
use super::{
message_chunk::MessageChunk,
tcp_types::{
AcknowledgeMessage, ErrorMessage, HelloMessage, MessageHeader, MessageType,
MESSAGE_HEADER_LEN,
},
};
#[derive(Debug)]
pub enum Message {
Hello(HelloMessage),
Acknowledge(AcknowledgeMessage),
Error(ErrorMessage),
Chunk(MessageChunk),
}
pub struct TcpCodec {
decoding_options: DecodingOptions,
}
impl Decoder for TcpCodec {
type Item = Message;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if buf.len() > MESSAGE_HEADER_LEN {
let message_header = {
let mut buf = io::Cursor::new(&buf[0..MESSAGE_HEADER_LEN]);
MessageHeader::decode(&mut buf, &self.decoding_options)?
};
let message_size = message_header.message_size as usize;
if buf.len() >= message_size {
let mut buf = buf.split_to(message_size);
let message =
Self::decode_message(message_header, &mut buf, &self.decoding_options)
.map_err(|e| {
error!("Codec got an error {} while decoding a message", e);
io::Error::from(e)
})?;
Ok(Some(message))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
}
impl Encoder<Message> for TcpCodec {
type Error = io::Error;
fn encode(&mut self, data: Message, buf: &mut BytesMut) -> Result<(), io::Error> {
match data {
Message::Hello(msg) => self.write(msg, buf),
Message::Acknowledge(msg) => self.write(msg, buf),
Message::Error(msg) => self.write(msg, buf),
Message::Chunk(msg) => self.write(msg, buf),
}
}
}
impl TcpCodec {
pub fn new(decoding_options: DecodingOptions) -> TcpCodec {
TcpCodec { decoding_options }
}
fn write<T>(&self, msg: T, buf: &mut BytesMut) -> Result<(), io::Error>
where
T: BinaryEncoder<T> + std::fmt::Debug,
{
buf.reserve(msg.byte_len());
msg.encode(&mut buf.writer()).map(|_| ()).map_err(|err| {
error!("Error writing message {:?}, err = {}", msg, err);
io::Error::new(io::ErrorKind::Other, format!("Error = {}", err))
})
}
fn decode_message(
message_header: MessageHeader,
buf: &mut BytesMut,
decoding_options: &DecodingOptions,
) -> Result<Message, StatusCode> {
let mut buf = io::Cursor::new(&buf[..]);
match message_header.message_type {
MessageType::Acknowledge => Ok(Message::Acknowledge(AcknowledgeMessage::decode(
&mut buf,
decoding_options,
)?)),
MessageType::Hello => Ok(Message::Hello(HelloMessage::decode(
&mut buf,
decoding_options,
)?)),
MessageType::Error => Ok(Message::Error(ErrorMessage::decode(
&mut buf,
decoding_options,
)?)),
MessageType::Chunk => Ok(Message::Chunk(MessageChunk::decode(
&mut buf,
decoding_options,
)?)),
MessageType::Invalid => {
error!("Message type for chunk is invalid.");
Err(StatusCode::BadCommunicationError)
}
}
}
}