Skip to main content

go_lib/
net.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Goroutine-aware TCP networking.
3//!
4//! `TcpListener` and `TcpStream` wrap non-blocking OS sockets and integrate
5//! with the go-lib scheduler via the netpoll backend (`epoll` on Linux,
6//! `kqueue` on macOS).  When a socket operation would block (`EAGAIN` /
7//! `EWOULDBLOCK`), the goroutine is parked via `gopark` and re-enqueued by
8//! the netpoll machinery when the socket becomes ready.
9//!
10//! ## Usage
11//!
12//! ```no_run
13//! use std::io::{Read, Write};
14//!
15//! #[go_lib::main]
16//! fn main() {
17//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:8080").unwrap();
18//!     loop {
19//!         let mut stream = listener.accept().unwrap();
20//!         go_lib::go!(move || {
21//!             let mut buf = [0u8; 1024];
22//!             let n = stream.read(&mut buf).unwrap();
23//!             stream.write_all(&buf[..n]).unwrap();
24//!         });
25//!     }
26//! }
27//! ```
28//!
29//! ## `std::io` trait implementations
30//!
31//! `TcpStream` implements [`std::io::Read`] and [`std::io::Write`], so it
32//! works directly with any code that accepts `impl Read` or `impl Write` —
33//! including `BufReader`, `BufWriter`, and Rust's I/O adapters — without
34//! any unsafe wrapper or raw-fd manipulation.
35//!
36//! [`TcpStream::try_clone`] duplicates the underlying fd via `dup(2)`,
37//! yielding an independent stream that shares the same TCP connection.  This
38//! is useful for splitting a connection into separate read and write halves:
39//!
40//! ```no_run
41//! use std::io::{Read, Write};
42//!
43//! #[go_lib::main]
44//! fn main() {
45//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:9000").unwrap();
46//!     let stream = listener.accept().unwrap();
47//!     let mut writer = stream.try_clone().unwrap();
48//!     go_lib::go!(move || {
49//!         // `stream` is the read half; `writer` is the write half.
50//!         let mut buf = [0u8; 512];
51//!         let n = (&stream).read(&mut buf).unwrap();   // via &TcpStream impl
52//!         writer.write_all(&buf[..n]).unwrap();
53//!     });
54//! }
55//! ```
56//!
57//! ## Porting note
58//!
59//! Go's `net` package calls `runtime.poll.pollDesc.waitRead` / `waitWrite`
60//! which translate directly to `netpoll_arm(fd, POLL_READ/WRITE, gp)` +
61//! `gopark`.  The same protocol is used here.
62
63use std::io::{self, Read, Write};
64use std::net::{SocketAddr, ToSocketAddrs};
65use std::os::unix::io::RawFd;
66
67use libc;
68
69use crate::runtime::g::WaitReason;
70use crate::runtime::netpoll::{netpoll_arm, netpoll_unarm, POLL_READ, POLL_WRITE};
71use crate::runtime::park::gopark;
72
73// ---------------------------------------------------------------------------
74// Helpers — non-blocking socket creation and address conversion
75// ---------------------------------------------------------------------------
76
77/// Create a non-blocking `SOCK_STREAM` socket for the given address family.
78///
79/// On Linux, `SOCK_NONBLOCK` is passed directly to `socket(2)`.
80/// On macOS (which lacks `SOCK_NONBLOCK`), `O_NONBLOCK` is set via `fcntl`.
81fn nonblocking_tcp_socket(family: libc::c_int) -> io::Result<RawFd> {
82    #[cfg(target_os = "linux")]
83    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
84
85    #[cfg(not(target_os = "linux"))]
86    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
87
88    if fd < 0 {
89        return Err(io::Error::last_os_error());
90    }
91
92    // On platforms where SOCK_NONBLOCK is not available, set O_NONBLOCK via fcntl.
93    #[cfg(not(target_os = "linux"))]
94    {
95        let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
96        if flags < 0
97            || unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0
98        {
99            unsafe { libc::close(fd) };
100            return Err(io::Error::last_os_error());
101        }
102    }
103
104    Ok(fd)
105}
106
107fn set_reuseaddr(fd: RawFd) -> io::Result<()> {
108    let one: libc::c_int = 1;
109    let ret = unsafe {
110        libc::setsockopt(
111            fd,
112            libc::SOL_SOCKET,
113            libc::SO_REUSEADDR,
114            &one as *const _ as *const libc::c_void,
115            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
116        )
117    };
118    if ret < 0 {
119        Err(io::Error::last_os_error())
120    } else {
121        Ok(())
122    }
123}
124
125/// Convert a `SocketAddr` to a `libc::sockaddr_storage` + length.
126fn to_sockaddr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
127    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
128    match addr {
129        SocketAddr::V4(v4) => {
130            let sa: &mut libc::sockaddr_in =
131                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in) };
132            sa.sin_family = libc::AF_INET as libc::sa_family_t;
133            sa.sin_port   = v4.port().to_be();
134            sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
135            (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
136        }
137        SocketAddr::V6(v6) => {
138            let sa: &mut libc::sockaddr_in6 =
139                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in6) };
140            sa.sin6_family   = libc::AF_INET6 as libc::sa_family_t;
141            sa.sin6_port     = v6.port().to_be();
142            sa.sin6_addr.s6_addr = v6.ip().octets();
143            (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
144        }
145    }
146}
147
148fn addr_family(addr: SocketAddr) -> libc::c_int {
149    match addr {
150        SocketAddr::V4(_) => libc::AF_INET,
151        SocketAddr::V6(_) => libc::AF_INET6,
152    }
153}
154
155/// Park the calling goroutine until `fd` is ready for `mode`
156/// (`POLL_READ` or `POLL_WRITE`).
157///
158/// # Safety
159/// Must be called from a live goroutine context.
160unsafe fn park_on_fd(fd: RawFd, mode: u32) {
161    let gp = crate::runtime::g::current_g();
162    debug_assert!(!gp.is_null(), "park_on_fd: not running on a goroutine");
163    unsafe {
164        netpoll_arm(fd, mode, gp);
165        gopark(WaitReason::IOWait);
166        // gopark suspends this goroutine; execution resumes after goready()
167        // is called by the netpoll machinery.
168    }
169}
170
171// ---------------------------------------------------------------------------
172// Address resolution
173// ---------------------------------------------------------------------------
174
175/// Resolve `addr` to its first concrete [`SocketAddr`], running the lookup on a
176/// dedicated OS thread.
177///
178/// [`ToSocketAddrs::to_socket_addrs`] may invoke the platform resolver
179/// (`getaddrinfo`), which can consume tens of kilobytes of stack — far more
180/// than a goroutine's small fixed stack (32 KiB in release builds; see
181/// [`crate::runtime::stack`]).  go-lib has no compiler-inserted `morestack`
182/// checks, so a stack-hungry C function whose prologue jumps past the guard
183/// page in a single `sub rsp, N` faults in unmapped memory and the SIGSEGV
184/// handler cannot recover.  Resolving a hostname directly on a goroutine
185/// therefore crashes in release builds.
186///
187/// To avoid this we perform the resolution on a freshly spawned OS thread,
188/// which has a full-size (≈8 MiB) system stack that `getaddrinfo` cannot
189/// overflow.  The calling goroutine's M blocks in `join()` for the (usually
190/// brief) duration of the lookup, mirroring Go's use of a dedicated thread for
191/// blocking `cgo`/resolver calls.  A scoped thread is used so non-`'static`
192/// inputs (e.g. `&str`) can be borrowed without allocation.
193fn resolve_first_addr<A: ToSocketAddrs + Send>(addr: A) -> io::Result<SocketAddr> {
194    std::thread::scope(|scope| {
195        scope
196            .spawn(move || {
197                addr.to_socket_addrs()?.next().ok_or_else(|| {
198                    io::Error::new(io::ErrorKind::InvalidInput, "no address given")
199                })
200            })
201            .join()
202            .map_err(|_| io::Error::other("address resolution thread panicked"))?
203    })
204}
205
206// ---------------------------------------------------------------------------
207// TcpListener
208// ---------------------------------------------------------------------------
209
210/// A goroutine-aware TCP server socket.
211///
212/// Calls to [`accept`][TcpListener::accept] park the current goroutine when no
213/// connection is immediately available and resume it when one arrives.
214pub struct TcpListener {
215    fd: RawFd,
216}
217
218impl TcpListener {
219    /// Bind a non-blocking TCP listener to `addr`.
220    ///
221    /// Equivalent to `net.Listen("tcp", addr)` in Go.
222    pub fn bind<A: ToSocketAddrs + Send>(addr: A) -> io::Result<Self> {
223        let addr = resolve_first_addr(addr)?;
224
225        let fd = nonblocking_tcp_socket(addr_family(addr))?;
226        set_reuseaddr(fd)?;
227
228        let (sa, sa_len) = to_sockaddr(addr);
229        let ret = unsafe {
230            libc::bind(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
231        };
232        if ret < 0 {
233            unsafe { libc::close(fd) };
234            return Err(io::Error::last_os_error());
235        }
236
237        let ret = unsafe { libc::listen(fd, 128) };
238        if ret < 0 {
239            unsafe { libc::close(fd) };
240            return Err(io::Error::last_os_error());
241        }
242
243        Ok(TcpListener { fd })
244    }
245
246    /// Accept the next incoming connection.
247    ///
248    /// Parks the goroutine if no connection is immediately available, resuming
249    /// it when the OS delivers one.
250    pub fn accept(&self) -> io::Result<TcpStream> {
251        loop {
252            let cfd = unsafe {
253                libc::accept(self.fd, std::ptr::null_mut(), std::ptr::null_mut())
254            };
255            if cfd >= 0 {
256                // Set O_NONBLOCK on the accepted socket.
257                let flags = unsafe { libc::fcntl(cfd, libc::F_GETFL) };
258                if flags >= 0 {
259                    unsafe { libc::fcntl(cfd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
260                }
261                return Ok(TcpStream { fd: cfd });
262            }
263            let err = io::Error::last_os_error();
264            match err.raw_os_error().unwrap_or(0) {
265                libc::EAGAIN => {
266                    // No connection yet — park until the listener fd is readable.
267                    unsafe { park_on_fd(self.fd, POLL_READ) };
268                    // After wakeup, retry accept().
269                }
270                _ => return Err(err),
271            }
272        }
273    }
274
275    /// Return the underlying raw file descriptor.
276    pub fn as_raw_fd(&self) -> RawFd {
277        self.fd
278    }
279}
280
281impl Drop for TcpListener {
282    fn drop(&mut self) {
283        netpoll_unarm(self.fd);
284        unsafe { libc::close(self.fd) };
285    }
286}
287
288// ---------------------------------------------------------------------------
289// TcpStream
290// ---------------------------------------------------------------------------
291
292/// A goroutine-aware TCP stream socket.
293///
294/// Blocking reads and writes park the calling goroutine (via the netpoll
295/// backend) when the operation would block, resuming it when data is
296/// available or the send buffer has space.
297///
298/// `TcpStream` implements [`std::io::Read`] and [`std::io::Write`] (for both
299/// `&mut TcpStream` and `&TcpStream`), so it works with any Rust I/O adapter
300/// without unsafe wrapper code.  Use [`try_clone`][TcpStream::try_clone] to
301/// split a connection into independent read and write halves.
302pub struct TcpStream {
303    fd: RawFd,
304}
305
306impl TcpStream {
307    /// Connect to `addr`.
308    ///
309    /// Parks the goroutine until the connection completes if it does not
310    /// complete immediately (which is typical for non-blocking `connect`).
311    ///
312    /// Equivalent to `net.Dial("tcp", addr)` in Go.
313    pub fn connect<A: ToSocketAddrs + Send>(addr: A) -> io::Result<Self> {
314        let addr = resolve_first_addr(addr)?;
315
316        let fd = nonblocking_tcp_socket(addr_family(addr))?;
317        let (sa, sa_len) = to_sockaddr(addr);
318
319        let ret = unsafe {
320            libc::connect(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
321        };
322
323        if ret < 0 {
324            let err = io::Error::last_os_error();
325            match err.raw_os_error().unwrap_or(0) {
326                libc::EINPROGRESS | libc::EAGAIN => {
327                    // Connection in progress — park until the socket is writable.
328                    unsafe { park_on_fd(fd, POLL_WRITE) };
329                    // Check for connect error via SO_ERROR.
330                    let mut so_err: libc::c_int = 0;
331                    let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
332                    unsafe {
333                        libc::getsockopt(
334                            fd,
335                            libc::SOL_SOCKET,
336                            libc::SO_ERROR,
337                            &mut so_err as *mut _ as *mut libc::c_void,
338                            &mut len,
339                        )
340                    };
341                    if so_err != 0 {
342                        unsafe { libc::close(fd) };
343                        return Err(io::Error::from_raw_os_error(so_err));
344                    }
345                }
346                _ => {
347                    unsafe { libc::close(fd) };
348                    return Err(err);
349                }
350            }
351        }
352
353        Ok(TcpStream { fd })
354    }
355
356    /// Read bytes from the stream into `buf`.
357    ///
358    /// Parks the goroutine if no data is immediately available.
359    pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
360        loop {
361            let n = unsafe {
362                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
363            };
364            if n >= 0 {
365                return Ok(n as usize);
366            }
367            let err = io::Error::last_os_error();
368            match err.raw_os_error().unwrap_or(0) {
369                libc::EAGAIN => {
370                    unsafe { park_on_fd(self.fd, POLL_READ) };
371                }
372                _ => return Err(err),
373            }
374        }
375    }
376
377    /// Write `buf` to the stream.
378    ///
379    /// Parks the goroutine if the send buffer is full.  Returns the number of
380    /// bytes written (may be less than `buf.len()`).
381    pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
382        loop {
383            let n = unsafe {
384                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
385            };
386            if n >= 0 {
387                return Ok(n as usize);
388            }
389            let err = io::Error::last_os_error();
390            match err.raw_os_error().unwrap_or(0) {
391                libc::EAGAIN => {
392                    unsafe { park_on_fd(self.fd, POLL_WRITE) };
393                }
394                _ => return Err(err),
395            }
396        }
397    }
398
399    /// Duplicate this stream, creating a second `TcpStream` that refers to the
400    /// same underlying TCP connection.
401    ///
402    /// The duplicate is an independent `TcpStream` with its own fd (via
403    /// `dup(2)`).  Both streams share the same socket; reads and writes on
404    /// either half see the same data stream.  Closing one does not close the
405    /// other.
406    ///
407    /// The typical use-case is splitting a connection into a dedicated read
408    /// half and a dedicated write half for use in separate goroutines.
409    pub fn try_clone(&self) -> io::Result<TcpStream> {
410        let new_fd = unsafe { libc::dup(self.fd) };
411        if new_fd < 0 {
412            return Err(io::Error::last_os_error());
413        }
414        Ok(TcpStream { fd: new_fd })
415    }
416
417    /// Return the remote address of the peer this stream is connected to.
418    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
419        sockaddr_of(self.fd, /* peer = */ true)
420    }
421
422    /// Return the local address this stream is bound to.
423    pub fn local_addr(&self) -> io::Result<SocketAddr> {
424        sockaddr_of(self.fd, /* peer = */ false)
425    }
426
427    /// Return the underlying raw file descriptor.
428    pub fn as_raw_fd(&self) -> RawFd {
429        self.fd
430    }
431}
432
433impl Drop for TcpStream {
434    fn drop(&mut self) {
435        netpoll_unarm(self.fd);
436        unsafe { libc::close(self.fd) };
437    }
438}
439
440// ---------------------------------------------------------------------------
441// std::io trait implementations for TcpStream
442// ---------------------------------------------------------------------------
443
444/// Implements [`std::io::Read`] by delegating to [`TcpStream::read`].
445///
446/// This allows `TcpStream` to be used with any Rust I/O adapter that accepts
447/// `impl Read`, such as `BufReader`, `Read::read_to_string`, etc., without
448/// any unsafe wrapper or raw-fd manipulation.
449impl Read for TcpStream {
450    #[inline]
451    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
452        TcpStream::read(self, buf)
453    }
454}
455
456/// Implements [`std::io::Read`] on a shared reference by issuing a raw
457/// `libc::read` call.  The fd is non-blocking; EAGAIN causes the goroutine
458/// to park via netpoll exactly as the owned-`&mut self` path does.
459///
460/// This enables using the same `TcpStream` for both reading and writing from
461/// two separate code sites within the same goroutine (e.g. after splitting
462/// into read/write halves conceptually without calling `try_clone`).
463impl Read for &TcpStream {
464    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
465        loop {
466            let n = unsafe {
467                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
468            };
469            if n >= 0 {
470                return Ok(n as usize);
471            }
472            let err = io::Error::last_os_error();
473            match err.raw_os_error().unwrap_or(0) {
474                libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_READ) },
475                _ => return Err(err),
476            }
477        }
478    }
479}
480
481/// Implements [`std::io::Write`] by delegating to [`TcpStream::write`].
482/// `flush` is a no-op because the kernel TCP stack handles buffering.
483impl Write for TcpStream {
484    #[inline]
485    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
486        TcpStream::write(self, buf)
487    }
488
489    #[inline]
490    fn flush(&mut self) -> io::Result<()> {
491        Ok(())
492    }
493}
494
495/// Implements [`std::io::Write`] on a shared reference.
496impl Write for &TcpStream {
497    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
498        loop {
499            let n = unsafe {
500                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
501            };
502            if n >= 0 {
503                return Ok(n as usize);
504            }
505            let err = io::Error::last_os_error();
506            match err.raw_os_error().unwrap_or(0) {
507                libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_WRITE) },
508                _ => return Err(err),
509            }
510        }
511    }
512
513    #[inline]
514    fn flush(&mut self) -> io::Result<()> {
515        Ok(())
516    }
517}
518
519// ---------------------------------------------------------------------------
520// std::io trait implementations for TcpListener
521// ---------------------------------------------------------------------------
522
523impl TcpListener {
524    /// Return the local address the listener is bound to.
525    pub fn local_addr(&self) -> io::Result<SocketAddr> {
526        sockaddr_of(self.fd, /* peer = */ false)
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Address helpers
532// ---------------------------------------------------------------------------
533
534/// Query the local or peer address of `fd`.
535fn sockaddr_of(fd: RawFd, peer: bool) -> io::Result<SocketAddr> {
536    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
537    let mut len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
538    let ret = unsafe {
539        if peer {
540            libc::getpeername(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
541        } else {
542            libc::getsockname(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
543        }
544    };
545    if ret < 0 {
546        return Err(io::Error::last_os_error());
547    }
548    match storage.ss_family as libc::c_int {
549        libc::AF_INET => {
550            let sa: &libc::sockaddr_in =
551                unsafe { &*(&storage as *const _ as *const libc::sockaddr_in) };
552            let ip = std::net::Ipv4Addr::from(u32::from_be(sa.sin_addr.s_addr));
553            let port = u16::from_be(sa.sin_port);
554            Ok(SocketAddr::from((ip, port)))
555        }
556        libc::AF_INET6 => {
557            let sa: &libc::sockaddr_in6 =
558                unsafe { &*(&storage as *const _ as *const libc::sockaddr_in6) };
559            let ip = std::net::Ipv6Addr::from(sa.sin6_addr.s6_addr);
560            let port = u16::from_be(sa.sin6_port);
561            Ok(SocketAddr::from((ip, port)))
562        }
563        family => Err(io::Error::new(
564            io::ErrorKind::Unsupported,
565            format!("unsupported address family: {family}"),
566        )),
567    }
568}