1use bytes::{Buf, Bytes, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::Error;
5use crate::message::Message;
6
7const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
9
10pub struct MessageCodec {
15 max_size: usize,
16}
17
18impl MessageCodec {
19 pub fn new() -> Self {
21 MessageCodec {
22 max_size: MAX_MESSAGE_SIZE,
23 }
24 }
25
26 pub fn with_max_size(mut self, max: usize) -> Self {
28 self.max_size = max;
29 self
30 }
31}
32
33impl Default for MessageCodec {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl Decoder for MessageCodec {
40 type Item = Message<Bytes>;
41 type Error = Error;
42
43 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message<Bytes>>, Error> {
44 if src.len() < 4 {
46 return Ok(None);
47 }
48
49 let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
51
52 if length > self.max_size {
53 return Err(Error::MessageTooLarge {
54 size: length,
55 max: self.max_size,
56 });
57 }
58
59 let total = 4 + length;
61 if src.len() < total {
62 src.reserve(total - src.len());
64 return Ok(None);
65 }
66
67 src.advance(4);
69 let payload = src.split_to(length);
71
72 Message::from_payload(payload.freeze()).map(Some)
73 }
74}
75
76impl Encoder<Message<Bytes>> for MessageCodec {
77 type Error = Error;
78
79 fn encode(&mut self, msg: Message<Bytes>, dst: &mut BytesMut) -> Result<(), Error> {
80 msg.encode_into(dst);
81 Ok(())
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use bytes::Bytes;
89
90 #[test]
91 fn codec_decode_single_message() {
92 let mut codec = MessageCodec::new();
93 let msg = Message::Have { index: 42 };
94 let wire = msg.to_bytes();
95
96 let mut buf = BytesMut::from(wire.as_ref());
97 let decoded = codec.decode(&mut buf).unwrap().unwrap();
98 assert_eq!(decoded, msg);
99 assert!(buf.is_empty());
100 }
101
102 #[test]
103 fn codec_decode_partial_then_complete() {
104 let mut codec = MessageCodec::new();
105 let msg = Message::Request {
106 index: 1,
107 begin: 0,
108 length: 16384,
109 };
110 let wire = msg.to_bytes();
111
112 let mut buf = BytesMut::new();
114 buf.extend_from_slice(&wire[..5]); assert!(codec.decode(&mut buf).unwrap().is_none());
116
117 buf.extend_from_slice(&wire[5..]);
119 let decoded = codec.decode(&mut buf).unwrap().unwrap();
120 assert_eq!(decoded, msg);
121 }
122
123 #[test]
124 fn codec_decode_multiple_messages() {
125 let mut codec = MessageCodec::new();
126 let msg1 = Message::Choke;
127 let msg2 = Message::Have { index: 7 };
128
129 let mut buf = BytesMut::new();
130 buf.extend_from_slice(&msg1.to_bytes());
131 buf.extend_from_slice(&msg2.to_bytes());
132
133 let d1 = codec.decode(&mut buf).unwrap().unwrap();
134 let d2 = codec.decode(&mut buf).unwrap().unwrap();
135 assert_eq!(d1, msg1);
136 assert_eq!(d2, msg2);
137 }
138
139 #[test]
140 fn codec_decode_keepalive() {
141 let mut codec = MessageCodec::new();
142 let mut buf = BytesMut::from(&[0u8, 0, 0, 0][..]);
143 let decoded = codec.decode(&mut buf).unwrap().unwrap();
144 assert_eq!(decoded, Message::KeepAlive);
145 }
146
147 #[test]
148 fn codec_reject_oversized() {
149 let mut codec = MessageCodec::new().with_max_size(100);
150 let mut buf = BytesMut::new();
151 buf.extend_from_slice(&200u32.to_be_bytes()); buf.extend_from_slice(&vec![0u8; 200]);
153
154 let result = codec.decode(&mut buf);
155 assert!(result.is_err());
156 }
157
158 #[test]
159 fn codec_encode() {
160 let mut codec = MessageCodec::new();
161 let msg = Message::Piece {
162 index: 0,
163 begin: 0,
164 data_0: Bytes::from_static(b"data"),
165 data_1: Bytes::new(),
166 };
167
168 let mut buf = BytesMut::new();
169 codec.encode(msg.clone(), &mut buf).unwrap();
170
171 let decoded = codec.decode(&mut buf).unwrap().unwrap();
173 assert_eq!(decoded, msg);
174 }
175
176 #[test]
177 fn codec_insufficient_header() {
178 let mut codec = MessageCodec::new();
179 let mut buf = BytesMut::from(&[0u8, 0][..]); assert!(codec.decode(&mut buf).unwrap().is_none());
181 }
182}