1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::UdpSocket as MioUdpSocket;
7use std::fmt::{Debug, Error, Formatter};
8use std::io;
9use std::net::UdpSocket as StdUdpSocket;
10use std::net::{SocketAddr, ToSocketAddrs};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::{Duration, Instant};
14
15pub struct UdpReadSocket {
16 udp_socket: MioUdpSocket,
17 read_token: Token,
18 read_future: Option<ReadyFuture<()>>,
19 pub read_timeout: Duration,
20}
21
22impl UdpReadSocket {
23 pub fn new(udp_socket: MioUdpSocket) -> Self {
24 UdpReadSocket {
25 udp_socket,
26 read_token: event_listener().next_token(),
27 read_future: None,
28 read_timeout: Duration::from_secs(20),
29 }
30 }
31
32 pub fn set_read_timeout(&mut self, duration: Duration) {
33 self.read_timeout = duration;
34 }
35
36 fn wait_read_data(&mut self) -> io::Result<()> {
37 let future = event_listener().listen_read(
38 &mut self.udp_socket,
39 Instant::now() + self.read_timeout,
40 self.read_token,
41 )?;
42 self.read_future = Some(future);
43 Ok(())
44 }
45
46 fn poll_read_attempt(
47 &mut self,
48 cx: &mut Context<'_>,
49 buf: &mut [u8],
50 ) -> Poll<io::Result<(usize, SocketAddr)>> {
51 let mut future = match self.read_future.take() {
52 None => {
53 match self.udp_socket.recv_from(buf) {
54 Ok((size, addr)) => return Poll::Ready(Ok((size, addr))),
55 Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
56 Err(err) => return Poll::Ready(Err(err)),
57 }
58 if let Err(err) = self.wait_read_data() {
59 return Poll::Ready(Err(err));
60 }
61 self.read_future.take().unwrap()
62 }
63 Some(future) => future,
64 };
65 match future.poll_unpin(cx) {
66 Poll::Pending => {
67 self.read_future = Some(future);
68 Poll::Pending
69 }
70 Poll::Ready(ReadyFutureResult::Timeout) => {
71 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
72 }
73 Poll::Ready(_) => match self.udp_socket.recv_from(buf) {
74 Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
75 Err(err) => Poll::Ready(Err(err)),
76 },
77 }
78 }
79
80 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
81 self.udp_socket.recv_from(buf)
82 }
83
84 pub fn local_addr(&self) -> io::Result<SocketAddr> {
85 self.udp_socket.local_addr()
86 }
87}
88
89impl AsyncRead for UdpReadSocket {
90 fn poll_read(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 buf: &mut [u8],
94 ) -> Poll<io::Result<usize>> {
95 let me = self.get_mut();
96 match me.poll_read_attempt(cx, buf) {
97 Poll::Pending => Poll::Pending,
98 Poll::Ready(Ok((size, _))) => Poll::Ready(Ok(size)),
99 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
100 }
101 }
102}
103
104impl Drop for UdpReadSocket {
105 fn drop(&mut self) {
106 event_listener()
107 .stop_listening(&mut self.udp_socket, self.read_token)
108 .ok();
109 }
110}
111
112pub struct UdpWriteSocket {
113 udp_socket: MioUdpSocket,
114 write_token: Token,
115 write_future: Option<ReadyFuture<()>>,
116 pub write_timeout: Duration,
117}
118
119impl UdpWriteSocket {
120 pub fn new(udp_socket: MioUdpSocket) -> Self {
121 UdpWriteSocket {
122 udp_socket,
123 write_token: event_listener().next_token(),
124 write_future: None,
125 write_timeout: Duration::from_secs(2),
126 }
127 }
128
129 pub fn set_write_timeout(&mut self, duration: Duration) {
130 self.write_timeout = duration;
131 }
132
133 fn wait_write_ready(&mut self) -> io::Result<()> {
134 let future = event_listener().listen_write(
135 &mut self.udp_socket,
136 Instant::now() + self.write_timeout,
137 self.write_token,
138 )?;
139 self.write_future = Some(future);
140 Ok(())
141 }
142
143 fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
144 let mut future = match self.write_future.take() {
145 None => {
146 match self.udp_socket.send(buf) {
147 Ok(size) => return Poll::Ready(Ok(size)),
148 Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
149 Err(err) => return Poll::Ready(Err(err)),
150 }
151
152 if let Err(err) = self.wait_write_ready() {
153 return Poll::Ready(Err(err));
154 }
155 self.write_future.take().unwrap()
156 }
157 Some(future) => future,
158 };
159 match future.poll_unpin(cx) {
160 Poll::Pending => {
161 self.write_future = Some(future);
162 Poll::Pending
163 }
164 Poll::Ready(ReadyFutureResult::Timeout) => {
165 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
166 }
167 Poll::Ready(_) => match self.udp_socket.send(buf) {
168 Ok(size) => Poll::Ready(Ok(size)),
169 Err(err) => Poll::Ready(Err(err)),
170 },
171 }
172 }
173
174 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
175 self.udp_socket.send_to(buf, target)
176 }
177
178 pub fn local_addr(&self) -> io::Result<SocketAddr> {
179 self.udp_socket.local_addr()
180 }
181}
182
183impl AsyncWrite for UdpWriteSocket {
184 fn poll_write(
185 self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 buf: &[u8],
188 ) -> Poll<io::Result<usize>> {
189 let me = self.get_mut();
190 me.poll_write_attempt(cx, buf)
191 }
192
193 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194 Poll::Ready(Ok(()))
195 }
196
197 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 Poll::Ready(Ok(()))
199 }
200}
201
202impl Drop for UdpWriteSocket {
203 fn drop(&mut self) {
204 event_listener()
205 .stop_listening(&mut self.udp_socket, self.write_token)
206 .ok();
207 }
208}
209
210pub struct UdpSocket {
216 read_socket: UdpReadSocket,
217 write_socket: UdpWriteSocket,
218}
219
220impl UdpSocket {
221 pub fn from(udp_socket: StdUdpSocket) -> io::Result<UdpSocket> {
222 udp_socket.set_nonblocking(true)?;
223 Ok(UdpSocket {
224 read_socket: UdpReadSocket::new(MioUdpSocket::from_std(udp_socket.try_clone()?)),
225 write_socket: UdpWriteSocket::new(MioUdpSocket::from_std(udp_socket)),
226 })
227 }
228
229 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
230 Self::from(StdUdpSocket::bind(addr)?)
231 }
232
233 pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
234 for addr in addr.to_socket_addrs()? {
235 self.read_socket.udp_socket.connect(addr)?;
236 self.write_socket.udp_socket.connect(addr)?;
237 break;
238 }
239 Ok(())
240 }
241
242 pub fn bind_and_connect<A: ToSocketAddrs, B: ToSocketAddrs>(
243 addr: A,
244 to_addr: B,
245 ) -> io::Result<UdpSocket> {
246 let result = Self::bind(addr)?;
247 result.connect(to_addr)?;
248 Ok(result)
249 }
250
251 pub fn read_socket(&self) -> &UdpReadSocket {
252 &self.read_socket
253 }
254
255 pub fn read_socket_mut(&mut self) -> &mut UdpReadSocket {
256 &mut self.read_socket
257 }
258
259 pub fn write_socket(&self) -> &UdpWriteSocket {
260 &self.write_socket
261 }
262
263 pub fn write_socket_mut(&mut self) -> &mut UdpWriteSocket {
264 &mut self.write_socket
265 }
266
267 pub fn set_read_timeout(&mut self, duration: Duration) {
268 self.read_socket.set_read_timeout(duration);
269 }
270
271 pub fn set_write_timeout(&mut self, duration: Duration) {
272 self.write_socket.set_write_timeout(duration);
273 }
274
275 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
276 self.write_socket.send_to(buf, target)
277 }
278
279 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
280 self.read_socket.recv_from(buf)
281 }
282
283 pub fn local_addr(&self) -> io::Result<SocketAddr> {
284 self.read_socket.local_addr()
285 }
286
287 pub fn split(self) -> (UdpReadSocket, UdpWriteSocket) {
288 (self.read_socket, self.write_socket)
289 }
290}
291
292impl AsyncRead for UdpSocket {
293 fn poll_read(
294 self: Pin<&mut Self>,
295 cx: &mut Context<'_>,
296 buf: &mut [u8],
297 ) -> Poll<io::Result<usize>> {
298 let me = self.get_mut();
299 Pin::new(&mut me.read_socket).poll_read(cx, buf)
300 }
301}
302
303impl AsyncWrite for UdpSocket {
304 fn poll_write(
305 self: Pin<&mut Self>,
306 cx: &mut Context<'_>,
307 buf: &[u8],
308 ) -> Poll<io::Result<usize>> {
309 let me = self.get_mut();
310 Pin::new(&mut me.write_socket).poll_write(cx, buf)
311 }
312
313 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314 let me = self.get_mut();
315 Pin::new(&mut me.write_socket).poll_flush(cx)
316 }
317
318 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
319 let me = self.get_mut();
320 Pin::new(&mut me.write_socket).poll_close(cx)
321 }
322}
323
324impl Debug for UdpSocket {
325 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
326 write!(f, "{:?}", self.read_socket.udp_socket)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::timer::timer::Timer;
334 use futures::executor::block_on;
335 use std::sync::{Arc, Mutex};
336 use std::thread;
337 use std::time::Duration;
338
339 fn setup_test_sockets() -> (StdUdpSocket, StdUdpSocket) {
340 let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
341 let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
342 (server, client)
343 }
344
345 #[test]
346 fn test_udp_wrapper_creation() {
347 let socket = StdUdpSocket::bind("127.0.0.1:0").unwrap();
348 let addr = socket.local_addr().unwrap();
349
350 let wrapper = UdpSocket::from(socket);
351 assert!(wrapper.is_ok());
352
353 let wrapper = wrapper.unwrap();
354 assert_eq!(wrapper.local_addr().unwrap(), addr);
355 }
356
357 #[test]
358 fn test_udp_wrapper_bind() {
359 let wrapper = UdpSocket::bind("127.0.0.1:0");
360 assert!(wrapper.is_ok());
361
362 let wrapper = wrapper.unwrap();
363 let addr = wrapper.local_addr().unwrap();
364 assert!(addr.port() > 0);
365 assert_eq!(addr.ip().to_string(), "127.0.0.1");
366 }
367
368 #[test]
369 fn test_udp_wrapper_bind_and_connect() {
370 let (server, _) = setup_test_sockets();
371 let server_addr = server.local_addr().unwrap();
372
373 let wrapper = UdpSocket::bind_and_connect("127.0.0.1:0", server_addr);
374 assert!(wrapper.is_ok());
375 }
376
377 #[test]
378 fn test_timeout_setters() {
379 let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
380 let mut wrapper = wrapper;
381
382 wrapper.set_read_timeout(Duration::from_secs(30));
383 wrapper.set_write_timeout(Duration::from_secs(5));
384
385 assert_eq!(wrapper.read_socket().read_timeout, Duration::from_secs(30));
386 assert_eq!(wrapper.write_socket().write_timeout, Duration::from_secs(5));
387 }
388
389 #[test]
390 fn test_socket_accessors() {
391 let mut wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
392
393 let read_socket = wrapper.read_socket();
394 assert_eq!(read_socket.read_timeout, Duration::from_secs(20));
395
396 let read_socket_mut = wrapper.read_socket_mut();
397 read_socket_mut.set_read_timeout(Duration::from_secs(15));
398 assert_eq!(read_socket_mut.read_timeout, Duration::from_secs(15));
399
400 let write_socket = wrapper.write_socket();
401 assert_eq!(write_socket.write_timeout, Duration::from_secs(2));
402
403 let write_socket_mut = wrapper.write_socket_mut();
404 write_socket_mut.set_write_timeout(Duration::from_secs(10));
405 assert_eq!(write_socket_mut.write_timeout, Duration::from_secs(10));
406 }
407
408 #[test]
409 fn test_sync_send_recv() {
410 let (server, client) = setup_test_sockets();
411 let server_addr = server.local_addr().unwrap();
412 let client_addr = client.local_addr().unwrap();
413
414 let server_wrapper = UdpSocket::from(server).unwrap();
415 let client_wrapper = UdpSocket::from(client).unwrap();
416
417 let test_data = b"Hello UDP!";
418 let sent = client_wrapper.send_to(test_data, server_addr);
419 assert!(sent.is_ok());
420 assert_eq!(sent.unwrap(), test_data.len());
421
422 thread::sleep(Duration::from_millis(10));
423
424 let mut buf = [0u8; 1024];
425 let received = server_wrapper.recv_from(&mut buf);
426 assert!(received.is_ok());
427 let (size, addr) = received.unwrap();
428 assert_eq!(size, test_data.len());
429 assert_eq!(&buf[..size], test_data);
430 assert_eq!(addr, client_addr);
431 }
432
433 #[test]
434 fn test_async_read_write() {
435 let (server, client) = setup_test_sockets();
436 let server_addr = server.local_addr().unwrap();
437
438 thread::spawn(move || {
439 let mut buf = [0u8; 1024];
440 if let Ok((size, addr)) = server.recv_from(&mut buf) {
441 let _ = server.send_to(&buf[..size], addr);
442 }
443 });
444
445 thread::sleep(Duration::from_millis(10));
446
447 let test_future = async {
448 let wrapper = UdpSocket::from(client).unwrap();
449
450 let test_data = b"Async UDP test!";
451 let sent = wrapper.send_to(test_data, server_addr);
452 assert!(sent.is_ok());
453
454 let mut buf = [0u8; 1024];
455 let read_result = wrapper.recv_from(&mut buf);
456 if let Ok((size, addr)) = read_result {
457 assert_eq!(size, test_data.len());
458 assert_eq!(&buf[..size], test_data);
459 assert_eq!(addr, server_addr);
460 }
461 };
462
463 block_on(test_future);
464 }
465
466 #[test]
467 fn test_async_with_timer() {
468 let mut timer = Timer::new();
469 let (server, client) = setup_test_sockets();
470 let server_addr = server.local_addr().unwrap();
471
472 thread::spawn(move || {
473 let mut buf = [0u8; 1024];
474 if let Ok((size, addr)) = server.recv_from(&mut buf) {
475 thread::sleep(Duration::from_millis(50));
476 let _ = server.send_to(&buf[..size], addr);
477 }
478 });
479
480 let test_future = async {
481 let wrapper = UdpSocket::from(client).unwrap();
482 timer.wait(Duration::from_millis(20)).await;
483 let test_data = b"Delayed UDP!";
484 let sent = wrapper.send_to(test_data, server_addr);
485 assert!(sent.is_ok());
486 };
487
488 block_on(test_future);
489 }
490
491 #[test]
492 fn test_concurrent_operations() {
493 let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
494 let server_addr = server.local_addr().unwrap();
495 let response_count = Arc::new(Mutex::new(0));
496 let response_count_clone = response_count.clone();
497
498 thread::spawn(move || {
499 let mut buf = [0u8; 1024];
500 for _ in 0..3 {
501 if let Ok((size, addr)) = server.recv_from(&mut buf) {
502 let _ = server.send_to(&buf[..size], addr);
503 let mut count = response_count_clone.lock().unwrap();
504 *count += 1;
505 }
506 }
507 });
508
509 thread::sleep(Duration::from_millis(10));
510
511 let test_future = async {
512 let mut futures = Vec::new();
513
514 for i in 0..3 {
515 let test_data = format!("Message {}", i);
516 let future = async move {
517 let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
518 let wrapper = UdpSocket::from(client).unwrap();
519
520 let sent = wrapper.send_to(test_data.as_bytes(), server_addr);
521 assert!(sent.is_ok());
522 };
523 futures.push(future);
524 }
525
526 futures::future::join_all(futures).await;
527 };
528
529 block_on(test_future);
530 thread::sleep(Duration::from_millis(100));
531 let count = response_count.lock().unwrap();
532 assert_eq!(*count, 3);
533 }
534
535 #[test]
536 fn test_timeout_behavior() {
537 let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
538 let mut wrapper = wrapper;
539
540 wrapper.set_read_timeout(Duration::from_millis(50));
541
542 let test_future = async {
543 let mut buf = [0u8; 1024];
544
545 let result = wrapper.recv_from(&mut buf);
546 match result {
547 Ok(_) => {
548 panic!("Unexpected data received");
549 }
550 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
551 }
553 Err(e) => {
554 panic!("Unexpected error: {:?}", e);
555 }
556 }
557 };
558
559 block_on(test_future);
560 }
561
562 #[test]
563 fn test_multiple_sends_to_different_addresses() {
564 let server1 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
565 let server2 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
566 let server1_addr = server1.local_addr().unwrap();
567 let server2_addr = server2.local_addr().unwrap();
568
569 let (_, client) = setup_test_sockets();
570 let wrapper = UdpSocket::from(client).unwrap();
571
572 let data1 = b"Hello Server 1";
573 let sent1 = wrapper.send_to(data1, server1_addr);
574 assert!(sent1.is_ok());
575 assert_eq!(sent1.unwrap(), data1.len());
576
577 let data2 = b"Hello Server 2";
578 let sent2 = wrapper.send_to(data2, server2_addr);
579 assert!(sent2.is_ok());
580 assert_eq!(sent2.unwrap(), data2.len());
581
582 thread::sleep(Duration::from_millis(10));
583
584 let mut buf1 = [0u8; 1024];
585 let received1 = server1.recv_from(&mut buf1);
586 assert!(received1.is_ok());
587 let (size1, _) = received1.unwrap();
588 assert_eq!(&buf1[..size1], data1);
589
590 let mut buf2 = [0u8; 1024];
591 let received2 = server2.recv_from(&mut buf2);
592 assert!(received2.is_ok());
593 let (size2, _) = received2.unwrap();
594 assert_eq!(&buf2[..size2], data2);
595 }
596
597 #[test]
598 fn test_large_data_transmission() {
599 let (server, client) = setup_test_sockets();
600 let server_addr = server.local_addr().unwrap();
601
602 thread::spawn(move || {
603 let mut buf = [0u8; 2048];
604 if let Ok((size, addr)) = server.recv_from(&mut buf) {
605 let _ = server.send_to(&buf[..size], addr);
606 }
607 });
608
609 thread::sleep(Duration::from_millis(10));
610
611 let wrapper = UdpSocket::from(client).unwrap();
612
613 let large_data = vec![0xAB; 1400];
614 let sent = wrapper.send_to(&large_data, server_addr);
615 assert!(sent.is_ok());
616 assert_eq!(sent.unwrap(), large_data.len());
617
618 thread::sleep(Duration::from_millis(20));
619
620 let mut buf = [0u8; 2048];
621 let received = wrapper.recv_from(&mut buf);
622 assert!(received.is_ok());
623 let (size, addr) = received.unwrap();
624 assert_eq!(size, large_data.len());
625 assert_eq!(&buf[..size], &large_data[..]);
626 assert_eq!(addr, server_addr);
627 }
628
629 #[test]
630 fn test_drop_behavior() {
631 let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
632 let addr = wrapper.local_addr().unwrap();
633 drop(wrapper);
634
635 thread::sleep(Duration::from_millis(10));
636 let new_wrapper = UdpSocket::bind(addr);
637 assert!(new_wrapper.is_ok());
638 }
639
640 #[test]
641 fn test_split_sockets_independently() {
642 let (server, client) = setup_test_sockets();
643 let server_addr = server.local_addr().unwrap();
644
645 thread::spawn(move || {
646 let mut buf = [0u8; 1024];
647 if let Ok((size, addr)) = server.recv_from(&mut buf) {
648 let _ = server.send_to(&buf[..size], addr);
649 }
650 });
651
652 thread::sleep(Duration::from_millis(10));
653
654 let test_future = async {
655 let wrapper = UdpSocket::from(client).unwrap();
656 let (read_socket, write_socket) = wrapper.split();
657
658 let test_data = b"Split socket test";
660 write_socket.send_to(test_data, server_addr).unwrap();
661
662 let mut buf = [0u8; 1024];
663 let received = read_socket.recv_from(&mut buf);
664
665 match received {
666 Ok((size, addr)) => {
667 assert_eq!(&buf[..size], test_data);
668 assert_eq!(addr, server_addr);
669 }
670 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
671 Err(e) => panic!("Unexpected error: {:?}", e),
672 }
673 };
674
675 block_on(test_future);
676 }
677
678 #[test]
679 fn test_connected_socket_operations() {
680 let (server, client) = setup_test_sockets();
681 let server_addr = server.local_addr().unwrap();
682
683 let wrapper = UdpSocket::from(client).unwrap();
684 let test_data = b"Connected test";
685 let result = wrapper.send_to(test_data, server_addr);
686 assert!(result.is_ok());
687 }
688}