Skip to main content

irontide_wire/
codec.rs

1use bytes::{Buf, Bytes, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::Error;
5use crate::message::Message;
6
7/// Maximum message size (16 MiB) — protects against malicious peers.
8const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
9
10/// Tokio codec for length-prefixed BitTorrent messages.
11///
12/// Wire format: `<4-byte big-endian length><payload>`
13/// where payload = `<message-id><message-body>` (or empty for keep-alive).
14pub struct MessageCodec {
15    max_size: usize,
16}
17
18impl MessageCodec {
19    /// Create a codec with the default maximum message size (16 MiB).
20    pub fn new() -> Self {
21        MessageCodec {
22            max_size: MAX_MESSAGE_SIZE,
23        }
24    }
25
26    /// Set the maximum allowed message size.
27    pub fn with_max_size(mut self, max: usize) -> Self {
28        self.max_size = max;
29        self
30    }
31}
32
33impl Default for MessageCodec {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl Decoder for MessageCodec {
40    type Item = Message<Bytes>;
41    type Error = Error;
42
43    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message<Bytes>>, Error> {
44        // Need at least 4 bytes for the length prefix
45        if src.len() < 4 {
46            return Ok(None);
47        }
48
49        // Peek at the length (don't advance yet)
50        let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
51
52        if length > self.max_size {
53            return Err(Error::MessageTooLarge {
54                size: length,
55                max: self.max_size,
56            });
57        }
58
59        // Check if we have the full message
60        let total = 4 + length;
61        if src.len() < total {
62            // Reserve space for the rest
63            src.reserve(total - src.len());
64            return Ok(None);
65        }
66
67        // Consume length prefix
68        src.advance(4);
69        // Take the payload
70        let payload = src.split_to(length);
71
72        Message::from_payload(payload.freeze()).map(Some)
73    }
74}
75
76impl Encoder<Message<Bytes>> for MessageCodec {
77    type Error = Error;
78
79    fn encode(&mut self, msg: Message<Bytes>, dst: &mut BytesMut) -> Result<(), Error> {
80        msg.encode_into(dst);
81        Ok(())
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use bytes::Bytes;
89
90    #[test]
91    fn codec_decode_single_message() {
92        let mut codec = MessageCodec::new();
93        let msg = Message::Have { index: 42 };
94        let wire = msg.to_bytes();
95
96        let mut buf = BytesMut::from(wire.as_ref());
97        let decoded = codec.decode(&mut buf).unwrap().unwrap();
98        assert_eq!(decoded, msg);
99        assert!(buf.is_empty());
100    }
101
102    #[test]
103    fn codec_decode_partial_then_complete() {
104        let mut codec = MessageCodec::new();
105        let msg = Message::Request {
106            index: 1,
107            begin: 0,
108            length: 16384,
109        };
110        let wire = msg.to_bytes();
111
112        // Feed partial data first
113        let mut buf = BytesMut::new();
114        buf.extend_from_slice(&wire[..5]); // only length + part of payload
115        assert!(codec.decode(&mut buf).unwrap().is_none());
116
117        // Feed the rest
118        buf.extend_from_slice(&wire[5..]);
119        let decoded = codec.decode(&mut buf).unwrap().unwrap();
120        assert_eq!(decoded, msg);
121    }
122
123    #[test]
124    fn codec_decode_multiple_messages() {
125        let mut codec = MessageCodec::new();
126        let msg1 = Message::Choke;
127        let msg2 = Message::Have { index: 7 };
128
129        let mut buf = BytesMut::new();
130        buf.extend_from_slice(&msg1.to_bytes());
131        buf.extend_from_slice(&msg2.to_bytes());
132
133        let d1 = codec.decode(&mut buf).unwrap().unwrap();
134        let d2 = codec.decode(&mut buf).unwrap().unwrap();
135        assert_eq!(d1, msg1);
136        assert_eq!(d2, msg2);
137    }
138
139    #[test]
140    fn codec_decode_keepalive() {
141        let mut codec = MessageCodec::new();
142        let mut buf = BytesMut::from(&[0u8, 0, 0, 0][..]);
143        let decoded = codec.decode(&mut buf).unwrap().unwrap();
144        assert_eq!(decoded, Message::KeepAlive);
145    }
146
147    #[test]
148    fn codec_reject_oversized() {
149        let mut codec = MessageCodec::new().with_max_size(100);
150        let mut buf = BytesMut::new();
151        buf.extend_from_slice(&200u32.to_be_bytes()); // length = 200 > max 100
152        buf.extend_from_slice(&vec![0u8; 200]);
153
154        let result = codec.decode(&mut buf);
155        assert!(result.is_err());
156    }
157
158    #[test]
159    fn codec_encode() {
160        let mut codec = MessageCodec::new();
161        let msg = Message::Piece {
162            index: 0,
163            begin: 0,
164            data_0: Bytes::from_static(b"data"),
165            data_1: Bytes::new(),
166        };
167
168        let mut buf = BytesMut::new();
169        codec.encode(msg.clone(), &mut buf).unwrap();
170
171        // Decode it back
172        let decoded = codec.decode(&mut buf).unwrap().unwrap();
173        assert_eq!(decoded, msg);
174    }
175
176    #[test]
177    fn codec_insufficient_header() {
178        let mut codec = MessageCodec::new();
179        let mut buf = BytesMut::from(&[0u8, 0][..]); // only 2 bytes
180        assert!(codec.decode(&mut buf).unwrap().is_none());
181    }
182}