rbit/peer/
message.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use super::error::PeerError;
4
5/// The BitTorrent protocol identifier.
6pub const PROTOCOL: &[u8] = b"BitTorrent protocol";
7/// Length of the handshake message in bytes.
8pub const HANDSHAKE_LEN: usize = 68;
9
10/// Message type identifiers in the peer wire protocol.
11///
12/// Each message (except KeepAlive) has a one-byte ID following the length prefix.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[repr(u8)]
15pub enum MessageId {
16    /// Stop sending data to the peer.
17    Choke = 0,
18    /// Ready to send data to the peer.
19    Unchoke = 1,
20    /// Want data from the peer.
21    Interested = 2,
22    /// Don't want data from the peer.
23    NotInterested = 3,
24    /// Announce a newly-acquired piece.
25    Have = 4,
26    /// Announce all available pieces.
27    Bitfield = 5,
28    /// Request a data block.
29    Request = 6,
30    /// Send piece data.
31    Piece = 7,
32    /// Cancel a pending request.
33    Cancel = 8,
34    /// DHT port announcement.
35    Port = 9,
36    // Fast extension (BEP-6)
37    /// Suggest a piece to download.
38    Suggest = 13,
39    /// Peer has all pieces (seeder).
40    HaveAll = 14,
41    /// Peer has no pieces.
42    HaveNone = 15,
43    /// Reject a block request.
44    Reject = 16,
45    /// Allow downloading while choked.
46    AllowedFast = 17,
47    // Extension protocol (BEP-10)
48    /// Extension protocol message.
49    Extended = 20,
50    // BitTorrent v2 (BEP-52)
51    /// Request merkle tree hashes.
52    HashRequest = 21,
53    /// Merkle tree hash response.
54    Hashes = 22,
55    /// Reject a hash request.
56    HashReject = 23,
57}
58
59impl TryFrom<u8> for MessageId {
60    type Error = PeerError;
61
62    fn try_from(value: u8) -> Result<Self, Self::Error> {
63        match value {
64            0 => Ok(MessageId::Choke),
65            1 => Ok(MessageId::Unchoke),
66            2 => Ok(MessageId::Interested),
67            3 => Ok(MessageId::NotInterested),
68            4 => Ok(MessageId::Have),
69            5 => Ok(MessageId::Bitfield),
70            6 => Ok(MessageId::Request),
71            7 => Ok(MessageId::Piece),
72            8 => Ok(MessageId::Cancel),
73            9 => Ok(MessageId::Port),
74            13 => Ok(MessageId::Suggest),
75            14 => Ok(MessageId::HaveAll),
76            15 => Ok(MessageId::HaveNone),
77            16 => Ok(MessageId::Reject),
78            17 => Ok(MessageId::AllowedFast),
79            20 => Ok(MessageId::Extended),
80            21 => Ok(MessageId::HashRequest),
81            22 => Ok(MessageId::Hashes),
82            23 => Ok(MessageId::HashReject),
83            _ => Err(PeerError::InvalidMessageId(value)),
84        }
85    }
86}
87
88/// The BitTorrent handshake message.
89///
90/// The handshake is the first message exchanged between peers and includes:
91/// - Protocol identifier ("BitTorrent protocol")
92/// - Reserved bytes (8 bytes, used for capability flags)
93/// - Info hash (20 bytes, identifies the torrent)
94/// - Peer ID (20 bytes, identifies the client)
95///
96/// # Reserved Bytes
97///
98/// Bits in the reserved bytes indicate protocol extensions:
99/// - Byte 5, bit 4: Extension protocol ([BEP-10])
100/// - Byte 7, bit 0: DHT ([BEP-5])
101/// - Byte 7, bit 2: Fast extension ([BEP-6])
102///
103/// [BEP-5]: http://bittorrent.org/beps/bep_0005.html
104/// [BEP-6]: http://bittorrent.org/beps/bep_0006.html
105/// [BEP-10]: http://bittorrent.org/beps/bep_0010.html
106#[derive(Debug, Clone)]
107pub struct Handshake {
108    /// The torrent's info hash.
109    pub info_hash: [u8; 20],
110    /// The sender's peer ID.
111    pub peer_id: [u8; 20],
112    /// Reserved bytes for protocol extensions.
113    pub reserved: [u8; 8],
114}
115
116impl Handshake {
117    /// Creates a new handshake with extension protocol and fast extension enabled.
118    pub fn new(info_hash: [u8; 20], peer_id: [u8; 20]) -> Self {
119        let mut reserved = [0u8; 8];
120        reserved[5] |= 0x10; // Extension protocol (BEP-10)
121        reserved[7] |= 0x04; // Fast extension (BEP-6)
122        Self {
123            info_hash,
124            peer_id,
125            reserved,
126        }
127    }
128
129    /// Creates a new handshake with v2 support enabled.
130    ///
131    /// This sets the extension protocol, fast extension, and v2 capability bits.
132    pub fn new_v2(info_hash: [u8; 20], peer_id: [u8; 20]) -> Self {
133        let mut reserved = [0u8; 8];
134        reserved[5] |= 0x10; // Extension protocol (BEP-10)
135        reserved[7] |= 0x04; // Fast extension (BEP-6)
136        reserved[7] |= 0x10; // BitTorrent v2 (BEP-52)
137        Self {
138            info_hash,
139            peer_id,
140            reserved,
141        }
142    }
143
144    /// Returns `true` if the peer supports the extension protocol ([BEP-10]).
145    pub fn supports_extension_protocol(&self) -> bool {
146        (self.reserved[5] & 0x10) != 0
147    }
148
149    /// Returns `true` if the peer supports the fast extension ([BEP-6]).
150    pub fn supports_fast_extension(&self) -> bool {
151        (self.reserved[7] & 0x04) != 0
152    }
153
154    /// Returns `true` if the peer supports DHT ([BEP-5]).
155    pub fn supports_dht(&self) -> bool {
156        (self.reserved[7] & 0x01) != 0
157    }
158
159    /// Returns `true` if the peer supports BitTorrent v2 ([BEP-52]).
160    ///
161    /// The v2 capability is indicated by the 4th most significant bit
162    /// in the last byte of the reserved field (bit 4, 0x10).
163    pub fn supports_v2(&self) -> bool {
164        (self.reserved[7] & 0x10) != 0
165    }
166
167    /// Sets the v2 support bit in the reserved field.
168    pub fn set_v2_support(&mut self, enabled: bool) {
169        if enabled {
170            self.reserved[7] |= 0x10;
171        } else {
172            self.reserved[7] &= !0x10;
173        }
174    }
175
176    /// Encodes the handshake to bytes for transmission.
177    pub fn encode(&self) -> Bytes {
178        let mut buf = BytesMut::with_capacity(HANDSHAKE_LEN);
179        buf.put_u8(19);
180        buf.put_slice(PROTOCOL);
181        buf.put_slice(&self.reserved);
182        buf.put_slice(&self.info_hash);
183        buf.put_slice(&self.peer_id);
184        buf.freeze()
185    }
186
187    pub fn decode(data: &[u8]) -> Result<Self, PeerError> {
188        if data.len() < HANDSHAKE_LEN {
189            return Err(PeerError::InvalidHandshake);
190        }
191
192        if data[0] != 19 || &data[1..20] != PROTOCOL {
193            return Err(PeerError::InvalidHandshake);
194        }
195
196        let mut reserved = [0u8; 8];
197        reserved.copy_from_slice(&data[20..28]);
198
199        let mut info_hash = [0u8; 20];
200        info_hash.copy_from_slice(&data[28..48]);
201
202        let mut peer_id = [0u8; 20];
203        peer_id.copy_from_slice(&data[48..68]);
204
205        Ok(Self {
206            info_hash,
207            peer_id,
208            reserved,
209        })
210    }
211}
212
213/// A peer wire protocol message.
214///
215/// Messages are length-prefixed: a 4-byte big-endian length followed by
216/// a 1-byte message ID (except KeepAlive which has length 0) and payload.
217///
218/// # Examples
219///
220/// ```
221/// use rbit::peer::Message;
222///
223/// // Create a request for piece 0, offset 0, 16KB
224/// let request = Message::Request {
225///     index: 0,
226///     begin: 0,
227///     length: 16384,
228/// };
229///
230/// // Encode to bytes
231/// let bytes = request.encode();
232/// assert_eq!(bytes.len(), 17); // 4-byte length + 1-byte ID + 12-byte payload
233/// ```
234#[derive(Debug, Clone)]
235pub enum Message {
236    /// Empty message to keep the connection alive.
237    KeepAlive,
238    /// We are choking the peer (not sending data).
239    Choke,
240    /// We are unchoking the peer (ready to send data).
241    Unchoke,
242    /// We are interested in the peer's data.
243    Interested,
244    /// We are not interested in the peer's data.
245    NotInterested,
246    /// Announce that we have a piece.
247    Have { piece: u32 },
248    /// Bitfield of all pieces we have.
249    Bitfield(Bytes),
250    /// Request a block of data.
251    Request { index: u32, begin: u32, length: u32 },
252    /// Send piece data.
253    Piece { index: u32, begin: u32, data: Bytes },
254    /// Cancel a pending request.
255    Cancel { index: u32, begin: u32, length: u32 },
256    /// DHT port announcement.
257    Port(u16),
258    // Fast extension
259    /// Suggest a piece to download (fast extension).
260    Suggest { piece: u32 },
261    /// Peer has all pieces (fast extension, seeder shortcut).
262    HaveAll,
263    /// Peer has no pieces (fast extension).
264    HaveNone,
265    /// Reject a block request (fast extension).
266    Reject { index: u32, begin: u32, length: u32 },
267    /// Allow downloading this piece while choked (fast extension).
268    AllowedFast { piece: u32 },
269    // Extension protocol
270    /// Extension protocol message ([BEP-10]).
271    Extended { id: u8, payload: Bytes },
272    // BitTorrent v2 (BEP-52)
273    /// Request merkle tree hashes from a peer.
274    ///
275    /// Used to request hash blocks from a file's merkle tree.
276    HashRequest {
277        /// The merkle root of the file (32 bytes).
278        pieces_root: [u8; 32],
279        /// The tree layer to request (0 = leaf layer).
280        base_layer: u32,
281        /// Starting index in the layer (must be multiple of length).
282        index: u32,
283        /// Number of hashes to request (must be power of 2, >= 2, <= 512).
284        length: u32,
285        /// Number of ancestor layers to include as proof.
286        proof_layers: u32,
287    },
288    /// Response containing merkle tree hashes.
289    ///
290    /// Contains the requested hashes plus uncle hashes for verification.
291    Hashes {
292        /// The merkle root of the file (32 bytes).
293        pieces_root: [u8; 32],
294        /// The tree layer (0 = leaf layer).
295        base_layer: u32,
296        /// Starting index in the layer.
297        index: u32,
298        /// Number of hashes in the base layer.
299        length: u32,
300        /// Number of proof layers included.
301        proof_layers: u32,
302        /// Concatenated 32-byte hashes (length + proof hashes).
303        hashes: Bytes,
304    },
305    /// Reject a hash request.
306    ///
307    /// Sent when a peer cannot or will not service a hash request.
308    HashReject {
309        /// The merkle root of the file (32 bytes).
310        pieces_root: [u8; 32],
311        /// The tree layer requested.
312        base_layer: u32,
313        /// Starting index requested.
314        index: u32,
315        /// Number of hashes requested.
316        length: u32,
317        /// Number of proof layers requested.
318        proof_layers: u32,
319    },
320}
321
322impl Message {
323    /// Encodes the message to bytes for transmission.
324    ///
325    /// The output includes the 4-byte length prefix.
326    pub fn encode(&self) -> Bytes {
327        let mut buf = BytesMut::new();
328
329        match self {
330            Message::KeepAlive => {
331                buf.put_u32(0);
332            }
333            Message::Choke => {
334                buf.put_u32(1);
335                buf.put_u8(MessageId::Choke as u8);
336            }
337            Message::Unchoke => {
338                buf.put_u32(1);
339                buf.put_u8(MessageId::Unchoke as u8);
340            }
341            Message::Interested => {
342                buf.put_u32(1);
343                buf.put_u8(MessageId::Interested as u8);
344            }
345            Message::NotInterested => {
346                buf.put_u32(1);
347                buf.put_u8(MessageId::NotInterested as u8);
348            }
349            Message::Have { piece } => {
350                buf.put_u32(5);
351                buf.put_u8(MessageId::Have as u8);
352                buf.put_u32(*piece);
353            }
354            Message::Bitfield(bits) => {
355                buf.put_u32(1 + bits.len() as u32);
356                buf.put_u8(MessageId::Bitfield as u8);
357                buf.put_slice(bits);
358            }
359            Message::Request {
360                index,
361                begin,
362                length,
363            } => {
364                buf.put_u32(13);
365                buf.put_u8(MessageId::Request as u8);
366                buf.put_u32(*index);
367                buf.put_u32(*begin);
368                buf.put_u32(*length);
369            }
370            Message::Piece { index, begin, data } => {
371                buf.put_u32(9 + data.len() as u32);
372                buf.put_u8(MessageId::Piece as u8);
373                buf.put_u32(*index);
374                buf.put_u32(*begin);
375                buf.put_slice(data);
376            }
377            Message::Cancel {
378                index,
379                begin,
380                length,
381            } => {
382                buf.put_u32(13);
383                buf.put_u8(MessageId::Cancel as u8);
384                buf.put_u32(*index);
385                buf.put_u32(*begin);
386                buf.put_u32(*length);
387            }
388            Message::Port(port) => {
389                buf.put_u32(3);
390                buf.put_u8(MessageId::Port as u8);
391                buf.put_u16(*port);
392            }
393            Message::Suggest { piece } => {
394                buf.put_u32(5);
395                buf.put_u8(MessageId::Suggest as u8);
396                buf.put_u32(*piece);
397            }
398            Message::HaveAll => {
399                buf.put_u32(1);
400                buf.put_u8(MessageId::HaveAll as u8);
401            }
402            Message::HaveNone => {
403                buf.put_u32(1);
404                buf.put_u8(MessageId::HaveNone as u8);
405            }
406            Message::Reject {
407                index,
408                begin,
409                length,
410            } => {
411                buf.put_u32(13);
412                buf.put_u8(MessageId::Reject as u8);
413                buf.put_u32(*index);
414                buf.put_u32(*begin);
415                buf.put_u32(*length);
416            }
417            Message::AllowedFast { piece } => {
418                buf.put_u32(5);
419                buf.put_u8(MessageId::AllowedFast as u8);
420                buf.put_u32(*piece);
421            }
422            Message::Extended { id, payload } => {
423                buf.put_u32(2 + payload.len() as u32);
424                buf.put_u8(MessageId::Extended as u8);
425                buf.put_u8(*id);
426                buf.put_slice(payload);
427            }
428            // BitTorrent v2 messages (BEP-52)
429            // HashRequest: 1 byte msg_id + 32 bytes root + 4*4 bytes fields = 49 bytes
430            Message::HashRequest {
431                pieces_root,
432                base_layer,
433                index,
434                length,
435                proof_layers,
436            } => {
437                buf.put_u32(49);
438                buf.put_u8(MessageId::HashRequest as u8);
439                buf.put_slice(pieces_root);
440                buf.put_u32(*base_layer);
441                buf.put_u32(*index);
442                buf.put_u32(*length);
443                buf.put_u32(*proof_layers);
444            }
445            // Hashes: 1 byte msg_id + 32 bytes root + 4*4 bytes fields + hashes
446            Message::Hashes {
447                pieces_root,
448                base_layer,
449                index,
450                length,
451                proof_layers,
452                hashes,
453            } => {
454                buf.put_u32(49 + hashes.len() as u32);
455                buf.put_u8(MessageId::Hashes as u8);
456                buf.put_slice(pieces_root);
457                buf.put_u32(*base_layer);
458                buf.put_u32(*index);
459                buf.put_u32(*length);
460                buf.put_u32(*proof_layers);
461                buf.put_slice(hashes);
462            }
463            // HashReject: same as HashRequest (49 bytes)
464            Message::HashReject {
465                pieces_root,
466                base_layer,
467                index,
468                length,
469                proof_layers,
470            } => {
471                buf.put_u32(49);
472                buf.put_u8(MessageId::HashReject as u8);
473                buf.put_slice(pieces_root);
474                buf.put_u32(*base_layer);
475                buf.put_u32(*index);
476                buf.put_u32(*length);
477                buf.put_u32(*proof_layers);
478            }
479        }
480
481        buf.freeze()
482    }
483
484    pub fn decode(mut data: Bytes) -> Result<Self, PeerError> {
485        if data.len() < 4 {
486            return Err(PeerError::InvalidMessage("too short".into()));
487        }
488
489        let length = data.get_u32() as usize;
490
491        if length == 0 {
492            return Ok(Message::KeepAlive);
493        }
494
495        if data.remaining() < length {
496            return Err(PeerError::InvalidMessage("incomplete message".into()));
497        }
498
499        let id = MessageId::try_from(data.get_u8())?;
500
501        match id {
502            MessageId::Choke => Ok(Message::Choke),
503            MessageId::Unchoke => Ok(Message::Unchoke),
504            MessageId::Interested => Ok(Message::Interested),
505            MessageId::NotInterested => Ok(Message::NotInterested),
506            MessageId::Have => {
507                if data.remaining() < 4 {
508                    return Err(PeerError::InvalidMessage("have too short".into()));
509                }
510                Ok(Message::Have {
511                    piece: data.get_u32(),
512                })
513            }
514            MessageId::Bitfield => Ok(Message::Bitfield(data.copy_to_bytes(length - 1))),
515            MessageId::Request => {
516                if data.remaining() < 12 {
517                    return Err(PeerError::InvalidMessage("request too short".into()));
518                }
519                Ok(Message::Request {
520                    index: data.get_u32(),
521                    begin: data.get_u32(),
522                    length: data.get_u32(),
523                })
524            }
525            MessageId::Piece => {
526                if data.remaining() < 8 {
527                    return Err(PeerError::InvalidMessage("piece too short".into()));
528                }
529                let index = data.get_u32();
530                let begin = data.get_u32();
531                let block_data = data.copy_to_bytes(length - 9);
532                Ok(Message::Piece {
533                    index,
534                    begin,
535                    data: block_data,
536                })
537            }
538            MessageId::Cancel => {
539                if data.remaining() < 12 {
540                    return Err(PeerError::InvalidMessage("cancel too short".into()));
541                }
542                Ok(Message::Cancel {
543                    index: data.get_u32(),
544                    begin: data.get_u32(),
545                    length: data.get_u32(),
546                })
547            }
548            MessageId::Port => {
549                if data.remaining() < 2 {
550                    return Err(PeerError::InvalidMessage("port too short".into()));
551                }
552                Ok(Message::Port(data.get_u16()))
553            }
554            MessageId::Suggest => {
555                if data.remaining() < 4 {
556                    return Err(PeerError::InvalidMessage("suggest too short".into()));
557                }
558                Ok(Message::Suggest {
559                    piece: data.get_u32(),
560                })
561            }
562            MessageId::HaveAll => Ok(Message::HaveAll),
563            MessageId::HaveNone => Ok(Message::HaveNone),
564            MessageId::Reject => {
565                if data.remaining() < 12 {
566                    return Err(PeerError::InvalidMessage("reject too short".into()));
567                }
568                Ok(Message::Reject {
569                    index: data.get_u32(),
570                    begin: data.get_u32(),
571                    length: data.get_u32(),
572                })
573            }
574            MessageId::AllowedFast => {
575                if data.remaining() < 4 {
576                    return Err(PeerError::InvalidMessage("allowed fast too short".into()));
577                }
578                Ok(Message::AllowedFast {
579                    piece: data.get_u32(),
580                })
581            }
582            MessageId::Extended => {
583                if data.remaining() < 1 {
584                    return Err(PeerError::InvalidMessage("extended too short".into()));
585                }
586                let ext_id = data.get_u8();
587                let payload = data.copy_to_bytes(length - 2);
588                Ok(Message::Extended {
589                    id: ext_id,
590                    payload,
591                })
592            }
593            // BitTorrent v2 messages (BEP-52)
594            MessageId::HashRequest => {
595                // 32 bytes root + 4*4 bytes = 48 bytes payload
596                if data.remaining() < 48 {
597                    return Err(PeerError::InvalidMessage("hash request too short".into()));
598                }
599                let mut pieces_root = [0u8; 32];
600                pieces_root.copy_from_slice(&data.copy_to_bytes(32));
601                Ok(Message::HashRequest {
602                    pieces_root,
603                    base_layer: data.get_u32(),
604                    index: data.get_u32(),
605                    length: data.get_u32(),
606                    proof_layers: data.get_u32(),
607                })
608            }
609            MessageId::Hashes => {
610                // 32 bytes root + 4*4 bytes = 48 bytes header, rest is hashes
611                if data.remaining() < 48 {
612                    return Err(PeerError::InvalidMessage("hashes too short".into()));
613                }
614                let mut pieces_root = [0u8; 32];
615                pieces_root.copy_from_slice(&data.copy_to_bytes(32));
616                let base_layer = data.get_u32();
617                let index = data.get_u32();
618                let hash_length = data.get_u32();
619                let proof_layers = data.get_u32();
620                // Remaining bytes are the concatenated hashes
621                let hashes_len = length - 49; // length - 1 (msg_id) - 48 (header)
622                if data.remaining() < hashes_len {
623                    return Err(PeerError::InvalidMessage("hashes data too short".into()));
624                }
625                let hashes = data.copy_to_bytes(hashes_len);
626                // Validate hash data length is multiple of 32
627                if hashes.len() % 32 != 0 {
628                    return Err(PeerError::InvalidMessage(
629                        "hashes not multiple of 32 bytes".into(),
630                    ));
631                }
632                Ok(Message::Hashes {
633                    pieces_root,
634                    base_layer,
635                    index,
636                    length: hash_length,
637                    proof_layers,
638                    hashes,
639                })
640            }
641            MessageId::HashReject => {
642                // 32 bytes root + 4*4 bytes = 48 bytes payload
643                if data.remaining() < 48 {
644                    return Err(PeerError::InvalidMessage("hash reject too short".into()));
645                }
646                let mut pieces_root = [0u8; 32];
647                pieces_root.copy_from_slice(&data.copy_to_bytes(32));
648                Ok(Message::HashReject {
649                    pieces_root,
650                    base_layer: data.get_u32(),
651                    index: data.get_u32(),
652                    length: data.get_u32(),
653                    proof_layers: data.get_u32(),
654                })
655            }
656        }
657    }
658}
659
660/// Validates a HashRequest according to BEP-52 requirements.
661///
662/// Returns an error message if invalid, or None if valid.
663pub fn validate_hash_request(length: u32, index: u32) -> Option<&'static str> {
664    // Length must be >= 2
665    if length < 2 {
666        return Some("length must be >= 2");
667    }
668    // Length must be power of 2
669    if length & (length - 1) != 0 {
670        return Some("length must be power of 2");
671    }
672    // Length should not exceed 512 (soft limit, but we enforce it)
673    if length > 512 {
674        return Some("length exceeds 512");
675    }
676    // Index must be multiple of length
677    if index % length != 0 {
678        return Some("index must be multiple of length");
679    }
680    None
681}