Skip to main content

irontide_wire/
handshake.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use irontide_core::Id20;
4
5use crate::error::{Error, Result};
6
7/// BitTorrent protocol string.
8const PROTOCOL: &[u8] = b"BitTorrent protocol";
9
10/// Total handshake size: 1 + 19 + 8 + 20 + 20 = 68 bytes.
11pub const HANDSHAKE_SIZE: usize = 68;
12
13/// Peer wire handshake (BEP 3).
14///
15/// Format: `<pstrlen><pstr><reserved><info_hash><peer_id>`
16/// - pstrlen: 1 byte (19)
17/// - pstr: 19 bytes ("BitTorrent protocol")
18/// - reserved: 8 bytes (extension flags)
19/// - info_hash: 20 bytes
20/// - peer_id: 20 bytes
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Handshake {
23    /// 8-byte reserved field for extension flags.
24    pub reserved: [u8; 8],
25    /// Info hash of the torrent.
26    pub info_hash: Id20,
27    /// Peer ID.
28    pub peer_id: Id20,
29}
30
31impl Handshake {
32    /// Create a new handshake with extension protocol support (BEP 10).
33    pub fn new(info_hash: Id20, peer_id: Id20) -> Self {
34        let mut reserved = [0u8; 8];
35        // Bit 20 (byte 5, bit 4) = Extension Protocol (BEP 10)
36        reserved[5] |= 0x10;
37        Handshake {
38            reserved,
39            info_hash,
40            peer_id,
41        }
42    }
43
44    /// Check if the peer supports the Extension Protocol (BEP 10).
45    pub fn supports_extensions(&self) -> bool {
46        self.reserved[5] & 0x10 != 0
47    }
48
49    /// Check if the peer supports DHT (BEP 5).
50    pub fn supports_dht(&self) -> bool {
51        self.reserved[7] & 0x01 != 0
52    }
53
54    /// Enable DHT support flag.
55    pub fn with_dht(mut self) -> Self {
56        self.reserved[7] |= 0x01;
57        self
58    }
59
60    /// Check if the peer supports Fast Extension (BEP 6).
61    pub fn supports_fast(&self) -> bool {
62        self.reserved[7] & 0x04 != 0
63    }
64
65    /// Enable Fast Extension flag.
66    pub fn with_fast(mut self) -> Self {
67        self.reserved[7] |= 0x04;
68        self
69    }
70
71    /// Serialize to bytes.
72    pub fn to_bytes(&self) -> Bytes {
73        let mut buf = BytesMut::with_capacity(HANDSHAKE_SIZE);
74        buf.put_u8(19);
75        buf.put_slice(PROTOCOL);
76        buf.put_slice(&self.reserved);
77        buf.put_slice(self.info_hash.as_bytes());
78        buf.put_slice(self.peer_id.as_bytes());
79        buf.freeze()
80    }
81
82    /// Parse from bytes. Input must be exactly 68 bytes.
83    pub fn from_bytes(mut data: &[u8]) -> Result<Self> {
84        if data.len() < HANDSHAKE_SIZE {
85            return Err(Error::InvalidHandshake(format!(
86                "need {} bytes, got {}",
87                HANDSHAKE_SIZE,
88                data.len()
89            )));
90        }
91
92        let pstrlen = data.get_u8();
93        if pstrlen != 19 {
94            return Err(Error::InvalidHandshake(format!(
95                "pstrlen {pstrlen}, expected 19"
96            )));
97        }
98
99        let pstr = &data[..19];
100        if pstr != PROTOCOL {
101            return Err(Error::InvalidHandshake("wrong protocol string".into()));
102        }
103        data.advance(19);
104
105        let mut reserved = [0u8; 8];
106        reserved.copy_from_slice(&data[..8]);
107        data.advance(8);
108
109        let info_hash =
110            Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
111        data.advance(20);
112
113        let peer_id =
114            Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
115
116        Ok(Handshake {
117            reserved,
118            info_hash,
119            peer_id,
120        })
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn handshake_round_trip() {
130        let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
131        let peer_id = Id20::from_hex("0102030405060708091011121314151617181920").unwrap();
132
133        let hs = Handshake::new(info_hash, peer_id);
134        assert!(hs.supports_extensions());
135
136        let bytes = hs.to_bytes();
137        assert_eq!(bytes.len(), HANDSHAKE_SIZE);
138
139        let parsed = Handshake::from_bytes(&bytes).unwrap();
140        assert_eq!(hs, parsed);
141    }
142
143    #[test]
144    fn handshake_dht_flag() {
145        let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_dht();
146        assert!(hs.supports_dht());
147        assert!(hs.supports_extensions());
148
149        let parsed = Handshake::from_bytes(&hs.to_bytes()).unwrap();
150        assert!(parsed.supports_dht());
151    }
152
153    /// BEP 10: extension protocol is signalled by bit 20 from the right in the
154    /// 8-byte reserved field. Bit 20 maps to byte index 5, bit 4 (0x10).
155    #[test]
156    fn ext_handshake_reserved_bit_position() {
157        // Construct expected reserved bytes with ONLY bit 20 set
158        let mut expected_reserved = [0u8; 8];
159        // Bit 20 from the right: byte index = 20 / 8 = 2 from the right = index 5,
160        // bit position = 20 % 8 = 4, so mask = 0x10.
161        expected_reserved[5] = 0x10;
162
163        // Verify the Handshake::new() constructor sets exactly this bit
164        let hs = Handshake::new(Id20::ZERO, Id20::ZERO);
165        assert_eq!(
166            hs.reserved, expected_reserved,
167            "Handshake::new() reserved field must match BEP 10 spec"
168        );
169        assert_eq!(
170            hs.reserved[5] & 0x10,
171            0x10,
172            "BEP 10 extension bit must be at reserved[5] & 0x10"
173        );
174        assert!(hs.supports_extensions());
175
176        // Verify no other reserved bytes are set by default (only BEP 10)
177        for (i, &byte) in hs.reserved.iter().enumerate() {
178            if i == 5 {
179                assert_eq!(byte, 0x10, "byte 5 should be exactly 0x10");
180            } else {
181                assert_eq!(byte, 0, "byte {i} should be zero when only BEP 10 is set");
182            }
183        }
184
185        // Verify that clearing byte 5 bit 4 disables extension support
186        let mut hs_no_ext = hs.clone();
187        hs_no_ext.reserved[5] &= !0x10;
188        assert!(!hs_no_ext.supports_extensions());
189
190        // Verify bit 20 interpretation: counting from bit 0 (rightmost of reserved[7])
191        // bit 20 = byte 5 (from left, 0-indexed), bit 4 within that byte
192        let bit_index = 20;
193        let byte_index_from_right = bit_index / 8; // = 2
194        let byte_index = 7 - byte_index_from_right; // = 5
195        let bit_within_byte = bit_index % 8; // = 4
196        assert_eq!(byte_index, 5);
197        assert_eq!(bit_within_byte, 4);
198        assert_eq!(1u8 << bit_within_byte, 0x10);
199    }
200
201    #[test]
202    fn handshake_too_short() {
203        assert!(Handshake::from_bytes(&[0u8; 10]).is_err());
204    }
205
206    #[test]
207    fn handshake_fast_flag() {
208        let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_fast();
209        assert!(hs.supports_fast());
210        assert!(hs.supports_extensions());
211        // Ensure DHT and fast don't interfere
212        let hs2 = hs.with_dht();
213        assert!(hs2.supports_fast());
214        assert!(hs2.supports_dht());
215
216        let parsed = Handshake::from_bytes(&hs2.to_bytes()).unwrap();
217        assert!(parsed.supports_fast());
218        assert!(parsed.supports_dht());
219    }
220}