Skip to main content

mtorrent_dht/
udp.rs

1use super::error::Error as DhtError;
2use super::msgs::Message;
3use mtorrent_utils::{benc, debug_stopwatch};
4use std::future::Future;
5use std::mem::MaybeUninit;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::task::{Context, Poll, ready};
9use tokio::io::ReadBuf;
10use tokio::net::UdpSocket;
11use tokio::select;
12use tokio::sync::mpsc;
13use tokio::sync::mpsc::error::TrySendError;
14
15/// Actor that reads and writes UDP packets, and encodes/decodes DHT messages.
16pub struct IoDriver {
17    socket: UdpSocket,
18    ingress_sender: mpsc::Sender<(Message, SocketAddr)>,
19    egress_receiver: mpsc::Receiver<(Message, SocketAddr)>,
20}
21
22pub struct MessageChannelSender(pub(crate) mpsc::Sender<(Message, SocketAddr)>);
23pub struct MessageChannelReceiver(pub(crate) mpsc::Receiver<(Message, SocketAddr)>);
24
25pub(super) const MSG_QUEUE_LEN: usize = 512;
26
27/// Create the networking layer that handles low-level I/O.
28pub fn setup_udp(socket: UdpSocket) -> (MessageChannelSender, MessageChannelReceiver, IoDriver) {
29    let (ingress_sender, ingress_receiver) = mpsc::channel(MSG_QUEUE_LEN);
30    let (egress_sender, egress_receiver) = mpsc::channel(MSG_QUEUE_LEN);
31    let actor = IoDriver {
32        socket,
33        ingress_sender,
34        egress_receiver,
35    };
36    (MessageChannelSender(egress_sender), MessageChannelReceiver(ingress_receiver), actor)
37}
38
39impl IoDriver {
40    pub async fn run(self) {
41        let _sw = debug_stopwatch!("UDP runner");
42        let ingress = Ingress {
43            socket: &self.socket,
44            buffer: Box::new_uninit_slice(RX_BUFFER_SIZE),
45            sink: self.ingress_sender,
46        };
47        let egress = Egress {
48            socket: &self.socket,
49            pending: None,
50            source: self.egress_receiver,
51        };
52        select! {
53            biased;
54            _ = egress => (),
55            _ = ingress => (),
56        }
57    }
58}
59
60// Testing shows that we sometimes receive messages around 32 KiB in size.
61const RX_BUFFER_SIZE: usize = 32 * 1024;
62
63struct Ingress<'s> {
64    socket: &'s UdpSocket,
65    buffer: Box<[MaybeUninit<u8>]>,
66    sink: mpsc::Sender<(Message, SocketAddr)>,
67}
68
69struct Egress<'s> {
70    socket: &'s UdpSocket,
71    pending: Option<(Vec<u8>, SocketAddr)>,
72    source: mpsc::Receiver<(Message, SocketAddr)>,
73}
74
75impl<'s> Future for Ingress<'s> {
76    type Output = ();
77
78    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79        fn parse_msg(buffer: &[u8]) -> Result<Message, DhtError> {
80            let (bencode, len) = benc::Element::from_bytes_with_len(buffer)?;
81            if len < buffer.len() {
82                log::debug!("Ignored trailing bytes in an incoming UDP packet");
83            }
84            let message = Message::try_from(bencode)?;
85            Ok(message)
86        }
87
88        let Ingress {
89            socket,
90            buffer,
91            sink,
92        } = self.get_mut();
93        let mut buffer = ReadBuf::uninit(buffer);
94        loop {
95            buffer.clear();
96            let src_addr = match ready!(socket.poll_recv_from(cx, &mut buffer)) {
97                Err(e) => {
98                    log::error!("Failed to receive UDP packet: {e}");
99                    continue;
100                }
101                Ok(addr) => addr,
102            };
103            let message = match parse_msg(buffer.filled()) {
104                Err(e) => {
105                    log::debug!("Failed to parse message from {src_addr}: {e}");
106                    continue;
107                }
108                Ok(msg) => msg,
109            };
110            match sink.try_send((message, src_addr)) {
111                Err(TrySendError::Closed(_)) => {
112                    return Poll::Ready(());
113                }
114                Err(TrySendError::Full(_)) => {
115                    log::warn!("Dropping message from {src_addr}: channel is full");
116                    continue;
117                }
118                Ok(()) => continue,
119            }
120        }
121    }
122}
123
124impl<'s> Future for Egress<'s> {
125    type Output = ();
126
127    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128        let Egress {
129            socket,
130            pending,
131            source,
132        } = self.get_mut();
133
134        loop {
135            match pending {
136                None => match ready!(source.poll_recv(cx)) {
137                    None => {
138                        return Poll::Ready(());
139                    }
140                    Some((message, dst_addr)) => {
141                        let bencode = benc::Element::from(message);
142                        *pending = Some((bencode.encode(), dst_addr));
143                    }
144                },
145                Some((data, dest_addr)) => {
146                    let data_len = data.len();
147                    match ready!(socket.poll_send_to(cx, data, *dest_addr)) {
148                        Err(e) => {
149                            log::warn!("Failed to send UDP packet to {dest_addr}: {e}");
150                        }
151                        Ok(bytes_sent) if bytes_sent != data_len => {
152                            log::error!(
153                                "Could only send {bytes_sent}/{data_len} bytes to {dest_addr}"
154                            );
155                        }
156                        Ok(_) => (),
157                    }
158                    *pending = None;
159                }
160            }
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::msgs::*;
169    use crate::u160::U160;
170    use local_async_utils::prelude::*;
171    use std::net::{Ipv4Addr, SocketAddrV4};
172    use std::{io, iter};
173    use tokio::task;
174    use tokio::time::timeout;
175
176    async fn create_ipv4_socket(port: u16) -> io::Result<UdpSocket> {
177        UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await
178    }
179
180    #[tokio::test]
181    async fn receive_single_message() {
182        let sender_port = 6666u16;
183        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7777);
184
185        let sender_sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, sender_port))
186            .await
187            .unwrap();
188
189        let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
190        let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
191        task::spawn(runner.run());
192
193        let sent_msg = Message {
194            transaction_id: vec![1, 2, 3, 4],
195            version: None,
196            data: MessageData::Query(
197                PingArgs {
198                    id: [12u8; 20].into(),
199                }
200                .into(),
201            ),
202        };
203        sender_sock
204            .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
205            .await
206            .unwrap();
207        let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
208        assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), sender_port));
209        assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
210        assert_eq!(receved_msg.version, None);
211        let ping = match receved_msg.data {
212            MessageData::Query(QueryMsg::Ping(ping)) => ping,
213            _ => panic!("Expected a ping query"),
214        };
215        assert_eq!(ping.id, U160::from([12u8; 20]));
216    }
217
218    #[tokio::test]
219    async fn receive_valid_message_after_malformed() {
220        let _ = simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init();
221
222        let bad_sender_port = 6667u16;
223        let good_sender_port = 6668u16;
224        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7778u16);
225
226        let bad_sender = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, bad_sender_port))
227            .await
228            .unwrap();
229
230        let good_sender = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, good_sender_port))
231            .await
232            .unwrap();
233
234        let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
235        let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
236        task::spawn(runner.run());
237
238        bad_sender.send_to(b"malformed", receiver_addr).await.unwrap();
239
240        let sent_msg = Message {
241            transaction_id: vec![1, 2, 3, 4],
242            version: None,
243            data: MessageData::Query(
244                PingArgs {
245                    id: [12u8; 20].into(),
246                }
247                .into(),
248            ),
249        };
250        good_sender
251            .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
252            .await
253            .unwrap();
254
255        let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
256        assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), good_sender_port));
257        assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
258        assert_eq!(receved_msg.version, None);
259        let ping = match receved_msg.data {
260            MessageData::Query(QueryMsg::Ping(ping)) => ping,
261            _ => panic!("Expected a ping query"),
262        };
263        assert_eq!(ping.id, U160::from([12u8; 20]));
264    }
265
266    #[tokio::test]
267    async fn send_single_message() {
268        let sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 6669);
269        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7780);
270
271        let receiver_sock = UdpSocket::bind(receiver_addr).await.unwrap();
272
273        let socket = create_ipv4_socket(sender_addr.port()).await.unwrap();
274        let (MessageChannelSender(tx_channel), _rx_channel, runner) = setup_udp(socket);
275        task::spawn(runner.run());
276
277        tx_channel
278            .send((
279                Message {
280                    transaction_id: Vec::from(b"aa"),
281                    version: None,
282                    data: MessageData::Error(ErrorMsg {
283                        error_code: ErrorCode::Generic,
284                        error_msg: "A Generic Error Ocurred".to_owned(),
285                    }),
286                },
287                receiver_addr.into(),
288            ))
289            .await
290            .unwrap();
291
292        let mut buf = [0u8; 1500];
293        let (len, src_addr) =
294            timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
295        assert_eq!(src_addr, sender_addr.into());
296        assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:aa1:y1:ee");
297    }
298
299    #[tokio::test]
300    async fn successful_send_after_failed_send() {
301        let non_local_addr: SocketAddr = "212.129.33.59:6881".parse().unwrap();
302        let sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7782);
303        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7783);
304
305        let receiver_sock = UdpSocket::bind(receiver_addr).await.unwrap();
306
307        let socket = UdpSocket::bind(sender_addr).await.unwrap();
308        let (MessageChannelSender(tx_channel), _rx_channel, runner) = setup_udp(socket);
309        task::spawn(runner.run());
310
311        // send a message to nowhere
312        tx_channel
313            .send((
314                Message {
315                    transaction_id: Vec::from(b"aa"),
316                    version: None,
317                    data: MessageData::Error(ErrorMsg {
318                        error_code: ErrorCode::Generic,
319                        error_msg: "A Generic Error Ocurred".to_owned(),
320                    }),
321                },
322                non_local_addr,
323            ))
324            .await
325            .unwrap();
326
327        // send a message to somewhere
328        tx_channel
329            .send((
330                Message {
331                    transaction_id: Vec::from(b"aa"),
332                    version: None,
333                    data: MessageData::Error(ErrorMsg {
334                        error_code: ErrorCode::Generic,
335                        error_msg: "A Generic Error Ocurred".to_owned(),
336                    }),
337                },
338                receiver_addr.into(),
339            ))
340            .await
341            .unwrap();
342
343        let mut buf = [0u8; 1500];
344        let (len, src_addr) =
345            timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
346        assert_eq!(src_addr, sender_addr.into());
347        assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:aa1:y1:ee");
348    }
349
350    #[tokio::test]
351    async fn receive_multiple_messages() {
352        let sender_port = 7784u16;
353        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7785);
354
355        let sender_sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, sender_port))
356            .await
357            .unwrap();
358
359        let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
360        let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
361        task::spawn(runner.run());
362
363        let sent_msg = Message {
364            transaction_id: vec![1, 2, 3, 4],
365            version: None,
366            data: MessageData::Query(
367                PingArgs {
368                    id: [12u8; 20].into(),
369                }
370                .into(),
371            ),
372        };
373        sender_sock
374            .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
375            .await
376            .unwrap();
377        let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
378        assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), sender_port));
379        assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
380        assert_eq!(receved_msg.version, None);
381        let ping = match receved_msg.data {
382            MessageData::Query(QueryMsg::Ping(ping)) => ping,
383            _ => panic!("Expected a ping query"),
384        };
385        assert_eq!(ping.id, U160::from([12u8; 20]));
386
387        let sent_msg = Message {
388            transaction_id: vec![5, 6, 7, 8],
389            version: None,
390            data: MessageData::Query(
391                PingArgs {
392                    id: [13u8; 20].into(),
393                }
394                .into(),
395            ),
396        };
397        sender_sock
398            .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
399            .await
400            .unwrap();
401        let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
402        assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), sender_port));
403        assert_eq!(receved_msg.transaction_id, vec![5, 6, 7, 8]);
404        assert_eq!(receved_msg.version, None);
405        let ping = match receved_msg.data {
406            MessageData::Query(QueryMsg::Ping(ping)) => ping,
407            _ => panic!("Expected a ping query"),
408        };
409        assert_eq!(ping.id, U160::from([13u8; 20]));
410    }
411
412    #[tokio::test]
413    async fn send_multiple_messages() {
414        let sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7786);
415        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7787);
416
417        let receiver_sock = UdpSocket::bind(receiver_addr).await.unwrap();
418
419        let socket = create_ipv4_socket(sender_addr.port()).await.unwrap();
420        let (MessageChannelSender(tx_channel), _rx_channel, runner) = setup_udp(socket);
421        task::spawn(runner.run());
422
423        tx_channel
424            .send((
425                Message {
426                    transaction_id: Vec::from(b"aa"),
427                    version: None,
428                    data: MessageData::Error(ErrorMsg {
429                        error_code: ErrorCode::Generic,
430                        error_msg: "A Generic Error Ocurred".to_owned(),
431                    }),
432                },
433                receiver_addr.into(),
434            ))
435            .await
436            .unwrap();
437
438        let mut buf = [0u8; 1500];
439        let (len, src_addr) =
440            timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
441        assert_eq!(src_addr, sender_addr.into());
442        assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:aa1:y1:ee");
443
444        tx_channel
445            .send((
446                Message {
447                    transaction_id: Vec::from(b"bb"),
448                    version: None,
449                    data: MessageData::Error(ErrorMsg {
450                        error_code: ErrorCode::Generic,
451                        error_msg: "A Generic Error Ocurred".to_owned(),
452                    }),
453                },
454                receiver_addr.into(),
455            ))
456            .await
457            .unwrap();
458
459        let mut buf = [0u8; 1500];
460        let (len, src_addr) =
461            timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
462        assert_eq!(src_addr, sender_addr.into());
463        assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:bb1:y1:ee");
464    }
465
466    #[tokio::test]
467    async fn drop_incoming_message_when_channel_is_full() {
468        let _ = simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init();
469
470        let first_sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7788);
471        let second_sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7789);
472        let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7791);
473
474        let first_sender = UdpSocket::bind(first_sender_addr).await.unwrap();
475        let second_sender = UdpSocket::bind(second_sender_addr).await.unwrap();
476
477        let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
478        let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
479        let mut _reserved = iter::repeat_with(|| runner.ingress_sender.clone().try_reserve_owned())
480            .take(511)
481            .collect::<Vec<_>>();
482        task::spawn(runner.run());
483
484        let first_msg = Message {
485            transaction_id: vec![1, 2, 3, 4],
486            version: None,
487            data: MessageData::Query(
488                PingArgs {
489                    id: [12u8; 20].into(),
490                }
491                .into(),
492            ),
493        };
494        first_sender
495            .send_to(&benc::Element::from(first_msg).encode(), receiver_addr)
496            .await
497            .unwrap();
498
499        let dropped_msg = Message {
500            transaction_id: vec![5, 6, 7, 8],
501            version: None,
502            data: MessageData::Query(
503                PingArgs {
504                    id: [13u8; 20].into(),
505                }
506                .into(),
507            ),
508        };
509        second_sender
510            .send_to(&benc::Element::from(dropped_msg).encode(), receiver_addr)
511            .await
512            .unwrap();
513
514        let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
515        assert_eq!(src_addr, first_sender_addr.into());
516        assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
517        assert_eq!(receved_msg.version, None);
518        let ping = match receved_msg.data {
519            MessageData::Query(QueryMsg::Ping(ping)) => ping,
520            _ => panic!("Expected a ping query"),
521        };
522        assert_eq!(ping.id, U160::from([12u8; 20]));
523
524        let last_msg = Message {
525            transaction_id: vec![9, 0],
526            version: None,
527            data: MessageData::Query(
528                PingArgs {
529                    id: [14u8; 20].into(),
530                }
531                .into(),
532            ),
533        };
534        second_sender
535            .send_to(&benc::Element::from(last_msg).encode(), receiver_addr)
536            .await
537            .unwrap();
538
539        let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
540        assert_eq!(src_addr, second_sender_addr.into());
541        assert_eq!(receved_msg.transaction_id, vec![9, 0]);
542        assert_eq!(receved_msg.version, None);
543        let ping = match receved_msg.data {
544            MessageData::Query(QueryMsg::Ping(ping)) => ping,
545            _ => panic!("Expected a ping query"),
546        };
547        assert_eq!(ping.id, U160::from([14u8; 20]));
548    }
549}