Skip to main content

moduvex_runtime/net/
udp_socket.rs

1//! Async `UdpSocket` — non-blocking UDP datagram socket.
2//!
3//! `send_to` / `recv_from` return futures that resolve when the OS is ready
4//! to send or has data available, using the reactor's waker registry.
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use crate::platform::sys::{set_nonblocking, Interest};
13use crate::reactor::source::{next_token, IoSource};
14
15// ── UdpSocket ─────────────────────────────────────────────────────────────────
16
17/// Async UDP datagram socket.
18pub struct UdpSocket {
19    source: IoSource,
20}
21
22impl UdpSocket {
23    /// Bind a UDP socket to `addr`.
24    ///
25    /// Creates a `SOCK_DGRAM` socket, binds to `addr`, sets non-blocking, and
26    /// registers with the reactor for both read and write readiness.
27    pub fn bind(addr: SocketAddr) -> io::Result<Self> {
28        let fd = create_udp_socket(addr)?;
29        bind_socket(fd, addr)?;
30        set_nonblocking(fd)?;
31        let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
32        Ok(Self { source })
33    }
34
35    /// Return the local address the socket is bound to.
36    pub fn local_addr(&self) -> io::Result<SocketAddr> {
37        raw_local_addr(self.source.raw())
38    }
39
40    /// Return a future that sends `buf` to `target` and resolves to the number
41    /// of bytes sent.
42    pub fn send_to<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> SendToFuture<'a> {
43        SendToFuture {
44            socket: self,
45            buf,
46            target,
47        }
48    }
49
50    /// Return a future that receives a datagram into `buf` and resolves to
51    /// `(bytes_received, sender_addr)`.
52    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFromFuture<'a> {
53        RecvFromFuture { socket: self, buf }
54    }
55}
56
57impl Drop for UdpSocket {
58    fn drop(&mut self) {
59        let fd = self.source.raw();
60        // IoSource Drop deregisters from the reactor first; then we close fd.
61        // SAFETY: we own `fd` exclusively; Drop runs at most once.
62        unsafe { libc::close(fd) };
63    }
64}
65
66// ── SendToFuture ──────────────────────────────────────────────────────────────
67
68/// Future returned by [`UdpSocket::send_to`].
69pub struct SendToFuture<'a> {
70    socket: &'a UdpSocket,
71    buf: &'a [u8],
72    target: SocketAddr,
73}
74
75impl<'a> Future for SendToFuture<'a> {
76    type Output = io::Result<usize>;
77
78    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79        match try_send_to(self.socket.source.raw(), self.buf, self.target) {
80            Ok(n) => Poll::Ready(Ok(n)),
81            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
82                // Buffer full — wait for WRITABLE, then retry.
83                match Pin::new(&mut self.socket.source.writable()).poll(cx) {
84                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
85                    Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
86                }
87            }
88            Err(e) => Poll::Ready(Err(e)),
89        }
90    }
91}
92
93// ── RecvFromFuture ────────────────────────────────────────────────────────────
94
95/// Future returned by [`UdpSocket::recv_from`].
96pub struct RecvFromFuture<'a> {
97    socket: &'a UdpSocket,
98    buf: &'a mut [u8],
99}
100
101impl<'a> Future for RecvFromFuture<'a> {
102    type Output = io::Result<(usize, SocketAddr)>;
103
104    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105        let fd = self.socket.source.raw();
106        match try_recv_from(fd, self.buf) {
107            Ok(result) => Poll::Ready(Ok(result)),
108            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
109                // No data yet — register waker and wait for READABLE.
110                match Pin::new(&mut self.socket.source.readable()).poll(cx) {
111                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
112                    Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
113                }
114            }
115            Err(e) => Poll::Ready(Err(e)),
116        }
117    }
118}
119
120// ── Unix helpers ──────────────────────────────────────────────────────────────
121
122/// Create a UDP socket appropriate for `addr`'s family.
123fn create_udp_socket(addr: SocketAddr) -> io::Result<i32> {
124    let family = match addr {
125        SocketAddr::V4(_) => libc::AF_INET,
126        SocketAddr::V6(_) => libc::AF_INET6,
127    };
128    // SAFETY: documented syscall with valid AF_INET/AF_INET6 + SOCK_DGRAM.
129    let fd = unsafe { libc::socket(family, libc::SOCK_DGRAM, 0) };
130    if fd == -1 {
131        return Err(io::Error::last_os_error());
132    }
133    Ok(fd)
134}
135
136/// Bind `fd` to `addr`.
137fn bind_socket(fd: i32, addr: SocketAddr) -> io::Result<()> {
138    let (sa_ptr, sa_len) = socketaddr_to_raw(addr);
139    // SAFETY: `fd` is a valid unbound socket; `sa_ptr`/`sa_len` are correct.
140    let rc = unsafe { libc::bind(fd, sa_ptr, sa_len) };
141    // SAFETY: reclaims the Box created by `socketaddr_to_raw`.
142    unsafe { reclaim_sockaddr(sa_ptr, addr) };
143    if rc == -1 {
144        return Err(io::Error::last_os_error());
145    }
146    Ok(())
147}
148
149/// Non-blocking `sendto`. Returns `Ok(n)` or `Err(WouldBlock)`.
150fn try_send_to(fd: i32, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
151    let (sa_ptr, sa_len) = socketaddr_to_raw(target);
152    // SAFETY: `fd` is a valid UDP socket; `buf` is a valid readable slice;
153    // `sa_ptr`/`sa_len` describe a valid sockaddr for `target`.
154    let n = unsafe {
155        libc::sendto(
156            fd,
157            buf.as_ptr() as *const libc::c_void,
158            buf.len(),
159            0, // flags
160            sa_ptr,
161            sa_len,
162        )
163    };
164    // SAFETY: reclaims the Box created by `socketaddr_to_raw`.
165    unsafe { reclaim_sockaddr(sa_ptr, target) };
166    if n == -1 {
167        return Err(io::Error::last_os_error());
168    }
169    Ok(n as usize)
170}
171
172/// Non-blocking `recvfrom`. Returns `Ok((n, sender))` or `Err(WouldBlock)`.
173fn try_recv_from(fd: i32, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
174    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
175    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
176    // SAFETY: `fd` is a valid UDP socket; `buf` is a valid writable slice;
177    // `addr` is zeroed and large enough for both address families.
178    let n = unsafe {
179        libc::recvfrom(
180            fd,
181            buf.as_mut_ptr() as *mut libc::c_void,
182            buf.len(),
183            0, // flags
184            &mut addr as *mut _ as *mut libc::sockaddr,
185            &mut len,
186        )
187    };
188    if n == -1 {
189        return Err(io::Error::last_os_error());
190    }
191    let sender = sockaddr_to_socketaddr(&addr, len)?;
192    Ok((n as usize, sender))
193}
194
195/// Query the local address of `fd` via `getsockname`.
196fn raw_local_addr(fd: i32) -> io::Result<SocketAddr> {
197    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
198    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
199    // SAFETY: `fd` is a valid bound socket; `addr` buffer is large enough.
200    let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
201    if rc == -1 {
202        return Err(io::Error::last_os_error());
203    }
204    sockaddr_to_socketaddr(&addr, len)
205}
206
207/// Convert `SocketAddr` to a heap-allocated raw sockaddr pair.
208/// Caller must call `reclaim_sockaddr` with the same `addr` after use.
209fn socketaddr_to_raw(addr: SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
210    match addr {
211        SocketAddr::V4(v4) => {
212            let octets = v4.ip().octets();
213            // SAFETY: zeroed() is a valid initial bit pattern; all fields set below.
214            let mut sin: libc::sockaddr_in = unsafe { std::mem::zeroed() };
215            sin.sin_family = libc::AF_INET as libc::sa_family_t;
216            sin.sin_port = v4.port().to_be();
217            sin.sin_addr = libc::in_addr {
218                s_addr: u32::from_be_bytes(octets).to_be(),
219            };
220            let ptr = Box::into_raw(Box::new(sin)) as *const libc::sockaddr;
221            (
222                ptr,
223                std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
224            )
225        }
226        SocketAddr::V6(v6) => {
227            // SAFETY: zeroed() is a valid initial bit pattern; all fields set below.
228            let mut sin6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
229            sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
230            sin6.sin6_port = v6.port().to_be();
231            sin6.sin6_flowinfo = v6.flowinfo();
232            sin6.sin6_addr = libc::in6_addr {
233                s6_addr: v6.ip().octets(),
234            };
235            sin6.sin6_scope_id = v6.scope_id();
236            let ptr = Box::into_raw(Box::new(sin6)) as *const libc::sockaddr;
237            (
238                ptr,
239                std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
240            )
241        }
242    }
243}
244
245/// # Safety
246/// `ptr` must have been produced by `socketaddr_to_raw` with the same `addr`.
247unsafe fn reclaim_sockaddr(ptr: *const libc::sockaddr, addr: SocketAddr) {
248    match addr {
249        SocketAddr::V4(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in)),
250        SocketAddr::V6(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in6)),
251    }
252}
253
254/// Convert a kernel-filled `sockaddr_in6` buffer to `SocketAddr`.
255/// The buffer may actually contain a `sockaddr_in` — family field disambiguates.
256fn sockaddr_to_socketaddr(
257    addr: &libc::sockaddr_in6,
258    len: libc::socklen_t,
259) -> io::Result<SocketAddr> {
260    let family = addr.sin6_family as libc::c_int;
261    match family {
262        libc::AF_INET if len >= std::mem::size_of::<libc::sockaddr_in>() as u32 => {
263            // SAFETY: kernel wrote AF_INET data of the correct size; reinterpreting
264            // the buffer as sockaddr_in is valid because the layouts are compatible.
265            let v4: &libc::sockaddr_in =
266                unsafe { &*(addr as *const _ as *const libc::sockaddr_in) };
267            let ip = std::net::Ipv4Addr::from(u32::from_be(v4.sin_addr.s_addr));
268            let port = u16::from_be(v4.sin_port);
269            Ok(SocketAddr::V4(std::net::SocketAddrV4::new(ip, port)))
270        }
271        libc::AF_INET6 if len >= std::mem::size_of::<libc::sockaddr_in6>() as u32 => {
272            let ip = std::net::Ipv6Addr::from(addr.sin6_addr.s6_addr);
273            let port = u16::from_be(addr.sin6_port);
274            Ok(SocketAddr::V6(std::net::SocketAddrV6::new(
275                ip,
276                port,
277                addr.sin6_flowinfo,
278                addr.sin6_scope_id,
279            )))
280        }
281        _ => Err(io::Error::new(
282            io::ErrorKind::InvalidData,
283            format!("unsupported address family: {family}"),
284        )),
285    }
286}
287
288// ── Tests ─────────────────────────────────────────────────────────────────────
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use crate::executor::block_on_with_spawn;
294
295    #[test]
296    fn bind_and_local_addr() {
297        let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).expect("bind failed");
298        let addr = sock.local_addr().expect("local_addr failed");
299        assert_eq!(addr.ip().to_string(), "127.0.0.1");
300        assert!(addr.port() > 0);
301    }
302
303    #[test]
304    fn send_to_and_recv_from() {
305        block_on_with_spawn(async {
306            let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
307            let recv_addr = receiver.local_addr().unwrap();
308
309            let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
310
311            // Send a datagram.
312            let msg = b"ping";
313            let n = sender.send_to(msg, recv_addr).await.unwrap();
314            assert_eq!(n, msg.len());
315
316            // Receive it.
317            let mut buf = [0u8; 16];
318            let (n, from) = receiver.recv_from(&mut buf).await.unwrap();
319            assert_eq!(n, msg.len());
320            assert_eq!(&buf[..n], msg);
321            // `from` should be the sender's address.
322            assert_eq!(from.ip(), sender.local_addr().unwrap().ip());
323        });
324    }
325
326    #[test]
327    fn udp_echo_round_trip() {
328        block_on_with_spawn(async {
329            let server = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
330            let server_addr = server.local_addr().unwrap();
331            let client = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
332
333            // Client sends, server echoes back.
334            client.send_to(b"hello", server_addr).await.unwrap();
335
336            let mut buf = [0u8; 16];
337            let (n, from) = server.recv_from(&mut buf).await.unwrap();
338            server.send_to(&buf[..n], from).await.unwrap();
339
340            let mut reply = [0u8; 16];
341            let (rn, _) = client.recv_from(&mut reply).await.unwrap();
342            assert_eq!(&reply[..rn], b"hello");
343        });
344    }
345}