irontide_wire/
handshake.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use irontide_core::Id20;
4
5use crate::error::{Error, Result};
6
7const PROTOCOL: &[u8] = b"BitTorrent protocol";
9
10pub const HANDSHAKE_SIZE: usize = 68;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Handshake {
23 pub reserved: [u8; 8],
25 pub info_hash: Id20,
27 pub peer_id: Id20,
29}
30
31impl Handshake {
32 pub fn new(info_hash: Id20, peer_id: Id20) -> Self {
34 let mut reserved = [0u8; 8];
35 reserved[5] |= 0x10;
37 Handshake {
38 reserved,
39 info_hash,
40 peer_id,
41 }
42 }
43
44 pub fn supports_extensions(&self) -> bool {
46 self.reserved[5] & 0x10 != 0
47 }
48
49 pub fn supports_dht(&self) -> bool {
51 self.reserved[7] & 0x01 != 0
52 }
53
54 pub fn with_dht(mut self) -> Self {
56 self.reserved[7] |= 0x01;
57 self
58 }
59
60 pub fn supports_fast(&self) -> bool {
62 self.reserved[7] & 0x04 != 0
63 }
64
65 pub fn with_fast(mut self) -> Self {
67 self.reserved[7] |= 0x04;
68 self
69 }
70
71 pub fn to_bytes(&self) -> Bytes {
73 let mut buf = BytesMut::with_capacity(HANDSHAKE_SIZE);
74 buf.put_u8(19);
75 buf.put_slice(PROTOCOL);
76 buf.put_slice(&self.reserved);
77 buf.put_slice(self.info_hash.as_bytes());
78 buf.put_slice(self.peer_id.as_bytes());
79 buf.freeze()
80 }
81
82 pub fn from_bytes(mut data: &[u8]) -> Result<Self> {
84 if data.len() < HANDSHAKE_SIZE {
85 return Err(Error::InvalidHandshake(format!(
86 "need {} bytes, got {}",
87 HANDSHAKE_SIZE,
88 data.len()
89 )));
90 }
91
92 let pstrlen = data.get_u8();
93 if pstrlen != 19 {
94 return Err(Error::InvalidHandshake(format!(
95 "pstrlen {pstrlen}, expected 19"
96 )));
97 }
98
99 let pstr = &data[..19];
100 if pstr != PROTOCOL {
101 return Err(Error::InvalidHandshake("wrong protocol string".into()));
102 }
103 data.advance(19);
104
105 let mut reserved = [0u8; 8];
106 reserved.copy_from_slice(&data[..8]);
107 data.advance(8);
108
109 let info_hash =
110 Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
111 data.advance(20);
112
113 let peer_id =
114 Id20::from_bytes(&data[..20]).map_err(|e| Error::InvalidHandshake(e.to_string()))?;
115
116 Ok(Handshake {
117 reserved,
118 info_hash,
119 peer_id,
120 })
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn handshake_round_trip() {
130 let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
131 let peer_id = Id20::from_hex("0102030405060708091011121314151617181920").unwrap();
132
133 let hs = Handshake::new(info_hash, peer_id);
134 assert!(hs.supports_extensions());
135
136 let bytes = hs.to_bytes();
137 assert_eq!(bytes.len(), HANDSHAKE_SIZE);
138
139 let parsed = Handshake::from_bytes(&bytes).unwrap();
140 assert_eq!(hs, parsed);
141 }
142
143 #[test]
144 fn handshake_dht_flag() {
145 let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_dht();
146 assert!(hs.supports_dht());
147 assert!(hs.supports_extensions());
148
149 let parsed = Handshake::from_bytes(&hs.to_bytes()).unwrap();
150 assert!(parsed.supports_dht());
151 }
152
153 #[test]
156 fn ext_handshake_reserved_bit_position() {
157 let mut expected_reserved = [0u8; 8];
159 expected_reserved[5] = 0x10;
162
163 let hs = Handshake::new(Id20::ZERO, Id20::ZERO);
165 assert_eq!(
166 hs.reserved, expected_reserved,
167 "Handshake::new() reserved field must match BEP 10 spec"
168 );
169 assert_eq!(
170 hs.reserved[5] & 0x10,
171 0x10,
172 "BEP 10 extension bit must be at reserved[5] & 0x10"
173 );
174 assert!(hs.supports_extensions());
175
176 for (i, &byte) in hs.reserved.iter().enumerate() {
178 if i == 5 {
179 assert_eq!(byte, 0x10, "byte 5 should be exactly 0x10");
180 } else {
181 assert_eq!(byte, 0, "byte {i} should be zero when only BEP 10 is set");
182 }
183 }
184
185 let mut hs_no_ext = hs.clone();
187 hs_no_ext.reserved[5] &= !0x10;
188 assert!(!hs_no_ext.supports_extensions());
189
190 let bit_index = 20;
193 let byte_index_from_right = bit_index / 8; let byte_index = 7 - byte_index_from_right; let bit_within_byte = bit_index % 8; assert_eq!(byte_index, 5);
197 assert_eq!(bit_within_byte, 4);
198 assert_eq!(1u8 << bit_within_byte, 0x10);
199 }
200
201 #[test]
202 fn handshake_too_short() {
203 assert!(Handshake::from_bytes(&[0u8; 10]).is_err());
204 }
205
206 #[test]
207 fn handshake_fast_flag() {
208 let hs = Handshake::new(Id20::ZERO, Id20::ZERO).with_fast();
209 assert!(hs.supports_fast());
210 assert!(hs.supports_extensions());
211 let hs2 = hs.with_dht();
213 assert!(hs2.supports_fast());
214 assert!(hs2.supports_dht());
215
216 let parsed = Handshake::from_bytes(&hs2.to_bytes()).unwrap();
217 assert!(parsed.supports_fast());
218 assert!(parsed.supports_dht());
219 }
220}