async_foundation/net/
udp_socket.rs

1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::UdpSocket as MioUdpSocket;
7use std::fmt::{Debug, Error, Formatter};
8use std::io;
9use std::net::UdpSocket as StdUdpSocket;
10use std::net::{SocketAddr, ToSocketAddrs};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::{Duration, Instant};
14
15pub struct UdpReadSocket {
16    udp_socket: MioUdpSocket,
17    read_token: Token,
18    read_future: Option<ReadyFuture<()>>,
19    pub read_timeout: Duration,
20}
21
22impl UdpReadSocket {
23    pub fn new(udp_socket: MioUdpSocket) -> Self {
24        UdpReadSocket {
25            udp_socket,
26            read_token: event_listener().next_token(),
27            read_future: None,
28            read_timeout: Duration::from_secs(20),
29        }
30    }
31
32    pub fn set_read_timeout(&mut self, duration: Duration) {
33        self.read_timeout = duration;
34    }
35
36    fn wait_read_data(&mut self) -> io::Result<()> {
37        let future = event_listener().listen_read(
38            &mut self.udp_socket,
39            Instant::now() + self.read_timeout,
40            self.read_token,
41        )?;
42        self.read_future = Some(future);
43        Ok(())
44    }
45
46    fn poll_read_attempt(
47        &mut self,
48        cx: &mut Context<'_>,
49        buf: &mut [u8],
50    ) -> Poll<io::Result<(usize, SocketAddr)>> {
51        let mut future = match self.read_future.take() {
52            None => {
53                match self.udp_socket.recv_from(buf) {
54                    Ok((size, addr)) => return Poll::Ready(Ok((size, addr))),
55                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
56                    Err(err) => return Poll::Ready(Err(err)),
57                }
58                if let Err(err) = self.wait_read_data() {
59                    return Poll::Ready(Err(err));
60                }
61                self.read_future.take().unwrap()
62            }
63            Some(future) => future,
64        };
65        match future.poll_unpin(cx) {
66            Poll::Pending => {
67                self.read_future = Some(future);
68                Poll::Pending
69            }
70            Poll::Ready(ReadyFutureResult::Timeout) => {
71                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
72            }
73            Poll::Ready(_) => match self.udp_socket.recv_from(buf) {
74                Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
75                Err(err) => Poll::Ready(Err(err)),
76            },
77        }
78    }
79
80    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
81        self.udp_socket.recv_from(buf)
82    }
83
84    pub fn local_addr(&self) -> io::Result<SocketAddr> {
85        self.udp_socket.local_addr()
86    }
87}
88
89impl AsyncRead for UdpReadSocket {
90    fn poll_read(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &mut [u8],
94    ) -> Poll<io::Result<usize>> {
95        let me = self.get_mut();
96        match me.poll_read_attempt(cx, buf) {
97            Poll::Pending => Poll::Pending,
98            Poll::Ready(Ok((size, _))) => Poll::Ready(Ok(size)),
99            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
100        }
101    }
102}
103
104impl Drop for UdpReadSocket {
105    fn drop(&mut self) {
106        event_listener()
107            .stop_listening(&mut self.udp_socket, self.read_token)
108            .ok();
109    }
110}
111
112pub struct UdpWriteSocket {
113    udp_socket: MioUdpSocket,
114    write_token: Token,
115    write_future: Option<ReadyFuture<()>>,
116    pub write_timeout: Duration,
117}
118
119impl UdpWriteSocket {
120    pub fn new(udp_socket: MioUdpSocket) -> Self {
121        UdpWriteSocket {
122            udp_socket,
123            write_token: event_listener().next_token(),
124            write_future: None,
125            write_timeout: Duration::from_secs(2),
126        }
127    }
128
129    pub fn set_write_timeout(&mut self, duration: Duration) {
130        self.write_timeout = duration;
131    }
132
133    fn wait_write_ready(&mut self) -> io::Result<()> {
134        let future = event_listener().listen_write(
135            &mut self.udp_socket,
136            Instant::now() + self.write_timeout,
137            self.write_token,
138        )?;
139        self.write_future = Some(future);
140        Ok(())
141    }
142
143    fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
144        let mut future = match self.write_future.take() {
145            None => {
146                match self.udp_socket.send(buf) {
147                    Ok(size) => return Poll::Ready(Ok(size)),
148                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
149                    Err(err) => return Poll::Ready(Err(err)),
150                }
151
152                if let Err(err) = self.wait_write_ready() {
153                    return Poll::Ready(Err(err));
154                }
155                self.write_future.take().unwrap()
156            }
157            Some(future) => future,
158        };
159        match future.poll_unpin(cx) {
160            Poll::Pending => {
161                self.write_future = Some(future);
162                Poll::Pending
163            }
164            Poll::Ready(ReadyFutureResult::Timeout) => {
165                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
166            }
167            Poll::Ready(_) => match self.udp_socket.send(buf) {
168                Ok(size) => Poll::Ready(Ok(size)),
169                Err(err) => Poll::Ready(Err(err)),
170            },
171        }
172    }
173
174    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
175        self.udp_socket.send_to(buf, target)
176    }
177
178    pub fn local_addr(&self) -> io::Result<SocketAddr> {
179        self.udp_socket.local_addr()
180    }
181}
182
183impl AsyncWrite for UdpWriteSocket {
184    fn poll_write(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &[u8],
188    ) -> Poll<io::Result<usize>> {
189        let me = self.get_mut();
190        me.poll_write_attempt(cx, buf)
191    }
192
193    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194        Poll::Ready(Ok(()))
195    }
196
197    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198        Poll::Ready(Ok(()))
199    }
200}
201
202impl Drop for UdpWriteSocket {
203    fn drop(&mut self) {
204        event_listener()
205            .stop_listening(&mut self.udp_socket, self.write_token)
206            .ok();
207    }
208}
209
210/// Async-friendly UDP socket wrapper.
211///
212/// `UdpSocket` owns separate read/write halves that integrate with the shared
213/// event listener for readiness notifications and implements both [`AsyncRead`]
214/// and [`AsyncWrite`] for stream-like usage when needed.
215pub struct UdpSocket {
216    read_socket: UdpReadSocket,
217    write_socket: UdpWriteSocket,
218}
219
220impl UdpSocket {
221    pub fn from(udp_socket: StdUdpSocket) -> io::Result<UdpSocket> {
222        udp_socket.set_nonblocking(true)?;
223        Ok(UdpSocket {
224            read_socket: UdpReadSocket::new(MioUdpSocket::from_std(udp_socket.try_clone()?)),
225            write_socket: UdpWriteSocket::new(MioUdpSocket::from_std(udp_socket)),
226        })
227    }
228
229    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
230        Self::from(StdUdpSocket::bind(addr)?)
231    }
232
233    pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
234        for addr in addr.to_socket_addrs()? {
235            self.read_socket.udp_socket.connect(addr)?;
236            self.write_socket.udp_socket.connect(addr)?;
237            break;
238        }
239        Ok(())
240    }
241
242    pub fn bind_and_connect<A: ToSocketAddrs, B: ToSocketAddrs>(
243        addr: A,
244        to_addr: B,
245    ) -> io::Result<UdpSocket> {
246        let result = Self::bind(addr)?;
247        result.connect(to_addr)?;
248        Ok(result)
249    }
250
251    pub fn read_socket(&self) -> &UdpReadSocket {
252        &self.read_socket
253    }
254
255    pub fn read_socket_mut(&mut self) -> &mut UdpReadSocket {
256        &mut self.read_socket
257    }
258
259    pub fn write_socket(&self) -> &UdpWriteSocket {
260        &self.write_socket
261    }
262
263    pub fn write_socket_mut(&mut self) -> &mut UdpWriteSocket {
264        &mut self.write_socket
265    }
266
267    pub fn set_read_timeout(&mut self, duration: Duration) {
268        self.read_socket.set_read_timeout(duration);
269    }
270
271    pub fn set_write_timeout(&mut self, duration: Duration) {
272        self.write_socket.set_write_timeout(duration);
273    }
274
275    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
276        self.write_socket.send_to(buf, target)
277    }
278
279    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
280        self.read_socket.recv_from(buf)
281    }
282
283    pub fn local_addr(&self) -> io::Result<SocketAddr> {
284        self.read_socket.local_addr()
285    }
286
287    pub fn split(self) -> (UdpReadSocket, UdpWriteSocket) {
288        (self.read_socket, self.write_socket)
289    }
290}
291
292impl AsyncRead for UdpSocket {
293    fn poll_read(
294        self: Pin<&mut Self>,
295        cx: &mut Context<'_>,
296        buf: &mut [u8],
297    ) -> Poll<io::Result<usize>> {
298        let me = self.get_mut();
299        Pin::new(&mut me.read_socket).poll_read(cx, buf)
300    }
301}
302
303impl AsyncWrite for UdpSocket {
304    fn poll_write(
305        self: Pin<&mut Self>,
306        cx: &mut Context<'_>,
307        buf: &[u8],
308    ) -> Poll<io::Result<usize>> {
309        let me = self.get_mut();
310        Pin::new(&mut me.write_socket).poll_write(cx, buf)
311    }
312
313    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314        let me = self.get_mut();
315        Pin::new(&mut me.write_socket).poll_flush(cx)
316    }
317
318    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
319        let me = self.get_mut();
320        Pin::new(&mut me.write_socket).poll_close(cx)
321    }
322}
323
324impl Debug for UdpSocket {
325    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
326        write!(f, "{:?}", self.read_socket.udp_socket)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::timer::timer::Timer;
334    use futures::executor::block_on;
335    use std::sync::{Arc, Mutex};
336    use std::thread;
337    use std::time::Duration;
338
339    fn setup_test_sockets() -> (StdUdpSocket, StdUdpSocket) {
340        let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
341        let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
342        (server, client)
343    }
344
345    #[test]
346    fn test_udp_wrapper_creation() {
347        let socket = StdUdpSocket::bind("127.0.0.1:0").unwrap();
348        let addr = socket.local_addr().unwrap();
349
350        let wrapper = UdpSocket::from(socket);
351        assert!(wrapper.is_ok());
352
353        let wrapper = wrapper.unwrap();
354        assert_eq!(wrapper.local_addr().unwrap(), addr);
355    }
356
357    #[test]
358    fn test_udp_wrapper_bind() {
359        let wrapper = UdpSocket::bind("127.0.0.1:0");
360        assert!(wrapper.is_ok());
361
362        let wrapper = wrapper.unwrap();
363        let addr = wrapper.local_addr().unwrap();
364        assert!(addr.port() > 0);
365        assert_eq!(addr.ip().to_string(), "127.0.0.1");
366    }
367
368    #[test]
369    fn test_udp_wrapper_bind_and_connect() {
370        let (server, _) = setup_test_sockets();
371        let server_addr = server.local_addr().unwrap();
372
373        let wrapper = UdpSocket::bind_and_connect("127.0.0.1:0", server_addr);
374        assert!(wrapper.is_ok());
375    }
376
377    #[test]
378    fn test_timeout_setters() {
379        let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
380        let mut wrapper = wrapper;
381
382        wrapper.set_read_timeout(Duration::from_secs(30));
383        wrapper.set_write_timeout(Duration::from_secs(5));
384
385        assert_eq!(wrapper.read_socket().read_timeout, Duration::from_secs(30));
386        assert_eq!(wrapper.write_socket().write_timeout, Duration::from_secs(5));
387    }
388
389    #[test]
390    fn test_socket_accessors() {
391        let mut wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
392
393        let read_socket = wrapper.read_socket();
394        assert_eq!(read_socket.read_timeout, Duration::from_secs(20));
395
396        let read_socket_mut = wrapper.read_socket_mut();
397        read_socket_mut.set_read_timeout(Duration::from_secs(15));
398        assert_eq!(read_socket_mut.read_timeout, Duration::from_secs(15));
399
400        let write_socket = wrapper.write_socket();
401        assert_eq!(write_socket.write_timeout, Duration::from_secs(2));
402
403        let write_socket_mut = wrapper.write_socket_mut();
404        write_socket_mut.set_write_timeout(Duration::from_secs(10));
405        assert_eq!(write_socket_mut.write_timeout, Duration::from_secs(10));
406    }
407
408    #[test]
409    fn test_sync_send_recv() {
410        let (server, client) = setup_test_sockets();
411        let server_addr = server.local_addr().unwrap();
412        let client_addr = client.local_addr().unwrap();
413
414        let server_wrapper = UdpSocket::from(server).unwrap();
415        let client_wrapper = UdpSocket::from(client).unwrap();
416
417        let test_data = b"Hello UDP!";
418        let sent = client_wrapper.send_to(test_data, server_addr);
419        assert!(sent.is_ok());
420        assert_eq!(sent.unwrap(), test_data.len());
421
422        thread::sleep(Duration::from_millis(10));
423
424        let mut buf = [0u8; 1024];
425        let received = server_wrapper.recv_from(&mut buf);
426        assert!(received.is_ok());
427        let (size, addr) = received.unwrap();
428        assert_eq!(size, test_data.len());
429        assert_eq!(&buf[..size], test_data);
430        assert_eq!(addr, client_addr);
431    }
432
433    #[test]
434    fn test_async_read_write() {
435        let (server, client) = setup_test_sockets();
436        let server_addr = server.local_addr().unwrap();
437
438        thread::spawn(move || {
439            let mut buf = [0u8; 1024];
440            if let Ok((size, addr)) = server.recv_from(&mut buf) {
441                let _ = server.send_to(&buf[..size], addr);
442            }
443        });
444
445        thread::sleep(Duration::from_millis(10));
446
447        let test_future = async {
448            let wrapper = UdpSocket::from(client).unwrap();
449
450            let test_data = b"Async UDP test!";
451            let sent = wrapper.send_to(test_data, server_addr);
452            assert!(sent.is_ok());
453
454            let mut buf = [0u8; 1024];
455            let read_result = wrapper.recv_from(&mut buf);
456            if let Ok((size, addr)) = read_result {
457                assert_eq!(size, test_data.len());
458                assert_eq!(&buf[..size], test_data);
459                assert_eq!(addr, server_addr);
460            }
461        };
462
463        block_on(test_future);
464    }
465
466    #[test]
467    fn test_async_with_timer() {
468        let mut timer = Timer::new();
469        let (server, client) = setup_test_sockets();
470        let server_addr = server.local_addr().unwrap();
471
472        thread::spawn(move || {
473            let mut buf = [0u8; 1024];
474            if let Ok((size, addr)) = server.recv_from(&mut buf) {
475                thread::sleep(Duration::from_millis(50));
476                let _ = server.send_to(&buf[..size], addr);
477            }
478        });
479
480        let test_future = async {
481            let wrapper = UdpSocket::from(client).unwrap();
482            timer.wait(Duration::from_millis(20)).await;
483            let test_data = b"Delayed UDP!";
484            let sent = wrapper.send_to(test_data, server_addr);
485            assert!(sent.is_ok());
486        };
487
488        block_on(test_future);
489    }
490
491    #[test]
492    fn test_concurrent_operations() {
493        let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
494        let server_addr = server.local_addr().unwrap();
495        let response_count = Arc::new(Mutex::new(0));
496        let response_count_clone = response_count.clone();
497
498        thread::spawn(move || {
499            let mut buf = [0u8; 1024];
500            for _ in 0..3 {
501                if let Ok((size, addr)) = server.recv_from(&mut buf) {
502                    let _ = server.send_to(&buf[..size], addr);
503                    let mut count = response_count_clone.lock().unwrap();
504                    *count += 1;
505                }
506            }
507        });
508
509        thread::sleep(Duration::from_millis(10));
510
511        let test_future = async {
512            let mut futures = Vec::new();
513
514            for i in 0..3 {
515                let test_data = format!("Message {}", i);
516                let future = async move {
517                    let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
518                    let wrapper = UdpSocket::from(client).unwrap();
519
520                    let sent = wrapper.send_to(test_data.as_bytes(), server_addr);
521                    assert!(sent.is_ok());
522                };
523                futures.push(future);
524            }
525
526            futures::future::join_all(futures).await;
527        };
528
529        block_on(test_future);
530        thread::sleep(Duration::from_millis(100));
531        let count = response_count.lock().unwrap();
532        assert_eq!(*count, 3);
533    }
534
535    #[test]
536    fn test_timeout_behavior() {
537        let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
538        let mut wrapper = wrapper;
539
540        wrapper.set_read_timeout(Duration::from_millis(50));
541
542        let test_future = async {
543            let mut buf = [0u8; 1024];
544
545            let result = wrapper.recv_from(&mut buf);
546            match result {
547                Ok(_) => {
548                    panic!("Unexpected data received");
549                }
550                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
551                    // This is expected for non-blocking UDP with no data
552                }
553                Err(e) => {
554                    panic!("Unexpected error: {:?}", e);
555                }
556            }
557        };
558
559        block_on(test_future);
560    }
561
562    #[test]
563    fn test_multiple_sends_to_different_addresses() {
564        let server1 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
565        let server2 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
566        let server1_addr = server1.local_addr().unwrap();
567        let server2_addr = server2.local_addr().unwrap();
568
569        let (_, client) = setup_test_sockets();
570        let wrapper = UdpSocket::from(client).unwrap();
571
572        let data1 = b"Hello Server 1";
573        let sent1 = wrapper.send_to(data1, server1_addr);
574        assert!(sent1.is_ok());
575        assert_eq!(sent1.unwrap(), data1.len());
576
577        let data2 = b"Hello Server 2";
578        let sent2 = wrapper.send_to(data2, server2_addr);
579        assert!(sent2.is_ok());
580        assert_eq!(sent2.unwrap(), data2.len());
581
582        thread::sleep(Duration::from_millis(10));
583
584        let mut buf1 = [0u8; 1024];
585        let received1 = server1.recv_from(&mut buf1);
586        assert!(received1.is_ok());
587        let (size1, _) = received1.unwrap();
588        assert_eq!(&buf1[..size1], data1);
589
590        let mut buf2 = [0u8; 1024];
591        let received2 = server2.recv_from(&mut buf2);
592        assert!(received2.is_ok());
593        let (size2, _) = received2.unwrap();
594        assert_eq!(&buf2[..size2], data2);
595    }
596
597    #[test]
598    fn test_large_data_transmission() {
599        let (server, client) = setup_test_sockets();
600        let server_addr = server.local_addr().unwrap();
601
602        thread::spawn(move || {
603            let mut buf = [0u8; 2048];
604            if let Ok((size, addr)) = server.recv_from(&mut buf) {
605                let _ = server.send_to(&buf[..size], addr);
606            }
607        });
608
609        thread::sleep(Duration::from_millis(10));
610
611        let wrapper = UdpSocket::from(client).unwrap();
612
613        let large_data = vec![0xAB; 1400];
614        let sent = wrapper.send_to(&large_data, server_addr);
615        assert!(sent.is_ok());
616        assert_eq!(sent.unwrap(), large_data.len());
617
618        thread::sleep(Duration::from_millis(20));
619
620        let mut buf = [0u8; 2048];
621        let received = wrapper.recv_from(&mut buf);
622        assert!(received.is_ok());
623        let (size, addr) = received.unwrap();
624        assert_eq!(size, large_data.len());
625        assert_eq!(&buf[..size], &large_data[..]);
626        assert_eq!(addr, server_addr);
627    }
628
629    #[test]
630    fn test_drop_behavior() {
631        let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
632        let addr = wrapper.local_addr().unwrap();
633        drop(wrapper);
634
635        thread::sleep(Duration::from_millis(10));
636        let new_wrapper = UdpSocket::bind(addr);
637        assert!(new_wrapper.is_ok());
638    }
639
640    #[test]
641    fn test_split_sockets_independently() {
642        let (server, client) = setup_test_sockets();
643        let server_addr = server.local_addr().unwrap();
644
645        thread::spawn(move || {
646            let mut buf = [0u8; 1024];
647            if let Ok((size, addr)) = server.recv_from(&mut buf) {
648                let _ = server.send_to(&buf[..size], addr);
649            }
650        });
651
652        thread::sleep(Duration::from_millis(10));
653
654        let test_future = async {
655            let wrapper = UdpSocket::from(client).unwrap();
656            let (read_socket, write_socket) = wrapper.split();
657
658            // Test that split sockets work independently
659            let test_data = b"Split socket test";
660            write_socket.send_to(test_data, server_addr).unwrap();
661
662            let mut buf = [0u8; 1024];
663            let received = read_socket.recv_from(&mut buf);
664
665            match received {
666                Ok((size, addr)) => {
667                    assert_eq!(&buf[..size], test_data);
668                    assert_eq!(addr, server_addr);
669                }
670                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
671                Err(e) => panic!("Unexpected error: {:?}", e),
672            }
673        };
674
675        block_on(test_future);
676    }
677
678    #[test]
679    fn test_connected_socket_operations() {
680        let (server, client) = setup_test_sockets();
681        let server_addr = server.local_addr().unwrap();
682
683        let wrapper = UdpSocket::from(client).unwrap();
684        let test_data = b"Connected test";
685        let result = wrapper.send_to(test_data, server_addr);
686        assert!(result.is_ok());
687    }
688}