1pub mod extended;
6
7use bincode::Options;
8use buffers::{ByteBuf, ByteBufOwned, ByteBufT};
9use byteorder::{ByteOrder, BE};
10use bytes::Bytes;
11use clone_to_owned::CloneToOwned;
12use extended::PeerExtendedMessageIds;
13use librqbit_core::{constants::CHUNK_SIZE, hash_id::Id20, lengths::ChunkInfo};
14use serde::{Deserialize, Serialize};
15
16use self::extended::ExtendedMessage;
17
18const INTEGER_LEN: usize = 4;
19const MSGID_LEN: usize = 1;
20const PREAMBLE_LEN: usize = INTEGER_LEN + MSGID_LEN;
21const PIECE_MESSAGE_PREAMBLE_LEN: usize = PREAMBLE_LEN + INTEGER_LEN * 2;
22pub const PIECE_MESSAGE_DEFAULT_LEN: usize = PIECE_MESSAGE_PREAMBLE_LEN + CHUNK_SIZE as usize;
23
24const NO_PAYLOAD_MSG_LEN: usize = PREAMBLE_LEN;
25
26const PSTR_BT1: &str = "BitTorrent protocol";
27
28const LEN_PREFIX_KEEPALIVE: u32 = 0;
29const LEN_PREFIX_CHOKE: u32 = 1;
30const LEN_PREFIX_UNCHOKE: u32 = 1;
31const LEN_PREFIX_INTERESTED: u32 = 1;
32const LEN_PREFIX_NOT_INTERESTED: u32 = 1;
33const LEN_PREFIX_HAVE: u32 = 5;
34const LEN_PREFIX_PIECE: u32 = 9;
35const LEN_PREFIX_REQUEST: u32 = 13;
36
37const MSGID_CHOKE: u8 = 0;
38const MSGID_UNCHOKE: u8 = 1;
39const MSGID_INTERESTED: u8 = 2;
40const MSGID_NOT_INTERESTED: u8 = 3;
41const MSGID_HAVE: u8 = 4;
42const MSGID_BITFIELD: u8 = 5;
43const MSGID_REQUEST: u8 = 6;
44const MSGID_PIECE: u8 = 7;
45const MSGID_CANCEL: u8 = 8;
46const MSGID_EXTENDED: u8 = 20;
47
48pub const EXTENDED_UT_METADATA_KEY: &[u8] = b"ut_metadata";
49pub const MY_EXTENDED_UT_METADATA: u8 = 3;
50
51pub const EXTENDED_UT_PEX_KEY: &[u8] = b"ut_pex";
52pub const MY_EXTENDED_UT_PEX: u8 = 1;
53
54#[derive(Debug)]
55pub enum MessageDeserializeError {
56 NotEnoughData(usize, &'static str),
57 UnsupportedMessageId(u8),
58 IncorrectLenPrefix {
59 received: u32,
60 expected: u32,
61 msg_id: u8,
62 },
63 OtherBincode {
64 error: bincode::Error,
65 msg_id: u8,
66 len_prefix: u32,
67 name: &'static str,
68 },
69 Other(anyhow::Error),
70}
71
72pub fn serialize_piece_preamble(chunk: &ChunkInfo, mut buf: &mut [u8]) -> usize {
73 BE::write_u32(&mut buf[0..4], LEN_PREFIX_PIECE + chunk.size);
74 buf[4] = MSGID_PIECE;
75
76 buf = &mut buf[PREAMBLE_LEN..];
77 BE::write_u32(&mut buf[0..4], chunk.piece_index.get());
78 BE::write_u32(&mut buf[4..8], chunk.offset);
79
80 PIECE_MESSAGE_PREAMBLE_LEN
81}
82
83#[derive(Debug)]
84pub struct Piece<B> {
85 pub index: u32,
86 pub begin: u32,
87 pub block: B,
88}
89
90impl<B: CloneToOwned> CloneToOwned for Piece<B> {
91 type Target = Piece<B::Target>;
92
93 fn clone_to_owned(&self, within_buffer: Option<&Bytes>) -> Self::Target {
94 Piece {
95 index: self.index,
96 begin: self.begin,
97 block: self.block.clone_to_owned(within_buffer),
98 }
99 }
100}
101
102impl<B> Piece<B>
103where
104 B: AsRef<[u8]>,
105{
106 pub fn from_data<T>(index: u32, begin: u32, block: T) -> Piece<B>
107 where
108 B: From<T>,
109 {
110 Piece {
111 index,
112 begin,
113 block: B::from(block),
114 }
115 }
116
117 pub fn serialize(&self, mut buf: &mut [u8]) -> usize {
118 byteorder::BigEndian::write_u32(&mut buf[0..4], self.index);
119 byteorder::BigEndian::write_u32(&mut buf[4..8], self.begin);
120 buf = &mut buf[8..];
121 buf.copy_from_slice(self.block.as_ref());
122 self.block.as_ref().len() + 8
123 }
124 pub fn deserialize<'a>(buf: &'a [u8]) -> Piece<B>
125 where
126 B: From<&'a [u8]> + 'a,
127 {
128 let index = byteorder::BigEndian::read_u32(&buf[0..4]);
129 let begin = byteorder::BigEndian::read_u32(&buf[4..8]);
130 let block = B::from(&buf[8..]);
131 Piece {
132 index,
133 begin,
134 block,
135 }
136 }
137}
138
139impl std::fmt::Display for MessageDeserializeError {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 MessageDeserializeError::NotEnoughData(b, name) => {
143 write!(
144 f,
145 "not enough data to deserialize {name}: expected at least {b} more bytes"
146 )
147 }
148 MessageDeserializeError::UnsupportedMessageId(msg_id) => {
149 write!(f, "unsupported message id {msg_id}")
150 }
151 MessageDeserializeError::IncorrectLenPrefix {
152 received,
153 expected,
154 msg_id,
155 } => write!(
156 f,
157 "incorrect len prefix for message id {msg_id}, expected {expected}, received {received}"
158 ),
159 MessageDeserializeError::OtherBincode {
160 error,
161 msg_id,
162 name,
163 len_prefix,
164 } => write!(
165 f,
166 "error deserializing {name} (msg_id={msg_id}, len_prefix={len_prefix}): {error:#}"
167 ),
168 MessageDeserializeError::Other(e) => write!(f, "{e}"),
169 }
170 }
171}
172
173impl std::error::Error for MessageDeserializeError {
174 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
175 match self {
176 MessageDeserializeError::OtherBincode { error, .. } => Some(error),
177 _ => None,
178 }
179 }
180}
181
182impl From<anyhow::Error> for MessageDeserializeError {
183 fn from(e: anyhow::Error) -> Self {
184 MessageDeserializeError::Other(e)
185 }
186}
187
188#[derive(Debug)]
189pub enum Message<ByteBuf: ByteBufT> {
190 Request(Request),
191 Cancel(Request),
192 Bitfield(ByteBuf),
193 KeepAlive,
194 Have(u32),
195 Choke,
196 Unchoke,
197 Interested,
198 NotInterested,
199 Piece(Piece<ByteBuf>),
200 Extended(ExtendedMessage<ByteBuf>),
201}
202
203pub type MessageBorrowed<'a> = Message<ByteBuf<'a>>;
204pub type MessageOwned = Message<ByteBufOwned>;
205
206pub type BitfieldBorrowed<'a> = &'a bitvec::slice::BitSlice<u8, bitvec::order::Msb0>;
207pub type BitfieldOwned = bitvec::vec::BitVec<u8, bitvec::order::Msb0>;
208
209pub struct Bitfield<'a> {
210 pub data: BitfieldBorrowed<'a>,
211}
212
213impl<ByteBuf> CloneToOwned for Message<ByteBuf>
214where
215 ByteBuf: ByteBufT,
216 <ByteBuf as CloneToOwned>::Target: ByteBufT,
217{
218 type Target = Message<<ByteBuf as CloneToOwned>::Target>;
219
220 fn clone_to_owned(&self, within_buffer: Option<&Bytes>) -> Self::Target {
221 match self {
222 Message::Request(req) => Message::Request(*req),
223 Message::Cancel(req) => Message::Cancel(*req),
224 Message::Bitfield(b) => Message::Bitfield(b.clone_to_owned(within_buffer)),
225 Message::Choke => Message::Choke,
226 Message::Unchoke => Message::Unchoke,
227 Message::Interested => Message::Interested,
228 Message::Piece(piece) => Message::Piece(Piece {
229 index: piece.index,
230 begin: piece.begin,
231 block: piece.block.clone_to_owned(within_buffer),
232 }),
233 Message::KeepAlive => Message::KeepAlive,
234 Message::Have(v) => Message::Have(*v),
235 Message::NotInterested => Message::NotInterested,
236 Message::Extended(e) => Message::Extended(e.clone_to_owned(within_buffer)),
237 }
238 }
239}
240
241impl<'a> Bitfield<'a> {
242 pub fn new_from_slice(buf: &'a [u8]) -> anyhow::Result<Self> {
243 Ok(Self {
244 data: bitvec::slice::BitSlice::from_slice(buf),
245 })
246 }
247}
248
249impl std::fmt::Debug for Bitfield<'_> {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 f.debug_struct("Bitfield")
252 .field("_ones", &self.data.count_ones())
253 .field("_len", &self.data.len())
254 .finish()
255 }
256}
257
258impl<ByteBuf> Message<ByteBuf>
259where
260 ByteBuf: ByteBufT,
261{
262 pub fn len_prefix_and_msg_id(&self) -> (u32, u8) {
263 match self {
264 Message::Request(_) => (LEN_PREFIX_REQUEST, MSGID_REQUEST),
265 Message::Cancel(_) => (LEN_PREFIX_REQUEST, MSGID_CANCEL),
266 Message::Bitfield(b) => (1 + b.as_ref().len() as u32, MSGID_BITFIELD),
267 Message::Choke => (LEN_PREFIX_CHOKE, MSGID_CHOKE),
268 Message::Unchoke => (LEN_PREFIX_UNCHOKE, MSGID_UNCHOKE),
269 Message::Interested => (LEN_PREFIX_INTERESTED, MSGID_INTERESTED),
270 Message::NotInterested => (LEN_PREFIX_NOT_INTERESTED, MSGID_NOT_INTERESTED),
271 Message::Piece(p) => (
272 LEN_PREFIX_PIECE + p.block.as_ref().len() as u32,
273 MSGID_PIECE,
274 ),
275 Message::KeepAlive => (LEN_PREFIX_KEEPALIVE, 0),
276 Message::Have(_) => (LEN_PREFIX_HAVE, MSGID_HAVE),
277 Message::Extended(_) => (0, MSGID_EXTENDED),
278 }
279 }
280 pub fn serialize(
281 &self,
282 out: &mut Vec<u8>,
283 peer_extended_messages: &dyn Fn() -> PeerExtendedMessageIds,
284 ) -> anyhow::Result<usize> {
285 let (lp, msg_id) = self.len_prefix_and_msg_id();
286
287 out.resize(PREAMBLE_LEN, 0);
288
289 byteorder::BigEndian::write_u32(&mut out[..4], lp);
290 out[4] = msg_id;
291
292 let ser = bopts();
293
294 match self {
295 Message::Request(request) | Message::Cancel(request) => {
296 const MSG_LEN: usize = PREAMBLE_LEN + 12;
297 out.resize(MSG_LEN, 0);
298 debug_assert_eq!(out[PREAMBLE_LEN..].len(), 12);
299 ser.serialize_into(&mut out[PREAMBLE_LEN..], request)
300 .unwrap();
301 Ok(MSG_LEN)
302 }
303 Message::Bitfield(b) => {
304 let block_len = b.as_ref().len();
305 let msg_len = PREAMBLE_LEN + block_len;
306 out.resize(msg_len, 0);
307 out[PREAMBLE_LEN..PREAMBLE_LEN + block_len].copy_from_slice(b.as_ref());
308 Ok(msg_len)
309 }
310 Message::Choke | Message::Unchoke | Message::Interested | Message::NotInterested => {
311 Ok(PREAMBLE_LEN)
312 }
313 Message::Piece(p) => {
314 let block_len = p.block.as_ref().len();
315 let payload_len = 8 + block_len;
316 let msg_len = PREAMBLE_LEN + payload_len;
317 out.resize(msg_len, 0);
318 let tmp = &mut out[PREAMBLE_LEN..];
319 p.serialize(&mut tmp[..payload_len]);
320 Ok(msg_len)
321 }
322 Message::KeepAlive => {
323 Ok(4)
325 }
326 Message::Have(v) => {
327 let msg_len = PREAMBLE_LEN + 4;
328 out.resize(msg_len, 0);
329 BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
330 Ok(msg_len)
331 }
332 Message::Extended(e) => {
333 e.serialize(out, peer_extended_messages)?;
334 let msg_size = out.len();
335 BE::write_u32(&mut out[..4], (msg_size - PREAMBLE_LEN + 1) as u32);
338 Ok(msg_size)
339 }
340 }
341 }
342 pub fn deserialize<'a>(
343 buf: &'a [u8],
344 ) -> Result<(Message<ByteBuf>, usize), MessageDeserializeError>
345 where
346 ByteBuf: From<&'a [u8]> + 'a + Deserialize<'a>,
347 {
348 let len_prefix = match buf.get(0..4) {
349 Some(bytes) => byteorder::BigEndian::read_u32(bytes),
350 None => return Err(MessageDeserializeError::NotEnoughData(4, "message")),
351 };
352 if len_prefix == 0 {
353 return Ok((Message::KeepAlive, 4));
354 }
355
356 let msg_id = match buf.get(4) {
357 Some(msg_id) => *msg_id,
358 None => return Err(MessageDeserializeError::NotEnoughData(1, "message")),
359 };
360 let rest = &buf[5..];
361 let decoder_config = bincode::DefaultOptions::new()
362 .with_fixint_encoding()
363 .with_big_endian();
364
365 match msg_id {
366 MSGID_CHOKE => {
367 if len_prefix != LEN_PREFIX_CHOKE {
368 return Err(MessageDeserializeError::IncorrectLenPrefix {
369 received: len_prefix,
370 expected: LEN_PREFIX_CHOKE,
371 msg_id,
372 });
373 }
374 Ok((Message::Choke, NO_PAYLOAD_MSG_LEN))
375 }
376 MSGID_UNCHOKE => {
377 if len_prefix != LEN_PREFIX_UNCHOKE {
378 return Err(MessageDeserializeError::IncorrectLenPrefix {
379 received: len_prefix,
380 expected: LEN_PREFIX_UNCHOKE,
381 msg_id,
382 });
383 }
384 Ok((Message::Unchoke, NO_PAYLOAD_MSG_LEN))
385 }
386 MSGID_INTERESTED => {
387 if len_prefix != LEN_PREFIX_INTERESTED {
388 return Err(MessageDeserializeError::IncorrectLenPrefix {
389 received: len_prefix,
390 expected: LEN_PREFIX_INTERESTED,
391 msg_id,
392 });
393 }
394 Ok((Message::Interested, NO_PAYLOAD_MSG_LEN))
395 }
396 MSGID_NOT_INTERESTED => {
397 if len_prefix != LEN_PREFIX_NOT_INTERESTED {
398 return Err(MessageDeserializeError::IncorrectLenPrefix {
399 received: len_prefix,
400 expected: LEN_PREFIX_NOT_INTERESTED,
401 msg_id,
402 });
403 }
404 Ok((Message::NotInterested, NO_PAYLOAD_MSG_LEN))
405 }
406 MSGID_HAVE => {
407 let expected_len = 4;
408 match rest.get(..expected_len) {
409 Some(h) => Ok((Message::Have(BE::read_u32(h)), PREAMBLE_LEN + expected_len)),
410 None => {
411 let missing = expected_len - rest.len();
412 Err(MessageDeserializeError::NotEnoughData(missing, "have"))
413 }
414 }
415 }
416 MSGID_BITFIELD => {
417 if len_prefix <= 1 {
418 return Err(MessageDeserializeError::IncorrectLenPrefix {
419 expected: 2,
420 received: len_prefix,
421 msg_id,
422 });
423 }
424 let expected_len = len_prefix as usize - 1;
425 match rest.get(..expected_len) {
426 Some(bitfield) => Ok((
427 Message::Bitfield(ByteBuf::from(bitfield)),
428 PREAMBLE_LEN + expected_len,
429 )),
430 None => {
431 let missing = expected_len - rest.len();
432 Err(MessageDeserializeError::NotEnoughData(missing, "bitfield"))
433 }
434 }
435 }
436 MSGID_REQUEST | MSGID_CANCEL => {
437 let expected_len = 12;
438 match rest.get(..expected_len) {
439 Some(b) => {
440 let request = decoder_config.deserialize::<Request>(b).unwrap();
441 let req = if msg_id == MSGID_REQUEST {
442 Message::Request(request)
443 } else {
444 Message::Cancel(request)
445 };
446 Ok((req, PREAMBLE_LEN + expected_len))
447 }
448 None => {
449 let missing = expected_len - rest.len();
450 Err(MessageDeserializeError::NotEnoughData(
451 missing,
452 if msg_id == MSGID_REQUEST {
453 "request"
454 } else {
455 "cancel"
456 },
457 ))
458 }
459 }
460 }
461 MSGID_PIECE => {
462 if len_prefix <= 9 {
463 return Err(MessageDeserializeError::IncorrectLenPrefix {
464 expected: 10,
465 received: len_prefix,
466 msg_id,
467 });
468 }
469 let expected_len = len_prefix as usize - 9 + 8;
471 match rest.get(..expected_len) {
472 Some(b) => Ok((
473 Message::Piece(Piece::deserialize(b)),
474 PREAMBLE_LEN + expected_len,
475 )),
476 None => Err(MessageDeserializeError::NotEnoughData(
477 expected_len - rest.len(),
478 "piece",
479 )),
480 }
481 }
482 MSGID_EXTENDED => {
483 if len_prefix <= 6 {
484 return Err(MessageDeserializeError::IncorrectLenPrefix {
485 expected: 6,
486 received: len_prefix,
487 msg_id,
488 });
489 }
490 let expected_len = len_prefix as usize - 1;
492 match rest.get(..expected_len) {
493 Some(b) => Ok((
494 Message::Extended(ExtendedMessage::deserialize(b)?),
495 PREAMBLE_LEN + expected_len,
496 )),
497 None => Err(MessageDeserializeError::NotEnoughData(
498 expected_len - rest.len(),
499 "extended",
500 )),
501 }
502 }
503 msg_id => Err(MessageDeserializeError::UnsupportedMessageId(msg_id)),
504 }
505 }
506}
507
508#[derive(Serialize, Deserialize, Debug)]
509pub struct Handshake<ByteBuf> {
510 pub pstr: ByteBuf,
511 pub reserved: [u8; 8],
512 pub info_hash: [u8; 20],
513 pub peer_id: [u8; 20],
514}
515
516fn bopts() -> impl bincode::Options {
517 bincode::DefaultOptions::new()
518 .with_fixint_encoding()
519 .with_big_endian()
520}
521
522impl Handshake<ByteBuf<'static>> {
523 pub fn new(info_hash: Id20, peer_id: Id20) -> Handshake<ByteBuf<'static>> {
524 debug_assert_eq!(PSTR_BT1.len(), 19);
525
526 let mut reserved: u64 = 0;
527 reserved |= 1 << 20;
529 let mut reserved_arr = [0u8; 8];
530 BE::write_u64(&mut reserved_arr, reserved);
531
532 Handshake {
533 pstr: ByteBuf(PSTR_BT1.as_bytes()),
534 reserved: reserved_arr,
535 info_hash: info_hash.0,
536 peer_id: peer_id.0,
537 }
538 }
539
540 pub fn deserialize(
541 b: &[u8],
542 ) -> Result<(Handshake<ByteBuf<'_>>, usize), MessageDeserializeError> {
543 let pstr_len = *b
544 .first()
545 .ok_or(MessageDeserializeError::NotEnoughData(1, "handshake"))?;
546 if pstr_len as usize != PSTR_BT1.len() {
547 return Err(MessageDeserializeError::Other(anyhow::anyhow!(
548 "pstr should be {} bytes long, but received {}",
549 PSTR_BT1.len(),
550 pstr_len
551 )));
552 }
553 let expected_len = 1usize + pstr_len as usize + 48;
554 let hbuf = b
555 .get(..expected_len)
556 .ok_or(MessageDeserializeError::NotEnoughData(
557 expected_len,
558 "handshake",
559 ))?;
560 let h = Self::bopts()
561 .deserialize::<Handshake<ByteBuf<'_>>>(hbuf)
562 .map_err(|e| MessageDeserializeError::Other(e.into()))?;
563 if h.pstr.0 != PSTR_BT1.as_bytes() {
564 return Err(MessageDeserializeError::Other(anyhow::anyhow!(
565 "pstr doesn't match bittorrent V1"
566 )));
567 }
568 Ok((h, expected_len))
569 }
570}
571
572impl<B> Handshake<B> {
573 pub fn supports_extended(&self) -> bool {
574 self.reserved[5] & 0x10 > 0
575 }
576 fn bopts() -> impl bincode::Options {
577 bincode::DefaultOptions::new()
578 }
579
580 pub fn serialize(&self, buf: &mut Vec<u8>)
581 where
582 B: Serialize,
583 {
584 Self::bopts().serialize_into(buf, &self).unwrap()
585 }
586}
587
588impl<B> CloneToOwned for Handshake<B>
589where
590 B: CloneToOwned,
591{
592 type Target = Handshake<<B as CloneToOwned>::Target>;
593
594 fn clone_to_owned(&self, within_buffer: Option<&Bytes>) -> Self::Target {
595 Handshake {
596 pstr: self.pstr.clone_to_owned(within_buffer),
597 reserved: self.reserved,
598 info_hash: self.info_hash,
599 peer_id: self.peer_id,
600 }
601 }
602}
603
604#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
605pub struct Request {
606 pub index: u32,
607 pub begin: u32,
608 pub length: u32,
609}
610
611impl Request {
612 pub fn new(index: u32, begin: u32, length: u32) -> Self {
613 Self {
614 index,
615 begin,
616 length,
617 }
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use crate::extended::handshake::ExtendedHandshake;
624
625 use super::*;
626 #[test]
627 fn test_handshake_serialize() {
628 let info_hash = Id20::new([
629 1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
630 ]);
631 let peer_id = Id20::new([
632 1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
633 ]);
634 let mut buf = Vec::new();
635 Handshake::new(info_hash, peer_id).serialize(&mut buf);
636 assert_eq!(buf.len(), 20 + 20 + 8 + 19 + 1);
637 }
638
639 #[test]
640 fn test_extended_serialize() {
641 let msg = Message::Extended(ExtendedMessage::Handshake(ExtendedHandshake::new()));
642 let mut out = Vec::new();
643 msg.serialize(&mut out, &Default::default).unwrap();
644 dbg!(out);
645 }
646
647 #[test]
648 fn test_deserialize_serialize_extended_is_same() {
649 use std::fs::File;
650 use std::io::Read;
651 let mut buf = Vec::new();
652 File::open("../librqbit/resources/test/extended-handshake.bin")
653 .unwrap()
654 .read_to_end(&mut buf)
655 .unwrap();
656 let (msg, size) = MessageBorrowed::deserialize(&buf).unwrap();
657 assert_eq!(size, buf.len());
658 let mut write_buf = Vec::new();
659 msg.serialize(&mut write_buf, &Default::default).unwrap();
660 if buf != write_buf {
661 {
662 use std::io::Write;
663 let mut f = std::fs::OpenOptions::new()
664 .create(true)
665 .truncate(true)
666 .write(true)
667 .open("/tmp/test_deserialize_serialize_extended_is_same")
668 .unwrap();
669 f.write_all(&write_buf).unwrap();
670 }
671 panic!("resources/test/extended-handshake.bin did not serialize exactly the same. Dumped to /tmp/test_deserialize_serialize_extended_is_same, you can compare with resources/test/extended-handshake.bin")
672 }
673 }
674}