Skip to main content

go_lib/
net.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Goroutine-aware TCP networking.
3//!
4//! `TcpListener` and `TcpStream` wrap non-blocking OS sockets and integrate
5//! with the go-lib scheduler via the netpoll backend (`epoll` on Linux,
6//! `kqueue` on macOS).  When a socket operation would block (`EAGAIN` /
7//! `EWOULDBLOCK`), the goroutine is parked via `gopark` and re-enqueued by
8//! the netpoll machinery when the socket becomes ready.
9//!
10//! ## Usage
11//!
12//! ```no_run
13//! go_lib::run(|| {
14//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:8080").unwrap();
15//!     loop {
16//!         let mut stream = listener.accept().unwrap();
17//!         go_lib::go!(move || {
18//!             let mut buf = [0u8; 1024];
19//!             let n = stream.read(&mut buf).unwrap();
20//!             stream.write(&buf[..n]).unwrap();
21//!         });
22//!     }
23//! });
24//! ```
25//!
26//! ## Porting note
27//!
28//! Go's `net` package calls `runtime.poll.pollDesc.waitRead` / `waitWrite`
29//! which translate directly to `netpoll_arm(fd, POLL_READ/WRITE, gp)` +
30//! `gopark`.  The same protocol is used here.
31
32use std::io;
33use std::net::{SocketAddr, ToSocketAddrs};
34use std::os::unix::io::RawFd;
35
36use libc;
37
38use crate::runtime::g::WaitReason;
39use crate::runtime::netpoll::{netpoll_arm, netpoll_unarm, POLL_READ, POLL_WRITE};
40use crate::runtime::park::gopark;
41
42// ---------------------------------------------------------------------------
43// Helpers — non-blocking socket creation and address conversion
44// ---------------------------------------------------------------------------
45
46/// Create a non-blocking `SOCK_STREAM` socket for the given address family.
47///
48/// On Linux, `SOCK_NONBLOCK` is passed directly to `socket(2)`.
49/// On macOS (which lacks `SOCK_NONBLOCK`), `O_NONBLOCK` is set via `fcntl`.
50fn nonblocking_tcp_socket(family: libc::c_int) -> io::Result<RawFd> {
51    #[cfg(target_os = "linux")]
52    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
53
54    #[cfg(not(target_os = "linux"))]
55    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
56
57    if fd < 0 {
58        return Err(io::Error::last_os_error());
59    }
60
61    // On platforms where SOCK_NONBLOCK is not available, set O_NONBLOCK via fcntl.
62    #[cfg(not(target_os = "linux"))]
63    {
64        let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
65        if flags < 0
66            || unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0
67        {
68            unsafe { libc::close(fd) };
69            return Err(io::Error::last_os_error());
70        }
71    }
72
73    Ok(fd)
74}
75
76fn set_reuseaddr(fd: RawFd) -> io::Result<()> {
77    let one: libc::c_int = 1;
78    let ret = unsafe {
79        libc::setsockopt(
80            fd,
81            libc::SOL_SOCKET,
82            libc::SO_REUSEADDR,
83            &one as *const _ as *const libc::c_void,
84            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
85        )
86    };
87    if ret < 0 {
88        Err(io::Error::last_os_error())
89    } else {
90        Ok(())
91    }
92}
93
94/// Convert a `SocketAddr` to a `libc::sockaddr_storage` + length.
95fn to_sockaddr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
96    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
97    match addr {
98        SocketAddr::V4(v4) => {
99            let sa: &mut libc::sockaddr_in =
100                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in) };
101            sa.sin_family = libc::AF_INET as libc::sa_family_t;
102            sa.sin_port   = v4.port().to_be();
103            sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
104            (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
105        }
106        SocketAddr::V6(v6) => {
107            let sa: &mut libc::sockaddr_in6 =
108                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in6) };
109            sa.sin6_family   = libc::AF_INET6 as libc::sa_family_t;
110            sa.sin6_port     = v6.port().to_be();
111            sa.sin6_addr.s6_addr = v6.ip().octets();
112            (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
113        }
114    }
115}
116
117fn addr_family(addr: SocketAddr) -> libc::c_int {
118    match addr {
119        SocketAddr::V4(_) => libc::AF_INET,
120        SocketAddr::V6(_) => libc::AF_INET6,
121    }
122}
123
124/// Park the calling goroutine until `fd` is ready for `mode`
125/// (`POLL_READ` or `POLL_WRITE`).
126///
127/// # Safety
128/// Must be called from a live goroutine context.
129unsafe fn park_on_fd(fd: RawFd, mode: u32) {
130    let gp = crate::runtime::g::current_g();
131    debug_assert!(!gp.is_null(), "park_on_fd: not running on a goroutine");
132    unsafe {
133        netpoll_arm(fd, mode, gp);
134        gopark(WaitReason::IOWait);
135        // gopark suspends this goroutine; execution resumes after goready()
136        // is called by the netpoll machinery.
137    }
138}
139
140// ---------------------------------------------------------------------------
141// TcpListener
142// ---------------------------------------------------------------------------
143
144/// A goroutine-aware TCP server socket.
145///
146/// Calls to [`accept`][TcpListener::accept] park the current goroutine when no
147/// connection is immediately available and resume it when one arrives.
148pub struct TcpListener {
149    fd: RawFd,
150}
151
152impl TcpListener {
153    /// Bind a non-blocking TCP listener to `addr`.
154    ///
155    /// Equivalent to `net.Listen("tcp", addr)` in Go.
156    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
157        let addr = addr
158            .to_socket_addrs()?
159            .next()
160            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
161
162        let fd = nonblocking_tcp_socket(addr_family(addr))?;
163        set_reuseaddr(fd)?;
164
165        let (sa, sa_len) = to_sockaddr(addr);
166        let ret = unsafe {
167            libc::bind(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
168        };
169        if ret < 0 {
170            unsafe { libc::close(fd) };
171            return Err(io::Error::last_os_error());
172        }
173
174        let ret = unsafe { libc::listen(fd, 128) };
175        if ret < 0 {
176            unsafe { libc::close(fd) };
177            return Err(io::Error::last_os_error());
178        }
179
180        Ok(TcpListener { fd })
181    }
182
183    /// Accept the next incoming connection.
184    ///
185    /// Parks the goroutine if no connection is immediately available, resuming
186    /// it when the OS delivers one.
187    pub fn accept(&self) -> io::Result<TcpStream> {
188        loop {
189            let cfd = unsafe {
190                libc::accept(self.fd, std::ptr::null_mut(), std::ptr::null_mut())
191            };
192            if cfd >= 0 {
193                // Set O_NONBLOCK on the accepted socket.
194                let flags = unsafe { libc::fcntl(cfd, libc::F_GETFL) };
195                if flags >= 0 {
196                    unsafe { libc::fcntl(cfd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
197                }
198                return Ok(TcpStream { fd: cfd });
199            }
200            let err = io::Error::last_os_error();
201            match err.raw_os_error().unwrap_or(0) {
202                libc::EAGAIN => {
203                    // No connection yet — park until the listener fd is readable.
204                    unsafe { park_on_fd(self.fd, POLL_READ) };
205                    // After wakeup, retry accept().
206                }
207                _ => return Err(err),
208            }
209        }
210    }
211
212    /// Return the underlying raw file descriptor.
213    pub fn as_raw_fd(&self) -> RawFd {
214        self.fd
215    }
216}
217
218impl Drop for TcpListener {
219    fn drop(&mut self) {
220        netpoll_unarm(self.fd);
221        unsafe { libc::close(self.fd) };
222    }
223}
224
225// ---------------------------------------------------------------------------
226// TcpStream
227// ---------------------------------------------------------------------------
228
229/// A goroutine-aware TCP stream socket.
230///
231/// [`read`][TcpStream::read] and [`write`][TcpStream::write] park the goroutine
232/// when the operation would block and resume it when data is available or the
233/// send buffer has space.
234pub struct TcpStream {
235    fd: RawFd,
236}
237
238impl TcpStream {
239    /// Connect to `addr`.
240    ///
241    /// Parks the goroutine until the connection completes if it does not
242    /// complete immediately (which is typical for non-blocking `connect`).
243    ///
244    /// Equivalent to `net.Dial("tcp", addr)` in Go.
245    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
246        let addr = addr
247            .to_socket_addrs()?
248            .next()
249            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
250
251        let fd = nonblocking_tcp_socket(addr_family(addr))?;
252        let (sa, sa_len) = to_sockaddr(addr);
253
254        let ret = unsafe {
255            libc::connect(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
256        };
257
258        if ret < 0 {
259            let err = io::Error::last_os_error();
260            match err.raw_os_error().unwrap_or(0) {
261                libc::EINPROGRESS | libc::EAGAIN => {
262                    // Connection in progress — park until the socket is writable.
263                    unsafe { park_on_fd(fd, POLL_WRITE) };
264                    // Check for connect error via SO_ERROR.
265                    let mut so_err: libc::c_int = 0;
266                    let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
267                    unsafe {
268                        libc::getsockopt(
269                            fd,
270                            libc::SOL_SOCKET,
271                            libc::SO_ERROR,
272                            &mut so_err as *mut _ as *mut libc::c_void,
273                            &mut len,
274                        )
275                    };
276                    if so_err != 0 {
277                        unsafe { libc::close(fd) };
278                        return Err(io::Error::from_raw_os_error(so_err));
279                    }
280                }
281                _ => {
282                    unsafe { libc::close(fd) };
283                    return Err(err);
284                }
285            }
286        }
287
288        Ok(TcpStream { fd })
289    }
290
291    /// Read bytes from the stream into `buf`.
292    ///
293    /// Parks the goroutine if no data is immediately available.
294    pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
295        loop {
296            let n = unsafe {
297                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
298            };
299            if n >= 0 {
300                return Ok(n as usize);
301            }
302            let err = io::Error::last_os_error();
303            match err.raw_os_error().unwrap_or(0) {
304                libc::EAGAIN => {
305                    unsafe { park_on_fd(self.fd, POLL_READ) };
306                }
307                _ => return Err(err),
308            }
309        }
310    }
311
312    /// Write `buf` to the stream.
313    ///
314    /// Parks the goroutine if the send buffer is full.  Returns the number of
315    /// bytes written (may be less than `buf.len()`).
316    pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
317        loop {
318            let n = unsafe {
319                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
320            };
321            if n >= 0 {
322                return Ok(n as usize);
323            }
324            let err = io::Error::last_os_error();
325            match err.raw_os_error().unwrap_or(0) {
326                libc::EAGAIN => {
327                    unsafe { park_on_fd(self.fd, POLL_WRITE) };
328                }
329                _ => return Err(err),
330            }
331        }
332    }
333
334    /// Return the underlying raw file descriptor.
335    pub fn as_raw_fd(&self) -> RawFd {
336        self.fd
337    }
338}
339
340impl Drop for TcpStream {
341    fn drop(&mut self) {
342        netpoll_unarm(self.fd);
343        unsafe { libc::close(self.fd) };
344    }
345}