1use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use crate::platform::sys::{set_nonblocking, Interest};
13use crate::reactor::source::{next_token, IoSource};
14
15use super::sockaddr::{reclaim_raw_sockaddr, sockaddr_to_socketaddr, socketaddr_to_raw};
16
17pub struct UdpSocket {
21 source: IoSource,
22}
23
24impl UdpSocket {
25 pub fn bind(addr: SocketAddr) -> io::Result<Self> {
30 let fd = create_udp_socket(addr)?;
31 bind_socket(fd, addr)?;
32 set_nonblocking(fd)?;
33 let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
34 Ok(Self { source })
35 }
36
37 pub fn local_addr(&self) -> io::Result<SocketAddr> {
39 raw_local_addr(self.source.raw())
40 }
41
42 pub fn send_to<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> SendToFuture<'a> {
45 SendToFuture {
46 socket: self,
47 buf,
48 target,
49 }
50 }
51
52 pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFromFuture<'a> {
55 RecvFromFuture { socket: self, buf }
56 }
57}
58
59impl Drop for UdpSocket {
60 fn drop(&mut self) {
61 let fd = self.source.raw();
62 unsafe { libc::close(fd) };
65 }
66}
67
68pub struct SendToFuture<'a> {
72 socket: &'a UdpSocket,
73 buf: &'a [u8],
74 target: SocketAddr,
75}
76
77impl<'a> Future for SendToFuture<'a> {
78 type Output = io::Result<usize>;
79
80 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81 match try_send_to(self.socket.source.raw(), self.buf, self.target) {
82 Ok(n) => Poll::Ready(Ok(n)),
83 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
84 match Pin::new(&mut self.socket.source.writable()).poll(cx) {
86 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
87 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
88 }
89 }
90 Err(e) => Poll::Ready(Err(e)),
91 }
92 }
93}
94
95pub struct RecvFromFuture<'a> {
99 socket: &'a UdpSocket,
100 buf: &'a mut [u8],
101}
102
103impl<'a> Future for RecvFromFuture<'a> {
104 type Output = io::Result<(usize, SocketAddr)>;
105
106 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107 let fd = self.socket.source.raw();
108 match try_recv_from(fd, self.buf) {
109 Ok(result) => Poll::Ready(Ok(result)),
110 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
111 match Pin::new(&mut self.socket.source.readable()).poll(cx) {
113 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
114 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
115 }
116 }
117 Err(e) => Poll::Ready(Err(e)),
118 }
119 }
120}
121
122fn create_udp_socket(addr: SocketAddr) -> io::Result<i32> {
126 let family = match addr {
127 SocketAddr::V4(_) => libc::AF_INET,
128 SocketAddr::V6(_) => libc::AF_INET6,
129 };
130 let fd = unsafe { libc::socket(family, libc::SOCK_DGRAM, 0) };
132 if fd == -1 {
133 return Err(io::Error::last_os_error());
134 }
135 Ok(fd)
136}
137
138fn bind_socket(fd: i32, addr: SocketAddr) -> io::Result<()> {
140 let (sa_ptr, sa_len) = socketaddr_to_raw(addr);
141 let rc = unsafe { libc::bind(fd, sa_ptr, sa_len) };
143 unsafe { reclaim_raw_sockaddr(sa_ptr, addr) };
145 if rc == -1 {
146 return Err(io::Error::last_os_error());
147 }
148 Ok(())
149}
150
151fn try_send_to(fd: i32, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
153 let (sa_ptr, sa_len) = socketaddr_to_raw(target);
154 let n = unsafe {
157 libc::sendto(
158 fd,
159 buf.as_ptr() as *const libc::c_void,
160 buf.len(),
161 0, sa_ptr,
163 sa_len,
164 )
165 };
166 unsafe { reclaim_raw_sockaddr(sa_ptr, target) };
168 if n == -1 {
169 return Err(io::Error::last_os_error());
170 }
171 Ok(n as usize)
172}
173
174fn try_recv_from(fd: i32, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
176 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
177 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
178 let n = unsafe {
181 libc::recvfrom(
182 fd,
183 buf.as_mut_ptr() as *mut libc::c_void,
184 buf.len(),
185 0, &mut addr as *mut _ as *mut libc::sockaddr,
187 &mut len,
188 )
189 };
190 if n == -1 {
191 return Err(io::Error::last_os_error());
192 }
193 let sender = sockaddr_to_socketaddr(&addr, len)?;
194 Ok((n as usize, sender))
195}
196
197fn raw_local_addr(fd: i32) -> io::Result<SocketAddr> {
199 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
200 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
201 let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
203 if rc == -1 {
204 return Err(io::Error::last_os_error());
205 }
206 sockaddr_to_socketaddr(&addr, len)
207}
208
209#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::executor::block_on_with_spawn;
215
216 #[test]
217 fn bind_and_local_addr() {
218 let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).expect("bind failed");
219 let addr = sock.local_addr().expect("local_addr failed");
220 assert_eq!(addr.ip().to_string(), "127.0.0.1");
221 assert!(addr.port() > 0);
222 }
223
224 #[test]
225 fn send_to_and_recv_from() {
226 block_on_with_spawn(async {
227 let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
228 let recv_addr = receiver.local_addr().unwrap();
229
230 let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
231
232 let msg = b"ping";
234 let n = sender.send_to(msg, recv_addr).await.unwrap();
235 assert_eq!(n, msg.len());
236
237 let mut buf = [0u8; 16];
239 let (n, from) = receiver.recv_from(&mut buf).await.unwrap();
240 assert_eq!(n, msg.len());
241 assert_eq!(&buf[..n], msg);
242 assert_eq!(from.ip(), sender.local_addr().unwrap().ip());
244 });
245 }
246
247 #[test]
248 fn udp_echo_round_trip() {
249 block_on_with_spawn(async {
250 let server = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
251 let server_addr = server.local_addr().unwrap();
252 let client = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
253
254 client.send_to(b"hello", server_addr).await.unwrap();
256
257 let mut buf = [0u8; 16];
258 let (n, from) = server.recv_from(&mut buf).await.unwrap();
259 server.send_to(&buf[..n], from).await.unwrap();
260
261 let mut reply = [0u8; 16];
262 let (rn, _) = client.recv_from(&mut reply).await.unwrap();
263 assert_eq!(&reply[..rn], b"hello");
264 });
265 }
266
267 #[test]
270 fn udp_bind_port_zero_gets_assigned() {
271 let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
272 let addr = sock.local_addr().unwrap();
273 assert!(addr.port() > 1024);
274 }
275
276 #[test]
277 fn udp_send_returns_correct_byte_count() {
278 block_on_with_spawn(async {
279 let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
280 let recv_addr = receiver.local_addr().unwrap();
281 let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
282 let msg = b"test123";
283 let n = sender.send_to(msg, recv_addr).await.unwrap();
284 assert_eq!(n, msg.len());
285 });
286 }
287
288 #[test]
289 fn udp_recv_from_returns_sender_ip() {
290 block_on_with_spawn(async {
291 let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
292 let recv_addr = receiver.local_addr().unwrap();
293 let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
294 let sender_addr = sender.local_addr().unwrap();
295 sender.send_to(b"hi", recv_addr).await.unwrap();
296 let mut buf = [0u8; 8];
297 let (_, from) = receiver.recv_from(&mut buf).await.unwrap();
298 assert_eq!(from.ip(), sender_addr.ip());
299 });
300 }
301
302 #[test]
303 fn udp_multiple_datagrams_sequential() {
304 block_on_with_spawn(async {
305 let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
306 let recv_addr = receiver.local_addr().unwrap();
307 let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
308
309 for i in 0u8..5 {
310 let msg = [i; 1];
311 sender.send_to(&msg, recv_addr).await.unwrap();
312 let mut buf = [0u8; 4];
313 let (n, _) = receiver.recv_from(&mut buf).await.unwrap();
314 assert_eq!(n, 1);
315 assert_eq!(buf[0], i);
316 }
317 });
318 }
319
320 #[test]
321 fn udp_large_datagram_fits_buf() {
322 block_on_with_spawn(async {
323 let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
324 let recv_addr = receiver.local_addr().unwrap();
325 let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
326 let msg = [42u8; 1024];
327 let n = sender.send_to(&msg, recv_addr).await.unwrap();
328 assert_eq!(n, 1024);
329 let mut buf = [0u8; 1024];
330 let (rn, _) = receiver.recv_from(&mut buf).await.unwrap();
331 assert_eq!(rn, 1024);
332 assert!(buf.iter().all(|&b| b == 42));
333 });
334 }
335
336 #[test]
337 fn udp_two_sockets_cross_exchange() {
338 block_on_with_spawn(async {
339 let a = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
340 let b = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
341 let a_addr = a.local_addr().unwrap();
342 let b_addr = b.local_addr().unwrap();
343
344 a.send_to(b"from_a", b_addr).await.unwrap();
346 let mut buf = [0u8; 8];
347 let (n, from) = b.recv_from(&mut buf).await.unwrap();
348 assert_eq!(&buf[..n], b"from_a");
349 assert_eq!(from.ip(), a_addr.ip());
350
351 b.send_to(b"from_b", a_addr).await.unwrap();
353 let mut buf2 = [0u8; 8];
354 let (n2, from2) = a.recv_from(&mut buf2).await.unwrap();
355 assert_eq!(&buf2[..n2], b"from_b");
356 assert_eq!(from2.ip(), b_addr.ip());
357 });
358 }
359}