Skip to main content

irontide_wire/
codec.rs

1use bytes::{Buf, Bytes, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::Error;
5use crate::message::Message;
6
7/// Maximum message size (16 MiB) — protects against malicious peers.
8const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
9
10/// Tokio codec for length-prefixed `BitTorrent` messages.
11///
12/// Wire format: `<4-byte big-endian length><payload>`
13/// where payload = `<message-id><message-body>` (or empty for keep-alive).
14pub struct MessageCodec {
15    max_size: usize,
16}
17
18impl MessageCodec {
19    /// Create a codec with the default maximum message size (16 MiB).
20    #[must_use]
21    pub fn new() -> Self {
22        Self {
23            max_size: MAX_MESSAGE_SIZE,
24        }
25    }
26
27    /// Set the maximum allowed message size.
28    #[must_use]
29    pub fn with_max_size(mut self, max: usize) -> Self {
30        self.max_size = max;
31        self
32    }
33}
34
35impl Default for MessageCodec {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl Decoder for MessageCodec {
42    type Item = Message<Bytes>;
43    type Error = Error;
44
45    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message<Bytes>>, Error> {
46        // Need at least 4 bytes for the length prefix
47        if src.len() < 4 {
48            return Ok(None);
49        }
50
51        // Peek at the length (don't advance yet)
52        let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
53
54        if length > self.max_size {
55            return Err(Error::MessageTooLarge {
56                size: length,
57                max: self.max_size,
58            });
59        }
60
61        // Check if we have the full message
62        let total = 4 + length;
63        if src.len() < total {
64            // Reserve space for the rest
65            src.reserve(total - src.len());
66            return Ok(None);
67        }
68
69        // Consume length prefix
70        src.advance(4);
71        // Take the payload
72        let payload = src.split_to(length);
73
74        Message::from_payload(payload.freeze()).map(Some)
75    }
76}
77
78impl Encoder<Message<Bytes>> for MessageCodec {
79    type Error = Error;
80
81    fn encode(&mut self, msg: Message<Bytes>, dst: &mut BytesMut) -> Result<(), Error> {
82        msg.encode_into(dst);
83        Ok(())
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use bytes::Bytes;
91
92    #[test]
93    fn codec_decode_single_message() {
94        let mut codec = MessageCodec::new();
95        let msg = Message::Have { index: 42 };
96        let wire = msg.to_bytes();
97
98        let mut buf = BytesMut::from(wire.as_ref());
99        let decoded = codec.decode(&mut buf).unwrap().unwrap();
100        assert_eq!(decoded, msg);
101        assert!(buf.is_empty());
102    }
103
104    #[test]
105    fn codec_decode_partial_then_complete() {
106        let mut codec = MessageCodec::new();
107        let msg = Message::Request {
108            index: 1,
109            begin: 0,
110            length: 16384,
111        };
112        let wire = msg.to_bytes();
113
114        // Feed partial data first
115        let mut buf = BytesMut::new();
116        buf.extend_from_slice(&wire[..5]); // only length + part of payload
117        assert!(codec.decode(&mut buf).unwrap().is_none());
118
119        // Feed the rest
120        buf.extend_from_slice(&wire[5..]);
121        let decoded = codec.decode(&mut buf).unwrap().unwrap();
122        assert_eq!(decoded, msg);
123    }
124
125    #[test]
126    fn codec_decode_multiple_messages() {
127        let mut codec = MessageCodec::new();
128        let msg1 = Message::Choke;
129        let msg2 = Message::Have { index: 7 };
130
131        let mut buf = BytesMut::new();
132        buf.extend_from_slice(&msg1.to_bytes());
133        buf.extend_from_slice(&msg2.to_bytes());
134
135        let d1 = codec.decode(&mut buf).unwrap().unwrap();
136        let d2 = codec.decode(&mut buf).unwrap().unwrap();
137        assert_eq!(d1, msg1);
138        assert_eq!(d2, msg2);
139    }
140
141    #[test]
142    fn codec_decode_keepalive() {
143        let mut codec = MessageCodec::new();
144        let mut buf = BytesMut::from(&[0u8, 0, 0, 0][..]);
145        let decoded = codec.decode(&mut buf).unwrap().unwrap();
146        assert_eq!(decoded, Message::KeepAlive);
147    }
148
149    #[test]
150    fn codec_reject_oversized() {
151        let mut codec = MessageCodec::new().with_max_size(100);
152        let mut buf = BytesMut::new();
153        buf.extend_from_slice(&200u32.to_be_bytes()); // length = 200 > max 100
154        buf.extend_from_slice(&[0u8; 200]);
155
156        let result = codec.decode(&mut buf);
157        assert!(result.is_err());
158    }
159
160    #[test]
161    fn codec_encode() {
162        let mut codec = MessageCodec::new();
163        let msg = Message::Piece {
164            index: 0,
165            begin: 0,
166            data_0: Bytes::from_static(b"data"),
167            data_1: Bytes::new(),
168        };
169
170        let mut buf = BytesMut::new();
171        codec.encode(msg.clone(), &mut buf).unwrap();
172
173        // Decode it back
174        let decoded = codec.decode(&mut buf).unwrap().unwrap();
175        assert_eq!(decoded, msg);
176    }
177
178    #[test]
179    fn codec_insufficient_header() {
180        let mut codec = MessageCodec::new();
181        let mut buf = BytesMut::from(&[0u8, 0][..]); // only 2 bytes
182        assert!(codec.decode(&mut buf).unwrap().is_none());
183    }
184}