Skip to main content

moduvex_runtime/net/
udp_socket.rs

1//! Async `UdpSocket` — non-blocking UDP datagram socket.
2//!
3//! `send_to` / `recv_from` return futures that resolve when the OS is ready
4//! to send or has data available, using the reactor's waker registry.
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use crate::platform::sys::{set_nonblocking, Interest};
13use crate::reactor::source::{next_token, IoSource};
14
15use super::sockaddr::{reclaim_raw_sockaddr, sockaddr_to_socketaddr, socketaddr_to_raw};
16
17// ── UdpSocket ─────────────────────────────────────────────────────────────────
18
19/// Async UDP datagram socket.
20pub struct UdpSocket {
21    source: IoSource,
22}
23
24impl UdpSocket {
25    /// Bind a UDP socket to `addr`.
26    ///
27    /// Creates a `SOCK_DGRAM` socket, binds to `addr`, sets non-blocking, and
28    /// registers with the reactor for both read and write readiness.
29    pub fn bind(addr: SocketAddr) -> io::Result<Self> {
30        let fd = create_udp_socket(addr)?;
31        bind_socket(fd, addr)?;
32        set_nonblocking(fd)?;
33        let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
34        Ok(Self { source })
35    }
36
37    /// Return the local address the socket is bound to.
38    pub fn local_addr(&self) -> io::Result<SocketAddr> {
39        raw_local_addr(self.source.raw())
40    }
41
42    /// Return a future that sends `buf` to `target` and resolves to the number
43    /// of bytes sent.
44    pub fn send_to<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> SendToFuture<'a> {
45        SendToFuture {
46            socket: self,
47            buf,
48            target,
49        }
50    }
51
52    /// Return a future that receives a datagram into `buf` and resolves to
53    /// `(bytes_received, sender_addr)`.
54    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFromFuture<'a> {
55        RecvFromFuture { socket: self, buf }
56    }
57}
58
59impl Drop for UdpSocket {
60    fn drop(&mut self) {
61        let fd = self.source.raw();
62        // IoSource Drop deregisters from the reactor first; then we close fd.
63        // SAFETY: we own `fd` exclusively; Drop runs at most once.
64        unsafe { libc::close(fd) };
65    }
66}
67
68// ── SendToFuture ──────────────────────────────────────────────────────────────
69
70/// Future returned by [`UdpSocket::send_to`].
71pub struct SendToFuture<'a> {
72    socket: &'a UdpSocket,
73    buf: &'a [u8],
74    target: SocketAddr,
75}
76
77impl<'a> Future for SendToFuture<'a> {
78    type Output = io::Result<usize>;
79
80    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81        match try_send_to(self.socket.source.raw(), self.buf, self.target) {
82            Ok(n) => Poll::Ready(Ok(n)),
83            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
84                // Buffer full — wait for WRITABLE, then retry.
85                match Pin::new(&mut self.socket.source.writable()).poll(cx) {
86                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
87                    Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
88                }
89            }
90            Err(e) => Poll::Ready(Err(e)),
91        }
92    }
93}
94
95// ── RecvFromFuture ────────────────────────────────────────────────────────────
96
97/// Future returned by [`UdpSocket::recv_from`].
98pub struct RecvFromFuture<'a> {
99    socket: &'a UdpSocket,
100    buf: &'a mut [u8],
101}
102
103impl<'a> Future for RecvFromFuture<'a> {
104    type Output = io::Result<(usize, SocketAddr)>;
105
106    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107        let fd = self.socket.source.raw();
108        match try_recv_from(fd, self.buf) {
109            Ok(result) => Poll::Ready(Ok(result)),
110            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
111                // No data yet — register waker and wait for READABLE.
112                match Pin::new(&mut self.socket.source.readable()).poll(cx) {
113                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
114                    Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
115                }
116            }
117            Err(e) => Poll::Ready(Err(e)),
118        }
119    }
120}
121
122// ── Unix helpers ──────────────────────────────────────────────────────────────
123
124/// Create a UDP socket appropriate for `addr`'s family.
125fn create_udp_socket(addr: SocketAddr) -> io::Result<i32> {
126    let family = match addr {
127        SocketAddr::V4(_) => libc::AF_INET,
128        SocketAddr::V6(_) => libc::AF_INET6,
129    };
130    // SAFETY: documented syscall with valid AF_INET/AF_INET6 + SOCK_DGRAM.
131    let fd = unsafe { libc::socket(family, libc::SOCK_DGRAM, 0) };
132    if fd == -1 {
133        return Err(io::Error::last_os_error());
134    }
135    Ok(fd)
136}
137
138/// Bind `fd` to `addr`.
139fn bind_socket(fd: i32, addr: SocketAddr) -> io::Result<()> {
140    let (sa_ptr, sa_len) = socketaddr_to_raw(addr);
141    // SAFETY: `fd` is a valid unbound socket; `sa_ptr`/`sa_len` are correct.
142    let rc = unsafe { libc::bind(fd, sa_ptr, sa_len) };
143    // SAFETY: reclaims the Box created by `socketaddr_to_raw`.
144    unsafe { reclaim_raw_sockaddr(sa_ptr, addr) };
145    if rc == -1 {
146        return Err(io::Error::last_os_error());
147    }
148    Ok(())
149}
150
151/// Non-blocking `sendto`. Returns `Ok(n)` or `Err(WouldBlock)`.
152fn try_send_to(fd: i32, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
153    let (sa_ptr, sa_len) = socketaddr_to_raw(target);
154    // SAFETY: `fd` is a valid UDP socket; `buf` is a valid readable slice;
155    // `sa_ptr`/`sa_len` describe a valid sockaddr for `target`.
156    let n = unsafe {
157        libc::sendto(
158            fd,
159            buf.as_ptr() as *const libc::c_void,
160            buf.len(),
161            0, // flags
162            sa_ptr,
163            sa_len,
164        )
165    };
166    // SAFETY: reclaims the Box created by `socketaddr_to_raw`.
167    unsafe { reclaim_raw_sockaddr(sa_ptr, target) };
168    if n == -1 {
169        return Err(io::Error::last_os_error());
170    }
171    Ok(n as usize)
172}
173
174/// Non-blocking `recvfrom`. Returns `Ok((n, sender))` or `Err(WouldBlock)`.
175fn try_recv_from(fd: i32, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
176    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
177    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
178    // SAFETY: `fd` is a valid UDP socket; `buf` is a valid writable slice;
179    // `addr` is zeroed and large enough for both address families.
180    let n = unsafe {
181        libc::recvfrom(
182            fd,
183            buf.as_mut_ptr() as *mut libc::c_void,
184            buf.len(),
185            0, // flags
186            &mut addr as *mut _ as *mut libc::sockaddr,
187            &mut len,
188        )
189    };
190    if n == -1 {
191        return Err(io::Error::last_os_error());
192    }
193    let sender = sockaddr_to_socketaddr(&addr, len)?;
194    Ok((n as usize, sender))
195}
196
197/// Query the local address of `fd` via `getsockname`.
198fn raw_local_addr(fd: i32) -> io::Result<SocketAddr> {
199    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
200    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
201    // SAFETY: `fd` is a valid bound socket; `addr` buffer is large enough.
202    let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
203    if rc == -1 {
204        return Err(io::Error::last_os_error());
205    }
206    sockaddr_to_socketaddr(&addr, len)
207}
208
209// ── Tests ─────────────────────────────────────────────────────────────────────
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::executor::block_on_with_spawn;
215
216    #[test]
217    fn bind_and_local_addr() {
218        let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).expect("bind failed");
219        let addr = sock.local_addr().expect("local_addr failed");
220        assert_eq!(addr.ip().to_string(), "127.0.0.1");
221        assert!(addr.port() > 0);
222    }
223
224    #[test]
225    fn send_to_and_recv_from() {
226        block_on_with_spawn(async {
227            let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
228            let recv_addr = receiver.local_addr().unwrap();
229
230            let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
231
232            // Send a datagram.
233            let msg = b"ping";
234            let n = sender.send_to(msg, recv_addr).await.unwrap();
235            assert_eq!(n, msg.len());
236
237            // Receive it.
238            let mut buf = [0u8; 16];
239            let (n, from) = receiver.recv_from(&mut buf).await.unwrap();
240            assert_eq!(n, msg.len());
241            assert_eq!(&buf[..n], msg);
242            // `from` should be the sender's address.
243            assert_eq!(from.ip(), sender.local_addr().unwrap().ip());
244        });
245    }
246
247    #[test]
248    fn udp_echo_round_trip() {
249        block_on_with_spawn(async {
250            let server = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
251            let server_addr = server.local_addr().unwrap();
252            let client = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
253
254            // Client sends, server echoes back.
255            client.send_to(b"hello", server_addr).await.unwrap();
256
257            let mut buf = [0u8; 16];
258            let (n, from) = server.recv_from(&mut buf).await.unwrap();
259            server.send_to(&buf[..n], from).await.unwrap();
260
261            let mut reply = [0u8; 16];
262            let (rn, _) = client.recv_from(&mut reply).await.unwrap();
263            assert_eq!(&reply[..rn], b"hello");
264        });
265    }
266
267    // ── Additional UDP socket tests ────────────────────────────────────────
268
269    #[test]
270    fn udp_bind_port_zero_gets_assigned() {
271        let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
272        let addr = sock.local_addr().unwrap();
273        assert!(addr.port() > 1024);
274    }
275
276    #[test]
277    fn udp_send_returns_correct_byte_count() {
278        block_on_with_spawn(async {
279            let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
280            let recv_addr = receiver.local_addr().unwrap();
281            let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
282            let msg = b"test123";
283            let n = sender.send_to(msg, recv_addr).await.unwrap();
284            assert_eq!(n, msg.len());
285        });
286    }
287
288    #[test]
289    fn udp_recv_from_returns_sender_ip() {
290        block_on_with_spawn(async {
291            let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
292            let recv_addr = receiver.local_addr().unwrap();
293            let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
294            let sender_addr = sender.local_addr().unwrap();
295            sender.send_to(b"hi", recv_addr).await.unwrap();
296            let mut buf = [0u8; 8];
297            let (_, from) = receiver.recv_from(&mut buf).await.unwrap();
298            assert_eq!(from.ip(), sender_addr.ip());
299        });
300    }
301
302    #[test]
303    fn udp_multiple_datagrams_sequential() {
304        block_on_with_spawn(async {
305            let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
306            let recv_addr = receiver.local_addr().unwrap();
307            let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
308
309            for i in 0u8..5 {
310                let msg = [i; 1];
311                sender.send_to(&msg, recv_addr).await.unwrap();
312                let mut buf = [0u8; 4];
313                let (n, _) = receiver.recv_from(&mut buf).await.unwrap();
314                assert_eq!(n, 1);
315                assert_eq!(buf[0], i);
316            }
317        });
318    }
319
320    #[test]
321    fn udp_large_datagram_fits_buf() {
322        block_on_with_spawn(async {
323            let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
324            let recv_addr = receiver.local_addr().unwrap();
325            let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
326            let msg = [42u8; 1024];
327            let n = sender.send_to(&msg, recv_addr).await.unwrap();
328            assert_eq!(n, 1024);
329            let mut buf = [0u8; 1024];
330            let (rn, _) = receiver.recv_from(&mut buf).await.unwrap();
331            assert_eq!(rn, 1024);
332            assert!(buf.iter().all(|&b| b == 42));
333        });
334    }
335
336    #[test]
337    fn udp_two_sockets_cross_exchange() {
338        block_on_with_spawn(async {
339            let a = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
340            let b = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
341            let a_addr = a.local_addr().unwrap();
342            let b_addr = b.local_addr().unwrap();
343
344            // A sends to B
345            a.send_to(b"from_a", b_addr).await.unwrap();
346            let mut buf = [0u8; 8];
347            let (n, from) = b.recv_from(&mut buf).await.unwrap();
348            assert_eq!(&buf[..n], b"from_a");
349            assert_eq!(from.ip(), a_addr.ip());
350
351            // B sends to A
352            b.send_to(b"from_b", a_addr).await.unwrap();
353            let mut buf2 = [0u8; 8];
354            let (n2, from2) = a.recv_from(&mut buf2).await.unwrap();
355            assert_eq!(&buf2[..n2], b"from_b");
356            assert_eq!(from2.ip(), b_addr.ip());
357        });
358    }
359}