Skip to main content

mainline/common/
messages.rs

1//! Serialize and decerealize Krpc messages.
2
3// Copied from <https://github.com/raptorswing/rustydht-lib/blob/main/src/packets/public.rs>
4
5#![allow(missing_docs)]
6
7mod internal;
8
9use std::convert::TryInto;
10use std::net::{Ipv4Addr, SocketAddrV4};
11
12use crate::common::{Id, Node, ID_SIZE};
13
14use super::InvalidIdSize;
15
16#[derive(Debug, PartialEq, Clone)]
17pub(crate) struct Message {
18    pub transaction_id: u32,
19
20    /// The version of the requester or responder.
21    pub version: Option<[u8; 4]>,
22
23    /// The IP address and port ("SocketAddr") of the requester as seen from the responder's point of view.
24    /// This should be set only on response, but is defined at this level with the other common fields to avoid defining yet another layer on the response objects.
25    pub requester_ip: Option<SocketAddrV4>,
26
27    pub message_type: MessageType,
28
29    /// For bep0043. When set true on a request, indicates that the requester can't reply to requests and that responders should not add requester to their routing tables.
30    /// Should only be set on requests - undefined behavior when set on a response.
31    pub read_only: bool,
32}
33
34#[derive(Debug, PartialEq, Clone)]
35pub enum MessageType {
36    Request(RequestSpecific),
37
38    Response(ResponseSpecific),
39
40    Error(ErrorSpecific),
41}
42
43#[derive(Debug, PartialEq, Clone)]
44pub struct ErrorSpecific {
45    pub code: i32,
46    pub description: String,
47}
48
49#[derive(Debug, PartialEq, Clone)]
50pub struct RequestSpecific {
51    pub requester_id: Id,
52    pub request_type: RequestTypeSpecific,
53}
54
55#[derive(Debug, PartialEq, Clone)]
56pub enum RequestTypeSpecific {
57    Ping,
58    FindNode(FindNodeRequestArguments),
59    GetPeers(GetPeersRequestArguments),
60    GetValue(GetValueRequestArguments),
61
62    Put(PutRequest),
63}
64
65#[derive(Debug, PartialEq, Clone)]
66pub struct PutRequest {
67    pub token: Box<[u8]>,
68    pub put_request_type: PutRequestSpecific,
69}
70
71#[derive(Debug, PartialEq, Clone)]
72pub enum PutRequestSpecific {
73    AnnouncePeer(AnnouncePeerRequestArguments),
74    PutImmutable(PutImmutableRequestArguments),
75    PutMutable(PutMutableRequestArguments),
76}
77
78impl PutRequestSpecific {
79    pub fn target(&self) -> &Id {
80        match self {
81            PutRequestSpecific::AnnouncePeer(AnnouncePeerRequestArguments {
82                info_hash, ..
83            }) => info_hash,
84            PutRequestSpecific::PutMutable(PutMutableRequestArguments { target, .. }) => target,
85            PutRequestSpecific::PutImmutable(PutImmutableRequestArguments { target, .. }) => target,
86        }
87    }
88}
89
90#[derive(Debug, PartialEq, Clone)]
91pub enum ResponseSpecific {
92    Ping(PingResponseArguments),
93    FindNode(FindNodeResponseArguments),
94    GetPeers(GetPeersResponseArguments),
95    GetImmutable(GetImmutableResponseArguments),
96    GetMutable(GetMutableResponseArguments),
97    NoValues(NoValuesResponseArguments),
98    NoMoreRecentValue(NoMoreRecentValueResponseArguments),
99}
100
101// === PING ===
102#[derive(Debug, PartialEq, Clone)]
103pub struct PingResponseArguments {
104    pub responder_id: Id,
105}
106
107// === FIND_NODE ===
108#[derive(Debug, PartialEq, Clone)]
109pub struct FindNodeRequestArguments {
110    pub target: Id,
111}
112
113#[derive(Debug, PartialEq, Clone)]
114pub struct FindNodeResponseArguments {
115    pub responder_id: Id,
116    pub nodes: Box<[Node]>,
117}
118
119// Get anything
120
121#[derive(Debug, PartialEq, Clone)]
122pub struct GetValueRequestArguments {
123    pub target: Id,
124    pub seq: Option<i64>,
125    // A bit of a hack, using this to carry an optional
126    // salt in the query.request field of [crate::query]
127    // not really encoded, decoded or sent over the wire.
128    pub salt: Option<Box<[u8]>>,
129}
130
131#[derive(Debug, PartialEq, Clone)]
132pub struct NoValuesResponseArguments {
133    pub responder_id: Id,
134    pub token: Box<[u8]>,
135    pub nodes: Option<Box<[Node]>>,
136}
137
138// === Get Peers ===
139
140#[derive(Debug, PartialEq, Clone)]
141pub struct GetPeersRequestArguments {
142    pub info_hash: Id,
143}
144
145#[derive(Debug, PartialEq, Clone)]
146pub struct GetPeersResponseArguments {
147    pub responder_id: Id,
148    pub token: Box<[u8]>,
149    pub values: Vec<SocketAddrV4>,
150    pub nodes: Option<Box<[Node]>>,
151}
152
153// === Announce Peer ===
154
155#[derive(Debug, PartialEq, Clone)]
156pub struct AnnouncePeerRequestArguments {
157    pub info_hash: Id,
158    pub port: u16,
159    pub implied_port: Option<bool>,
160}
161
162// === Get Immutable ===
163
164#[derive(Debug, PartialEq, Clone)]
165pub struct GetImmutableResponseArguments {
166    pub responder_id: Id,
167    pub token: Box<[u8]>,
168    pub nodes: Option<Box<[Node]>>,
169    pub v: Box<[u8]>,
170}
171
172// === Get Mutable ===
173
174#[derive(Debug, PartialEq, Clone)]
175pub struct GetMutableResponseArguments {
176    pub responder_id: Id,
177    pub token: Box<[u8]>,
178    pub nodes: Option<Box<[Node]>>,
179    pub v: Box<[u8]>,
180    pub k: [u8; 32],
181    pub seq: i64,
182    pub sig: [u8; 64],
183}
184
185#[derive(Debug, PartialEq, Clone)]
186pub struct NoMoreRecentValueResponseArguments {
187    pub responder_id: Id,
188    pub token: Box<[u8]>,
189    pub nodes: Option<Box<[Node]>>,
190    pub seq: i64,
191}
192
193// === Put Immutable ===
194
195#[derive(Debug, PartialEq, Clone)]
196pub struct PutImmutableRequestArguments {
197    pub target: Id,
198    pub v: Box<[u8]>,
199}
200
201// === Put Mutable ===
202
203#[derive(Debug, PartialEq, Clone)]
204pub struct PutMutableRequestArguments {
205    pub target: Id,
206    pub v: Box<[u8]>,
207    pub k: [u8; 32],
208    pub seq: i64,
209    pub sig: [u8; 64],
210    pub salt: Option<Box<[u8]>>,
211    pub cas: Option<i64>,
212}
213
214impl Message {
215    fn into_serde_message(self) -> internal::DHTMessage {
216        internal::DHTMessage {
217            transaction_id: self.transaction_id.to_be_bytes(),
218            version: self.version,
219            ip: self
220                .requester_ip
221                .map(|sockaddr| sockaddr_to_bytes(&sockaddr)),
222            read_only: if self.read_only { Some(1) } else { Some(0) },
223            variant: match self.message_type {
224                MessageType::Request(RequestSpecific {
225                    requester_id,
226                    request_type,
227                }) => internal::DHTMessageVariant::Request(match request_type {
228                    RequestTypeSpecific::Ping => internal::DHTRequestSpecific::Ping {
229                        arguments: internal::DHTPingRequestArguments {
230                            id: requester_id.into(),
231                        },
232                    },
233                    RequestTypeSpecific::FindNode(find_node_args) => {
234                        internal::DHTRequestSpecific::FindNode {
235                            arguments: internal::DHTFindNodeRequestArguments {
236                                id: requester_id.into(),
237                                target: find_node_args.target.into(),
238                            },
239                        }
240                    }
241                    RequestTypeSpecific::GetPeers(get_peers_args) => {
242                        internal::DHTRequestSpecific::GetPeers {
243                            arguments: internal::DHTGetPeersRequestArguments {
244                                id: requester_id.into(),
245                                info_hash: get_peers_args.info_hash.into(),
246                            },
247                        }
248                    }
249                    RequestTypeSpecific::GetValue(get_mutable_args) => {
250                        internal::DHTRequestSpecific::GetValue {
251                            arguments: internal::DHTGetValueRequestArguments {
252                                id: requester_id.into(),
253                                target: get_mutable_args.target.into(),
254                                seq: get_mutable_args.seq,
255                            },
256                        }
257                    }
258                    RequestTypeSpecific::Put(PutRequest {
259                        token,
260                        put_request_type,
261                    }) => match put_request_type {
262                        PutRequestSpecific::AnnouncePeer(announce_peer_args) => {
263                            internal::DHTRequestSpecific::AnnouncePeer {
264                                arguments: internal::DHTAnnouncePeerRequestArguments {
265                                    id: requester_id.into(),
266                                    token,
267
268                                    info_hash: announce_peer_args.info_hash.into(),
269                                    port: announce_peer_args.port,
270                                    implied_port: if announce_peer_args.implied_port.is_some() {
271                                        Some(1)
272                                    } else {
273                                        Some(0)
274                                    },
275                                },
276                            }
277                        }
278                        PutRequestSpecific::PutImmutable(put_immutable_arguments) => {
279                            internal::DHTRequestSpecific::PutValue {
280                                arguments: internal::DHTPutValueRequestArguments {
281                                    id: requester_id.into(),
282                                    token,
283
284                                    target: put_immutable_arguments.target.into(),
285                                    v: put_immutable_arguments.v,
286                                    k: None,
287                                    seq: None,
288                                    sig: None,
289                                    salt: None,
290                                    cas: None,
291                                },
292                            }
293                        }
294                        PutRequestSpecific::PutMutable(put_mutable_arguments) => {
295                            internal::DHTRequestSpecific::PutValue {
296                                arguments: internal::DHTPutValueRequestArguments {
297                                    id: requester_id.into(),
298                                    token,
299
300                                    target: put_mutable_arguments.target.into(),
301                                    v: put_mutable_arguments.v,
302                                    k: Some(put_mutable_arguments.k),
303                                    seq: Some(put_mutable_arguments.seq),
304                                    sig: Some(put_mutable_arguments.sig),
305                                    salt: put_mutable_arguments.salt,
306                                    cas: put_mutable_arguments.cas,
307                                },
308                            }
309                        }
310                    },
311                }),
312
313                MessageType::Response(res) => internal::DHTMessageVariant::Response(match res {
314                    ResponseSpecific::Ping(ping_args) => internal::DHTResponseSpecific::Ping {
315                        arguments: internal::DHTPingResponseArguments {
316                            id: ping_args.responder_id.into(),
317                        },
318                    },
319                    ResponseSpecific::FindNode(find_node_args) => {
320                        internal::DHTResponseSpecific::FindNode {
321                            arguments: internal::DHTFindNodeResponseArguments {
322                                id: find_node_args.responder_id.into(),
323                                nodes: nodes4_to_bytes(&find_node_args.nodes),
324                            },
325                        }
326                    }
327                    ResponseSpecific::GetPeers(get_peers_args) => {
328                        internal::DHTResponseSpecific::GetPeers {
329                            arguments: internal::DHTGetPeersResponseArguments {
330                                id: get_peers_args.responder_id.into(),
331                                token: get_peers_args.token,
332                                nodes: get_peers_args
333                                    .nodes
334                                    .as_ref()
335                                    .map(|nodes| nodes4_to_bytes(nodes)),
336                                values: peers_to_bytes(&get_peers_args.values),
337                            },
338                        }
339                    }
340                    ResponseSpecific::NoValues(no_values_arguments) => {
341                        internal::DHTResponseSpecific::NoValues {
342                            arguments: internal::DHTNoValuesResponseArguments {
343                                id: no_values_arguments.responder_id.into(),
344                                token: no_values_arguments.token,
345                                nodes: no_values_arguments
346                                    .nodes
347                                    .as_ref()
348                                    .map(|nodes| nodes4_to_bytes(nodes)),
349                            },
350                        }
351                    }
352                    ResponseSpecific::GetImmutable(get_immutable_args) => {
353                        internal::DHTResponseSpecific::GetImmutable {
354                            arguments: internal::DHTGetImmutableResponseArguments {
355                                id: get_immutable_args.responder_id.into(),
356                                token: get_immutable_args.token,
357                                nodes: get_immutable_args
358                                    .nodes
359                                    .as_ref()
360                                    .map(|nodes| nodes4_to_bytes(nodes)),
361                                v: get_immutable_args.v,
362                            },
363                        }
364                    }
365                    ResponseSpecific::GetMutable(get_mutable_args) => {
366                        internal::DHTResponseSpecific::GetMutable {
367                            arguments: internal::DHTGetMutableResponseArguments {
368                                id: get_mutable_args.responder_id.into(),
369                                token: get_mutable_args.token,
370                                nodes: get_mutable_args
371                                    .nodes
372                                    .as_ref()
373                                    .map(|nodes| nodes4_to_bytes(nodes)),
374                                v: get_mutable_args.v,
375                                k: get_mutable_args.k,
376                                seq: get_mutable_args.seq,
377                                sig: get_mutable_args.sig,
378                            },
379                        }
380                    }
381                    ResponseSpecific::NoMoreRecentValue(args) => {
382                        internal::DHTResponseSpecific::NoMoreRecentValue {
383                            arguments: internal::DHTNoMoreRecentValueResponseArguments {
384                                id: args.responder_id.into(),
385                                token: args.token,
386                                nodes: args.nodes.as_ref().map(|nodes| nodes4_to_bytes(nodes)),
387                                seq: args.seq,
388                            },
389                        }
390                    }
391                }),
392
393                MessageType::Error(err) => {
394                    internal::DHTMessageVariant::Error(internal::DHTErrorSpecific {
395                        error_info: (err.code, err.description),
396                    })
397                }
398            },
399        }
400    }
401
402    fn from_serde_message(msg: internal::DHTMessage) -> Result<Message, DecodeMessageError> {
403        Ok(Message {
404            transaction_id: u32::from_be_bytes(msg.transaction_id),
405            version: msg.version,
406            requester_ip: match msg.ip {
407                Some(ip) => Some(bytes_to_sockaddr(ip)?),
408                _ => None,
409            },
410            read_only: if let Some(read_only) = msg.read_only {
411                read_only > 0
412            } else {
413                false
414            },
415            message_type: match msg.variant {
416                internal::DHTMessageVariant::Request(req_variant) => {
417                    MessageType::Request(match req_variant {
418                        internal::DHTRequestSpecific::Ping { arguments } => RequestSpecific {
419                            requester_id: Id::from_bytes(arguments.id)?,
420                            request_type: RequestTypeSpecific::Ping,
421                        },
422                        internal::DHTRequestSpecific::FindNode { arguments } => RequestSpecific {
423                            requester_id: Id::from_bytes(arguments.id)?,
424                            request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments {
425                                target: Id::from_bytes(arguments.target)?,
426                            }),
427                        },
428                        internal::DHTRequestSpecific::GetPeers { arguments } => RequestSpecific {
429                            requester_id: Id::from_bytes(arguments.id)?,
430                            request_type: RequestTypeSpecific::GetPeers(GetPeersRequestArguments {
431                                info_hash: Id::from_bytes(arguments.info_hash)?,
432                            }),
433                        },
434                        internal::DHTRequestSpecific::GetValue { arguments } => RequestSpecific {
435                            requester_id: Id::from_bytes(arguments.id)?,
436
437                            request_type: RequestTypeSpecific::GetValue(GetValueRequestArguments {
438                                target: Id::from_bytes(arguments.target)?,
439                                seq: arguments.seq,
440                                salt: None,
441                            }),
442                        },
443                        internal::DHTRequestSpecific::AnnouncePeer { arguments } => {
444                            RequestSpecific {
445                                requester_id: Id::from_bytes(arguments.id)?,
446                                request_type: RequestTypeSpecific::Put(PutRequest {
447                                    token: arguments.token,
448                                    put_request_type: PutRequestSpecific::AnnouncePeer(
449                                        AnnouncePeerRequestArguments {
450                                            implied_port: arguments
451                                                .implied_port
452                                                .map(|implied_port| implied_port != 0),
453                                            info_hash: arguments.info_hash.into(),
454                                            port: arguments.port,
455                                        },
456                                    ),
457                                }),
458                            }
459                        }
460                        internal::DHTRequestSpecific::PutValue { arguments } => {
461                            if let Some(k) = arguments.k {
462                                RequestSpecific {
463                                    requester_id: Id::from_bytes(arguments.id)?,
464
465                                    request_type: RequestTypeSpecific::Put(PutRequest {
466                                        token: arguments.token,
467                                        put_request_type: PutRequestSpecific::PutMutable(
468                                            PutMutableRequestArguments {
469                                                target: Id::from_bytes(arguments.target)?,
470                                                v: arguments.v,
471                                                k,
472                                                seq: arguments.seq.expect(
473                                                    "Put mutable message to have sequence number",
474                                                ),
475                                                sig: arguments.sig.expect(
476                                                    "Put mutable message to have a signature",
477                                                ),
478                                                salt: arguments.salt,
479                                                cas: arguments.cas,
480                                            },
481                                        ),
482                                    }),
483                                }
484                            } else {
485                                RequestSpecific {
486                                    requester_id: Id::from_bytes(arguments.id)?,
487
488                                    request_type: RequestTypeSpecific::Put(PutRequest {
489                                        token: arguments.token,
490                                        put_request_type: PutRequestSpecific::PutImmutable(
491                                            PutImmutableRequestArguments {
492                                                target: Id::from_bytes(arguments.target)?,
493                                                v: arguments.v,
494                                            },
495                                        ),
496                                    }),
497                                }
498                            }
499                        }
500                    })
501                }
502
503                internal::DHTMessageVariant::Response(res_variant) => {
504                    MessageType::Response(match res_variant {
505                        internal::DHTResponseSpecific::Ping { arguments } => {
506                            ResponseSpecific::Ping(PingResponseArguments {
507                                responder_id: Id::from_bytes(arguments.id)?,
508                            })
509                        }
510                        internal::DHTResponseSpecific::FindNode { arguments } => {
511                            ResponseSpecific::FindNode(FindNodeResponseArguments {
512                                responder_id: Id::from_bytes(arguments.id)?,
513                                nodes: bytes_to_nodes4(&arguments.nodes)?,
514                            })
515                        }
516                        internal::DHTResponseSpecific::GetPeers { arguments } => {
517                            ResponseSpecific::GetPeers(GetPeersResponseArguments {
518                                responder_id: Id::from_bytes(arguments.id)?,
519                                token: arguments.token,
520                                nodes: match arguments.nodes {
521                                    Some(nodes) => Some(bytes_to_nodes4(nodes)?),
522                                    None => None,
523                                },
524                                values: bytes_to_peers(arguments.values)?,
525                            })
526                        }
527                        internal::DHTResponseSpecific::NoValues { arguments } => {
528                            ResponseSpecific::NoValues(NoValuesResponseArguments {
529                                responder_id: Id::from_bytes(arguments.id)?,
530                                token: arguments.token,
531                                nodes: match arguments.nodes {
532                                    Some(nodes) => Some(bytes_to_nodes4(nodes)?),
533                                    None => None,
534                                },
535                            })
536                        }
537                        internal::DHTResponseSpecific::GetImmutable { arguments } => {
538                            ResponseSpecific::GetImmutable(GetImmutableResponseArguments {
539                                responder_id: Id::from_bytes(arguments.id)?,
540                                token: arguments.token,
541                                nodes: match arguments.nodes {
542                                    Some(nodes) => Some(bytes_to_nodes4(nodes)?),
543                                    None => None,
544                                },
545                                v: arguments.v,
546                            })
547                        }
548                        internal::DHTResponseSpecific::GetMutable { arguments } => {
549                            ResponseSpecific::GetMutable(GetMutableResponseArguments {
550                                responder_id: Id::from_bytes(arguments.id)?,
551                                token: arguments.token,
552                                nodes: match arguments.nodes {
553                                    Some(nodes) => Some(bytes_to_nodes4(nodes)?),
554                                    None => None,
555                                },
556                                v: arguments.v,
557                                k: arguments.k,
558                                seq: arguments.seq,
559                                sig: arguments.sig,
560                            })
561                        }
562                        internal::DHTResponseSpecific::NoMoreRecentValue { arguments } => {
563                            ResponseSpecific::NoMoreRecentValue(
564                                NoMoreRecentValueResponseArguments {
565                                    responder_id: Id::from_bytes(arguments.id)?,
566                                    token: arguments.token,
567                                    nodes: match arguments.nodes {
568                                        Some(nodes) => Some(bytes_to_nodes4(nodes)?),
569                                        None => None,
570                                    },
571                                    seq: arguments.seq,
572                                },
573                            )
574                        }
575                    })
576                }
577
578                internal::DHTMessageVariant::Error(err) => MessageType::Error(ErrorSpecific {
579                    code: err.error_info.0,
580                    description: err.error_info.1,
581                }),
582            },
583        })
584    }
585
586    pub fn to_bytes(&self) -> Result<Vec<u8>, serde_bencode::Error> {
587        self.clone().into_serde_message().to_bytes()
588    }
589
590    pub fn from_bytes(bytes: &[u8]) -> Result<Message, DecodeMessageError> {
591        if bytes.len() < 15 {
592            return Err(DecodeMessageError::TooShort);
593        } else if bytes[0] != 100 {
594            return Err(DecodeMessageError::NotBencodeDictionary);
595        }
596
597        Message::from_serde_message(internal::DHTMessage::from_bytes(bytes)?)
598    }
599
600    /// Return the Id of the sender of the Message
601    ///
602    /// This is less straightforward than it seems because not *all* messages are sent
603    /// with an Id (all are except Error messages). This is reflected in the structure
604    /// of DHT Messages, and makes it a bit annoying to learn the sender's Id without
605    /// unraveling the entire message. This method is a convenience method to extract
606    /// the sender (or "author") Id from the guts of any Message.
607    pub fn get_author_id(&self) -> Option<Id> {
608        let id = match &self.message_type {
609            MessageType::Request(arguments) => arguments.requester_id,
610            MessageType::Response(response_variant) => match response_variant {
611                ResponseSpecific::Ping(arguments) => arguments.responder_id,
612                ResponseSpecific::FindNode(arguments) => arguments.responder_id,
613                ResponseSpecific::GetPeers(arguments) => arguments.responder_id,
614                ResponseSpecific::GetImmutable(arguments) => arguments.responder_id,
615                ResponseSpecific::GetMutable(arguments) => arguments.responder_id,
616                ResponseSpecific::NoValues(arguments) => arguments.responder_id,
617                ResponseSpecific::NoMoreRecentValue(arguments) => arguments.responder_id,
618            },
619            MessageType::Error(_) => {
620                return None;
621            }
622        };
623
624        Some(id)
625    }
626
627    /// If the response contains a closer nodes to the target, return that!
628    pub fn get_closer_nodes(&self) -> Option<&[Node]> {
629        match &self.message_type {
630            MessageType::Response(response_variant) => match response_variant {
631                ResponseSpecific::Ping(_) => None,
632                ResponseSpecific::FindNode(arguments) => Some(&arguments.nodes),
633                ResponseSpecific::GetPeers(arguments) => arguments.nodes.as_deref(),
634                ResponseSpecific::GetMutable(arguments) => arguments.nodes.as_deref(),
635                ResponseSpecific::GetImmutable(arguments) => arguments.nodes.as_deref(),
636                ResponseSpecific::NoValues(arguments) => arguments.nodes.as_deref(),
637                ResponseSpecific::NoMoreRecentValue(arguments) => arguments.nodes.as_deref(),
638            },
639            _ => None,
640        }
641    }
642
643    pub fn get_token(&self) -> Option<(Id, &[u8])> {
644        match &self.message_type {
645            MessageType::Response(response_variant) => match response_variant {
646                ResponseSpecific::Ping(_) => None,
647                ResponseSpecific::FindNode(_) => None,
648                ResponseSpecific::GetPeers(arguments) => {
649                    Some((arguments.responder_id, &arguments.token))
650                }
651                ResponseSpecific::GetImmutable(arguments) => {
652                    Some((arguments.responder_id, &arguments.token))
653                }
654                ResponseSpecific::GetMutable(arguments) => {
655                    Some((arguments.responder_id, &arguments.token))
656                }
657                ResponseSpecific::NoValues(arguments) => {
658                    Some((arguments.responder_id, &arguments.token))
659                }
660                ResponseSpecific::NoMoreRecentValue(arguments) => {
661                    Some((arguments.responder_id, &arguments.token))
662                }
663            },
664            _ => None,
665        }
666    }
667}
668
669fn bytes_to_sockaddr<T: AsRef<[u8]>>(bytes: T) -> Result<SocketAddrV4, DecodeMessageError> {
670    let bytes = bytes.as_ref();
671    match bytes.len() {
672        6 => {
673            let ip = Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]);
674
675            let port_bytes_as_array: [u8; 2] = bytes[4..6]
676                .try_into()
677                .map_err(|_| DecodeMessageError::InvalidPortEncoding)?;
678
679            let port: u16 = u16::from_be_bytes(port_bytes_as_array);
680
681            Ok(SocketAddrV4::new(ip, port))
682        }
683        18 => Err(DecodeMessageError::Ipv6Unsupported),
684        _ => Err(DecodeMessageError::InvalidSocketAddrEncodingLength),
685    }
686}
687
688pub fn sockaddr_to_bytes(sockaddr: &SocketAddrV4) -> [u8; 6] {
689    let mut bytes = [0u8; 6];
690
691    bytes[0..4].copy_from_slice(&sockaddr.ip().octets());
692
693    bytes[4..6].copy_from_slice(&sockaddr.port().to_be_bytes());
694
695    bytes
696}
697
698const NODE_BYTE_SIZE: usize = ID_SIZE + 6;
699
700fn nodes4_to_bytes(nodes: &[Node]) -> Box<[u8]> {
701    let mut bytes = Vec::with_capacity(NODE_BYTE_SIZE * nodes.len());
702
703    for node in nodes {
704        bytes.extend_from_slice(node.id().as_bytes());
705        bytes.extend_from_slice(&sockaddr_to_bytes(&node.address()));
706    }
707
708    bytes.into_boxed_slice()
709}
710
711fn bytes_to_nodes4<T: AsRef<[u8]>>(bytes: T) -> Result<Box<[Node]>, DecodeMessageError> {
712    let bytes = bytes.as_ref();
713
714    if bytes.len() % NODE_BYTE_SIZE != 0 {
715        return Err(DecodeMessageError::InvalidNodes4);
716    }
717
718    let expected_num = bytes.len() / NODE_BYTE_SIZE;
719    let mut to_ret = Vec::with_capacity(expected_num);
720    for i in 0..bytes.len() / NODE_BYTE_SIZE {
721        let i = i * NODE_BYTE_SIZE;
722        let id = Id::from_bytes(&bytes[i..i + ID_SIZE])?;
723        let sockaddr = bytes_to_sockaddr(&bytes[i + ID_SIZE..i + NODE_BYTE_SIZE])?;
724        let node = Node::new(id, sockaddr);
725        to_ret.push(node);
726    }
727
728    Ok(to_ret.into_boxed_slice())
729}
730
731fn peers_to_bytes(peers: &[SocketAddrV4]) -> Vec<serde_bytes::ByteBuf> {
732    peers
733        .iter()
734        .map(|p| serde_bytes::ByteBuf::from(sockaddr_to_bytes(p)))
735        .collect()
736}
737
738fn bytes_to_peers<T: AsRef<[serde_bytes::ByteBuf]>>(
739    bytes: T,
740) -> Result<Vec<SocketAddrV4>, DecodeMessageError> {
741    let bytes = bytes.as_ref();
742    bytes.iter().map(bytes_to_sockaddr).collect()
743}
744
745#[derive(thiserror::Error, Debug)]
746/// Mainline crate error enum.
747pub enum DecodeMessageError {
748    #[error("Expected message to be longer than 15 characters")]
749    TooShort,
750
751    #[error("Expected message to start with 'd'")]
752    NotBencodeDictionary,
753
754    #[error("Wrong number of bytes for nodes")]
755    InvalidNodes4,
756
757    #[error("wrong number of bytes for port")]
758    InvalidPortEncoding,
759
760    #[error("IPv6 is not yet implemented")]
761    Ipv6Unsupported,
762
763    #[error("Wrong number of bytes for sockaddr")]
764    InvalidSocketAddrEncodingLength,
765
766    #[error("Failed to parse packet bytes: {0}")]
767    BencodeError(#[from] serde_bencode::Error),
768
769    #[error(transparent)]
770    InvalidIdSize(#[from] InvalidIdSize),
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776
777    #[test]
778    fn test_ping_request() {
779        let original_msg = Message {
780            transaction_id: 258,
781            version: None,
782            requester_ip: None,
783            read_only: false,
784            message_type: MessageType::Request(RequestSpecific {
785                requester_id: Id::random(),
786                request_type: RequestTypeSpecific::Ping,
787            }),
788        };
789
790        let serde_msg = original_msg.clone().into_serde_message();
791        let bytes = serde_msg.to_bytes().unwrap();
792        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
793        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
794        assert_eq!(parsed_msg, original_msg);
795    }
796
797    #[test]
798    fn test_ping_response() {
799        let original_msg = Message {
800            transaction_id: 258,
801            version: Some([0xde, 0xad, 0, 1]),
802            requester_ip: Some("99.100.101.102:1030".parse().unwrap()),
803            read_only: false,
804            message_type: MessageType::Response(ResponseSpecific::Ping(PingResponseArguments {
805                responder_id: Id::random(),
806            })),
807        };
808
809        let serde_msg = original_msg.clone().into_serde_message();
810        let bytes = serde_msg.to_bytes().unwrap();
811        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
812        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
813        assert_eq!(parsed_msg, original_msg);
814    }
815
816    #[test]
817    fn test_find_node_request() {
818        let original_msg = Message {
819            transaction_id: 258,
820            version: Some([0x62, 0x61, 0x72, 0x66]),
821            requester_ip: None,
822            read_only: false,
823            message_type: MessageType::Request(RequestSpecific {
824                requester_id: Id::random(),
825                request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments {
826                    target: Id::random(),
827                }),
828            }),
829        };
830
831        let serde_msg = original_msg.clone().into_serde_message();
832        let bytes = serde_msg.to_bytes().unwrap();
833        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
834        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
835        assert_eq!(parsed_msg, original_msg);
836    }
837
838    #[test]
839    fn test_find_node_request_read_only() {
840        let original_msg = Message {
841            transaction_id: 258,
842            version: Some([0x62, 0x61, 0x72, 0x66]),
843            requester_ip: None,
844            read_only: true,
845            message_type: MessageType::Request(RequestSpecific {
846                requester_id: Id::random(),
847                request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments {
848                    target: Id::random(),
849                }),
850            }),
851        };
852
853        let serde_msg = original_msg.clone().into_serde_message();
854        let bytes = serde_msg.to_bytes().unwrap();
855        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
856        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
857        assert_eq!(parsed_msg, original_msg);
858    }
859
860    #[test]
861    fn test_find_node_response() {
862        let original_msg = Message {
863            transaction_id: 258,
864            version: Some([1, 2, 3, 4]),
865            requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
866            read_only: false,
867            message_type: MessageType::Response(ResponseSpecific::FindNode(
868                FindNodeResponseArguments {
869                    responder_id: Id::random(),
870                    nodes: [Node::new(Id::random(), "49.50.52.52:5354".parse().unwrap())].into(),
871                },
872            )),
873        };
874
875        let serde_msg = original_msg.clone().into_serde_message();
876        let bytes = serde_msg.to_bytes().unwrap();
877        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
878        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
879        assert_eq!(parsed_msg.get_author_id(), original_msg.get_author_id());
880        assert_eq!(
881            parsed_msg.get_closer_nodes().map(|nodes| nodes
882                .iter()
883                .map(|n| (n.id(), n.address()))
884                .collect::<Vec<_>>()),
885            original_msg.get_closer_nodes().map(|nodes| nodes
886                .iter()
887                .map(|n| (n.id(), n.address()))
888                .collect::<Vec<_>>())
889        );
890    }
891
892    #[test]
893    fn test_get_peers_request() {
894        let original_msg = Message {
895            transaction_id: 258,
896            version: Some([72, 73, 0, 1]),
897            requester_ip: None,
898            read_only: false,
899            message_type: MessageType::Request(RequestSpecific {
900                requester_id: Id::random(),
901                request_type: RequestTypeSpecific::GetPeers(GetPeersRequestArguments {
902                    info_hash: Id::random(),
903                }),
904            }),
905        };
906
907        let serde_msg = original_msg.clone().into_serde_message();
908        let bytes = serde_msg.to_bytes().unwrap();
909        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
910        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
911        assert_eq!(parsed_msg, original_msg);
912    }
913
914    #[test]
915    fn test_get_peers_response() {
916        let original_msg = Message {
917            transaction_id: 3,
918            version: Some([1, 2, 3, 4]),
919            requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
920            read_only: true,
921            message_type: MessageType::Response(ResponseSpecific::NoValues(
922                NoValuesResponseArguments {
923                    responder_id: Id::random(),
924                    token: [99, 100, 101, 102].into(),
925                    nodes: Some(
926                        [Node::new(Id::random(), "49.50.52.52:5354".parse().unwrap())].into(),
927                    ),
928                },
929            )),
930        };
931
932        let serde_msg = original_msg.clone().into_serde_message();
933        let bytes = serde_msg.to_bytes().unwrap();
934        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
935        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
936
937        assert_eq!(parsed_msg.transaction_id, original_msg.transaction_id);
938        assert_eq!(parsed_msg.version, original_msg.version);
939        assert_eq!(parsed_msg.requester_ip, original_msg.requester_ip);
940        assert_eq!(parsed_msg.get_author_id(), original_msg.get_author_id());
941        assert_eq!(
942            parsed_msg.get_closer_nodes().map(|nodes| nodes
943                .iter()
944                .map(|n| (n.id(), n.address()))
945                .collect::<Vec<_>>()),
946            original_msg.get_closer_nodes().map(|nodes| nodes
947                .iter()
948                .map(|n| (n.id(), n.address()))
949                .collect::<Vec<_>>())
950        );
951    }
952
953    #[test]
954    fn test_get_peers_response_peers() {
955        let original_msg = Message {
956            transaction_id: 3,
957            version: Some([1, 2, 3, 4]),
958            requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
959            read_only: false,
960            message_type: MessageType::Response(ResponseSpecific::GetPeers(
961                GetPeersResponseArguments {
962                    responder_id: Id::random(),
963                    token: vec![99, 100, 101, 102].into(),
964                    nodes: None,
965                    values: ["123.123.123.123:123".parse().unwrap()].into(),
966                },
967            )),
968        };
969
970        let serde_msg = original_msg.clone().into_serde_message();
971        let bytes = serde_msg.to_bytes().unwrap();
972        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
973        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
974        assert_eq!(parsed_msg, original_msg);
975    }
976
977    #[test]
978    fn test_get_peers_response_neither() {
979        let serde_message = internal::DHTMessage {
980            ip: None,
981            read_only: None,
982            transaction_id: [1, 2, 3, 4],
983            version: None,
984            variant: internal::DHTMessageVariant::Response(
985                internal::DHTResponseSpecific::NoValues {
986                    arguments: internal::DHTNoValuesResponseArguments {
987                        id: Id::random().into(),
988                        token: vec![0, 1].into(),
989                        nodes: None,
990                    },
991                },
992            ),
993        };
994        let parsed_msg = Message::from_serde_message(serde_message).unwrap();
995        assert!(matches!(
996            parsed_msg.message_type,
997            MessageType::Response(ResponseSpecific::NoValues(NoValuesResponseArguments { .. }))
998        ));
999    }
1000
1001    #[test]
1002    fn test_get_immutable_request() {
1003        let original_msg = Message {
1004            transaction_id: 258,
1005            version: Some([72, 73, 0, 1]),
1006            requester_ip: None,
1007            read_only: false,
1008            message_type: MessageType::Request(RequestSpecific {
1009                requester_id: Id::random(),
1010                request_type: RequestTypeSpecific::GetValue(GetValueRequestArguments {
1011                    target: Id::random(),
1012                    seq: Some(1231),
1013                    salt: None,
1014                }),
1015            }),
1016        };
1017
1018        let serde_msg = original_msg.clone().into_serde_message();
1019        let bytes = serde_msg.to_bytes().unwrap();
1020        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
1021        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
1022        assert_eq!(parsed_msg, original_msg);
1023    }
1024
1025    #[test]
1026    fn test_get_immutable_response() {
1027        let original_msg = Message {
1028            transaction_id: 3,
1029            version: Some([1, 2, 3, 4]),
1030            requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
1031            read_only: false,
1032            message_type: MessageType::Response(ResponseSpecific::GetImmutable(
1033                GetImmutableResponseArguments {
1034                    responder_id: Id::random(),
1035                    token: [99, 100, 101, 102].into(),
1036                    nodes: None,
1037                    v: [99, 100, 101, 102].into(),
1038                },
1039            )),
1040        };
1041
1042        let serde_msg = original_msg.clone().into_serde_message();
1043        let bytes = serde_msg.to_bytes().unwrap();
1044        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
1045        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
1046        assert_eq!(parsed_msg, original_msg);
1047    }
1048
1049    #[test]
1050    fn test_put_immutable_request() {
1051        let original_msg = Message {
1052            transaction_id: 3,
1053            version: Some([1, 2, 3, 4]),
1054            requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
1055            read_only: false,
1056            message_type: MessageType::Request(RequestSpecific {
1057                requester_id: Id::random(),
1058                request_type: RequestTypeSpecific::Put(PutRequest {
1059                    token: [99, 100, 101, 102].into(),
1060                    put_request_type: PutRequestSpecific::PutImmutable(
1061                        PutImmutableRequestArguments {
1062                            target: Id::random(),
1063                            v: [99, 100, 101, 102].into(),
1064                        },
1065                    ),
1066                }),
1067            }),
1068        };
1069
1070        let serde_msg = original_msg.clone().into_serde_message();
1071        let bytes = serde_msg.to_bytes().unwrap();
1072        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
1073        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
1074        assert_eq!(parsed_msg, original_msg);
1075    }
1076
1077    #[test]
1078    fn test_put_mutable_request() {
1079        let original_msg = Message {
1080            transaction_id: 3,
1081            version: Some([1, 2, 3, 4]),
1082            requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
1083            read_only: false,
1084            message_type: MessageType::Request(RequestSpecific {
1085                requester_id: Id::random(),
1086                request_type: RequestTypeSpecific::Put(PutRequest {
1087                    token: [99, 100, 101, 102].into(),
1088                    put_request_type: PutRequestSpecific::PutMutable(PutMutableRequestArguments {
1089                        target: Id::random(),
1090                        v: [99, 100, 101, 102].into(),
1091                        k: [100; 32],
1092                        seq: 100,
1093                        sig: [0; 64],
1094                        salt: Some([0, 2, 4, 8].into()),
1095                        cas: Some(100),
1096                    }),
1097                }),
1098            }),
1099        };
1100
1101        let serde_msg = original_msg.clone().into_serde_message();
1102        let bytes = serde_msg.to_bytes().unwrap();
1103        let parsed_serde_msg = internal::DHTMessage::from_bytes(&bytes).unwrap();
1104        let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
1105        assert_eq!(parsed_msg, original_msg);
1106    }
1107}