1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use super::error::PeerError;
4
5pub const PROTOCOL: &[u8] = b"BitTorrent protocol";
7pub const HANDSHAKE_LEN: usize = 68;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[repr(u8)]
15pub enum MessageId {
16 Choke = 0,
18 Unchoke = 1,
20 Interested = 2,
22 NotInterested = 3,
24 Have = 4,
26 Bitfield = 5,
28 Request = 6,
30 Piece = 7,
32 Cancel = 8,
34 Port = 9,
36 Suggest = 13,
39 HaveAll = 14,
41 HaveNone = 15,
43 Reject = 16,
45 AllowedFast = 17,
47 Extended = 20,
50 HashRequest = 21,
53 Hashes = 22,
55 HashReject = 23,
57}
58
59impl TryFrom<u8> for MessageId {
60 type Error = PeerError;
61
62 fn try_from(value: u8) -> Result<Self, Self::Error> {
63 match value {
64 0 => Ok(MessageId::Choke),
65 1 => Ok(MessageId::Unchoke),
66 2 => Ok(MessageId::Interested),
67 3 => Ok(MessageId::NotInterested),
68 4 => Ok(MessageId::Have),
69 5 => Ok(MessageId::Bitfield),
70 6 => Ok(MessageId::Request),
71 7 => Ok(MessageId::Piece),
72 8 => Ok(MessageId::Cancel),
73 9 => Ok(MessageId::Port),
74 13 => Ok(MessageId::Suggest),
75 14 => Ok(MessageId::HaveAll),
76 15 => Ok(MessageId::HaveNone),
77 16 => Ok(MessageId::Reject),
78 17 => Ok(MessageId::AllowedFast),
79 20 => Ok(MessageId::Extended),
80 21 => Ok(MessageId::HashRequest),
81 22 => Ok(MessageId::Hashes),
82 23 => Ok(MessageId::HashReject),
83 _ => Err(PeerError::InvalidMessageId(value)),
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
107pub struct Handshake {
108 pub info_hash: [u8; 20],
110 pub peer_id: [u8; 20],
112 pub reserved: [u8; 8],
114}
115
116impl Handshake {
117 pub fn new(info_hash: [u8; 20], peer_id: [u8; 20]) -> Self {
119 let mut reserved = [0u8; 8];
120 reserved[5] |= 0x10; reserved[7] |= 0x04; Self {
123 info_hash,
124 peer_id,
125 reserved,
126 }
127 }
128
129 pub fn new_v2(info_hash: [u8; 20], peer_id: [u8; 20]) -> Self {
133 let mut reserved = [0u8; 8];
134 reserved[5] |= 0x10; reserved[7] |= 0x04; reserved[7] |= 0x10; Self {
138 info_hash,
139 peer_id,
140 reserved,
141 }
142 }
143
144 pub fn supports_extension_protocol(&self) -> bool {
146 (self.reserved[5] & 0x10) != 0
147 }
148
149 pub fn supports_fast_extension(&self) -> bool {
151 (self.reserved[7] & 0x04) != 0
152 }
153
154 pub fn supports_dht(&self) -> bool {
156 (self.reserved[7] & 0x01) != 0
157 }
158
159 pub fn supports_v2(&self) -> bool {
164 (self.reserved[7] & 0x10) != 0
165 }
166
167 pub fn set_v2_support(&mut self, enabled: bool) {
169 if enabled {
170 self.reserved[7] |= 0x10;
171 } else {
172 self.reserved[7] &= !0x10;
173 }
174 }
175
176 pub fn encode(&self) -> Bytes {
178 let mut buf = BytesMut::with_capacity(HANDSHAKE_LEN);
179 buf.put_u8(19);
180 buf.put_slice(PROTOCOL);
181 buf.put_slice(&self.reserved);
182 buf.put_slice(&self.info_hash);
183 buf.put_slice(&self.peer_id);
184 buf.freeze()
185 }
186
187 pub fn decode(data: &[u8]) -> Result<Self, PeerError> {
188 if data.len() < HANDSHAKE_LEN {
189 return Err(PeerError::InvalidHandshake);
190 }
191
192 if data[0] != 19 || &data[1..20] != PROTOCOL {
193 return Err(PeerError::InvalidHandshake);
194 }
195
196 let mut reserved = [0u8; 8];
197 reserved.copy_from_slice(&data[20..28]);
198
199 let mut info_hash = [0u8; 20];
200 info_hash.copy_from_slice(&data[28..48]);
201
202 let mut peer_id = [0u8; 20];
203 peer_id.copy_from_slice(&data[48..68]);
204
205 Ok(Self {
206 info_hash,
207 peer_id,
208 reserved,
209 })
210 }
211}
212
213#[derive(Debug, Clone)]
235pub enum Message {
236 KeepAlive,
238 Choke,
240 Unchoke,
242 Interested,
244 NotInterested,
246 Have { piece: u32 },
248 Bitfield(Bytes),
250 Request { index: u32, begin: u32, length: u32 },
252 Piece { index: u32, begin: u32, data: Bytes },
254 Cancel { index: u32, begin: u32, length: u32 },
256 Port(u16),
258 Suggest { piece: u32 },
261 HaveAll,
263 HaveNone,
265 Reject { index: u32, begin: u32, length: u32 },
267 AllowedFast { piece: u32 },
269 Extended { id: u8, payload: Bytes },
272 HashRequest {
277 pieces_root: [u8; 32],
279 base_layer: u32,
281 index: u32,
283 length: u32,
285 proof_layers: u32,
287 },
288 Hashes {
292 pieces_root: [u8; 32],
294 base_layer: u32,
296 index: u32,
298 length: u32,
300 proof_layers: u32,
302 hashes: Bytes,
304 },
305 HashReject {
309 pieces_root: [u8; 32],
311 base_layer: u32,
313 index: u32,
315 length: u32,
317 proof_layers: u32,
319 },
320}
321
322impl Message {
323 pub fn encode(&self) -> Bytes {
327 let mut buf = BytesMut::new();
328
329 match self {
330 Message::KeepAlive => {
331 buf.put_u32(0);
332 }
333 Message::Choke => {
334 buf.put_u32(1);
335 buf.put_u8(MessageId::Choke as u8);
336 }
337 Message::Unchoke => {
338 buf.put_u32(1);
339 buf.put_u8(MessageId::Unchoke as u8);
340 }
341 Message::Interested => {
342 buf.put_u32(1);
343 buf.put_u8(MessageId::Interested as u8);
344 }
345 Message::NotInterested => {
346 buf.put_u32(1);
347 buf.put_u8(MessageId::NotInterested as u8);
348 }
349 Message::Have { piece } => {
350 buf.put_u32(5);
351 buf.put_u8(MessageId::Have as u8);
352 buf.put_u32(*piece);
353 }
354 Message::Bitfield(bits) => {
355 buf.put_u32(1 + bits.len() as u32);
356 buf.put_u8(MessageId::Bitfield as u8);
357 buf.put_slice(bits);
358 }
359 Message::Request {
360 index,
361 begin,
362 length,
363 } => {
364 buf.put_u32(13);
365 buf.put_u8(MessageId::Request as u8);
366 buf.put_u32(*index);
367 buf.put_u32(*begin);
368 buf.put_u32(*length);
369 }
370 Message::Piece { index, begin, data } => {
371 buf.put_u32(9 + data.len() as u32);
372 buf.put_u8(MessageId::Piece as u8);
373 buf.put_u32(*index);
374 buf.put_u32(*begin);
375 buf.put_slice(data);
376 }
377 Message::Cancel {
378 index,
379 begin,
380 length,
381 } => {
382 buf.put_u32(13);
383 buf.put_u8(MessageId::Cancel as u8);
384 buf.put_u32(*index);
385 buf.put_u32(*begin);
386 buf.put_u32(*length);
387 }
388 Message::Port(port) => {
389 buf.put_u32(3);
390 buf.put_u8(MessageId::Port as u8);
391 buf.put_u16(*port);
392 }
393 Message::Suggest { piece } => {
394 buf.put_u32(5);
395 buf.put_u8(MessageId::Suggest as u8);
396 buf.put_u32(*piece);
397 }
398 Message::HaveAll => {
399 buf.put_u32(1);
400 buf.put_u8(MessageId::HaveAll as u8);
401 }
402 Message::HaveNone => {
403 buf.put_u32(1);
404 buf.put_u8(MessageId::HaveNone as u8);
405 }
406 Message::Reject {
407 index,
408 begin,
409 length,
410 } => {
411 buf.put_u32(13);
412 buf.put_u8(MessageId::Reject as u8);
413 buf.put_u32(*index);
414 buf.put_u32(*begin);
415 buf.put_u32(*length);
416 }
417 Message::AllowedFast { piece } => {
418 buf.put_u32(5);
419 buf.put_u8(MessageId::AllowedFast as u8);
420 buf.put_u32(*piece);
421 }
422 Message::Extended { id, payload } => {
423 buf.put_u32(2 + payload.len() as u32);
424 buf.put_u8(MessageId::Extended as u8);
425 buf.put_u8(*id);
426 buf.put_slice(payload);
427 }
428 Message::HashRequest {
431 pieces_root,
432 base_layer,
433 index,
434 length,
435 proof_layers,
436 } => {
437 buf.put_u32(49);
438 buf.put_u8(MessageId::HashRequest as u8);
439 buf.put_slice(pieces_root);
440 buf.put_u32(*base_layer);
441 buf.put_u32(*index);
442 buf.put_u32(*length);
443 buf.put_u32(*proof_layers);
444 }
445 Message::Hashes {
447 pieces_root,
448 base_layer,
449 index,
450 length,
451 proof_layers,
452 hashes,
453 } => {
454 buf.put_u32(49 + hashes.len() as u32);
455 buf.put_u8(MessageId::Hashes as u8);
456 buf.put_slice(pieces_root);
457 buf.put_u32(*base_layer);
458 buf.put_u32(*index);
459 buf.put_u32(*length);
460 buf.put_u32(*proof_layers);
461 buf.put_slice(hashes);
462 }
463 Message::HashReject {
465 pieces_root,
466 base_layer,
467 index,
468 length,
469 proof_layers,
470 } => {
471 buf.put_u32(49);
472 buf.put_u8(MessageId::HashReject as u8);
473 buf.put_slice(pieces_root);
474 buf.put_u32(*base_layer);
475 buf.put_u32(*index);
476 buf.put_u32(*length);
477 buf.put_u32(*proof_layers);
478 }
479 }
480
481 buf.freeze()
482 }
483
484 pub fn decode(mut data: Bytes) -> Result<Self, PeerError> {
485 if data.len() < 4 {
486 return Err(PeerError::InvalidMessage("too short".into()));
487 }
488
489 let length = data.get_u32() as usize;
490
491 if length == 0 {
492 return Ok(Message::KeepAlive);
493 }
494
495 if data.remaining() < length {
496 return Err(PeerError::InvalidMessage("incomplete message".into()));
497 }
498
499 let id = MessageId::try_from(data.get_u8())?;
500
501 match id {
502 MessageId::Choke => Ok(Message::Choke),
503 MessageId::Unchoke => Ok(Message::Unchoke),
504 MessageId::Interested => Ok(Message::Interested),
505 MessageId::NotInterested => Ok(Message::NotInterested),
506 MessageId::Have => {
507 if data.remaining() < 4 {
508 return Err(PeerError::InvalidMessage("have too short".into()));
509 }
510 Ok(Message::Have {
511 piece: data.get_u32(),
512 })
513 }
514 MessageId::Bitfield => Ok(Message::Bitfield(data.copy_to_bytes(length - 1))),
515 MessageId::Request => {
516 if data.remaining() < 12 {
517 return Err(PeerError::InvalidMessage("request too short".into()));
518 }
519 Ok(Message::Request {
520 index: data.get_u32(),
521 begin: data.get_u32(),
522 length: data.get_u32(),
523 })
524 }
525 MessageId::Piece => {
526 if data.remaining() < 8 {
527 return Err(PeerError::InvalidMessage("piece too short".into()));
528 }
529 let index = data.get_u32();
530 let begin = data.get_u32();
531 let block_data = data.copy_to_bytes(length - 9);
532 Ok(Message::Piece {
533 index,
534 begin,
535 data: block_data,
536 })
537 }
538 MessageId::Cancel => {
539 if data.remaining() < 12 {
540 return Err(PeerError::InvalidMessage("cancel too short".into()));
541 }
542 Ok(Message::Cancel {
543 index: data.get_u32(),
544 begin: data.get_u32(),
545 length: data.get_u32(),
546 })
547 }
548 MessageId::Port => {
549 if data.remaining() < 2 {
550 return Err(PeerError::InvalidMessage("port too short".into()));
551 }
552 Ok(Message::Port(data.get_u16()))
553 }
554 MessageId::Suggest => {
555 if data.remaining() < 4 {
556 return Err(PeerError::InvalidMessage("suggest too short".into()));
557 }
558 Ok(Message::Suggest {
559 piece: data.get_u32(),
560 })
561 }
562 MessageId::HaveAll => Ok(Message::HaveAll),
563 MessageId::HaveNone => Ok(Message::HaveNone),
564 MessageId::Reject => {
565 if data.remaining() < 12 {
566 return Err(PeerError::InvalidMessage("reject too short".into()));
567 }
568 Ok(Message::Reject {
569 index: data.get_u32(),
570 begin: data.get_u32(),
571 length: data.get_u32(),
572 })
573 }
574 MessageId::AllowedFast => {
575 if data.remaining() < 4 {
576 return Err(PeerError::InvalidMessage("allowed fast too short".into()));
577 }
578 Ok(Message::AllowedFast {
579 piece: data.get_u32(),
580 })
581 }
582 MessageId::Extended => {
583 if data.remaining() < 1 {
584 return Err(PeerError::InvalidMessage("extended too short".into()));
585 }
586 let ext_id = data.get_u8();
587 let payload = data.copy_to_bytes(length - 2);
588 Ok(Message::Extended {
589 id: ext_id,
590 payload,
591 })
592 }
593 MessageId::HashRequest => {
595 if data.remaining() < 48 {
597 return Err(PeerError::InvalidMessage("hash request too short".into()));
598 }
599 let mut pieces_root = [0u8; 32];
600 pieces_root.copy_from_slice(&data.copy_to_bytes(32));
601 Ok(Message::HashRequest {
602 pieces_root,
603 base_layer: data.get_u32(),
604 index: data.get_u32(),
605 length: data.get_u32(),
606 proof_layers: data.get_u32(),
607 })
608 }
609 MessageId::Hashes => {
610 if data.remaining() < 48 {
612 return Err(PeerError::InvalidMessage("hashes too short".into()));
613 }
614 let mut pieces_root = [0u8; 32];
615 pieces_root.copy_from_slice(&data.copy_to_bytes(32));
616 let base_layer = data.get_u32();
617 let index = data.get_u32();
618 let hash_length = data.get_u32();
619 let proof_layers = data.get_u32();
620 let hashes_len = length - 49; if data.remaining() < hashes_len {
623 return Err(PeerError::InvalidMessage("hashes data too short".into()));
624 }
625 let hashes = data.copy_to_bytes(hashes_len);
626 if hashes.len() % 32 != 0 {
628 return Err(PeerError::InvalidMessage(
629 "hashes not multiple of 32 bytes".into(),
630 ));
631 }
632 Ok(Message::Hashes {
633 pieces_root,
634 base_layer,
635 index,
636 length: hash_length,
637 proof_layers,
638 hashes,
639 })
640 }
641 MessageId::HashReject => {
642 if data.remaining() < 48 {
644 return Err(PeerError::InvalidMessage("hash reject too short".into()));
645 }
646 let mut pieces_root = [0u8; 32];
647 pieces_root.copy_from_slice(&data.copy_to_bytes(32));
648 Ok(Message::HashReject {
649 pieces_root,
650 base_layer: data.get_u32(),
651 index: data.get_u32(),
652 length: data.get_u32(),
653 proof_layers: data.get_u32(),
654 })
655 }
656 }
657 }
658}
659
660pub fn validate_hash_request(length: u32, index: u32) -> Option<&'static str> {
664 if length < 2 {
666 return Some("length must be >= 2");
667 }
668 if length & (length - 1) != 0 {
670 return Some("length must be power of 2");
671 }
672 if length > 512 {
674 return Some("length exceeds 512");
675 }
676 if index % length != 0 {
678 return Some("index must be multiple of length");
679 }
680 None
681}