use bytes::{Buf, Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::error::Error;
use crate::message::Message;
const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
pub struct MessageCodec {
max_size: usize,
}
impl MessageCodec {
#[must_use]
pub fn new() -> Self {
Self {
max_size: MAX_MESSAGE_SIZE,
}
}
#[must_use]
pub fn with_max_size(mut self, max: usize) -> Self {
self.max_size = max;
self
}
}
impl Default for MessageCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for MessageCodec {
type Item = Message<Bytes>;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message<Bytes>>, Error> {
if src.len() < 4 {
return Ok(None);
}
let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
if length > self.max_size {
return Err(Error::MessageTooLarge {
size: length,
max: self.max_size,
});
}
let total = 4 + length;
if src.len() < total {
src.reserve(total - src.len());
return Ok(None);
}
src.advance(4);
let payload = src.split_to(length);
Message::from_payload(payload.freeze()).map(Some)
}
}
impl Encoder<Message<Bytes>> for MessageCodec {
type Error = Error;
fn encode(&mut self, msg: Message<Bytes>, dst: &mut BytesMut) -> Result<(), Error> {
msg.encode_into(dst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn codec_decode_single_message() {
let mut codec = MessageCodec::new();
let msg = Message::Have { index: 42 };
let wire = msg.to_bytes();
let mut buf = BytesMut::from(wire.as_ref());
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
assert!(buf.is_empty());
}
#[test]
fn codec_decode_partial_then_complete() {
let mut codec = MessageCodec::new();
let msg = Message::Request {
index: 1,
begin: 0,
length: 16384,
};
let wire = msg.to_bytes();
let mut buf = BytesMut::new();
buf.extend_from_slice(&wire[..5]); assert!(codec.decode(&mut buf).unwrap().is_none());
buf.extend_from_slice(&wire[5..]);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn codec_decode_multiple_messages() {
let mut codec = MessageCodec::new();
let msg1 = Message::Choke;
let msg2 = Message::Have { index: 7 };
let mut buf = BytesMut::new();
buf.extend_from_slice(&msg1.to_bytes());
buf.extend_from_slice(&msg2.to_bytes());
let d1 = codec.decode(&mut buf).unwrap().unwrap();
let d2 = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(d1, msg1);
assert_eq!(d2, msg2);
}
#[test]
fn codec_decode_keepalive() {
let mut codec = MessageCodec::new();
let mut buf = BytesMut::from(&[0u8, 0, 0, 0][..]);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, Message::KeepAlive);
}
#[test]
fn codec_reject_oversized() {
let mut codec = MessageCodec::new().with_max_size(100);
let mut buf = BytesMut::new();
buf.extend_from_slice(&200u32.to_be_bytes()); buf.extend_from_slice(&[0u8; 200]);
let result = codec.decode(&mut buf);
assert!(result.is_err());
}
#[test]
fn codec_encode() {
let mut codec = MessageCodec::new();
let msg = Message::Piece {
index: 0,
begin: 0,
data_0: Bytes::from_static(b"data"),
data_1: Bytes::new(),
};
let mut buf = BytesMut::new();
codec.encode(msg.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn codec_insufficient_header() {
let mut codec = MessageCodec::new();
let mut buf = BytesMut::from(&[0u8, 0][..]); assert!(codec.decode(&mut buf).unwrap().is_none());
}
}