Skip to main content

irontide_wire/
handshake.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use irontide_core::Id20;
4
5use crate::error::{Error, Result};
6
7/// `BitTorrent` protocol string.
8const PROTOCOL: &[u8] = b"BitTorrent protocol";
9
10/// Total handshake size: 1 + 19 + 8 + 20 + 20 = 68 bytes.
11pub const HANDSHAKE_SIZE: usize = 68;
12
13/// Peer wire handshake (BEP 3).
14///
15/// Format: `<pstrlen><pstr><reserved><info_hash><peer_id>`
16/// - pstrlen: 1 byte (19)
17/// - pstr: 19 bytes ("`BitTorrent` protocol")
18/// - reserved: 8 bytes (extension flags)
19/// - `info_hash`: 20 bytes
20/// - `peer_id`: 20 bytes
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Handshake {
23    /// 8-byte reserved field for extension flags.
24    pub reserved: [u8; 8],
25    /// Info hash of the torrent.
26    pub info_hash: Id20,
27    /// Peer ID.
28    pub peer_id: Id20,
29}
30
31impl Handshake {
32    /// Create a new handshake with extension protocol support (BEP 10).
33    #[must_use]
34    pub fn new(info_hash: Id20, peer_id: Id20) -> Self {
35        let mut reserved = [0u8; 8];
36        // Bit 20 (byte 5, bit 4) = Extension Protocol (BEP 10)
37        reserved[5] |= 0x10;
38        Self {
39            reserved,
40            info_hash,
41            peer_id,
42        }
43    }
44
45    /// Check if the peer supports the Extension Protocol (BEP 10).
46    #[must_use]
47    pub fn supports_extensions(&self) -> bool {
48        self.reserved[5] & 0x10 != 0
49    }
50
51    /// Check if the peer supports DHT (BEP 5).
52    #[must_use]
53    pub fn supports_dht(&self) -> bool {
54        self.reserved[7] & 0x01 != 0
55    }
56
57    /// Enable DHT support flag.
58    #[must_use]
59    pub fn with_dht(mut self) -> Self {
60        self.reserved[7] |= 0x01;
61        self
62    }
63
64    /// Check if the peer supports Fast Extension (BEP 6).
65    #[must_use]
66    pub fn supports_fast(&self) -> bool {
67        self.reserved[7] & 0x04 != 0
68    }
69
70    /// Enable Fast Extension flag.
71    #[must_use]
72    pub fn with_fast(mut self) -> Self {
73        self.reserved[7] |= 0x04;
74        self
75    }
76
77    /// Serialize to bytes.
78    #[must_use]
79    pub fn to_bytes(&self) -> Bytes {
80        let mut buf = BytesMut::with_capacity(HANDSHAKE_SIZE);
81        buf.put_u8(19);
82        buf.put_slice(PROTOCOL);
83        buf.put_slice(&self.reserved);
84        buf.put_slice(self.info_hash.as_bytes());
85        buf.put_slice(self.peer_id.as_bytes());
86        buf.freeze()
87    }
88
89    /// Parse from bytes. Input must be exactly 68 bytes.
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the data is too short or has an invalid protocol string.
94    pub fn from_bytes(mut data: &[u8]) -> Result<Self> {
95        if data.len() < HANDSHAKE_SIZE {
96            return Err(Error::InvalidHandshake(format!(
97                "need {} bytes, got {}",
98                HANDSHAKE_SIZE,
99                data.len()
100            )));
101        }
102
103        let pstrlen = data.get_u8();
104        if pstrlen != 19 {
105            return Err(Error::InvalidHandshake(format!(
106                "pstrlen {pstrlen}, expected 19"
107            )));
108        }
109
110        let pstr = &data[..19];
111        if pstr != PROTOCOL {
112            return Err(Error::InvalidHandshake("wrong protocol string".into()));
113        }
114        data.advance(19);
115
116        let mut reserved = [0u8; 8];
117        reserved.copy_from_slice(&data[..8]);
118        data.advance(8);
119
120        let info_hash =
121            Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
122        data.advance(20);
123
124        let peer_id =
125            Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
126
127        Ok(Self {
128            reserved,
129            info_hash,
130            peer_id,
131        })
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn handshake_round_trip() {
141        let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
142        let peer_id = Id20::from_hex("0102030405060708091011121314151617181920").unwrap();
143
144        let hs = Handshake::new(info_hash, peer_id);
145        assert!(hs.supports_extensions());
146
147        let bytes = hs.to_bytes();
148        assert_eq!(bytes.len(), HANDSHAKE_SIZE);
149
150        let parsed = Handshake::from_bytes(&bytes).unwrap();
151        assert_eq!(hs, parsed);
152    }
153
154    #[test]
155    fn handshake_dht_flag() {
156        let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_dht();
157        assert!(hs.supports_dht());
158        assert!(hs.supports_extensions());
159
160        let parsed = Handshake::from_bytes(&hs.to_bytes()).unwrap();
161        assert!(parsed.supports_dht());
162    }
163
164    /// BEP 10: extension protocol is signalled by bit 20 from the right in the
165    /// 8-byte reserved field. Bit 20 maps to byte index 5, bit 4 (0x10).
166    #[test]
167    fn ext_handshake_reserved_bit_position() {
168        // Construct expected reserved bytes with ONLY bit 20 set
169        let mut expected_reserved = [0u8; 8];
170        // Bit 20 from the right: byte index = 20 / 8 = 2 from the right = index 5,
171        // bit position = 20 % 8 = 4, so mask = 0x10.
172        expected_reserved[5] = 0x10;
173
174        // Verify the Handshake::new() constructor sets exactly this bit
175        let hs = Handshake::new(Id20::ZERO, Id20::ZERO);
176        assert_eq!(
177            hs.reserved, expected_reserved,
178            "Handshake::new() reserved field must match BEP 10 spec"
179        );
180        assert_eq!(
181            hs.reserved[5] & 0x10,
182            0x10,
183            "BEP 10 extension bit must be at reserved[5] & 0x10"
184        );
185        assert!(hs.supports_extensions());
186
187        // Verify no other reserved bytes are set by default (only BEP 10)
188        for (i, &byte) in hs.reserved.iter().enumerate() {
189            if i == 5 {
190                assert_eq!(byte, 0x10, "byte 5 should be exactly 0x10");
191            } else {
192                assert_eq!(byte, 0, "byte {i} should be zero when only BEP 10 is set");
193            }
194        }
195
196        // Verify that clearing byte 5 bit 4 disables extension support
197        let mut hs_no_ext = hs;
198        hs_no_ext.reserved[5] &= !0x10;
199        assert!(!hs_no_ext.supports_extensions());
200
201        // Verify bit 20 interpretation: counting from bit 0 (rightmost of reserved[7])
202        // bit 20 = byte 5 (from left, 0-indexed), bit 4 within that byte
203        let bit_index = 20;
204        let byte_index_from_right = bit_index / 8; // = 2
205        let byte_index = 7 - byte_index_from_right; // = 5
206        let bit_within_byte = bit_index % 8; // = 4
207        assert_eq!(byte_index, 5);
208        assert_eq!(bit_within_byte, 4);
209        assert_eq!(1u8 << bit_within_byte, 0x10);
210    }
211
212    #[test]
213    fn handshake_too_short() {
214        assert!(Handshake::from_bytes(&[0u8; 10]).is_err());
215    }
216
217    #[test]
218    fn handshake_fast_flag() {
219        let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_fast();
220        assert!(hs.supports_fast());
221        assert!(hs.supports_extensions());
222        // Ensure DHT and fast don't interfere
223        let hs2 = hs.with_dht();
224        assert!(hs2.supports_fast());
225        assert!(hs2.supports_dht());
226
227        let parsed = Handshake::from_bytes(&hs2.to_bytes()).unwrap();
228        assert!(parsed.supports_fast());
229        assert!(parsed.supports_dht());
230    }
231}