1use super::error::Error as DhtError;
2use super::msgs::Message;
3use mtorrent_utils::{benc, debug_stopwatch};
4use std::future::Future;
5use std::mem::MaybeUninit;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::task::{Context, Poll, ready};
9use tokio::io::ReadBuf;
10use tokio::net::UdpSocket;
11use tokio::select;
12use tokio::sync::mpsc;
13use tokio::sync::mpsc::error::TrySendError;
14
15pub struct IoDriver {
17 socket: UdpSocket,
18 ingress_sender: mpsc::Sender<(Message, SocketAddr)>,
19 egress_receiver: mpsc::Receiver<(Message, SocketAddr)>,
20}
21
22pub struct MessageChannelSender(pub(crate) mpsc::Sender<(Message, SocketAddr)>);
23pub struct MessageChannelReceiver(pub(crate) mpsc::Receiver<(Message, SocketAddr)>);
24
25pub(super) const MSG_QUEUE_LEN: usize = 512;
26
27pub fn setup_udp(socket: UdpSocket) -> (MessageChannelSender, MessageChannelReceiver, IoDriver) {
29 let (ingress_sender, ingress_receiver) = mpsc::channel(MSG_QUEUE_LEN);
30 let (egress_sender, egress_receiver) = mpsc::channel(MSG_QUEUE_LEN);
31 let actor = IoDriver {
32 socket,
33 ingress_sender,
34 egress_receiver,
35 };
36 (MessageChannelSender(egress_sender), MessageChannelReceiver(ingress_receiver), actor)
37}
38
39impl IoDriver {
40 pub async fn run(self) {
41 let _sw = debug_stopwatch!("UDP runner");
42 let ingress = Ingress {
43 socket: &self.socket,
44 buffer: Box::new_uninit_slice(RX_BUFFER_SIZE),
45 sink: self.ingress_sender,
46 };
47 let egress = Egress {
48 socket: &self.socket,
49 pending: None,
50 source: self.egress_receiver,
51 };
52 select! {
53 biased;
54 _ = egress => (),
55 _ = ingress => (),
56 }
57 }
58}
59
60const RX_BUFFER_SIZE: usize = 32 * 1024;
62
63struct Ingress<'s> {
64 socket: &'s UdpSocket,
65 buffer: Box<[MaybeUninit<u8>]>,
66 sink: mpsc::Sender<(Message, SocketAddr)>,
67}
68
69struct Egress<'s> {
70 socket: &'s UdpSocket,
71 pending: Option<(Vec<u8>, SocketAddr)>,
72 source: mpsc::Receiver<(Message, SocketAddr)>,
73}
74
75impl<'s> Future for Ingress<'s> {
76 type Output = ();
77
78 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79 fn parse_msg(buffer: &[u8]) -> Result<Message, DhtError> {
80 let (bencode, len) = benc::Element::from_bytes_with_len(buffer)?;
81 if len < buffer.len() {
82 log::debug!("Ignored trailing bytes in an incoming UDP packet");
83 }
84 let message = Message::try_from(bencode)?;
85 Ok(message)
86 }
87
88 let Ingress {
89 socket,
90 buffer,
91 sink,
92 } = self.get_mut();
93 let mut buffer = ReadBuf::uninit(buffer);
94 loop {
95 buffer.clear();
96 let src_addr = match ready!(socket.poll_recv_from(cx, &mut buffer)) {
97 Err(e) => {
98 log::error!("Failed to receive UDP packet: {e}");
99 continue;
100 }
101 Ok(addr) => addr,
102 };
103 let message = match parse_msg(buffer.filled()) {
104 Err(e) => {
105 log::debug!("Failed to parse message from {src_addr}: {e}");
106 continue;
107 }
108 Ok(msg) => msg,
109 };
110 match sink.try_send((message, src_addr)) {
111 Err(TrySendError::Closed(_)) => {
112 return Poll::Ready(());
113 }
114 Err(TrySendError::Full(_)) => {
115 log::warn!("Dropping message from {src_addr}: channel is full");
116 continue;
117 }
118 Ok(()) => continue,
119 }
120 }
121 }
122}
123
124impl<'s> Future for Egress<'s> {
125 type Output = ();
126
127 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128 let Egress {
129 socket,
130 pending,
131 source,
132 } = self.get_mut();
133
134 loop {
135 match pending {
136 None => match ready!(source.poll_recv(cx)) {
137 None => {
138 return Poll::Ready(());
139 }
140 Some((message, dst_addr)) => {
141 let bencode = benc::Element::from(message);
142 *pending = Some((bencode.encode(), dst_addr));
143 }
144 },
145 Some((data, dest_addr)) => {
146 let data_len = data.len();
147 match ready!(socket.poll_send_to(cx, data, *dest_addr)) {
148 Err(e) => {
149 log::warn!("Failed to send UDP packet to {dest_addr}: {e}");
150 }
151 Ok(bytes_sent) if bytes_sent != data_len => {
152 log::error!(
153 "Could only send {bytes_sent}/{data_len} bytes to {dest_addr}"
154 );
155 }
156 Ok(_) => (),
157 }
158 *pending = None;
159 }
160 }
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::msgs::*;
169 use crate::u160::U160;
170 use local_async_utils::prelude::*;
171 use std::net::{Ipv4Addr, SocketAddrV4};
172 use std::{io, iter};
173 use tokio::task;
174 use tokio::time::timeout;
175
176 async fn create_ipv4_socket(port: u16) -> io::Result<UdpSocket> {
177 UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await
178 }
179
180 #[tokio::test]
181 async fn receive_single_message() {
182 let sender_port = 6666u16;
183 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7777);
184
185 let sender_sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, sender_port))
186 .await
187 .unwrap();
188
189 let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
190 let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
191 task::spawn(runner.run());
192
193 let sent_msg = Message {
194 transaction_id: vec![1, 2, 3, 4],
195 version: None,
196 data: MessageData::Query(
197 PingArgs {
198 id: [12u8; 20].into(),
199 }
200 .into(),
201 ),
202 };
203 sender_sock
204 .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
205 .await
206 .unwrap();
207 let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
208 assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), sender_port));
209 assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
210 assert_eq!(receved_msg.version, None);
211 let ping = match receved_msg.data {
212 MessageData::Query(QueryMsg::Ping(ping)) => ping,
213 _ => panic!("Expected a ping query"),
214 };
215 assert_eq!(ping.id, U160::from([12u8; 20]));
216 }
217
218 #[tokio::test]
219 async fn receive_valid_message_after_malformed() {
220 let _ = simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init();
221
222 let bad_sender_port = 6667u16;
223 let good_sender_port = 6668u16;
224 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7778u16);
225
226 let bad_sender = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, bad_sender_port))
227 .await
228 .unwrap();
229
230 let good_sender = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, good_sender_port))
231 .await
232 .unwrap();
233
234 let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
235 let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
236 task::spawn(runner.run());
237
238 bad_sender.send_to(b"malformed", receiver_addr).await.unwrap();
239
240 let sent_msg = Message {
241 transaction_id: vec![1, 2, 3, 4],
242 version: None,
243 data: MessageData::Query(
244 PingArgs {
245 id: [12u8; 20].into(),
246 }
247 .into(),
248 ),
249 };
250 good_sender
251 .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
252 .await
253 .unwrap();
254
255 let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
256 assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), good_sender_port));
257 assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
258 assert_eq!(receved_msg.version, None);
259 let ping = match receved_msg.data {
260 MessageData::Query(QueryMsg::Ping(ping)) => ping,
261 _ => panic!("Expected a ping query"),
262 };
263 assert_eq!(ping.id, U160::from([12u8; 20]));
264 }
265
266 #[tokio::test]
267 async fn send_single_message() {
268 let sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 6669);
269 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7780);
270
271 let receiver_sock = UdpSocket::bind(receiver_addr).await.unwrap();
272
273 let socket = create_ipv4_socket(sender_addr.port()).await.unwrap();
274 let (MessageChannelSender(tx_channel), _rx_channel, runner) = setup_udp(socket);
275 task::spawn(runner.run());
276
277 tx_channel
278 .send((
279 Message {
280 transaction_id: Vec::from(b"aa"),
281 version: None,
282 data: MessageData::Error(ErrorMsg {
283 error_code: ErrorCode::Generic,
284 error_msg: "A Generic Error Ocurred".to_owned(),
285 }),
286 },
287 receiver_addr.into(),
288 ))
289 .await
290 .unwrap();
291
292 let mut buf = [0u8; 1500];
293 let (len, src_addr) =
294 timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
295 assert_eq!(src_addr, sender_addr.into());
296 assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:aa1:y1:ee");
297 }
298
299 #[tokio::test]
300 async fn successful_send_after_failed_send() {
301 let non_local_addr: SocketAddr = "212.129.33.59:6881".parse().unwrap();
302 let sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7782);
303 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7783);
304
305 let receiver_sock = UdpSocket::bind(receiver_addr).await.unwrap();
306
307 let socket = UdpSocket::bind(sender_addr).await.unwrap();
308 let (MessageChannelSender(tx_channel), _rx_channel, runner) = setup_udp(socket);
309 task::spawn(runner.run());
310
311 tx_channel
313 .send((
314 Message {
315 transaction_id: Vec::from(b"aa"),
316 version: None,
317 data: MessageData::Error(ErrorMsg {
318 error_code: ErrorCode::Generic,
319 error_msg: "A Generic Error Ocurred".to_owned(),
320 }),
321 },
322 non_local_addr,
323 ))
324 .await
325 .unwrap();
326
327 tx_channel
329 .send((
330 Message {
331 transaction_id: Vec::from(b"aa"),
332 version: None,
333 data: MessageData::Error(ErrorMsg {
334 error_code: ErrorCode::Generic,
335 error_msg: "A Generic Error Ocurred".to_owned(),
336 }),
337 },
338 receiver_addr.into(),
339 ))
340 .await
341 .unwrap();
342
343 let mut buf = [0u8; 1500];
344 let (len, src_addr) =
345 timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
346 assert_eq!(src_addr, sender_addr.into());
347 assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:aa1:y1:ee");
348 }
349
350 #[tokio::test]
351 async fn receive_multiple_messages() {
352 let sender_port = 7784u16;
353 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7785);
354
355 let sender_sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, sender_port))
356 .await
357 .unwrap();
358
359 let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
360 let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
361 task::spawn(runner.run());
362
363 let sent_msg = Message {
364 transaction_id: vec![1, 2, 3, 4],
365 version: None,
366 data: MessageData::Query(
367 PingArgs {
368 id: [12u8; 20].into(),
369 }
370 .into(),
371 ),
372 };
373 sender_sock
374 .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
375 .await
376 .unwrap();
377 let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
378 assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), sender_port));
379 assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
380 assert_eq!(receved_msg.version, None);
381 let ping = match receved_msg.data {
382 MessageData::Query(QueryMsg::Ping(ping)) => ping,
383 _ => panic!("Expected a ping query"),
384 };
385 assert_eq!(ping.id, U160::from([12u8; 20]));
386
387 let sent_msg = Message {
388 transaction_id: vec![5, 6, 7, 8],
389 version: None,
390 data: MessageData::Query(
391 PingArgs {
392 id: [13u8; 20].into(),
393 }
394 .into(),
395 ),
396 };
397 sender_sock
398 .send_to(&benc::Element::from(sent_msg).encode(), receiver_addr)
399 .await
400 .unwrap();
401 let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
402 assert_eq!(src_addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), sender_port));
403 assert_eq!(receved_msg.transaction_id, vec![5, 6, 7, 8]);
404 assert_eq!(receved_msg.version, None);
405 let ping = match receved_msg.data {
406 MessageData::Query(QueryMsg::Ping(ping)) => ping,
407 _ => panic!("Expected a ping query"),
408 };
409 assert_eq!(ping.id, U160::from([13u8; 20]));
410 }
411
412 #[tokio::test]
413 async fn send_multiple_messages() {
414 let sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7786);
415 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7787);
416
417 let receiver_sock = UdpSocket::bind(receiver_addr).await.unwrap();
418
419 let socket = create_ipv4_socket(sender_addr.port()).await.unwrap();
420 let (MessageChannelSender(tx_channel), _rx_channel, runner) = setup_udp(socket);
421 task::spawn(runner.run());
422
423 tx_channel
424 .send((
425 Message {
426 transaction_id: Vec::from(b"aa"),
427 version: None,
428 data: MessageData::Error(ErrorMsg {
429 error_code: ErrorCode::Generic,
430 error_msg: "A Generic Error Ocurred".to_owned(),
431 }),
432 },
433 receiver_addr.into(),
434 ))
435 .await
436 .unwrap();
437
438 let mut buf = [0u8; 1500];
439 let (len, src_addr) =
440 timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
441 assert_eq!(src_addr, sender_addr.into());
442 assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:aa1:y1:ee");
443
444 tx_channel
445 .send((
446 Message {
447 transaction_id: Vec::from(b"bb"),
448 version: None,
449 data: MessageData::Error(ErrorMsg {
450 error_code: ErrorCode::Generic,
451 error_msg: "A Generic Error Ocurred".to_owned(),
452 }),
453 },
454 receiver_addr.into(),
455 ))
456 .await
457 .unwrap();
458
459 let mut buf = [0u8; 1500];
460 let (len, src_addr) =
461 timeout(sec!(5), receiver_sock.recv_from(&mut buf)).await.unwrap().unwrap();
462 assert_eq!(src_addr, sender_addr.into());
463 assert_eq!(&buf[..len], b"d1:eli201e23:A Generic Error Ocurrede1:t2:bb1:y1:ee");
464 }
465
466 #[tokio::test]
467 async fn drop_incoming_message_when_channel_is_full() {
468 let _ = simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init();
469
470 let first_sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7788);
471 let second_sender_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7789);
472 let receiver_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7791);
473
474 let first_sender = UdpSocket::bind(first_sender_addr).await.unwrap();
475 let second_sender = UdpSocket::bind(second_sender_addr).await.unwrap();
476
477 let socket = create_ipv4_socket(receiver_addr.port()).await.unwrap();
478 let (_tx_channel, MessageChannelReceiver(mut rx_channel), runner) = setup_udp(socket);
479 let mut _reserved = iter::repeat_with(|| runner.ingress_sender.clone().try_reserve_owned())
480 .take(511)
481 .collect::<Vec<_>>();
482 task::spawn(runner.run());
483
484 let first_msg = Message {
485 transaction_id: vec![1, 2, 3, 4],
486 version: None,
487 data: MessageData::Query(
488 PingArgs {
489 id: [12u8; 20].into(),
490 }
491 .into(),
492 ),
493 };
494 first_sender
495 .send_to(&benc::Element::from(first_msg).encode(), receiver_addr)
496 .await
497 .unwrap();
498
499 let dropped_msg = Message {
500 transaction_id: vec![5, 6, 7, 8],
501 version: None,
502 data: MessageData::Query(
503 PingArgs {
504 id: [13u8; 20].into(),
505 }
506 .into(),
507 ),
508 };
509 second_sender
510 .send_to(&benc::Element::from(dropped_msg).encode(), receiver_addr)
511 .await
512 .unwrap();
513
514 let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
515 assert_eq!(src_addr, first_sender_addr.into());
516 assert_eq!(receved_msg.transaction_id, vec![1, 2, 3, 4]);
517 assert_eq!(receved_msg.version, None);
518 let ping = match receved_msg.data {
519 MessageData::Query(QueryMsg::Ping(ping)) => ping,
520 _ => panic!("Expected a ping query"),
521 };
522 assert_eq!(ping.id, U160::from([12u8; 20]));
523
524 let last_msg = Message {
525 transaction_id: vec![9, 0],
526 version: None,
527 data: MessageData::Query(
528 PingArgs {
529 id: [14u8; 20].into(),
530 }
531 .into(),
532 ),
533 };
534 second_sender
535 .send_to(&benc::Element::from(last_msg).encode(), receiver_addr)
536 .await
537 .unwrap();
538
539 let (receved_msg, src_addr) = timeout(sec!(5), rx_channel.recv()).await.unwrap().unwrap();
540 assert_eq!(src_addr, second_sender_addr.into());
541 assert_eq!(receved_msg.transaction_id, vec![9, 0]);
542 assert_eq!(receved_msg.version, None);
543 let ping = match receved_msg.data {
544 MessageData::Query(QueryMsg::Ping(ping)) => ping,
545 _ => panic!("Expected a ping query"),
546 };
547 assert_eq!(ping.id, U160::from([14u8; 20]));
548 }
549}