1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::UdpSocket as MioUdpSocket;
7use std::fmt::{Debug, Error, Formatter};
8use std::io;
9use std::net::UdpSocket as StdUdpSocket;
10use std::net::{SocketAddr, ToSocketAddrs};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::{Duration, Instant};
14
15pub struct UdpReadSocket {
16 udp_socket: MioUdpSocket,
17 read_token: Token,
18 read_future: Option<ReadyFuture<()>>,
19 pub read_timeout: Duration,
20}
21
22impl UdpReadSocket {
23 pub fn new(udp_socket: MioUdpSocket) -> Self {
24 UdpReadSocket {
25 udp_socket,
26 read_token: event_listener().next_token(),
27 read_future: None,
28 read_timeout: Duration::from_secs(20),
29 }
30 }
31
32 pub fn set_read_timeout(&mut self, duration: Duration) {
33 self.read_timeout = duration;
34 }
35
36 fn wait_read_data(&mut self) -> io::Result<()> {
37 let future = event_listener().listen_read(
38 &mut self.udp_socket,
39 Instant::now() + self.read_timeout,
40 self.read_token,
41 )?;
42 self.read_future = Some(future);
43 Ok(())
44 }
45
46 fn poll_read_attempt(
47 &mut self,
48 cx: &mut Context<'_>,
49 buf: &mut [u8],
50 ) -> Poll<io::Result<(usize, SocketAddr)>> {
51 let mut future = match self.read_future.take() {
52 None => {
53 match self.udp_socket.recv_from(buf) {
54 Ok((size, addr)) => return Poll::Ready(Ok((size, addr))),
55 Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
56 Err(err) => return Poll::Ready(Err(err)),
57 }
58 if let Err(err) = self.wait_read_data() {
59 return Poll::Ready(Err(err));
60 }
61 self.read_future.take().unwrap()
62 }
63 Some(future) => future,
64 };
65 match future.poll_unpin(cx) {
66 Poll::Pending => {
67 self.read_future = Some(future);
68 Poll::Pending
69 }
70 Poll::Ready(ReadyFutureResult::Timeout) => {
71 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
72 }
73 Poll::Ready(_) => match self.udp_socket.recv_from(buf) {
74 Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
75 Err(err) => Poll::Ready(Err(err)),
76 },
77 }
78 }
79
80 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
81 self.udp_socket.recv_from(buf)
82 }
83
84 pub fn local_addr(&self) -> io::Result<SocketAddr> {
85 self.udp_socket.local_addr()
86 }
87}
88
89impl AsyncRead for UdpReadSocket {
90 fn poll_read(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 buf: &mut [u8],
94 ) -> Poll<io::Result<usize>> {
95 let me = self.get_mut();
96 match me.poll_read_attempt(cx, buf) {
97 Poll::Pending => Poll::Pending,
98 Poll::Ready(Ok((size, _))) => Poll::Ready(Ok(size)),
99 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
100 }
101 }
102}
103
104impl Drop for UdpReadSocket {
105 fn drop(&mut self) {
106 event_listener()
107 .stop_listening(&mut self.udp_socket, self.read_token)
108 .ok();
109 }
110}
111
112pub struct UdpWriteSocket {
113 udp_socket: MioUdpSocket,
114 write_token: Token,
115 write_future: Option<ReadyFuture<()>>,
116 pub write_timeout: Duration,
117}
118
119impl UdpWriteSocket {
120 pub fn new(udp_socket: MioUdpSocket) -> Self {
121 UdpWriteSocket {
122 udp_socket,
123 write_token: event_listener().next_token(),
124 write_future: None,
125 write_timeout: Duration::from_secs(2),
126 }
127 }
128
129 pub fn set_write_timeout(&mut self, duration: Duration) {
130 self.write_timeout = duration;
131 }
132
133 fn wait_write_ready(&mut self) -> io::Result<()> {
134 let future = event_listener().listen_write(
135 &mut self.udp_socket,
136 Instant::now() + self.write_timeout,
137 self.write_token,
138 )?;
139 self.write_future = Some(future);
140 Ok(())
141 }
142
143 fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
144 let mut future = match self.write_future.take() {
145 None => {
146 match self.udp_socket.send(buf) {
147 Ok(size) => return Poll::Ready(Ok(size)),
148 Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
149 Err(err) => return Poll::Ready(Err(err)),
150 }
151
152 if let Err(err) = self.wait_write_ready() {
153 return Poll::Ready(Err(err));
154 }
155 self.write_future.take().unwrap()
156 }
157 Some(future) => future,
158 };
159 match future.poll_unpin(cx) {
160 Poll::Pending => {
161 self.write_future = Some(future);
162 Poll::Pending
163 }
164 Poll::Ready(ReadyFutureResult::Timeout) => {
165 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
166 }
167 Poll::Ready(_) => match self.udp_socket.send(buf) {
168 Ok(size) => Poll::Ready(Ok(size)),
169 Err(err) => Poll::Ready(Err(err)),
170 },
171 }
172 }
173
174 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
175 self.udp_socket.send_to(buf, target)
176 }
177
178 pub fn local_addr(&self) -> io::Result<SocketAddr> {
179 self.udp_socket.local_addr()
180 }
181}
182
183impl AsyncWrite for UdpWriteSocket {
184 fn poll_write(
185 self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 buf: &[u8],
188 ) -> Poll<io::Result<usize>> {
189 let me = self.get_mut();
190 me.poll_write_attempt(cx, buf)
191 }
192
193 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194 Poll::Ready(Ok(()))
195 }
196
197 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 Poll::Ready(Ok(()))
199 }
200}
201
202impl Drop for UdpWriteSocket {
203 fn drop(&mut self) {
204 event_listener()
205 .stop_listening(&mut self.udp_socket, self.write_token)
206 .ok();
207 }
208}
209
210pub struct UdpSocket {
211 read_socket: UdpReadSocket,
212 write_socket: UdpWriteSocket,
213}
214
215impl UdpSocket {
216 pub fn from(udp_socket: StdUdpSocket) -> io::Result<UdpSocket> {
217 udp_socket.set_nonblocking(true)?;
218 Ok(UdpSocket {
219 read_socket: UdpReadSocket::new(MioUdpSocket::from_std(udp_socket.try_clone()?)),
220 write_socket: UdpWriteSocket::new(MioUdpSocket::from_std(udp_socket)),
221 })
222 }
223
224 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
225 Self::from(StdUdpSocket::bind(addr)?)
226 }
227
228 pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
229 for addr in addr.to_socket_addrs()? {
230 self.read_socket.udp_socket.connect(addr)?;
231 self.write_socket.udp_socket.connect(addr)?;
232 break;
233 }
234 Ok(())
235 }
236
237 pub fn bind_and_connect<A: ToSocketAddrs, B: ToSocketAddrs>(
238 addr: A,
239 to_addr: B,
240 ) -> io::Result<UdpSocket> {
241 let result = Self::bind(addr)?;
242 result.connect(to_addr)?;
243 Ok(result)
244 }
245
246 pub fn read_socket(&self) -> &UdpReadSocket {
247 &self.read_socket
248 }
249
250 pub fn read_socket_mut(&mut self) -> &mut UdpReadSocket {
251 &mut self.read_socket
252 }
253
254 pub fn write_socket(&self) -> &UdpWriteSocket {
255 &self.write_socket
256 }
257
258 pub fn write_socket_mut(&mut self) -> &mut UdpWriteSocket {
259 &mut self.write_socket
260 }
261
262 pub fn set_read_timeout(&mut self, duration: Duration) {
263 self.read_socket.set_read_timeout(duration);
264 }
265
266 pub fn set_write_timeout(&mut self, duration: Duration) {
267 self.write_socket.set_write_timeout(duration);
268 }
269
270 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
271 self.write_socket.send_to(buf, target)
272 }
273
274 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
275 self.read_socket.recv_from(buf)
276 }
277
278 pub fn local_addr(&self) -> io::Result<SocketAddr> {
279 self.read_socket.local_addr()
280 }
281
282 pub fn split(self) -> (UdpReadSocket, UdpWriteSocket) {
283 (self.read_socket, self.write_socket)
284 }
285}
286
287impl AsyncRead for UdpSocket {
288 fn poll_read(
289 self: Pin<&mut Self>,
290 cx: &mut Context<'_>,
291 buf: &mut [u8],
292 ) -> Poll<io::Result<usize>> {
293 let me = self.get_mut();
294 Pin::new(&mut me.read_socket).poll_read(cx, buf)
295 }
296}
297
298impl AsyncWrite for UdpSocket {
299 fn poll_write(
300 self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 buf: &[u8],
303 ) -> Poll<io::Result<usize>> {
304 let me = self.get_mut();
305 Pin::new(&mut me.write_socket).poll_write(cx, buf)
306 }
307
308 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
309 let me = self.get_mut();
310 Pin::new(&mut me.write_socket).poll_flush(cx)
311 }
312
313 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314 let me = self.get_mut();
315 Pin::new(&mut me.write_socket).poll_close(cx)
316 }
317}
318
319impl Debug for UdpSocket {
320 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
321 write!(f, "{:?}", self.read_socket.udp_socket)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::timer::timer::Timer;
329 use futures::executor::block_on;
330 use std::sync::{Arc, Mutex};
331 use std::thread;
332 use std::time::Duration;
333
334 fn setup_test_sockets() -> (StdUdpSocket, StdUdpSocket) {
335 let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
336 let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
337 (server, client)
338 }
339
340 #[test]
341 fn test_udp_wrapper_creation() {
342 let socket = StdUdpSocket::bind("127.0.0.1:0").unwrap();
343 let addr = socket.local_addr().unwrap();
344
345 let wrapper = UdpSocket::from(socket);
346 assert!(wrapper.is_ok());
347
348 let wrapper = wrapper.unwrap();
349 assert_eq!(wrapper.local_addr().unwrap(), addr);
350 }
351
352 #[test]
353 fn test_udp_wrapper_bind() {
354 let wrapper = UdpSocket::bind("127.0.0.1:0");
355 assert!(wrapper.is_ok());
356
357 let wrapper = wrapper.unwrap();
358 let addr = wrapper.local_addr().unwrap();
359 assert!(addr.port() > 0);
360 assert_eq!(addr.ip().to_string(), "127.0.0.1");
361 }
362
363 #[test]
364 fn test_udp_wrapper_bind_and_connect() {
365 let (server, _) = setup_test_sockets();
366 let server_addr = server.local_addr().unwrap();
367
368 let wrapper = UdpSocket::bind_and_connect("127.0.0.1:0", server_addr);
369 assert!(wrapper.is_ok());
370 }
371
372 #[test]
373 fn test_timeout_setters() {
374 let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
375 let mut wrapper = wrapper;
376
377 wrapper.set_read_timeout(Duration::from_secs(30));
378 wrapper.set_write_timeout(Duration::from_secs(5));
379
380 assert_eq!(wrapper.read_socket().read_timeout, Duration::from_secs(30));
381 assert_eq!(wrapper.write_socket().write_timeout, Duration::from_secs(5));
382 }
383
384 #[test]
385 fn test_socket_accessors() {
386 let mut wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
387
388 let read_socket = wrapper.read_socket();
389 assert_eq!(read_socket.read_timeout, Duration::from_secs(20));
390
391 let read_socket_mut = wrapper.read_socket_mut();
392 read_socket_mut.set_read_timeout(Duration::from_secs(15));
393 assert_eq!(read_socket_mut.read_timeout, Duration::from_secs(15));
394
395 let write_socket = wrapper.write_socket();
396 assert_eq!(write_socket.write_timeout, Duration::from_secs(2));
397
398 let write_socket_mut = wrapper.write_socket_mut();
399 write_socket_mut.set_write_timeout(Duration::from_secs(10));
400 assert_eq!(write_socket_mut.write_timeout, Duration::from_secs(10));
401 }
402
403 #[test]
404 fn test_sync_send_recv() {
405 let (server, client) = setup_test_sockets();
406 let server_addr = server.local_addr().unwrap();
407 let client_addr = client.local_addr().unwrap();
408
409 let server_wrapper = UdpSocket::from(server).unwrap();
410 let client_wrapper = UdpSocket::from(client).unwrap();
411
412 let test_data = b"Hello UDP!";
413 let sent = client_wrapper.send_to(test_data, server_addr);
414 assert!(sent.is_ok());
415 assert_eq!(sent.unwrap(), test_data.len());
416
417 thread::sleep(Duration::from_millis(10));
418
419 let mut buf = [0u8; 1024];
420 let received = server_wrapper.recv_from(&mut buf);
421 assert!(received.is_ok());
422 let (size, addr) = received.unwrap();
423 assert_eq!(size, test_data.len());
424 assert_eq!(&buf[..size], test_data);
425 assert_eq!(addr, client_addr);
426 }
427
428 #[test]
429 fn test_async_read_write() {
430 let (server, client) = setup_test_sockets();
431 let server_addr = server.local_addr().unwrap();
432
433 thread::spawn(move || {
434 let mut buf = [0u8; 1024];
435 if let Ok((size, addr)) = server.recv_from(&mut buf) {
436 let _ = server.send_to(&buf[..size], addr);
437 }
438 });
439
440 thread::sleep(Duration::from_millis(10));
441
442 let test_future = async {
443 let wrapper = UdpSocket::from(client).unwrap();
444
445 let test_data = b"Async UDP test!";
446 let sent = wrapper.send_to(test_data, server_addr);
447 assert!(sent.is_ok());
448
449 let mut buf = [0u8; 1024];
450 let read_result = wrapper.recv_from(&mut buf);
451 if let Ok((size, addr)) = read_result {
452 assert_eq!(size, test_data.len());
453 assert_eq!(&buf[..size], test_data);
454 assert_eq!(addr, server_addr);
455 }
456 };
457
458 block_on(test_future);
459 }
460
461 #[test]
462 fn test_async_with_timer() {
463 let mut timer = Timer::new();
464 let (server, client) = setup_test_sockets();
465 let server_addr = server.local_addr().unwrap();
466
467 thread::spawn(move || {
468 let mut buf = [0u8; 1024];
469 if let Ok((size, addr)) = server.recv_from(&mut buf) {
470 thread::sleep(Duration::from_millis(50));
471 let _ = server.send_to(&buf[..size], addr);
472 }
473 });
474
475 let test_future = async {
476 let wrapper = UdpSocket::from(client).unwrap();
477 timer.wait(Duration::from_millis(20)).await;
478 let test_data = b"Delayed UDP!";
479 let sent = wrapper.send_to(test_data, server_addr);
480 assert!(sent.is_ok());
481 };
482
483 block_on(test_future);
484 }
485
486 #[test]
487 fn test_concurrent_operations() {
488 let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
489 let server_addr = server.local_addr().unwrap();
490 let response_count = Arc::new(Mutex::new(0));
491 let response_count_clone = response_count.clone();
492
493 thread::spawn(move || {
494 let mut buf = [0u8; 1024];
495 for _ in 0..3 {
496 if let Ok((size, addr)) = server.recv_from(&mut buf) {
497 let _ = server.send_to(&buf[..size], addr);
498 let mut count = response_count_clone.lock().unwrap();
499 *count += 1;
500 }
501 }
502 });
503
504 thread::sleep(Duration::from_millis(10));
505
506 let test_future = async {
507 let mut futures = Vec::new();
508
509 for i in 0..3 {
510 let test_data = format!("Message {}", i);
511 let future = async move {
512 let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
513 let wrapper = UdpSocket::from(client).unwrap();
514
515 let sent = wrapper.send_to(test_data.as_bytes(), server_addr);
516 assert!(sent.is_ok());
517 };
518 futures.push(future);
519 }
520
521 futures::future::join_all(futures).await;
522 };
523
524 block_on(test_future);
525 thread::sleep(Duration::from_millis(100));
526 let count = response_count.lock().unwrap();
527 assert_eq!(*count, 3);
528 }
529
530 #[test]
531 fn test_timeout_behavior() {
532 let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
533 let mut wrapper = wrapper;
534
535 wrapper.set_read_timeout(Duration::from_millis(50));
536
537 let test_future = async {
538 let mut buf = [0u8; 1024];
539
540 let result = wrapper.recv_from(&mut buf);
541 match result {
542 Ok(_) => {
543 panic!("Unexpected data received");
544 }
545 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
546 }
548 Err(e) => {
549 panic!("Unexpected error: {:?}", e);
550 }
551 }
552 };
553
554 block_on(test_future);
555 }
556
557 #[test]
558 fn test_multiple_sends_to_different_addresses() {
559 let server1 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
560 let server2 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
561 let server1_addr = server1.local_addr().unwrap();
562 let server2_addr = server2.local_addr().unwrap();
563
564 let (_, client) = setup_test_sockets();
565 let wrapper = UdpSocket::from(client).unwrap();
566
567 let data1 = b"Hello Server 1";
568 let sent1 = wrapper.send_to(data1, server1_addr);
569 assert!(sent1.is_ok());
570 assert_eq!(sent1.unwrap(), data1.len());
571
572 let data2 = b"Hello Server 2";
573 let sent2 = wrapper.send_to(data2, server2_addr);
574 assert!(sent2.is_ok());
575 assert_eq!(sent2.unwrap(), data2.len());
576
577 thread::sleep(Duration::from_millis(10));
578
579 let mut buf1 = [0u8; 1024];
580 let received1 = server1.recv_from(&mut buf1);
581 assert!(received1.is_ok());
582 let (size1, _) = received1.unwrap();
583 assert_eq!(&buf1[..size1], data1);
584
585 let mut buf2 = [0u8; 1024];
586 let received2 = server2.recv_from(&mut buf2);
587 assert!(received2.is_ok());
588 let (size2, _) = received2.unwrap();
589 assert_eq!(&buf2[..size2], data2);
590 }
591
592 #[test]
593 fn test_large_data_transmission() {
594 let (server, client) = setup_test_sockets();
595 let server_addr = server.local_addr().unwrap();
596
597 thread::spawn(move || {
598 let mut buf = [0u8; 2048];
599 if let Ok((size, addr)) = server.recv_from(&mut buf) {
600 let _ = server.send_to(&buf[..size], addr);
601 }
602 });
603
604 thread::sleep(Duration::from_millis(10));
605
606 let wrapper = UdpSocket::from(client).unwrap();
607
608 let large_data = vec![0xAB; 1400];
609 let sent = wrapper.send_to(&large_data, server_addr);
610 assert!(sent.is_ok());
611 assert_eq!(sent.unwrap(), large_data.len());
612
613 thread::sleep(Duration::from_millis(20));
614
615 let mut buf = [0u8; 2048];
616 let received = wrapper.recv_from(&mut buf);
617 assert!(received.is_ok());
618 let (size, addr) = received.unwrap();
619 assert_eq!(size, large_data.len());
620 assert_eq!(&buf[..size], &large_data[..]);
621 assert_eq!(addr, server_addr);
622 }
623
624 #[test]
625 fn test_drop_behavior() {
626 let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
627 let addr = wrapper.local_addr().unwrap();
628 drop(wrapper);
629
630 thread::sleep(Duration::from_millis(10));
631 let new_wrapper = UdpSocket::bind(addr);
632 assert!(new_wrapper.is_ok());
633 }
634
635 #[test]
636 fn test_split_sockets_independently() {
637 let (server, client) = setup_test_sockets();
638 let server_addr = server.local_addr().unwrap();
639
640 thread::spawn(move || {
641 let mut buf = [0u8; 1024];
642 if let Ok((size, addr)) = server.recv_from(&mut buf) {
643 let _ = server.send_to(&buf[..size], addr);
644 }
645 });
646
647 thread::sleep(Duration::from_millis(10));
648
649 let test_future = async {
650 let wrapper = UdpSocket::from(client).unwrap();
651 let (read_socket, write_socket) = wrapper.split();
652
653 let test_data = b"Split socket test";
655 write_socket.send_to(test_data, server_addr).unwrap();
656
657 let mut buf = [0u8; 1024];
658 let received = read_socket.recv_from(&mut buf);
659
660 match received {
661 Ok((size, addr)) => {
662 assert_eq!(&buf[..size], test_data);
663 assert_eq!(addr, server_addr);
664 }
665 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
666 Err(e) => panic!("Unexpected error: {:?}", e),
667 }
668 };
669
670 block_on(test_future);
671 }
672
673 #[test]
674 fn test_connected_socket_operations() {
675 let (server, client) = setup_test_sockets();
676 let server_addr = server.local_addr().unwrap();
677
678 let wrapper = UdpSocket::from(client).unwrap();
679 let test_data = b"Connected test";
680 let result = wrapper.send_to(test_data, server_addr);
681 assert!(result.is_ok());
682 }
683}