arta_async_std/net/
udp_socket.rs

1use crate::AsyncStdGlobalRuntime;
2use arta::net::{RuntimeUdpSocket, ToSocketAddrs};
3use cfg_if::cfg_if;
4use futures::{prelude::Future, TryFutureExt};
5use socket2::SockRef;
6use std::net::SocketAddr;
7
8cfg_if! {
9    if #[cfg(windows)] {
10        impl std::os::windows::io::AsRawSocket for AsyncStdUdpSocket {
11            fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
12                self.inner.as_raw_socket()
13            }
14        }
15
16        impl std::os::windows::io::AsSocket for AsyncStdUdpSocket {
17            fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> {
18                let raw_socket = std::os::windows::io::AsRawSocket::as_raw_socket(self);
19                unsafe { std::os::windows::io::BorrowedSocket::borrow_raw(raw_socket) }
20            }
21        }
22
23        impl From<std::os::windows::io::OwnedSocket> for AsyncStdUdpSocket {
24            fn from(socket: std::os::windows::io::OwnedSocket) -> Self {
25                Self {
26                    inner: async_std::net::UdpSocket::from(std::net::UdpSocket::from(socket))
27                }
28            }
29        }
30    } else if #[cfg(any(unix, target_os = "wasi"))] {
31        impl std::os::fd::AsRawFd for AsyncStdUdpSocket {
32            fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
33                self.inner.as_raw_fd()
34            }
35        }
36
37        impl std::os::fd::AsFd for AsyncStdUdpSocket {
38            fn as_fd(&self) -> std::os::unix::prelude::BorrowedFd<'_> {
39                let raw_fd = std::os::fd::AsRawFd::as_raw_fd(self);
40                unsafe { std::os::fd::BorrowedFd::borrow_raw(raw_fd) }
41            }
42        }
43
44        impl From<std::os::fd::OwnedFd> for AsyncStdUdpSocket {
45            fn from(fd: std::os::fd::OwnedFd) -> Self {
46                Self {
47                    inner: async_std::net::UdpSocket::from(std::net::UdpSocket::from(fd))
48                }
49            }
50        }
51    }
52}
53
54/// Async-std specific [`RuntimeUdpSocket`] implementation.
55pub struct AsyncStdUdpSocket {
56    inner: async_std::net::UdpSocket,
57}
58
59impl RuntimeUdpSocket for AsyncStdUdpSocket {
60    type Runtime = AsyncStdGlobalRuntime;
61
62    fn bind(
63        runtime: &Self::Runtime,
64        addrs: impl ToSocketAddrs<Self::Runtime>,
65    ) -> impl Future<Output = std::io::Result<Self>> + Send
66    where
67        Self: Sized,
68    {
69        addrs.for_each_resolved_addr_until_success(runtime, |addr| {
70            async_std::net::UdpSocket::bind(addr).map_ok(|socket| Self { inner: socket })
71        })
72    }
73
74    fn connect(
75        &self,
76        addrs: impl ToSocketAddrs<Self::Runtime>,
77    ) -> impl Future<Output = std::io::Result<()>> + Send {
78        addrs.for_each_resolved_addr_until_success(&AsyncStdGlobalRuntime, |addr| {
79            self.inner.connect(addr)
80        })
81    }
82
83    fn send(&self, buf: &[u8]) -> impl Future<Output = std::io::Result<usize>> + Send {
84        self.inner.send(buf)
85    }
86
87    async fn send_to(
88        &self,
89        buf: &[u8],
90        addrs: impl ToSocketAddrs<Self::Runtime>,
91    ) -> std::io::Result<usize> {
92        if let Some(addr) = addrs.to_socket_addrs(&AsyncStdGlobalRuntime).await?.next() {
93            self.inner.send_to(buf, addr).await
94        } else {
95            Err(std::io::Error::new(
96                std::io::ErrorKind::InvalidData,
97                "no address was resolved",
98            ))
99        }
100    }
101
102    fn recv(&self, buf: &mut [u8]) -> impl Future<Output = std::io::Result<usize>> + Send {
103        self.inner.recv(buf)
104    }
105
106    fn recv_from(
107        &self,
108        buf: &mut [u8],
109    ) -> impl Future<Output = std::io::Result<(usize, SocketAddr)>> + Send {
110        self.inner.recv_from(buf)
111    }
112
113    fn local_addr(&self) -> std::io::Result<SocketAddr> {
114        self.inner.local_addr()
115    }
116
117    fn set_broadcast(&self, is_enabled: bool) -> std::io::Result<()> {
118        self.inner.set_broadcast(is_enabled)
119    }
120
121    fn broadcast(&self) -> std::io::Result<bool> {
122        self.inner.broadcast()
123    }
124
125    fn join_multicast_v4(
126        &self,
127        multiaddr: std::net::Ipv4Addr,
128        interface: std::net::Ipv4Addr,
129    ) -> std::io::Result<()> {
130        self.inner.join_multicast_v4(multiaddr, interface)
131    }
132
133    fn leave_multicast_v4(
134        &self,
135        multiaddr: std::net::Ipv4Addr,
136        interface: std::net::Ipv4Addr,
137    ) -> std::io::Result<()> {
138        self.inner.leave_multicast_v4(multiaddr, interface)
139    }
140
141    fn set_multicast_loop_v4(&self, is_enabled: bool) -> std::io::Result<()> {
142        self.inner.set_multicast_loop_v4(is_enabled)
143    }
144
145    fn multicast_loop_v4(&self) -> std::io::Result<bool> {
146        self.inner.multicast_loop_v4()
147    }
148
149    fn set_multicast_ttl_v4(&self, ttl: u32) -> std::io::Result<()> {
150        self.inner.set_multicast_ttl_v4(ttl)
151    }
152
153    fn multicast_ttl_v4(&self) -> std::io::Result<u32> {
154        self.inner.multicast_ttl_v4()
155    }
156
157    fn join_multicast_v6(
158        &self,
159        multiaddr: std::net::Ipv6Addr,
160        interface: u32,
161    ) -> std::io::Result<()> {
162        self.inner.join_multicast_v6(&multiaddr, interface)
163    }
164
165    fn leave_multicast_v6(
166        &self,
167        multiaddr: std::net::Ipv6Addr,
168        interface: u32,
169    ) -> std::io::Result<()> {
170        self.inner.leave_multicast_v6(&multiaddr, interface)
171    }
172
173    fn set_multicast_loop_v6(&self, is_enabled: bool) -> std::io::Result<()> {
174        self.inner.set_multicast_loop_v6(is_enabled)
175    }
176
177    fn multicast_loop_v6(&self) -> std::io::Result<bool> {
178        self.inner.multicast_loop_v6()
179    }
180
181    fn ttl(&self) -> std::io::Result<u32> {
182        self.inner.ttl()
183    }
184
185    fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
186        self.inner.set_ttl(ttl)
187    }
188
189    fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
190        SockRef::from(self).take_error()
191    }
192}