async_foundation/net/
udp_socket.rs

1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::UdpSocket as MioUdpSocket;
7use std::fmt::{Debug, Error, Formatter};
8use std::io;
9use std::net::UdpSocket as StdUdpSocket;
10use std::net::{SocketAddr, ToSocketAddrs};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::{Duration, Instant};
14
15pub struct UdpReadSocket {
16    udp_socket: MioUdpSocket,
17    read_token: Token,
18    read_future: Option<ReadyFuture<()>>,
19    pub read_timeout: Duration,
20}
21
22impl UdpReadSocket {
23    pub fn new(udp_socket: MioUdpSocket) -> Self {
24        UdpReadSocket {
25            udp_socket,
26            read_token: event_listener().next_token(),
27            read_future: None,
28            read_timeout: Duration::from_secs(20),
29        }
30    }
31
32    pub fn set_read_timeout(&mut self, duration: Duration) {
33        self.read_timeout = duration;
34    }
35
36    fn wait_read_data(&mut self) -> io::Result<()> {
37        let future = event_listener().listen_read(
38            &mut self.udp_socket,
39            Instant::now() + self.read_timeout,
40            self.read_token,
41        )?;
42        self.read_future = Some(future);
43        Ok(())
44    }
45
46    fn poll_read_attempt(
47        &mut self,
48        cx: &mut Context<'_>,
49        buf: &mut [u8],
50    ) -> Poll<io::Result<(usize, SocketAddr)>> {
51        let mut future = match self.read_future.take() {
52            None => {
53                match self.udp_socket.recv_from(buf) {
54                    Ok((size, addr)) => return Poll::Ready(Ok((size, addr))),
55                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
56                    Err(err) => return Poll::Ready(Err(err)),
57                }
58                if let Err(err) = self.wait_read_data() {
59                    return Poll::Ready(Err(err));
60                }
61                self.read_future.take().unwrap()
62            }
63            Some(future) => future,
64        };
65        match future.poll_unpin(cx) {
66            Poll::Pending => {
67                self.read_future = Some(future);
68                Poll::Pending
69            }
70            Poll::Ready(ReadyFutureResult::Timeout) => {
71                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
72            }
73            Poll::Ready(_) => match self.udp_socket.recv_from(buf) {
74                Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
75                Err(err) => Poll::Ready(Err(err)),
76            },
77        }
78    }
79
80    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
81        self.udp_socket.recv_from(buf)
82    }
83
84    pub fn local_addr(&self) -> io::Result<SocketAddr> {
85        self.udp_socket.local_addr()
86    }
87}
88
89impl AsyncRead for UdpReadSocket {
90    fn poll_read(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &mut [u8],
94    ) -> Poll<io::Result<usize>> {
95        let me = self.get_mut();
96        match me.poll_read_attempt(cx, buf) {
97            Poll::Pending => Poll::Pending,
98            Poll::Ready(Ok((size, _))) => Poll::Ready(Ok(size)),
99            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
100        }
101    }
102}
103
104impl Drop for UdpReadSocket {
105    fn drop(&mut self) {
106        event_listener()
107            .stop_listening(&mut self.udp_socket, self.read_token)
108            .ok();
109    }
110}
111
112pub struct UdpWriteSocket {
113    udp_socket: MioUdpSocket,
114    write_token: Token,
115    write_future: Option<ReadyFuture<()>>,
116    pub write_timeout: Duration,
117}
118
119impl UdpWriteSocket {
120    pub fn new(udp_socket: MioUdpSocket) -> Self {
121        UdpWriteSocket {
122            udp_socket,
123            write_token: event_listener().next_token(),
124            write_future: None,
125            write_timeout: Duration::from_secs(2),
126        }
127    }
128
129    pub fn set_write_timeout(&mut self, duration: Duration) {
130        self.write_timeout = duration;
131    }
132
133    fn wait_write_ready(&mut self) -> io::Result<()> {
134        let future = event_listener().listen_write(
135            &mut self.udp_socket,
136            Instant::now() + self.write_timeout,
137            self.write_token,
138        )?;
139        self.write_future = Some(future);
140        Ok(())
141    }
142
143    fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
144        let mut future = match self.write_future.take() {
145            None => {
146                match self.udp_socket.send(buf) {
147                    Ok(size) => return Poll::Ready(Ok(size)),
148                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
149                    Err(err) => return Poll::Ready(Err(err)),
150                }
151
152                if let Err(err) = self.wait_write_ready() {
153                    return Poll::Ready(Err(err));
154                }
155                self.write_future.take().unwrap()
156            }
157            Some(future) => future,
158        };
159        match future.poll_unpin(cx) {
160            Poll::Pending => {
161                self.write_future = Some(future);
162                Poll::Pending
163            }
164            Poll::Ready(ReadyFutureResult::Timeout) => {
165                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
166            }
167            Poll::Ready(_) => match self.udp_socket.send(buf) {
168                Ok(size) => Poll::Ready(Ok(size)),
169                Err(err) => Poll::Ready(Err(err)),
170            },
171        }
172    }
173
174    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
175        self.udp_socket.send_to(buf, target)
176    }
177
178    pub fn local_addr(&self) -> io::Result<SocketAddr> {
179        self.udp_socket.local_addr()
180    }
181}
182
183impl AsyncWrite for UdpWriteSocket {
184    fn poll_write(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &[u8],
188    ) -> Poll<io::Result<usize>> {
189        let me = self.get_mut();
190        me.poll_write_attempt(cx, buf)
191    }
192
193    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194        Poll::Ready(Ok(()))
195    }
196
197    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198        Poll::Ready(Ok(()))
199    }
200}
201
202impl Drop for UdpWriteSocket {
203    fn drop(&mut self) {
204        event_listener()
205            .stop_listening(&mut self.udp_socket, self.write_token)
206            .ok();
207    }
208}
209
210pub struct UdpSocket {
211    read_socket: UdpReadSocket,
212    write_socket: UdpWriteSocket,
213}
214
215impl UdpSocket {
216    pub fn from(udp_socket: StdUdpSocket) -> io::Result<UdpSocket> {
217        udp_socket.set_nonblocking(true)?;
218        Ok(UdpSocket {
219            read_socket: UdpReadSocket::new(MioUdpSocket::from_std(udp_socket.try_clone()?)),
220            write_socket: UdpWriteSocket::new(MioUdpSocket::from_std(udp_socket)),
221        })
222    }
223
224    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
225        Self::from(StdUdpSocket::bind(addr)?)
226    }
227
228    pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
229        for addr in addr.to_socket_addrs()? {
230            self.read_socket.udp_socket.connect(addr)?;
231            self.write_socket.udp_socket.connect(addr)?;
232            break;
233        }
234        Ok(())
235    }
236
237    pub fn bind_and_connect<A: ToSocketAddrs, B: ToSocketAddrs>(
238        addr: A,
239        to_addr: B,
240    ) -> io::Result<UdpSocket> {
241        let result = Self::bind(addr)?;
242        result.connect(to_addr)?;
243        Ok(result)
244    }
245
246    pub fn read_socket(&self) -> &UdpReadSocket {
247        &self.read_socket
248    }
249
250    pub fn read_socket_mut(&mut self) -> &mut UdpReadSocket {
251        &mut self.read_socket
252    }
253
254    pub fn write_socket(&self) -> &UdpWriteSocket {
255        &self.write_socket
256    }
257
258    pub fn write_socket_mut(&mut self) -> &mut UdpWriteSocket {
259        &mut self.write_socket
260    }
261
262    pub fn set_read_timeout(&mut self, duration: Duration) {
263        self.read_socket.set_read_timeout(duration);
264    }
265
266    pub fn set_write_timeout(&mut self, duration: Duration) {
267        self.write_socket.set_write_timeout(duration);
268    }
269
270    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
271        self.write_socket.send_to(buf, target)
272    }
273
274    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
275        self.read_socket.recv_from(buf)
276    }
277
278    pub fn local_addr(&self) -> io::Result<SocketAddr> {
279        self.read_socket.local_addr()
280    }
281
282    pub fn split(self) -> (UdpReadSocket, UdpWriteSocket) {
283        (self.read_socket, self.write_socket)
284    }
285}
286
287impl AsyncRead for UdpSocket {
288    fn poll_read(
289        self: Pin<&mut Self>,
290        cx: &mut Context<'_>,
291        buf: &mut [u8],
292    ) -> Poll<io::Result<usize>> {
293        let me = self.get_mut();
294        Pin::new(&mut me.read_socket).poll_read(cx, buf)
295    }
296}
297
298impl AsyncWrite for UdpSocket {
299    fn poll_write(
300        self: Pin<&mut Self>,
301        cx: &mut Context<'_>,
302        buf: &[u8],
303    ) -> Poll<io::Result<usize>> {
304        let me = self.get_mut();
305        Pin::new(&mut me.write_socket).poll_write(cx, buf)
306    }
307
308    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
309        let me = self.get_mut();
310        Pin::new(&mut me.write_socket).poll_flush(cx)
311    }
312
313    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314        let me = self.get_mut();
315        Pin::new(&mut me.write_socket).poll_close(cx)
316    }
317}
318
319impl Debug for UdpSocket {
320    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
321        write!(f, "{:?}", self.read_socket.udp_socket)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::timer::timer::Timer;
329    use futures::executor::block_on;
330    use std::sync::{Arc, Mutex};
331    use std::thread;
332    use std::time::Duration;
333
334    fn setup_test_sockets() -> (StdUdpSocket, StdUdpSocket) {
335        let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
336        let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
337        (server, client)
338    }
339
340    #[test]
341    fn test_udp_wrapper_creation() {
342        let socket = StdUdpSocket::bind("127.0.0.1:0").unwrap();
343        let addr = socket.local_addr().unwrap();
344
345        let wrapper = UdpSocket::from(socket);
346        assert!(wrapper.is_ok());
347
348        let wrapper = wrapper.unwrap();
349        assert_eq!(wrapper.local_addr().unwrap(), addr);
350    }
351
352    #[test]
353    fn test_udp_wrapper_bind() {
354        let wrapper = UdpSocket::bind("127.0.0.1:0");
355        assert!(wrapper.is_ok());
356
357        let wrapper = wrapper.unwrap();
358        let addr = wrapper.local_addr().unwrap();
359        assert!(addr.port() > 0);
360        assert_eq!(addr.ip().to_string(), "127.0.0.1");
361    }
362
363    #[test]
364    fn test_udp_wrapper_bind_and_connect() {
365        let (server, _) = setup_test_sockets();
366        let server_addr = server.local_addr().unwrap();
367
368        let wrapper = UdpSocket::bind_and_connect("127.0.0.1:0", server_addr);
369        assert!(wrapper.is_ok());
370    }
371
372    #[test]
373    fn test_timeout_setters() {
374        let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
375        let mut wrapper = wrapper;
376
377        wrapper.set_read_timeout(Duration::from_secs(30));
378        wrapper.set_write_timeout(Duration::from_secs(5));
379
380        assert_eq!(wrapper.read_socket().read_timeout, Duration::from_secs(30));
381        assert_eq!(wrapper.write_socket().write_timeout, Duration::from_secs(5));
382    }
383
384    #[test]
385    fn test_socket_accessors() {
386        let mut wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
387
388        let read_socket = wrapper.read_socket();
389        assert_eq!(read_socket.read_timeout, Duration::from_secs(20));
390
391        let read_socket_mut = wrapper.read_socket_mut();
392        read_socket_mut.set_read_timeout(Duration::from_secs(15));
393        assert_eq!(read_socket_mut.read_timeout, Duration::from_secs(15));
394
395        let write_socket = wrapper.write_socket();
396        assert_eq!(write_socket.write_timeout, Duration::from_secs(2));
397
398        let write_socket_mut = wrapper.write_socket_mut();
399        write_socket_mut.set_write_timeout(Duration::from_secs(10));
400        assert_eq!(write_socket_mut.write_timeout, Duration::from_secs(10));
401    }
402
403    #[test]
404    fn test_sync_send_recv() {
405        let (server, client) = setup_test_sockets();
406        let server_addr = server.local_addr().unwrap();
407        let client_addr = client.local_addr().unwrap();
408
409        let server_wrapper = UdpSocket::from(server).unwrap();
410        let client_wrapper = UdpSocket::from(client).unwrap();
411
412        let test_data = b"Hello UDP!";
413        let sent = client_wrapper.send_to(test_data, server_addr);
414        assert!(sent.is_ok());
415        assert_eq!(sent.unwrap(), test_data.len());
416
417        thread::sleep(Duration::from_millis(10));
418
419        let mut buf = [0u8; 1024];
420        let received = server_wrapper.recv_from(&mut buf);
421        assert!(received.is_ok());
422        let (size, addr) = received.unwrap();
423        assert_eq!(size, test_data.len());
424        assert_eq!(&buf[..size], test_data);
425        assert_eq!(addr, client_addr);
426    }
427
428    #[test]
429    fn test_async_read_write() {
430        let (server, client) = setup_test_sockets();
431        let server_addr = server.local_addr().unwrap();
432
433        thread::spawn(move || {
434            let mut buf = [0u8; 1024];
435            if let Ok((size, addr)) = server.recv_from(&mut buf) {
436                let _ = server.send_to(&buf[..size], addr);
437            }
438        });
439
440        thread::sleep(Duration::from_millis(10));
441
442        let test_future = async {
443            let wrapper = UdpSocket::from(client).unwrap();
444
445            let test_data = b"Async UDP test!";
446            let sent = wrapper.send_to(test_data, server_addr);
447            assert!(sent.is_ok());
448
449            let mut buf = [0u8; 1024];
450            let read_result = wrapper.recv_from(&mut buf);
451            if let Ok((size, addr)) = read_result {
452                assert_eq!(size, test_data.len());
453                assert_eq!(&buf[..size], test_data);
454                assert_eq!(addr, server_addr);
455            }
456        };
457
458        block_on(test_future);
459    }
460
461    #[test]
462    fn test_async_with_timer() {
463        let mut timer = Timer::new();
464        let (server, client) = setup_test_sockets();
465        let server_addr = server.local_addr().unwrap();
466
467        thread::spawn(move || {
468            let mut buf = [0u8; 1024];
469            if let Ok((size, addr)) = server.recv_from(&mut buf) {
470                thread::sleep(Duration::from_millis(50));
471                let _ = server.send_to(&buf[..size], addr);
472            }
473        });
474
475        let test_future = async {
476            let wrapper = UdpSocket::from(client).unwrap();
477            timer.wait(Duration::from_millis(20)).await;
478            let test_data = b"Delayed UDP!";
479            let sent = wrapper.send_to(test_data, server_addr);
480            assert!(sent.is_ok());
481        };
482
483        block_on(test_future);
484    }
485
486    #[test]
487    fn test_concurrent_operations() {
488        let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
489        let server_addr = server.local_addr().unwrap();
490        let response_count = Arc::new(Mutex::new(0));
491        let response_count_clone = response_count.clone();
492
493        thread::spawn(move || {
494            let mut buf = [0u8; 1024];
495            for _ in 0..3 {
496                if let Ok((size, addr)) = server.recv_from(&mut buf) {
497                    let _ = server.send_to(&buf[..size], addr);
498                    let mut count = response_count_clone.lock().unwrap();
499                    *count += 1;
500                }
501            }
502        });
503
504        thread::sleep(Duration::from_millis(10));
505
506        let test_future = async {
507            let mut futures = Vec::new();
508
509            for i in 0..3 {
510                let test_data = format!("Message {}", i);
511                let future = async move {
512                    let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
513                    let wrapper = UdpSocket::from(client).unwrap();
514
515                    let sent = wrapper.send_to(test_data.as_bytes(), server_addr);
516                    assert!(sent.is_ok());
517                };
518                futures.push(future);
519            }
520
521            futures::future::join_all(futures).await;
522        };
523
524        block_on(test_future);
525        thread::sleep(Duration::from_millis(100));
526        let count = response_count.lock().unwrap();
527        assert_eq!(*count, 3);
528    }
529
530    #[test]
531    fn test_timeout_behavior() {
532        let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
533        let mut wrapper = wrapper;
534
535        wrapper.set_read_timeout(Duration::from_millis(50));
536
537        let test_future = async {
538            let mut buf = [0u8; 1024];
539
540            let result = wrapper.recv_from(&mut buf);
541            match result {
542                Ok(_) => {
543                    panic!("Unexpected data received");
544                }
545                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
546                    // This is expected for non-blocking UDP with no data
547                }
548                Err(e) => {
549                    panic!("Unexpected error: {:?}", e);
550                }
551            }
552        };
553
554        block_on(test_future);
555    }
556
557    #[test]
558    fn test_multiple_sends_to_different_addresses() {
559        let server1 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
560        let server2 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
561        let server1_addr = server1.local_addr().unwrap();
562        let server2_addr = server2.local_addr().unwrap();
563
564        let (_, client) = setup_test_sockets();
565        let wrapper = UdpSocket::from(client).unwrap();
566
567        let data1 = b"Hello Server 1";
568        let sent1 = wrapper.send_to(data1, server1_addr);
569        assert!(sent1.is_ok());
570        assert_eq!(sent1.unwrap(), data1.len());
571
572        let data2 = b"Hello Server 2";
573        let sent2 = wrapper.send_to(data2, server2_addr);
574        assert!(sent2.is_ok());
575        assert_eq!(sent2.unwrap(), data2.len());
576
577        thread::sleep(Duration::from_millis(10));
578
579        let mut buf1 = [0u8; 1024];
580        let received1 = server1.recv_from(&mut buf1);
581        assert!(received1.is_ok());
582        let (size1, _) = received1.unwrap();
583        assert_eq!(&buf1[..size1], data1);
584
585        let mut buf2 = [0u8; 1024];
586        let received2 = server2.recv_from(&mut buf2);
587        assert!(received2.is_ok());
588        let (size2, _) = received2.unwrap();
589        assert_eq!(&buf2[..size2], data2);
590    }
591
592    #[test]
593    fn test_large_data_transmission() {
594        let (server, client) = setup_test_sockets();
595        let server_addr = server.local_addr().unwrap();
596
597        thread::spawn(move || {
598            let mut buf = [0u8; 2048];
599            if let Ok((size, addr)) = server.recv_from(&mut buf) {
600                let _ = server.send_to(&buf[..size], addr);
601            }
602        });
603
604        thread::sleep(Duration::from_millis(10));
605
606        let wrapper = UdpSocket::from(client).unwrap();
607
608        let large_data = vec![0xAB; 1400];
609        let sent = wrapper.send_to(&large_data, server_addr);
610        assert!(sent.is_ok());
611        assert_eq!(sent.unwrap(), large_data.len());
612
613        thread::sleep(Duration::from_millis(20));
614
615        let mut buf = [0u8; 2048];
616        let received = wrapper.recv_from(&mut buf);
617        assert!(received.is_ok());
618        let (size, addr) = received.unwrap();
619        assert_eq!(size, large_data.len());
620        assert_eq!(&buf[..size], &large_data[..]);
621        assert_eq!(addr, server_addr);
622    }
623
624    #[test]
625    fn test_drop_behavior() {
626        let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
627        let addr = wrapper.local_addr().unwrap();
628        drop(wrapper);
629
630        thread::sleep(Duration::from_millis(10));
631        let new_wrapper = UdpSocket::bind(addr);
632        assert!(new_wrapper.is_ok());
633    }
634
635    #[test]
636    fn test_split_sockets_independently() {
637        let (server, client) = setup_test_sockets();
638        let server_addr = server.local_addr().unwrap();
639
640        thread::spawn(move || {
641            let mut buf = [0u8; 1024];
642            if let Ok((size, addr)) = server.recv_from(&mut buf) {
643                let _ = server.send_to(&buf[..size], addr);
644            }
645        });
646
647        thread::sleep(Duration::from_millis(10));
648
649        let test_future = async {
650            let wrapper = UdpSocket::from(client).unwrap();
651            let (read_socket, write_socket) = wrapper.split();
652
653            // Test that split sockets work independently
654            let test_data = b"Split socket test";
655            write_socket.send_to(test_data, server_addr).unwrap();
656
657            let mut buf = [0u8; 1024];
658            let received = read_socket.recv_from(&mut buf);
659
660            match received {
661                Ok((size, addr)) => {
662                    assert_eq!(&buf[..size], test_data);
663                    assert_eq!(addr, server_addr);
664                }
665                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
666                Err(e) => panic!("Unexpected error: {:?}", e),
667            }
668        };
669
670        block_on(test_future);
671    }
672
673    #[test]
674    fn test_connected_socket_operations() {
675        let (server, client) = setup_test_sockets();
676        let server_addr = server.local_addr().unwrap();
677
678        let wrapper = UdpSocket::from(client).unwrap();
679        let test_data = b"Connected test";
680        let result = wrapper.send_to(test_data, server_addr);
681        assert!(result.is_ok());
682    }
683}