Skip to main content

irontide_wire/
message.rs

1use bytes::{BufMut, Bytes, BytesMut};
2
3use crate::error::{Error, Result};
4
5/// Standard BitTorrent peer wire messages (BEP 3).
6///
7/// Generic over buffer type `B` to support both owned (`Bytes`) and borrowed
8/// (`&[u8]`) payloads.  Data-carrying variants (`Piece`, `Bitfield`,
9/// `Extended`) use `B`; fixed-field variants are buffer-agnostic.
10///
11/// The `Piece` variant carries two buffer fields (`data_0`, `data_1`) to
12/// support ring-buffer wrap-around: when a block spans the ring boundary the
13/// two non-contiguous slices are stored separately. In the common (no-wrap)
14/// case `data_1` is empty.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum Message<B = Bytes> {
17    /// Keep connection alive (no payload, length=0).
18    KeepAlive,
19    /// Peer is choking us.
20    Choke,
21    /// Peer is unchoking us.
22    Unchoke,
23    /// We're interested in the peer's data.
24    Interested,
25    /// We're not interested.
26    NotInterested,
27    /// Peer has piece `index`.
28    Have {
29        /// Piece index.
30        index: u32,
31    },
32    /// Peer's complete bitfield.
33    Bitfield(B),
34    /// Request a block: piece index, byte offset within piece, length.
35    Request {
36        /// Piece index.
37        index: u32,
38        /// Byte offset within the piece.
39        begin: u32,
40        /// Requested block length in bytes.
41        length: u32,
42    },
43    /// A data block: piece index, byte offset, data.
44    ///
45    /// Two buffer fields support ring-buffer wrap-around.  When the block does
46    /// not straddle the ring boundary, `data_1` is empty.
47    Piece {
48        /// Piece index.
49        index: u32,
50        /// Byte offset within the piece.
51        begin: u32,
52        /// First (or only) contiguous slice of block payload.
53        data_0: B,
54        /// Second contiguous slice when the ring buffer wraps; empty otherwise.
55        data_1: B,
56    },
57    /// Cancel a previously sent request.
58    Cancel {
59        /// Piece index.
60        index: u32,
61        /// Byte offset within the piece.
62        begin: u32,
63        /// Block length in bytes.
64        length: u32,
65    },
66    /// DHT port (BEP 5).
67    Port(u16),
68    /// Extension message (BEP 10). ext_id=0 is handshake.
69    Extended {
70        /// Extension message ID (0 = handshake).
71        ext_id: u8,
72        /// Bencoded extension payload.
73        payload: B,
74    },
75    /// BEP 6: Suggest a piece for the peer to download.
76    SuggestPiece(u32),
77    /// BEP 6: We have all pieces.
78    HaveAll,
79    /// BEP 6: We have no pieces.
80    HaveNone,
81    /// BEP 6: Reject a request from the peer.
82    RejectRequest {
83        /// Piece index.
84        index: u32,
85        /// Byte offset within the piece.
86        begin: u32,
87        /// Block length in bytes.
88        length: u32,
89    },
90    /// BEP 6: Piece index the peer is allowed to request while choked.
91    AllowedFast(u32),
92    /// BEP 52: Request hashes from a file's Merkle tree.
93    HashRequest {
94        /// File root hash identifying the Merkle tree.
95        pieces_root: irontide_core::Id32,
96        /// Tree layer (0 = leaf/block layer).
97        base: u32,
98        /// Starting node index within the layer.
99        index: u32,
100        /// Number of consecutive hashes requested.
101        count: u32,
102        /// Number of uncle proof layers to include.
103        proof_layers: u32,
104    },
105    /// BEP 52: Response with hashes and uncle proof.
106    Hashes {
107        /// File root hash identifying the Merkle tree.
108        pieces_root: irontide_core::Id32,
109        /// Tree layer (0 = leaf/block layer).
110        base: u32,
111        /// Starting node index within the layer.
112        index: u32,
113        /// Number of consecutive hashes in the response.
114        count: u32,
115        /// Number of uncle proof layers included.
116        proof_layers: u32,
117        /// Hash values followed by uncle proof hashes.
118        hashes: Vec<irontide_core::Id32>,
119    },
120    /// BEP 52: Reject a hash request.
121    HashReject {
122        /// File root hash identifying the Merkle tree.
123        pieces_root: irontide_core::Id32,
124        /// Tree layer that was requested.
125        base: u32,
126        /// Starting node index that was requested.
127        index: u32,
128        /// Number of hashes that was requested.
129        count: u32,
130        /// Number of proof layers that was requested.
131        proof_layers: u32,
132    },
133}
134
135// Message IDs per BEP 3
136const ID_CHOKE: u8 = 0;
137const ID_UNCHOKE: u8 = 1;
138const ID_INTERESTED: u8 = 2;
139const ID_NOT_INTERESTED: u8 = 3;
140const ID_HAVE: u8 = 4;
141const ID_BITFIELD: u8 = 5;
142const ID_REQUEST: u8 = 6;
143const ID_PIECE: u8 = 7;
144const ID_CANCEL: u8 = 8;
145const ID_PORT: u8 = 9;
146const ID_EXTENDED: u8 = 20;
147
148// BEP 6 Fast Extension
149const ID_SUGGEST_PIECE: u8 = 0x0D;
150const ID_HAVE_ALL: u8 = 0x0E;
151const ID_HAVE_NONE: u8 = 0x0F;
152const ID_REJECT_REQUEST: u8 = 0x10;
153const ID_ALLOWED_FAST: u8 = 0x11;
154
155// BEP 52 Hash Messages
156const ID_HASH_REQUEST: u8 = 21;
157const ID_HASHES: u8 = 22;
158const ID_HASH_REJECT: u8 = 23;
159
160impl<B: AsRef<[u8]>> Message<B> {
161    /// Serialize a message to bytes (length-prefix + id + payload).
162    ///
163    /// The returned bytes include the 4-byte length prefix.
164    pub fn to_bytes(&self) -> Bytes {
165        match self {
166            Message::KeepAlive => {
167                let mut buf = BytesMut::with_capacity(4);
168                buf.put_u32(0);
169                buf.freeze()
170            }
171            Message::Choke => fixed_msg(ID_CHOKE),
172            Message::Unchoke => fixed_msg(ID_UNCHOKE),
173            Message::Interested => fixed_msg(ID_INTERESTED),
174            Message::NotInterested => fixed_msg(ID_NOT_INTERESTED),
175            Message::Have { index } => {
176                let mut buf = BytesMut::with_capacity(9);
177                buf.put_u32(5);
178                buf.put_u8(ID_HAVE);
179                buf.put_u32(*index);
180                buf.freeze()
181            }
182            Message::Bitfield(bits) => {
183                let bits = bits.as_ref();
184                let mut buf = BytesMut::with_capacity(5 + bits.len());
185                buf.put_u32(1 + bits.len() as u32);
186                buf.put_u8(ID_BITFIELD);
187                buf.put_slice(bits);
188                buf.freeze()
189            }
190            Message::Request {
191                index,
192                begin,
193                length,
194            } => triple_msg(ID_REQUEST, *index, *begin, *length),
195            Message::Piece {
196                index,
197                begin,
198                data_0,
199                data_1,
200            } => {
201                let d0 = data_0.as_ref();
202                let d1 = data_1.as_ref();
203                let data_len = d0.len() + d1.len();
204                let mut buf = BytesMut::with_capacity(13 + data_len);
205                buf.put_u32(9 + data_len as u32);
206                buf.put_u8(ID_PIECE);
207                buf.put_u32(*index);
208                buf.put_u32(*begin);
209                buf.put_slice(d0);
210                buf.put_slice(d1);
211                buf.freeze()
212            }
213            Message::Cancel {
214                index,
215                begin,
216                length,
217            } => triple_msg(ID_CANCEL, *index, *begin, *length),
218            Message::Port(port) => {
219                let mut buf = BytesMut::with_capacity(7);
220                buf.put_u32(3);
221                buf.put_u8(ID_PORT);
222                buf.put_u16(*port);
223                buf.freeze()
224            }
225            Message::Extended { ext_id, payload } => {
226                let payload = payload.as_ref();
227                let mut buf = BytesMut::with_capacity(6 + payload.len());
228                buf.put_u32(2 + payload.len() as u32);
229                buf.put_u8(ID_EXTENDED);
230                buf.put_u8(*ext_id);
231                buf.put_slice(payload);
232                buf.freeze()
233            }
234            Message::SuggestPiece(index) => {
235                let mut buf = BytesMut::with_capacity(9);
236                buf.put_u32(5);
237                buf.put_u8(ID_SUGGEST_PIECE);
238                buf.put_u32(*index);
239                buf.freeze()
240            }
241            Message::HaveAll => fixed_msg(ID_HAVE_ALL),
242            Message::HaveNone => fixed_msg(ID_HAVE_NONE),
243            Message::RejectRequest {
244                index,
245                begin,
246                length,
247            } => triple_msg(ID_REJECT_REQUEST, *index, *begin, *length),
248            Message::AllowedFast(index) => {
249                let mut buf = BytesMut::with_capacity(9);
250                buf.put_u32(5);
251                buf.put_u8(ID_ALLOWED_FAST);
252                buf.put_u32(*index);
253                buf.freeze()
254            }
255            Message::HashRequest {
256                pieces_root,
257                base,
258                index,
259                count,
260                proof_layers,
261            }
262            | Message::HashReject {
263                pieces_root,
264                base,
265                index,
266                count,
267                proof_layers,
268            } => {
269                let id = match self {
270                    Message::HashRequest { .. } => ID_HASH_REQUEST,
271                    _ => ID_HASH_REJECT,
272                };
273                let mut buf = BytesMut::with_capacity(53);
274                buf.put_u32(49); // 1 + 32 + 4*4
275                buf.put_u8(id);
276                buf.put_slice(&pieces_root.0);
277                buf.put_u32(*base);
278                buf.put_u32(*index);
279                buf.put_u32(*count);
280                buf.put_u32(*proof_layers);
281                buf.freeze()
282            }
283            Message::Hashes {
284                pieces_root,
285                base,
286                index,
287                count,
288                proof_layers,
289                hashes,
290            } => {
291                let hash_bytes = hashes.len() * 32;
292                let payload_len = 1 + 32 + 16 + hash_bytes;
293                let mut buf = BytesMut::with_capacity(4 + payload_len);
294                buf.put_u32(payload_len as u32);
295                buf.put_u8(ID_HASHES);
296                buf.put_slice(&pieces_root.0);
297                buf.put_u32(*base);
298                buf.put_u32(*index);
299                buf.put_u32(*count);
300                buf.put_u32(*proof_layers);
301                for h in hashes {
302                    buf.put_slice(&h.0);
303                }
304                buf.freeze()
305            }
306        }
307    }
308
309    /// Encode this message (with length prefix) directly into a buffer.
310    ///
311    /// Unlike [`to_bytes`](Self::to_bytes), this writes directly into `dst`
312    /// without allocating an intermediate `Bytes`, avoiding a double-copy
313    /// when used with `tokio_util::codec::Encoder`.
314    pub fn encode_into(&self, dst: &mut BytesMut) {
315        match self {
316            Message::KeepAlive => {
317                dst.put_u32(0);
318            }
319            Message::Choke => encode_fixed_into(dst, ID_CHOKE),
320            Message::Unchoke => encode_fixed_into(dst, ID_UNCHOKE),
321            Message::Interested => encode_fixed_into(dst, ID_INTERESTED),
322            Message::NotInterested => encode_fixed_into(dst, ID_NOT_INTERESTED),
323            Message::Have { index } => {
324                dst.put_u32(5);
325                dst.put_u8(ID_HAVE);
326                dst.put_u32(*index);
327            }
328            Message::Bitfield(bits) => {
329                let bits = bits.as_ref();
330                dst.reserve(5 + bits.len());
331                dst.put_u32(1 + bits.len() as u32);
332                dst.put_u8(ID_BITFIELD);
333                dst.put_slice(bits);
334            }
335            Message::Request {
336                index,
337                begin,
338                length,
339            } => encode_triple_into(dst, ID_REQUEST, *index, *begin, *length),
340            Message::Piece {
341                index,
342                begin,
343                data_0,
344                data_1,
345            } => {
346                let d0 = data_0.as_ref();
347                let d1 = data_1.as_ref();
348                let data_len = d0.len() + d1.len();
349                dst.reserve(13 + data_len);
350                dst.put_u32(9 + data_len as u32);
351                dst.put_u8(ID_PIECE);
352                dst.put_u32(*index);
353                dst.put_u32(*begin);
354                dst.put_slice(d0);
355                dst.put_slice(d1);
356            }
357            Message::Cancel {
358                index,
359                begin,
360                length,
361            } => encode_triple_into(dst, ID_CANCEL, *index, *begin, *length),
362            Message::Port(port) => {
363                dst.put_u32(3);
364                dst.put_u8(ID_PORT);
365                dst.put_u16(*port);
366            }
367            Message::Extended { ext_id, payload } => {
368                let payload = payload.as_ref();
369                dst.reserve(6 + payload.len());
370                dst.put_u32(2 + payload.len() as u32);
371                dst.put_u8(ID_EXTENDED);
372                dst.put_u8(*ext_id);
373                dst.put_slice(payload);
374            }
375            Message::SuggestPiece(index) => {
376                dst.put_u32(5);
377                dst.put_u8(ID_SUGGEST_PIECE);
378                dst.put_u32(*index);
379            }
380            Message::HaveAll => encode_fixed_into(dst, ID_HAVE_ALL),
381            Message::HaveNone => encode_fixed_into(dst, ID_HAVE_NONE),
382            Message::RejectRequest {
383                index,
384                begin,
385                length,
386            } => encode_triple_into(dst, ID_REJECT_REQUEST, *index, *begin, *length),
387            Message::AllowedFast(index) => {
388                dst.put_u32(5);
389                dst.put_u8(ID_ALLOWED_FAST);
390                dst.put_u32(*index);
391            }
392            Message::HashRequest {
393                pieces_root,
394                base,
395                index,
396                count,
397                proof_layers,
398            }
399            | Message::HashReject {
400                pieces_root,
401                base,
402                index,
403                count,
404                proof_layers,
405            } => {
406                let id = match self {
407                    Message::HashRequest { .. } => ID_HASH_REQUEST,
408                    _ => ID_HASH_REJECT,
409                };
410                dst.put_u32(49); // 1 + 32 + 4*4
411                dst.put_u8(id);
412                dst.put_slice(&pieces_root.0);
413                dst.put_u32(*base);
414                dst.put_u32(*index);
415                dst.put_u32(*count);
416                dst.put_u32(*proof_layers);
417            }
418            Message::Hashes {
419                pieces_root,
420                base,
421                index,
422                count,
423                proof_layers,
424                hashes,
425            } => {
426                let hash_bytes = hashes.len() * 32;
427                let payload_len = 1 + 32 + 16 + hash_bytes;
428                dst.reserve(4 + payload_len);
429                dst.put_u32(payload_len as u32);
430                dst.put_u8(ID_HASHES);
431                dst.put_slice(&pieces_root.0);
432                dst.put_u32(*base);
433                dst.put_u32(*index);
434                dst.put_u32(*count);
435                dst.put_u32(*proof_layers);
436                for h in hashes {
437                    dst.put_slice(&h.0);
438                }
439            }
440        }
441    }
442
443    /// Return the exact encoded wire length in bytes, including the 4-byte
444    /// length prefix.
445    ///
446    /// This is a pure computation — no allocation, no encoding.  Use it to
447    /// check whether a message fits in a fixed-size buffer before calling
448    /// [`encode_to_slice`](Self::encode_to_slice).
449    #[must_use]
450    pub fn wire_len(&self) -> usize {
451        match self {
452            Message::KeepAlive => 4,
453            Message::Choke
454            | Message::Unchoke
455            | Message::Interested
456            | Message::NotInterested
457            | Message::HaveAll
458            | Message::HaveNone => 5,
459            Message::Have { .. } | Message::SuggestPiece(_) | Message::AllowedFast(_) => 9,
460            Message::Port(_) => 7,
461            Message::Request { .. } | Message::Cancel { .. } | Message::RejectRequest { .. } => 17,
462            Message::Bitfield(bits) => 5 + bits.as_ref().len(),
463            Message::Piece { data_0, data_1, .. } => {
464                13 + data_0.as_ref().len() + data_1.as_ref().len()
465            }
466            Message::Extended { payload, .. } => 6 + payload.as_ref().len(),
467            Message::HashRequest { .. } | Message::HashReject { .. } => 53,
468            Message::Hashes { hashes, .. } => 53 + hashes.len() * 32,
469        }
470    }
471
472    /// Encode this message (with length prefix) into a raw byte slice.
473    ///
474    /// Returns the number of bytes written. The caller must ensure `dst` is
475    /// large enough to hold the encoded message. For peer wire messages,
476    /// `MAX_MSG_LEN` (16397) is always sufficient.
477    ///
478    /// Uses `std::io::Cursor<&mut [u8]>` + `std::io::Write` instead of
479    /// `BufMut`, producing identical bytes to [`encode_into`](Self::encode_into).
480    #[must_use]
481    pub fn encode_to_slice(&self, dst: &mut [u8]) -> usize {
482        use std::io::{Cursor, Write};
483
484        let mut cursor = Cursor::new(dst);
485
486        match self {
487            Message::KeepAlive => {
488                cursor.write_all(&0u32.to_be_bytes()).unwrap();
489            }
490            Message::Choke => {
491                cursor.write_all(&1u32.to_be_bytes()).unwrap();
492                cursor.write_all(&[ID_CHOKE]).unwrap();
493            }
494            Message::Unchoke => {
495                cursor.write_all(&1u32.to_be_bytes()).unwrap();
496                cursor.write_all(&[ID_UNCHOKE]).unwrap();
497            }
498            Message::Interested => {
499                cursor.write_all(&1u32.to_be_bytes()).unwrap();
500                cursor.write_all(&[ID_INTERESTED]).unwrap();
501            }
502            Message::NotInterested => {
503                cursor.write_all(&1u32.to_be_bytes()).unwrap();
504                cursor.write_all(&[ID_NOT_INTERESTED]).unwrap();
505            }
506            Message::Have { index } => {
507                cursor.write_all(&5u32.to_be_bytes()).unwrap();
508                cursor.write_all(&[ID_HAVE]).unwrap();
509                cursor.write_all(&index.to_be_bytes()).unwrap();
510            }
511            Message::Bitfield(bits) => {
512                let bits = bits.as_ref();
513                cursor
514                    .write_all(&(1 + bits.len() as u32).to_be_bytes())
515                    .unwrap();
516                cursor.write_all(&[ID_BITFIELD]).unwrap();
517                cursor.write_all(bits).unwrap();
518            }
519            Message::Request {
520                index,
521                begin,
522                length,
523            } => {
524                cursor.write_all(&13u32.to_be_bytes()).unwrap();
525                cursor.write_all(&[ID_REQUEST]).unwrap();
526                cursor.write_all(&index.to_be_bytes()).unwrap();
527                cursor.write_all(&begin.to_be_bytes()).unwrap();
528                cursor.write_all(&length.to_be_bytes()).unwrap();
529            }
530            Message::Piece {
531                index,
532                begin,
533                data_0,
534                data_1,
535            } => {
536                let d0 = data_0.as_ref();
537                let d1 = data_1.as_ref();
538                let data_len = d0.len() + d1.len();
539                cursor
540                    .write_all(&(9 + data_len as u32).to_be_bytes())
541                    .unwrap();
542                cursor.write_all(&[ID_PIECE]).unwrap();
543                cursor.write_all(&index.to_be_bytes()).unwrap();
544                cursor.write_all(&begin.to_be_bytes()).unwrap();
545                cursor.write_all(d0).unwrap();
546                cursor.write_all(d1).unwrap();
547            }
548            Message::Cancel {
549                index,
550                begin,
551                length,
552            } => {
553                cursor.write_all(&13u32.to_be_bytes()).unwrap();
554                cursor.write_all(&[ID_CANCEL]).unwrap();
555                cursor.write_all(&index.to_be_bytes()).unwrap();
556                cursor.write_all(&begin.to_be_bytes()).unwrap();
557                cursor.write_all(&length.to_be_bytes()).unwrap();
558            }
559            Message::Port(port) => {
560                cursor.write_all(&3u32.to_be_bytes()).unwrap();
561                cursor.write_all(&[ID_PORT]).unwrap();
562                cursor.write_all(&port.to_be_bytes()).unwrap();
563            }
564            Message::Extended { ext_id, payload } => {
565                let payload = payload.as_ref();
566                cursor
567                    .write_all(&(2 + payload.len() as u32).to_be_bytes())
568                    .unwrap();
569                cursor.write_all(&[ID_EXTENDED]).unwrap();
570                cursor.write_all(&[*ext_id]).unwrap();
571                cursor.write_all(payload).unwrap();
572            }
573            Message::SuggestPiece(index) => {
574                cursor.write_all(&5u32.to_be_bytes()).unwrap();
575                cursor.write_all(&[ID_SUGGEST_PIECE]).unwrap();
576                cursor.write_all(&index.to_be_bytes()).unwrap();
577            }
578            Message::HaveAll => {
579                cursor.write_all(&1u32.to_be_bytes()).unwrap();
580                cursor.write_all(&[ID_HAVE_ALL]).unwrap();
581            }
582            Message::HaveNone => {
583                cursor.write_all(&1u32.to_be_bytes()).unwrap();
584                cursor.write_all(&[ID_HAVE_NONE]).unwrap();
585            }
586            Message::RejectRequest {
587                index,
588                begin,
589                length,
590            } => {
591                cursor.write_all(&13u32.to_be_bytes()).unwrap();
592                cursor.write_all(&[ID_REJECT_REQUEST]).unwrap();
593                cursor.write_all(&index.to_be_bytes()).unwrap();
594                cursor.write_all(&begin.to_be_bytes()).unwrap();
595                cursor.write_all(&length.to_be_bytes()).unwrap();
596            }
597            Message::AllowedFast(index) => {
598                cursor.write_all(&5u32.to_be_bytes()).unwrap();
599                cursor.write_all(&[ID_ALLOWED_FAST]).unwrap();
600                cursor.write_all(&index.to_be_bytes()).unwrap();
601            }
602            Message::HashRequest {
603                pieces_root,
604                base,
605                index,
606                count,
607                proof_layers,
608            }
609            | Message::HashReject {
610                pieces_root,
611                base,
612                index,
613                count,
614                proof_layers,
615            } => {
616                let id = match self {
617                    Message::HashRequest { .. } => ID_HASH_REQUEST,
618                    _ => ID_HASH_REJECT,
619                };
620                cursor.write_all(&49u32.to_be_bytes()).unwrap();
621                cursor.write_all(&[id]).unwrap();
622                cursor.write_all(&pieces_root.0).unwrap();
623                cursor.write_all(&base.to_be_bytes()).unwrap();
624                cursor.write_all(&index.to_be_bytes()).unwrap();
625                cursor.write_all(&count.to_be_bytes()).unwrap();
626                cursor.write_all(&proof_layers.to_be_bytes()).unwrap();
627            }
628            Message::Hashes {
629                pieces_root,
630                base,
631                index,
632                count,
633                proof_layers,
634                hashes,
635            } => {
636                let hash_bytes = hashes.len() * 32;
637                let payload_len = 1 + 32 + 16 + hash_bytes;
638                cursor
639                    .write_all(&(payload_len as u32).to_be_bytes())
640                    .unwrap();
641                cursor.write_all(&[ID_HASHES]).unwrap();
642                cursor.write_all(&pieces_root.0).unwrap();
643                cursor.write_all(&base.to_be_bytes()).unwrap();
644                cursor.write_all(&index.to_be_bytes()).unwrap();
645                cursor.write_all(&count.to_be_bytes()).unwrap();
646                cursor.write_all(&proof_layers.to_be_bytes()).unwrap();
647                for h in hashes {
648                    cursor.write_all(&h.0).unwrap();
649                }
650            }
651        }
652
653        cursor.position() as usize
654    }
655}
656
657impl Message<&[u8]> {
658    /// Convert a borrowed message to an owned `Message<Bytes>`.
659    ///
660    /// Fixed-field variants are zero-cost (no data to copy).
661    /// Data-carrying variants (`Piece`, `Bitfield`, `Extended`) copy their
662    /// slices into fresh `Bytes` allocations.
663    #[must_use]
664    pub fn to_owned_bytes(&self) -> Message<Bytes> {
665        match *self {
666            Message::KeepAlive => Message::KeepAlive,
667            Message::Choke => Message::Choke,
668            Message::Unchoke => Message::Unchoke,
669            Message::Interested => Message::Interested,
670            Message::NotInterested => Message::NotInterested,
671            Message::Have { index } => Message::Have { index },
672            Message::Bitfield(data) => Message::Bitfield(Bytes::copy_from_slice(data)),
673            Message::Request {
674                index,
675                begin,
676                length,
677            } => Message::Request {
678                index,
679                begin,
680                length,
681            },
682            Message::Piece {
683                index,
684                begin,
685                data_0,
686                data_1,
687            } => Message::Piece {
688                index,
689                begin,
690                data_0: Bytes::copy_from_slice(data_0),
691                data_1: Bytes::copy_from_slice(data_1),
692            },
693            Message::Cancel {
694                index,
695                begin,
696                length,
697            } => Message::Cancel {
698                index,
699                begin,
700                length,
701            },
702            Message::Port(port) => Message::Port(port),
703            Message::Extended { ext_id, payload } => Message::Extended {
704                ext_id,
705                payload: Bytes::copy_from_slice(payload),
706            },
707            Message::SuggestPiece(index) => Message::SuggestPiece(index),
708            Message::HaveAll => Message::HaveAll,
709            Message::HaveNone => Message::HaveNone,
710            Message::RejectRequest {
711                index,
712                begin,
713                length,
714            } => Message::RejectRequest {
715                index,
716                begin,
717                length,
718            },
719            Message::AllowedFast(index) => Message::AllowedFast(index),
720            Message::HashRequest {
721                pieces_root,
722                base,
723                index,
724                count,
725                proof_layers,
726            } => Message::HashRequest {
727                pieces_root,
728                base,
729                index,
730                count,
731                proof_layers,
732            },
733            Message::Hashes {
734                ref pieces_root,
735                base,
736                index,
737                count,
738                proof_layers,
739                ref hashes,
740            } => Message::Hashes {
741                pieces_root: *pieces_root,
742                base,
743                index,
744                count,
745                proof_layers,
746                hashes: hashes.clone(),
747            },
748            Message::HashReject {
749                pieces_root,
750                base,
751                index,
752                count,
753                proof_layers,
754            } => Message::HashReject {
755                pieces_root,
756                base,
757                index,
758                count,
759                proof_layers,
760            },
761        }
762    }
763}
764
765impl Message<Bytes> {
766    /// Parse a message from its payload (after the 4-byte length prefix has
767    /// been consumed). `payload` is everything after the length prefix.
768    pub fn from_payload(payload: Bytes) -> Result<Self> {
769        if payload.is_empty() {
770            return Ok(Message::KeepAlive);
771        }
772
773        let id = payload[0];
774        let body = &payload[1..];
775
776        match id {
777            ID_CHOKE => Ok(Message::Choke),
778            ID_UNCHOKE => Ok(Message::Unchoke),
779            ID_INTERESTED => Ok(Message::Interested),
780            ID_NOT_INTERESTED => Ok(Message::NotInterested),
781            ID_HAVE => {
782                ensure_len(body, 4, "Have")?;
783                Ok(Message::Have {
784                    index: read_u32(body),
785                })
786            }
787            ID_BITFIELD => Ok(Message::Bitfield(payload.slice(1..))),
788            ID_REQUEST => {
789                ensure_len(body, 12, "Request")?;
790                Ok(Message::Request {
791                    index: read_u32(body),
792                    begin: read_u32(&body[4..]),
793                    length: read_u32(&body[8..]),
794                })
795            }
796            ID_PIECE => {
797                ensure_len(body, 8, "Piece")?;
798                let index = read_u32(body);
799                let begin = read_u32(&body[4..]);
800                Ok(Message::Piece {
801                    index,
802                    begin,
803                    data_0: payload.slice(9..),
804                    data_1: Bytes::new(),
805                })
806            }
807            ID_CANCEL => {
808                ensure_len(body, 12, "Cancel")?;
809                Ok(Message::Cancel {
810                    index: read_u32(body),
811                    begin: read_u32(&body[4..]),
812                    length: read_u32(&body[8..]),
813                })
814            }
815            ID_PORT => {
816                ensure_len(body, 2, "Port")?;
817                Ok(Message::Port(u16::from_be_bytes([body[0], body[1]])))
818            }
819            ID_EXTENDED => {
820                ensure_len(body, 1, "Extended")?;
821                let ext_id = body[0];
822                Ok(Message::Extended {
823                    ext_id,
824                    payload: payload.slice(2..),
825                })
826            }
827            ID_SUGGEST_PIECE => {
828                ensure_len(body, 4, "SuggestPiece")?;
829                Ok(Message::SuggestPiece(read_u32(body)))
830            }
831            ID_HAVE_ALL => Ok(Message::HaveAll),
832            ID_HAVE_NONE => Ok(Message::HaveNone),
833            ID_REJECT_REQUEST => {
834                ensure_len(body, 12, "RejectRequest")?;
835                Ok(Message::RejectRequest {
836                    index: read_u32(body),
837                    begin: read_u32(&body[4..]),
838                    length: read_u32(&body[8..]),
839                })
840            }
841            ID_ALLOWED_FAST => {
842                ensure_len(body, 4, "AllowedFast")?;
843                Ok(Message::AllowedFast(read_u32(body)))
844            }
845            ID_HASH_REQUEST | ID_HASH_REJECT => {
846                ensure_len(body, 48, "HashRequest/Reject")?;
847                let mut root = [0u8; 32];
848                root.copy_from_slice(&body[..32]);
849                let pieces_root = irontide_core::Id32(root);
850                let base = read_u32(&body[32..]);
851                let index = read_u32(&body[36..]);
852                let count = read_u32(&body[40..]);
853                let proof_layers = read_u32(&body[44..]);
854                if id == ID_HASH_REQUEST {
855                    Ok(Message::HashRequest {
856                        pieces_root,
857                        base,
858                        index,
859                        count,
860                        proof_layers,
861                    })
862                } else {
863                    Ok(Message::HashReject {
864                        pieces_root,
865                        base,
866                        index,
867                        count,
868                        proof_layers,
869                    })
870                }
871            }
872            ID_HASHES => {
873                ensure_len(body, 48, "Hashes")?;
874                let mut root = [0u8; 32];
875                root.copy_from_slice(&body[..32]);
876                let pieces_root = irontide_core::Id32(root);
877                let base = read_u32(&body[32..]);
878                let index = read_u32(&body[36..]);
879                let count = read_u32(&body[40..]);
880                let proof_layers = read_u32(&body[44..]);
881                let hash_data = &body[48..];
882                if !hash_data.len().is_multiple_of(32) {
883                    return Err(Error::MessageTooShort {
884                        expected: 48 + 32,
885                        got: body.len(),
886                    });
887                }
888                let hashes = hash_data
889                    .chunks_exact(32)
890                    .map(|chunk| {
891                        let mut h = [0u8; 32];
892                        h.copy_from_slice(chunk);
893                        irontide_core::Id32(h)
894                    })
895                    .collect();
896                Ok(Message::Hashes {
897                    pieces_root,
898                    base,
899                    index,
900                    count,
901                    proof_layers,
902                    hashes,
903                })
904            }
905            _ => Err(Error::InvalidMessageId(id)),
906        }
907    }
908}
909
910fn encode_fixed_into(dst: &mut BytesMut, id: u8) {
911    dst.put_u32(1);
912    dst.put_u8(id);
913}
914
915fn encode_triple_into(dst: &mut BytesMut, id: u8, a: u32, b: u32, c: u32) {
916    dst.put_u32(13);
917    dst.put_u8(id);
918    dst.put_u32(a);
919    dst.put_u32(b);
920    dst.put_u32(c);
921}
922
923fn fixed_msg(id: u8) -> Bytes {
924    let mut buf = BytesMut::with_capacity(5);
925    buf.put_u32(1);
926    buf.put_u8(id);
927    buf.freeze()
928}
929
930fn triple_msg(id: u8, a: u32, b: u32, c: u32) -> Bytes {
931    let mut buf = BytesMut::with_capacity(17);
932    buf.put_u32(13);
933    buf.put_u8(id);
934    buf.put_u32(a);
935    buf.put_u32(b);
936    buf.put_u32(c);
937    buf.freeze()
938}
939
940fn read_u32(buf: &[u8]) -> u32 {
941    let mut b = [0u8; 4];
942    b.copy_from_slice(&buf[..4]);
943    u32::from_be_bytes(b)
944}
945
946fn ensure_len(body: &[u8], min: usize, _name: &str) -> Result<()> {
947    if body.len() < min {
948        Err(Error::MessageTooShort {
949            expected: min,
950            got: body.len(),
951        })
952    } else {
953        Ok(())
954    }
955}
956
957/// BEP 6 Allowed-Fast set generation.
958///
959/// Generates a deterministic set of piece indices that a peer is allowed
960/// to request even while choked. Uses IP masking + info_hash + SHA1.
961///
962/// For IPv4: masks to /24 (matching BEP 6 spec).
963/// For IPv6: masks to /48 (matching libtorrent convention).
964pub fn allowed_fast_set(
965    info_hash: &irontide_core::Id20,
966    peer_ip: std::net::Ipv4Addr,
967    num_pieces: u32,
968    count: usize,
969) -> Vec<u32> {
970    allowed_fast_set_for_ip(info_hash, std::net::IpAddr::V4(peer_ip), num_pieces, count)
971}
972
973/// BEP 6 Allowed-Fast set generation for any IP address family.
974///
975/// IPv4: /24 prefix mask. IPv6: /48 prefix mask (libtorrent convention).
976pub fn allowed_fast_set_for_ip(
977    info_hash: &irontide_core::Id20,
978    peer_ip: std::net::IpAddr,
979    num_pieces: u32,
980    count: usize,
981) -> Vec<u32> {
982    use irontide_core::sha1;
983
984    if num_pieces == 0 {
985        return Vec::new();
986    }
987
988    let count = count.min(num_pieces as usize);
989    let mut result = Vec::with_capacity(count);
990
991    // Build masked IP bytes based on address family
992    let masked: Vec<u8> = match peer_ip {
993        std::net::IpAddr::V4(ipv4) => {
994            // Mask to /24
995            let o = ipv4.octets();
996            vec![o[0], o[1], o[2], 0]
997        }
998        std::net::IpAddr::V6(ipv6) => {
999            // Mask to /48: keep first 6 bytes, zero the rest
1000            let o = ipv6.octets();
1001            let mut masked = [0u8; 16];
1002            masked[..6].copy_from_slice(&o[..6]);
1003            masked.to_vec()
1004        }
1005    };
1006
1007    // Initial hash: SHA1(masked_ip + info_hash)
1008    let mut input = Vec::with_capacity(masked.len() + 20);
1009    input.extend_from_slice(&masked);
1010    input.extend_from_slice(info_hash.as_bytes());
1011    let mut hash = sha1(&input);
1012
1013    while result.len() < count {
1014        let hash_bytes = hash.as_bytes();
1015        // Each 20-byte hash gives us 5 candidate indices (4 bytes each)
1016        for i in (0..20).step_by(4) {
1017            if result.len() >= count {
1018                break;
1019            }
1020            let index = u32::from_be_bytes([
1021                hash_bytes[i],
1022                hash_bytes[i + 1],
1023                hash_bytes[i + 2],
1024                hash_bytes[i + 3],
1025            ]) % num_pieces;
1026            if !result.contains(&index) {
1027                result.push(index);
1028            }
1029        }
1030        // Re-hash for more candidates
1031        hash = sha1(hash.as_bytes());
1032    }
1033
1034    result
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040
1041    fn round_trip(msg: Message) {
1042        let bytes = msg.to_bytes();
1043        // Skip the 4-byte length prefix for parsing
1044        let parsed = Message::from_payload(Bytes::copy_from_slice(&bytes[4..])).unwrap();
1045        assert_eq!(msg, parsed);
1046    }
1047
1048    #[test]
1049    fn keepalive() {
1050        round_trip(Message::KeepAlive);
1051    }
1052
1053    #[test]
1054    fn choke_unchoke() {
1055        round_trip(Message::Choke);
1056        round_trip(Message::Unchoke);
1057    }
1058
1059    #[test]
1060    fn interested() {
1061        round_trip(Message::Interested);
1062        round_trip(Message::NotInterested);
1063    }
1064
1065    #[test]
1066    fn have() {
1067        round_trip(Message::Have { index: 42 });
1068    }
1069
1070    #[test]
1071    fn bitfield() {
1072        round_trip(Message::Bitfield(Bytes::from_static(&[0xFF, 0x80])));
1073    }
1074
1075    #[test]
1076    fn request() {
1077        round_trip(Message::Request {
1078            index: 1,
1079            begin: 0,
1080            length: 16384,
1081        });
1082    }
1083
1084    #[test]
1085    fn piece() {
1086        round_trip(Message::Piece {
1087            index: 1,
1088            begin: 0,
1089            data_0: Bytes::from_static(b"hello world"),
1090            data_1: Bytes::new(),
1091        });
1092    }
1093
1094    #[test]
1095    fn cancel() {
1096        round_trip(Message::Cancel {
1097            index: 1,
1098            begin: 0,
1099            length: 16384,
1100        });
1101    }
1102
1103    #[test]
1104    fn port() {
1105        round_trip(Message::Port(6881));
1106    }
1107
1108    #[test]
1109    fn extended() {
1110        round_trip(Message::Extended {
1111            ext_id: 1,
1112            payload: Bytes::from_static(b"test payload"),
1113        });
1114    }
1115
1116    #[test]
1117    fn invalid_message_id() {
1118        assert!(Message::from_payload(Bytes::from_static(&[99u8])).is_err());
1119    }
1120
1121    #[test]
1122    fn suggest_piece() {
1123        round_trip(Message::SuggestPiece(42));
1124    }
1125
1126    #[test]
1127    fn have_all() {
1128        round_trip(Message::HaveAll);
1129    }
1130
1131    #[test]
1132    fn have_none() {
1133        round_trip(Message::HaveNone);
1134    }
1135
1136    #[test]
1137    fn reject_request() {
1138        round_trip(Message::RejectRequest {
1139            index: 1,
1140            begin: 0,
1141            length: 16384,
1142        });
1143    }
1144
1145    #[test]
1146    fn allowed_fast() {
1147        round_trip(Message::AllowedFast(7));
1148    }
1149
1150    #[test]
1151    fn allowed_fast_set_deterministic() {
1152        use irontide_core::Id20;
1153        let ih = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
1154        let ip: std::net::Ipv4Addr = "192.168.1.100".parse().unwrap();
1155        let set1 = allowed_fast_set(&ih, ip, 1000, 10);
1156        let set2 = allowed_fast_set(&ih, ip, 1000, 10);
1157        assert_eq!(set1, set2);
1158        assert_eq!(set1.len(), 10);
1159        // All indices in range
1160        assert!(set1.iter().all(|&i| i < 1000));
1161    }
1162
1163    #[test]
1164    fn allowed_fast_set_unique() {
1165        use irontide_core::Id20;
1166        let ih = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
1167        let ip: std::net::Ipv4Addr = "10.0.0.1".parse().unwrap();
1168        let set = allowed_fast_set(&ih, ip, 50, 10);
1169        let unique: std::collections::HashSet<u32> = set.iter().copied().collect();
1170        assert_eq!(set.len(), unique.len(), "all indices should be unique");
1171    }
1172
1173    #[test]
1174    fn allowed_fast_set_empty_torrent() {
1175        use irontide_core::Id20;
1176        let ih = Id20::ZERO;
1177        let ip: std::net::Ipv4Addr = "127.0.0.1".parse().unwrap();
1178        assert!(allowed_fast_set(&ih, ip, 0, 10).is_empty());
1179    }
1180
1181    #[test]
1182    fn allowed_fast_set_ipv6() {
1183        use irontide_core::Id20;
1184        let ih = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
1185        let ip: std::net::IpAddr = "2001:db8::1".parse().unwrap();
1186        let set = allowed_fast_set_for_ip(&ih, ip, 1000, 10);
1187        assert_eq!(set.len(), 10);
1188        assert!(set.iter().all(|&i| i < 1000));
1189
1190        // Same /48 prefix should produce same set
1191        let ip2: std::net::IpAddr = "2001:db8::ffff".parse().unwrap();
1192        let set2 = allowed_fast_set_for_ip(&ih, ip2, 1000, 10);
1193        assert_eq!(set, set2);
1194
1195        // Different /48 prefix should produce different set
1196        let ip3: std::net::IpAddr = "2001:db9::1".parse().unwrap();
1197        let set3 = allowed_fast_set_for_ip(&ih, ip3, 1000, 10);
1198        assert_ne!(set, set3);
1199    }
1200
1201    #[test]
1202    fn hash_request_round_trip() {
1203        let msg = Message::HashRequest {
1204            pieces_root: irontide_core::Id32::ZERO,
1205            base: 7,
1206            index: 0,
1207            count: 512,
1208            proof_layers: 3,
1209        };
1210        round_trip(msg);
1211    }
1212
1213    #[test]
1214    fn hash_reject_round_trip() {
1215        let msg = Message::HashReject {
1216            pieces_root: irontide_core::Id32::ZERO,
1217            base: 7,
1218            index: 0,
1219            count: 512,
1220            proof_layers: 3,
1221        };
1222        round_trip(msg);
1223    }
1224
1225    #[test]
1226    fn hashes_round_trip() {
1227        let h1 = irontide_core::sha256(b"block1");
1228        let h2 = irontide_core::sha256(b"block2");
1229        let uncle = irontide_core::sha256(b"uncle");
1230        let msg = Message::Hashes {
1231            pieces_root: irontide_core::Id32::ZERO,
1232            base: 0,
1233            index: 0,
1234            count: 2,
1235            proof_layers: 1,
1236            hashes: vec![h1, h2, uncle],
1237        };
1238        round_trip(msg);
1239    }
1240
1241    #[test]
1242    fn hash_request_exact_wire_size() {
1243        let msg: Message = Message::HashRequest {
1244            pieces_root: irontide_core::Id32::ZERO,
1245            base: 0,
1246            index: 0,
1247            count: 1,
1248            proof_layers: 0,
1249        };
1250        let bytes = msg.to_bytes();
1251        // 4 (length prefix) + 1 (msg id) + 32 (root) + 4*4 (fields) = 53
1252        assert_eq!(bytes.len(), 53);
1253    }
1254
1255    #[test]
1256    fn hashes_variable_length() {
1257        let h = irontide_core::sha256(b"test");
1258        let msg: Message = Message::Hashes {
1259            pieces_root: irontide_core::Id32::ZERO,
1260            base: 0,
1261            index: 0,
1262            count: 1,
1263            proof_layers: 0,
1264            hashes: vec![h],
1265        };
1266        let bytes = msg.to_bytes();
1267        // 4 + 1 + 32 + 4*4 + 1*32 = 85
1268        assert_eq!(bytes.len(), 85);
1269    }
1270
1271    #[test]
1272    fn hash_request_too_short() {
1273        // msg id 21, but only 10 bytes of body (need 48)
1274        let mut payload = vec![21u8];
1275        payload.extend_from_slice(&[0u8; 10]);
1276        assert!(Message::from_payload(Bytes::from(payload)).is_err());
1277    }
1278
1279    #[test]
1280    fn encode_into_matches_to_bytes() {
1281        let messages = vec![
1282            Message::KeepAlive,
1283            Message::Choke,
1284            Message::Unchoke,
1285            Message::Interested,
1286            Message::NotInterested,
1287            Message::Have { index: 42 },
1288            Message::Bitfield(Bytes::from_static(b"\xff\x00")),
1289            Message::Request {
1290                index: 1,
1291                begin: 0,
1292                length: 16384,
1293            },
1294            Message::Piece {
1295                index: 0,
1296                begin: 0,
1297                data_0: Bytes::from_static(b"block data here"),
1298                data_1: Bytes::new(),
1299            },
1300            Message::Cancel {
1301                index: 1,
1302                begin: 0,
1303                length: 16384,
1304            },
1305            Message::Port(6881),
1306            Message::Extended {
1307                ext_id: 0,
1308                payload: Bytes::from_static(b"ext payload"),
1309            },
1310            Message::SuggestPiece(7),
1311            Message::HaveAll,
1312            Message::HaveNone,
1313            Message::RejectRequest {
1314                index: 1,
1315                begin: 0,
1316                length: 16384,
1317            },
1318            Message::AllowedFast(5),
1319            Message::HashRequest {
1320                pieces_root: irontide_core::Id32::ZERO,
1321                base: 7,
1322                index: 0,
1323                count: 512,
1324                proof_layers: 3,
1325            },
1326            Message::HashReject {
1327                pieces_root: irontide_core::Id32::ZERO,
1328                base: 7,
1329                index: 0,
1330                count: 512,
1331                proof_layers: 3,
1332            },
1333            Message::Hashes {
1334                pieces_root: irontide_core::Id32::ZERO,
1335                base: 0,
1336                index: 0,
1337                count: 2,
1338                proof_layers: 1,
1339                hashes: vec![
1340                    irontide_core::sha256(b"block1"),
1341                    irontide_core::sha256(b"block2"),
1342                    irontide_core::sha256(b"uncle"),
1343                ],
1344            },
1345        ];
1346        for msg in messages {
1347            let expected = msg.to_bytes();
1348            let mut buf = BytesMut::new();
1349            msg.encode_into(&mut buf);
1350            assert_eq!(&expected[..], &buf[..], "mismatch for {msg:?}");
1351        }
1352    }
1353
1354    #[test]
1355    fn allowed_fast_set_ipv4_compat() {
1356        // allowed_fast_set (IPv4-only) and allowed_fast_set_for_ip with V4 should match
1357        use irontide_core::Id20;
1358        let ih = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
1359        let ipv4: std::net::Ipv4Addr = "192.168.1.100".parse().unwrap();
1360        let set_v4 = allowed_fast_set(&ih, ipv4, 1000, 10);
1361        let set_ip = allowed_fast_set_for_ip(&ih, std::net::IpAddr::V4(ipv4), 1000, 10);
1362        assert_eq!(set_v4, set_ip);
1363    }
1364
1365    // ── M158: BEP 6 AllowedFast spec vector conformance ──
1366
1367    #[test]
1368    fn allowed_fast_bep6_spec_vector_k7() {
1369        // BEP 6 exact test vector: IP 80.4.4.200, InfoHash 0xaaaa..aa, 1313 pieces, k=7
1370        use irontide_core::Id20;
1371        let ih = Id20([0xaa; 20]);
1372        let ip: std::net::Ipv4Addr = "80.4.4.200".parse().unwrap();
1373        let set = allowed_fast_set(&ih, ip, 1313, 7);
1374        assert_eq!(set, vec![1059, 431, 808, 1217, 287, 376, 1188]);
1375    }
1376
1377    #[test]
1378    fn allowed_fast_bep6_spec_vector_k9() {
1379        // BEP 6 exact test vector: same inputs, k=9 — extends k=7 with [353, 508]
1380        use irontide_core::Id20;
1381        let ih = Id20([0xaa; 20]);
1382        let ip: std::net::Ipv4Addr = "80.4.4.200".parse().unwrap();
1383        let set = allowed_fast_set(&ih, ip, 1313, 9);
1384        assert_eq!(set, vec![1059, 431, 808, 1217, 287, 376, 1188, 353, 508]);
1385    }
1386
1387    #[test]
1388    fn allowed_fast_ip_masking() {
1389        // BEP 6 spec: peer IP is masked to /24 before hashing.
1390        // 80.4.4.200 should mask to 80.4.4.0 = [0x50, 0x04, 0x04, 0x00].
1391        // Verify that any IP in the same /24 produces the same set.
1392        use irontide_core::Id20;
1393        let ih = Id20([0xaa; 20]);
1394        let ip_a: std::net::Ipv4Addr = "80.4.4.200".parse().unwrap();
1395        let ip_b: std::net::Ipv4Addr = "80.4.4.0".parse().unwrap();
1396        let ip_c: std::net::Ipv4Addr = "80.4.4.255".parse().unwrap();
1397        let set_a = allowed_fast_set(&ih, ip_a, 1313, 7);
1398        let set_b = allowed_fast_set(&ih, ip_b, 1313, 7);
1399        let set_c = allowed_fast_set(&ih, ip_c, 1313, 7);
1400        assert_eq!(set_a, set_b, "80.4.4.200 and 80.4.4.0 must produce same set (same /24)");
1401        assert_eq!(set_a, set_c, "80.4.4.200 and 80.4.4.255 must produce same set (same /24)");
1402
1403        // Verify the masked bytes are correct: 80=0x50, 4=0x04, 4=0x04, 0=0x00
1404        let octets = ip_a.octets();
1405        let masked = [octets[0], octets[1], octets[2], 0u8];
1406        assert_eq!(masked, [0x50, 0x04, 0x04, 0x00]);
1407    }
1408
1409    // ── M110: Generic Message<B> tests ──
1410
1411    #[test]
1412    fn message_piece_two_fields_round_trip() {
1413        // Encode a Piece with data_0 only (data_1 empty) — common case.
1414        let msg = Message::Piece {
1415            index: 5,
1416            begin: 16384,
1417            data_0: Bytes::from_static(b"block payload here"),
1418            data_1: Bytes::new(),
1419        };
1420        let bytes = msg.to_bytes();
1421        let parsed = Message::from_payload(Bytes::copy_from_slice(&bytes[4..])).unwrap();
1422        assert_eq!(msg, parsed);
1423    }
1424
1425    #[test]
1426    fn message_piece_split_data_round_trip() {
1427        // Encode a Piece with both data_0 and data_1 populated (ring-wrap case).
1428        // The wire format concatenates them, so decoding puts everything into data_0.
1429        let msg = Message::Piece {
1430            index: 3,
1431            begin: 0,
1432            data_0: Bytes::from_static(b"first half"),
1433            data_1: Bytes::from_static(b" second half"),
1434        };
1435        let bytes = msg.to_bytes();
1436        let parsed = Message::from_payload(Bytes::copy_from_slice(&bytes[4..])).unwrap();
1437        // After round-trip through the wire, the data is concatenated into data_0
1438        // with data_1 empty (the receiver doesn't know about ring-wrap).
1439        assert_eq!(
1440            parsed,
1441            Message::Piece {
1442                index: 3,
1443                begin: 0,
1444                data_0: Bytes::from_static(b"first half second half"),
1445                data_1: Bytes::new(),
1446            }
1447        );
1448    }
1449
1450    #[test]
1451    fn message_generic_encode_borrowed() {
1452        // Verify that Message<&[u8]> can encode_into just like Message<Bytes>.
1453        let borrowed: Message<&[u8]> = Message::Piece {
1454            index: 1,
1455            begin: 0,
1456            data_0: b"borrowed data",
1457            data_1: b"",
1458        };
1459        let owned: Message<Bytes> = Message::Piece {
1460            index: 1,
1461            begin: 0,
1462            data_0: Bytes::from_static(b"borrowed data"),
1463            data_1: Bytes::new(),
1464        };
1465        let mut buf_borrowed = BytesMut::new();
1466        borrowed.encode_into(&mut buf_borrowed);
1467        let mut buf_owned = BytesMut::new();
1468        owned.encode_into(&mut buf_owned);
1469        assert_eq!(
1470            buf_borrowed, buf_owned,
1471            "borrowed and owned encode identically"
1472        );
1473
1474        // Also test to_bytes
1475        assert_eq!(borrowed.to_bytes(), owned.to_bytes());
1476
1477        // Test borrowed Bitfield
1478        let bf_borrowed: Message<&[u8]> = Message::Bitfield(b"\xff\x80");
1479        let bf_owned: Message<Bytes> = Message::Bitfield(Bytes::from_static(b"\xff\x80"));
1480        assert_eq!(bf_borrowed.to_bytes(), bf_owned.to_bytes());
1481
1482        // Test borrowed Extended
1483        let ext_borrowed: Message<&[u8]> = Message::Extended {
1484            ext_id: 1,
1485            payload: b"payload",
1486        };
1487        let ext_owned: Message<Bytes> = Message::Extended {
1488            ext_id: 1,
1489            payload: Bytes::from_static(b"payload"),
1490        };
1491        assert_eq!(ext_borrowed.to_bytes(), ext_owned.to_bytes());
1492    }
1493
1494    // ── M115: encode_to_slice tests ──
1495
1496    /// Build the complete set of message variants used for encode_to_slice tests.
1497    fn all_message_variants() -> Vec<Message> {
1498        vec![
1499            Message::KeepAlive,
1500            Message::Choke,
1501            Message::Unchoke,
1502            Message::Interested,
1503            Message::NotInterested,
1504            Message::Have { index: 42 },
1505            Message::Bitfield(Bytes::from_static(b"\xff\x00")),
1506            Message::Request {
1507                index: 1,
1508                begin: 0,
1509                length: 16384,
1510            },
1511            Message::Piece {
1512                index: 0,
1513                begin: 0,
1514                data_0: Bytes::from_static(b"block data here"),
1515                data_1: Bytes::new(),
1516            },
1517            Message::Piece {
1518                index: 3,
1519                begin: 0,
1520                data_0: Bytes::from_static(b"first half"),
1521                data_1: Bytes::from_static(b" second half"),
1522            },
1523            Message::Cancel {
1524                index: 1,
1525                begin: 0,
1526                length: 16384,
1527            },
1528            Message::Port(6881),
1529            Message::Extended {
1530                ext_id: 0,
1531                payload: Bytes::from_static(b"ext payload"),
1532            },
1533            Message::SuggestPiece(7),
1534            Message::HaveAll,
1535            Message::HaveNone,
1536            Message::RejectRequest {
1537                index: 1,
1538                begin: 0,
1539                length: 16384,
1540            },
1541            Message::AllowedFast(5),
1542            Message::HashRequest {
1543                pieces_root: irontide_core::Id32::ZERO,
1544                base: 7,
1545                index: 0,
1546                count: 512,
1547                proof_layers: 3,
1548            },
1549            Message::HashReject {
1550                pieces_root: irontide_core::Id32::ZERO,
1551                base: 7,
1552                index: 0,
1553                count: 512,
1554                proof_layers: 3,
1555            },
1556            Message::Hashes {
1557                pieces_root: irontide_core::Id32::ZERO,
1558                base: 0,
1559                index: 0,
1560                count: 2,
1561                proof_layers: 1,
1562                hashes: vec![
1563                    irontide_core::sha256(b"block1"),
1564                    irontide_core::sha256(b"block2"),
1565                    irontide_core::sha256(b"uncle"),
1566                ],
1567            },
1568        ]
1569    }
1570
1571    #[test]
1572    fn encode_to_slice_roundtrip() {
1573        for msg in all_message_variants() {
1574            let mut buf = [0u8; 4096];
1575            let n = msg.encode_to_slice(&mut buf);
1576            // Skip the 4-byte length prefix for parsing
1577            let parsed = Message::from_payload(Bytes::copy_from_slice(&buf[4..n])).unwrap();
1578            // For Piece with split data, wire format concatenates into data_0
1579            match &msg {
1580                Message::Piece {
1581                    index,
1582                    begin,
1583                    data_0,
1584                    data_1,
1585                } if !data_1.is_empty() => {
1586                    let mut combined = Vec::from(data_0.as_ref());
1587                    combined.extend_from_slice(data_1.as_ref());
1588                    let expected = Message::Piece {
1589                        index: *index,
1590                        begin: *begin,
1591                        data_0: Bytes::from(combined),
1592                        data_1: Bytes::new(),
1593                    };
1594                    assert_eq!(parsed, expected, "roundtrip mismatch for split Piece");
1595                }
1596                _ => {
1597                    assert_eq!(msg, parsed, "roundtrip mismatch for {msg:?}");
1598                }
1599            }
1600        }
1601    }
1602
1603    #[test]
1604    fn encode_to_slice_matches_encode_into() {
1605        for msg in all_message_variants() {
1606            let mut slice_buf = [0u8; 4096];
1607            let n = msg.encode_to_slice(&mut slice_buf);
1608
1609            let mut bytes_buf = BytesMut::new();
1610            msg.encode_into(&mut bytes_buf);
1611
1612            assert_eq!(
1613                &slice_buf[..n],
1614                &bytes_buf[..],
1615                "encode_to_slice vs encode_into mismatch for {msg:?}"
1616            );
1617        }
1618    }
1619
1620    #[test]
1621    fn wire_len_matches_encoded_size() {
1622        for msg in all_message_variants() {
1623            let expected = msg.to_bytes().len();
1624            assert_eq!(msg.wire_len(), expected, "wire_len mismatch for {msg:?}");
1625        }
1626    }
1627
1628    #[test]
1629    fn wire_len_large_bitfield() {
1630        // Bitfield for a torrent with >131K pieces (16,384 bytes of bitfield data).
1631        let bits = vec![0xFFu8; 20_000];
1632        let msg = Message::Bitfield(Bytes::from(bits.clone()));
1633        assert_eq!(msg.wire_len(), 5 + bits.len());
1634        assert_eq!(msg.wire_len(), msg.to_bytes().len());
1635    }
1636}