1use bytes::{Buf, Bytes, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::Error;
5use crate::message::Message;
6
7const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
9
10pub struct MessageCodec {
15 max_size: usize,
16}
17
18impl MessageCodec {
19 #[must_use]
21 pub fn new() -> Self {
22 Self {
23 max_size: MAX_MESSAGE_SIZE,
24 }
25 }
26
27 #[must_use]
29 pub fn with_max_size(mut self, max: usize) -> Self {
30 self.max_size = max;
31 self
32 }
33}
34
35impl Default for MessageCodec {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl Decoder for MessageCodec {
42 type Item = Message<Bytes>;
43 type Error = Error;
44
45 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message<Bytes>>, Error> {
46 if src.len() < 4 {
48 return Ok(None);
49 }
50
51 let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
53
54 if length > self.max_size {
55 return Err(Error::MessageTooLarge {
56 size: length,
57 max: self.max_size,
58 });
59 }
60
61 let total = 4 + length;
63 if src.len() < total {
64 src.reserve(total - src.len());
66 return Ok(None);
67 }
68
69 src.advance(4);
71 let payload = src.split_to(length);
73
74 Message::from_payload(payload.freeze()).map(Some)
75 }
76}
77
78impl Encoder<Message<Bytes>> for MessageCodec {
79 type Error = Error;
80
81 fn encode(&mut self, msg: Message<Bytes>, dst: &mut BytesMut) -> Result<(), Error> {
82 msg.encode_into(dst);
83 Ok(())
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use bytes::Bytes;
91
92 #[test]
93 fn codec_decode_single_message() {
94 let mut codec = MessageCodec::new();
95 let msg = Message::Have { index: 42 };
96 let wire = msg.to_bytes();
97
98 let mut buf = BytesMut::from(wire.as_ref());
99 let decoded = codec.decode(&mut buf).unwrap().unwrap();
100 assert_eq!(decoded, msg);
101 assert!(buf.is_empty());
102 }
103
104 #[test]
105 fn codec_decode_partial_then_complete() {
106 let mut codec = MessageCodec::new();
107 let msg = Message::Request {
108 index: 1,
109 begin: 0,
110 length: 16384,
111 };
112 let wire = msg.to_bytes();
113
114 let mut buf = BytesMut::new();
116 buf.extend_from_slice(&wire[..5]); assert!(codec.decode(&mut buf).unwrap().is_none());
118
119 buf.extend_from_slice(&wire[5..]);
121 let decoded = codec.decode(&mut buf).unwrap().unwrap();
122 assert_eq!(decoded, msg);
123 }
124
125 #[test]
126 fn codec_decode_multiple_messages() {
127 let mut codec = MessageCodec::new();
128 let msg1 = Message::Choke;
129 let msg2 = Message::Have { index: 7 };
130
131 let mut buf = BytesMut::new();
132 buf.extend_from_slice(&msg1.to_bytes());
133 buf.extend_from_slice(&msg2.to_bytes());
134
135 let d1 = codec.decode(&mut buf).unwrap().unwrap();
136 let d2 = codec.decode(&mut buf).unwrap().unwrap();
137 assert_eq!(d1, msg1);
138 assert_eq!(d2, msg2);
139 }
140
141 #[test]
142 fn codec_decode_keepalive() {
143 let mut codec = MessageCodec::new();
144 let mut buf = BytesMut::from(&[0u8, 0, 0, 0][..]);
145 let decoded = codec.decode(&mut buf).unwrap().unwrap();
146 assert_eq!(decoded, Message::KeepAlive);
147 }
148
149 #[test]
150 fn codec_reject_oversized() {
151 let mut codec = MessageCodec::new().with_max_size(100);
152 let mut buf = BytesMut::new();
153 buf.extend_from_slice(&200u32.to_be_bytes()); buf.extend_from_slice(&[0u8; 200]);
155
156 let result = codec.decode(&mut buf);
157 assert!(result.is_err());
158 }
159
160 #[test]
161 fn codec_encode() {
162 let mut codec = MessageCodec::new();
163 let msg = Message::Piece {
164 index: 0,
165 begin: 0,
166 data_0: Bytes::from_static(b"data"),
167 data_1: Bytes::new(),
168 };
169
170 let mut buf = BytesMut::new();
171 codec.encode(msg.clone(), &mut buf).unwrap();
172
173 let decoded = codec.decode(&mut buf).unwrap().unwrap();
175 assert_eq!(decoded, msg);
176 }
177
178 #[test]
179 fn codec_insufficient_header() {
180 let mut codec = MessageCodec::new();
181 let mut buf = BytesMut::from(&[0u8, 0][..]); assert!(codec.decode(&mut buf).unwrap().is_none());
183 }
184}