irontide_wire/
handshake.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use irontide_core::Id20;
4
5use crate::error::{Error, Result};
6
7const PROTOCOL: &[u8] = b"BitTorrent protocol";
9
10pub const HANDSHAKE_SIZE: usize = 68;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Handshake {
23 pub reserved: [u8; 8],
25 pub info_hash: Id20,
27 pub peer_id: Id20,
29}
30
31impl Handshake {
32 #[must_use]
34 pub fn new(info_hash: Id20, peer_id: Id20) -> Self {
35 let mut reserved = [0u8; 8];
36 reserved[5] |= 0x10;
38 Self {
39 reserved,
40 info_hash,
41 peer_id,
42 }
43 }
44
45 #[must_use]
47 pub fn supports_extensions(&self) -> bool {
48 self.reserved[5] & 0x10 != 0
49 }
50
51 #[must_use]
53 pub fn supports_dht(&self) -> bool {
54 self.reserved[7] & 0x01 != 0
55 }
56
57 #[must_use]
59 pub fn with_dht(mut self) -> Self {
60 self.reserved[7] |= 0x01;
61 self
62 }
63
64 #[must_use]
66 pub fn supports_fast(&self) -> bool {
67 self.reserved[7] & 0x04 != 0
68 }
69
70 #[must_use]
72 pub fn with_fast(mut self) -> Self {
73 self.reserved[7] |= 0x04;
74 self
75 }
76
77 #[must_use]
79 pub fn to_bytes(&self) -> Bytes {
80 let mut buf = BytesMut::with_capacity(HANDSHAKE_SIZE);
81 buf.put_u8(19);
82 buf.put_slice(PROTOCOL);
83 buf.put_slice(&self.reserved);
84 buf.put_slice(self.info_hash.as_bytes());
85 buf.put_slice(self.peer_id.as_bytes());
86 buf.freeze()
87 }
88
89 pub fn from_bytes(mut data: &[u8]) -> Result<Self> {
95 if data.len() < HANDSHAKE_SIZE {
96 return Err(Error::InvalidHandshake(format!(
97 "need {} bytes, got {}",
98 HANDSHAKE_SIZE,
99 data.len()
100 )));
101 }
102
103 let pstrlen = data.get_u8();
104 if pstrlen != 19 {
105 return Err(Error::InvalidHandshake(format!(
106 "pstrlen {pstrlen}, expected 19"
107 )));
108 }
109
110 let pstr = &data[..19];
111 if pstr != PROTOCOL {
112 return Err(Error::InvalidHandshake("wrong protocol string".into()));
113 }
114 data.advance(19);
115
116 let mut reserved = [0u8; 8];
117 reserved.copy_from_slice(&data[..8]);
118 data.advance(8);
119
120 let info_hash =
121 Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
122 data.advance(20);
123
124 let peer_id =
125 Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
126
127 Ok(Self {
128 reserved,
129 info_hash,
130 peer_id,
131 })
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn handshake_round_trip() {
141 let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
142 let peer_id = Id20::from_hex("0102030405060708091011121314151617181920").unwrap();
143
144 let hs = Handshake::new(info_hash, peer_id);
145 assert!(hs.supports_extensions());
146
147 let bytes = hs.to_bytes();
148 assert_eq!(bytes.len(), HANDSHAKE_SIZE);
149
150 let parsed = Handshake::from_bytes(&bytes).unwrap();
151 assert_eq!(hs, parsed);
152 }
153
154 #[test]
155 fn handshake_dht_flag() {
156 let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_dht();
157 assert!(hs.supports_dht());
158 assert!(hs.supports_extensions());
159
160 let parsed = Handshake::from_bytes(&hs.to_bytes()).unwrap();
161 assert!(parsed.supports_dht());
162 }
163
164 #[test]
167 fn ext_handshake_reserved_bit_position() {
168 let mut expected_reserved = [0u8; 8];
170 expected_reserved[5] = 0x10;
173
174 let hs = Handshake::new(Id20::ZERO, Id20::ZERO);
176 assert_eq!(
177 hs.reserved, expected_reserved,
178 "Handshake::new() reserved field must match BEP 10 spec"
179 );
180 assert_eq!(
181 hs.reserved[5] & 0x10,
182 0x10,
183 "BEP 10 extension bit must be at reserved[5] & 0x10"
184 );
185 assert!(hs.supports_extensions());
186
187 for (i, &byte) in hs.reserved.iter().enumerate() {
189 if i == 5 {
190 assert_eq!(byte, 0x10, "byte 5 should be exactly 0x10");
191 } else {
192 assert_eq!(byte, 0, "byte {i} should be zero when only BEP 10 is set");
193 }
194 }
195
196 let mut hs_no_ext = hs;
198 hs_no_ext.reserved[5] &= !0x10;
199 assert!(!hs_no_ext.supports_extensions());
200
201 let bit_index = 20;
204 let byte_index_from_right = bit_index / 8; let byte_index = 7 - byte_index_from_right; let bit_within_byte = bit_index % 8; assert_eq!(byte_index, 5);
208 assert_eq!(bit_within_byte, 4);
209 assert_eq!(1u8 << bit_within_byte, 0x10);
210 }
211
212 #[test]
213 fn handshake_too_short() {
214 assert!(Handshake::from_bytes(&[0u8; 10]).is_err());
215 }
216
217 #[test]
218 fn handshake_fast_flag() {
219 let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_fast();
220 assert!(hs.supports_fast());
221 assert!(hs.supports_extensions());
222 let hs2 = hs.with_dht();
224 assert!(hs2.supports_fast());
225 assert!(hs2.supports_dht());
226
227 let parsed = Handshake::from_bytes(&hs2.to_bytes()).unwrap();
228 assert!(parsed.supports_fast());
229 assert!(parsed.supports_dht());
230 }
231}