1use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use crate::platform::sys::{set_nonblocking, Interest};
13use crate::reactor::source::{next_token, IoSource};
14
15pub struct UdpSocket {
19 source: IoSource,
20}
21
22impl UdpSocket {
23 pub fn bind(addr: SocketAddr) -> io::Result<Self> {
28 let fd = create_udp_socket(addr)?;
29 bind_socket(fd, addr)?;
30 set_nonblocking(fd)?;
31 let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
32 Ok(Self { source })
33 }
34
35 pub fn local_addr(&self) -> io::Result<SocketAddr> {
37 raw_local_addr(self.source.raw())
38 }
39
40 pub fn send_to<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> SendToFuture<'a> {
43 SendToFuture {
44 socket: self,
45 buf,
46 target,
47 }
48 }
49
50 pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFromFuture<'a> {
53 RecvFromFuture { socket: self, buf }
54 }
55}
56
57impl Drop for UdpSocket {
58 fn drop(&mut self) {
59 let fd = self.source.raw();
60 unsafe { libc::close(fd) };
63 }
64}
65
66pub struct SendToFuture<'a> {
70 socket: &'a UdpSocket,
71 buf: &'a [u8],
72 target: SocketAddr,
73}
74
75impl<'a> Future for SendToFuture<'a> {
76 type Output = io::Result<usize>;
77
78 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79 match try_send_to(self.socket.source.raw(), self.buf, self.target) {
80 Ok(n) => Poll::Ready(Ok(n)),
81 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
82 match Pin::new(&mut self.socket.source.writable()).poll(cx) {
84 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
85 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
86 }
87 }
88 Err(e) => Poll::Ready(Err(e)),
89 }
90 }
91}
92
93pub struct RecvFromFuture<'a> {
97 socket: &'a UdpSocket,
98 buf: &'a mut [u8],
99}
100
101impl<'a> Future for RecvFromFuture<'a> {
102 type Output = io::Result<(usize, SocketAddr)>;
103
104 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105 let fd = self.socket.source.raw();
106 match try_recv_from(fd, self.buf) {
107 Ok(result) => Poll::Ready(Ok(result)),
108 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
109 match Pin::new(&mut self.socket.source.readable()).poll(cx) {
111 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
112 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
113 }
114 }
115 Err(e) => Poll::Ready(Err(e)),
116 }
117 }
118}
119
120fn create_udp_socket(addr: SocketAddr) -> io::Result<i32> {
124 let family = match addr {
125 SocketAddr::V4(_) => libc::AF_INET,
126 SocketAddr::V6(_) => libc::AF_INET6,
127 };
128 let fd = unsafe { libc::socket(family, libc::SOCK_DGRAM, 0) };
130 if fd == -1 {
131 return Err(io::Error::last_os_error());
132 }
133 Ok(fd)
134}
135
136fn bind_socket(fd: i32, addr: SocketAddr) -> io::Result<()> {
138 let (sa_ptr, sa_len) = socketaddr_to_raw(addr);
139 let rc = unsafe { libc::bind(fd, sa_ptr, sa_len) };
141 unsafe { reclaim_sockaddr(sa_ptr, addr) };
143 if rc == -1 {
144 return Err(io::Error::last_os_error());
145 }
146 Ok(())
147}
148
149fn try_send_to(fd: i32, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
151 let (sa_ptr, sa_len) = socketaddr_to_raw(target);
152 let n = unsafe {
155 libc::sendto(
156 fd,
157 buf.as_ptr() as *const libc::c_void,
158 buf.len(),
159 0, sa_ptr,
161 sa_len,
162 )
163 };
164 unsafe { reclaim_sockaddr(sa_ptr, target) };
166 if n == -1 {
167 return Err(io::Error::last_os_error());
168 }
169 Ok(n as usize)
170}
171
172fn try_recv_from(fd: i32, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
174 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
175 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
176 let n = unsafe {
179 libc::recvfrom(
180 fd,
181 buf.as_mut_ptr() as *mut libc::c_void,
182 buf.len(),
183 0, &mut addr as *mut _ as *mut libc::sockaddr,
185 &mut len,
186 )
187 };
188 if n == -1 {
189 return Err(io::Error::last_os_error());
190 }
191 let sender = sockaddr_to_socketaddr(&addr, len)?;
192 Ok((n as usize, sender))
193}
194
195fn raw_local_addr(fd: i32) -> io::Result<SocketAddr> {
197 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
198 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
199 let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
201 if rc == -1 {
202 return Err(io::Error::last_os_error());
203 }
204 sockaddr_to_socketaddr(&addr, len)
205}
206
207fn socketaddr_to_raw(addr: SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
210 match addr {
211 SocketAddr::V4(v4) => {
212 let octets = v4.ip().octets();
213 let mut sin: libc::sockaddr_in = unsafe { std::mem::zeroed() };
215 sin.sin_family = libc::AF_INET as libc::sa_family_t;
216 sin.sin_port = v4.port().to_be();
217 sin.sin_addr = libc::in_addr {
218 s_addr: u32::from_be_bytes(octets).to_be(),
219 };
220 let ptr = Box::into_raw(Box::new(sin)) as *const libc::sockaddr;
221 (
222 ptr,
223 std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
224 )
225 }
226 SocketAddr::V6(v6) => {
227 let mut sin6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
229 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
230 sin6.sin6_port = v6.port().to_be();
231 sin6.sin6_flowinfo = v6.flowinfo();
232 sin6.sin6_addr = libc::in6_addr {
233 s6_addr: v6.ip().octets(),
234 };
235 sin6.sin6_scope_id = v6.scope_id();
236 let ptr = Box::into_raw(Box::new(sin6)) as *const libc::sockaddr;
237 (
238 ptr,
239 std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
240 )
241 }
242 }
243}
244
245unsafe fn reclaim_sockaddr(ptr: *const libc::sockaddr, addr: SocketAddr) {
248 match addr {
249 SocketAddr::V4(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in)),
250 SocketAddr::V6(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in6)),
251 }
252}
253
254fn sockaddr_to_socketaddr(
257 addr: &libc::sockaddr_in6,
258 len: libc::socklen_t,
259) -> io::Result<SocketAddr> {
260 let family = addr.sin6_family as libc::c_int;
261 match family {
262 libc::AF_INET if len >= std::mem::size_of::<libc::sockaddr_in>() as u32 => {
263 let v4: &libc::sockaddr_in =
266 unsafe { &*(addr as *const _ as *const libc::sockaddr_in) };
267 let ip = std::net::Ipv4Addr::from(u32::from_be(v4.sin_addr.s_addr));
268 let port = u16::from_be(v4.sin_port);
269 Ok(SocketAddr::V4(std::net::SocketAddrV4::new(ip, port)))
270 }
271 libc::AF_INET6 if len >= std::mem::size_of::<libc::sockaddr_in6>() as u32 => {
272 let ip = std::net::Ipv6Addr::from(addr.sin6_addr.s6_addr);
273 let port = u16::from_be(addr.sin6_port);
274 Ok(SocketAddr::V6(std::net::SocketAddrV6::new(
275 ip,
276 port,
277 addr.sin6_flowinfo,
278 addr.sin6_scope_id,
279 )))
280 }
281 _ => Err(io::Error::new(
282 io::ErrorKind::InvalidData,
283 format!("unsupported address family: {family}"),
284 )),
285 }
286}
287
288#[cfg(test)]
291mod tests {
292 use super::*;
293 use crate::executor::block_on_with_spawn;
294
295 #[test]
296 fn bind_and_local_addr() {
297 let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).expect("bind failed");
298 let addr = sock.local_addr().expect("local_addr failed");
299 assert_eq!(addr.ip().to_string(), "127.0.0.1");
300 assert!(addr.port() > 0);
301 }
302
303 #[test]
304 fn send_to_and_recv_from() {
305 block_on_with_spawn(async {
306 let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
307 let recv_addr = receiver.local_addr().unwrap();
308
309 let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
310
311 let msg = b"ping";
313 let n = sender.send_to(msg, recv_addr).await.unwrap();
314 assert_eq!(n, msg.len());
315
316 let mut buf = [0u8; 16];
318 let (n, from) = receiver.recv_from(&mut buf).await.unwrap();
319 assert_eq!(n, msg.len());
320 assert_eq!(&buf[..n], msg);
321 assert_eq!(from.ip(), sender.local_addr().unwrap().ip());
323 });
324 }
325
326 #[test]
327 fn udp_echo_round_trip() {
328 block_on_with_spawn(async {
329 let server = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
330 let server_addr = server.local_addr().unwrap();
331 let client = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
332
333 client.send_to(b"hello", server_addr).await.unwrap();
335
336 let mut buf = [0u8; 16];
337 let (n, from) = server.recv_from(&mut buf).await.unwrap();
338 server.send_to(&buf[..n], from).await.unwrap();
339
340 let mut reply = [0u8; 16];
341 let (rn, _) = client.recv_from(&mut reply).await.unwrap();
342 assert_eq!(&reply[..rn], b"hello");
343 });
344 }
345}