use asynchronous_codec::{Decoder, Encoder};
use bytes::{Buf, BufMut, BytesMut};
use miltr_common::decoding::ServerCommand;
use miltr_common::encoding::{ClientMessage, Writable};
use miltr_common::ProtocolError;
use miltr_utils::trace;
#[derive(Debug, Clone)]
pub(crate) struct MilterCodec {
max_buffer_size: usize,
}
impl MilterCodec {
pub(crate) fn new(max_buffer_size: usize) -> Self {
Self { max_buffer_size }
}
}
impl Decoder for MilterCodec {
type Item = ServerCommand;
type Error = ProtocolError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
let mut length_bytes = [0u8; 4];
length_bytes.copy_from_slice(&src[..4]);
let length = u32::from_be_bytes(length_bytes) as usize;
if length > self.max_buffer_size {
return Err(ProtocolError::TooMuchData(length));
}
if src.len() < 4 + length {
src.reserve(4 + length - src.len());
return Ok(None);
}
let mut parse_buf = src.split_to(4 + length);
parse_buf.advance(4);
trace!(length = parse_buf.len(), "Read bytes from the network");
Ok(Some(ServerCommand::parse(parse_buf)?))
}
}
impl Encoder for MilterCodec {
type Item<'i> = &'i ClientMessage;
type Error = ProtocolError;
fn encode(&mut self, item: &ClientMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
let item_len = item.len();
if item_len > self.max_buffer_size || item_len > usize::MAX - 1 {
return Err(ProtocolError::TooMuchData(item_len));
}
let packet_len = 1_usize .checked_add(item_len) .ok_or(ProtocolError::TooMuchData(item_len))?;
let packet_len_be = u32::to_be_bytes(packet_len as u32);
dst.reserve(packet_len);
dst.extend_from_slice(&packet_len_be);
dst.put_u8(item.code());
item.write(dst);
trace!(length = dst.len(), "Wrote bytes to the network");
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_fuzz_1() {
let mut input = BytesMut::from_iter([0, 0, 0, 4, 109, 255, 255, 7]);
let mut codec = MilterCodec::new(2_usize.pow(16));
let _output = codec
.decode(&mut input)
.expect_err("This is not enough data");
}
}