Skip to main content

go_lib/
net.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Goroutine-aware TCP networking.
3//!
4//! `TcpListener` and `TcpStream` wrap non-blocking OS sockets and integrate
5//! with the go-lib scheduler via the netpoll backend (`epoll` on Linux,
6//! `kqueue` on macOS).  When a socket operation would block (`EAGAIN` /
7//! `EWOULDBLOCK`), the goroutine is parked via `gopark` and re-enqueued by
8//! the netpoll machinery when the socket becomes ready.
9//!
10//! ## Usage
11//!
12//! ```no_run
13//! use std::io::{Read, Write};
14//!
15//! #[go_lib::main]
16//! fn main() {
17//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:8080").unwrap();
18//!     loop {
19//!         let mut stream = listener.accept().unwrap();
20//!         go_lib::go!(move || {
21//!             let mut buf = [0u8; 1024];
22//!             let n = stream.read(&mut buf).unwrap();
23//!             stream.write_all(&buf[..n]).unwrap();
24//!         });
25//!     }
26//! }
27//! ```
28//!
29//! ## `std::io` trait implementations
30//!
31//! `TcpStream` implements [`std::io::Read`] and [`std::io::Write`], so it
32//! works directly with any code that accepts `impl Read` or `impl Write` —
33//! including `BufReader`, `BufWriter`, and Rust's I/O adapters — without
34//! any unsafe wrapper or raw-fd manipulation.
35//!
36//! [`TcpStream::try_clone`] duplicates the underlying fd via `dup(2)`,
37//! yielding an independent stream that shares the same TCP connection.  This
38//! is useful for splitting a connection into separate read and write halves:
39//!
40//! ```no_run
41//! use std::io::{Read, Write};
42//!
43//! #[go_lib::main]
44//! fn main() {
45//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:9000").unwrap();
46//!     let stream = listener.accept().unwrap();
47//!     let mut writer = stream.try_clone().unwrap();
48//!     go_lib::go!(move || {
49//!         // `stream` is the read half; `writer` is the write half.
50//!         let mut buf = [0u8; 512];
51//!         let n = (&stream).read(&mut buf).unwrap();   // via &TcpStream impl
52//!         writer.write_all(&buf[..n]).unwrap();
53//!     });
54//! }
55//! ```
56//!
57//! ## Porting note
58//!
59//! Go's `net` package calls `runtime.poll.pollDesc.waitRead` / `waitWrite`
60//! which translate directly to `netpoll_arm(fd, POLL_READ/WRITE, gp)` +
61//! `gopark`.  The same protocol is used here.
62
63use std::io::{self, Read, Write};
64use std::net::{SocketAddr, ToSocketAddrs};
65use std::os::unix::io::RawFd;
66
67use libc;
68
69use crate::runtime::g::WaitReason;
70use crate::runtime::netpoll::{netpoll_arm, netpoll_unarm, POLL_READ, POLL_WRITE};
71use crate::runtime::park::gopark;
72
73// ---------------------------------------------------------------------------
74// Helpers — non-blocking socket creation and address conversion
75// ---------------------------------------------------------------------------
76
77/// Create a non-blocking `SOCK_STREAM` socket for the given address family.
78///
79/// On Linux, `SOCK_NONBLOCK` is passed directly to `socket(2)`.
80/// On macOS (which lacks `SOCK_NONBLOCK`), `O_NONBLOCK` is set via `fcntl`.
81fn nonblocking_tcp_socket(family: libc::c_int) -> io::Result<RawFd> {
82    #[cfg(target_os = "linux")]
83    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
84
85    #[cfg(not(target_os = "linux"))]
86    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
87
88    if fd < 0 {
89        return Err(io::Error::last_os_error());
90    }
91
92    // On platforms where SOCK_NONBLOCK is not available, set O_NONBLOCK via fcntl.
93    #[cfg(not(target_os = "linux"))]
94    {
95        let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
96        if flags < 0
97            || unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0
98        {
99            unsafe { libc::close(fd) };
100            return Err(io::Error::last_os_error());
101        }
102    }
103
104    Ok(fd)
105}
106
107fn set_reuseaddr(fd: RawFd) -> io::Result<()> {
108    let one: libc::c_int = 1;
109    let ret = unsafe {
110        libc::setsockopt(
111            fd,
112            libc::SOL_SOCKET,
113            libc::SO_REUSEADDR,
114            &one as *const _ as *const libc::c_void,
115            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
116        )
117    };
118    if ret < 0 {
119        Err(io::Error::last_os_error())
120    } else {
121        Ok(())
122    }
123}
124
125/// Convert a `SocketAddr` to a `libc::sockaddr_storage` + length.
126fn to_sockaddr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
127    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
128    match addr {
129        SocketAddr::V4(v4) => {
130            let sa: &mut libc::sockaddr_in =
131                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in) };
132            sa.sin_family = libc::AF_INET as libc::sa_family_t;
133            sa.sin_port   = v4.port().to_be();
134            sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
135            (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
136        }
137        SocketAddr::V6(v6) => {
138            let sa: &mut libc::sockaddr_in6 =
139                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in6) };
140            sa.sin6_family   = libc::AF_INET6 as libc::sa_family_t;
141            sa.sin6_port     = v6.port().to_be();
142            sa.sin6_addr.s6_addr = v6.ip().octets();
143            (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
144        }
145    }
146}
147
148fn addr_family(addr: SocketAddr) -> libc::c_int {
149    match addr {
150        SocketAddr::V4(_) => libc::AF_INET,
151        SocketAddr::V6(_) => libc::AF_INET6,
152    }
153}
154
155/// Park the calling goroutine until `fd` is ready for `mode`
156/// (`POLL_READ` or `POLL_WRITE`).
157///
158/// # Safety
159/// Must be called from a live goroutine context.
160unsafe fn park_on_fd(fd: RawFd, mode: u32) {
161    let gp = crate::runtime::g::current_g();
162    debug_assert!(!gp.is_null(), "park_on_fd: not running on a goroutine");
163    unsafe {
164        netpoll_arm(fd, mode, gp);
165        gopark(WaitReason::IOWait);
166        // gopark suspends this goroutine; execution resumes after goready()
167        // is called by the netpoll machinery.
168    }
169}
170
171// ---------------------------------------------------------------------------
172// TcpListener
173// ---------------------------------------------------------------------------
174
175/// A goroutine-aware TCP server socket.
176///
177/// Calls to [`accept`][TcpListener::accept] park the current goroutine when no
178/// connection is immediately available and resume it when one arrives.
179pub struct TcpListener {
180    fd: RawFd,
181}
182
183impl TcpListener {
184    /// Bind a non-blocking TCP listener to `addr`.
185    ///
186    /// Equivalent to `net.Listen("tcp", addr)` in Go.
187    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
188        let addr = addr
189            .to_socket_addrs()?
190            .next()
191            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
192
193        let fd = nonblocking_tcp_socket(addr_family(addr))?;
194        set_reuseaddr(fd)?;
195
196        let (sa, sa_len) = to_sockaddr(addr);
197        let ret = unsafe {
198            libc::bind(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
199        };
200        if ret < 0 {
201            unsafe { libc::close(fd) };
202            return Err(io::Error::last_os_error());
203        }
204
205        let ret = unsafe { libc::listen(fd, 128) };
206        if ret < 0 {
207            unsafe { libc::close(fd) };
208            return Err(io::Error::last_os_error());
209        }
210
211        Ok(TcpListener { fd })
212    }
213
214    /// Accept the next incoming connection.
215    ///
216    /// Parks the goroutine if no connection is immediately available, resuming
217    /// it when the OS delivers one.
218    pub fn accept(&self) -> io::Result<TcpStream> {
219        loop {
220            let cfd = unsafe {
221                libc::accept(self.fd, std::ptr::null_mut(), std::ptr::null_mut())
222            };
223            if cfd >= 0 {
224                // Set O_NONBLOCK on the accepted socket.
225                let flags = unsafe { libc::fcntl(cfd, libc::F_GETFL) };
226                if flags >= 0 {
227                    unsafe { libc::fcntl(cfd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
228                }
229                return Ok(TcpStream { fd: cfd });
230            }
231            let err = io::Error::last_os_error();
232            match err.raw_os_error().unwrap_or(0) {
233                libc::EAGAIN => {
234                    // No connection yet — park until the listener fd is readable.
235                    unsafe { park_on_fd(self.fd, POLL_READ) };
236                    // After wakeup, retry accept().
237                }
238                _ => return Err(err),
239            }
240        }
241    }
242
243    /// Return the underlying raw file descriptor.
244    pub fn as_raw_fd(&self) -> RawFd {
245        self.fd
246    }
247}
248
249impl Drop for TcpListener {
250    fn drop(&mut self) {
251        netpoll_unarm(self.fd);
252        unsafe { libc::close(self.fd) };
253    }
254}
255
256// ---------------------------------------------------------------------------
257// TcpStream
258// ---------------------------------------------------------------------------
259
260/// A goroutine-aware TCP stream socket.
261///
262/// Blocking reads and writes park the calling goroutine (via the netpoll
263/// backend) when the operation would block, resuming it when data is
264/// available or the send buffer has space.
265///
266/// `TcpStream` implements [`std::io::Read`] and [`std::io::Write`] (for both
267/// `&mut TcpStream` and `&TcpStream`), so it works with any Rust I/O adapter
268/// without unsafe wrapper code.  Use [`try_clone`][TcpStream::try_clone] to
269/// split a connection into independent read and write halves.
270pub struct TcpStream {
271    fd: RawFd,
272}
273
274impl TcpStream {
275    /// Connect to `addr`.
276    ///
277    /// Parks the goroutine until the connection completes if it does not
278    /// complete immediately (which is typical for non-blocking `connect`).
279    ///
280    /// Equivalent to `net.Dial("tcp", addr)` in Go.
281    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
282        let addr = addr
283            .to_socket_addrs()?
284            .next()
285            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
286
287        let fd = nonblocking_tcp_socket(addr_family(addr))?;
288        let (sa, sa_len) = to_sockaddr(addr);
289
290        let ret = unsafe {
291            libc::connect(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
292        };
293
294        if ret < 0 {
295            let err = io::Error::last_os_error();
296            match err.raw_os_error().unwrap_or(0) {
297                libc::EINPROGRESS | libc::EAGAIN => {
298                    // Connection in progress — park until the socket is writable.
299                    unsafe { park_on_fd(fd, POLL_WRITE) };
300                    // Check for connect error via SO_ERROR.
301                    let mut so_err: libc::c_int = 0;
302                    let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
303                    unsafe {
304                        libc::getsockopt(
305                            fd,
306                            libc::SOL_SOCKET,
307                            libc::SO_ERROR,
308                            &mut so_err as *mut _ as *mut libc::c_void,
309                            &mut len,
310                        )
311                    };
312                    if so_err != 0 {
313                        unsafe { libc::close(fd) };
314                        return Err(io::Error::from_raw_os_error(so_err));
315                    }
316                }
317                _ => {
318                    unsafe { libc::close(fd) };
319                    return Err(err);
320                }
321            }
322        }
323
324        Ok(TcpStream { fd })
325    }
326
327    /// Read bytes from the stream into `buf`.
328    ///
329    /// Parks the goroutine if no data is immediately available.
330    pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
331        loop {
332            let n = unsafe {
333                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
334            };
335            if n >= 0 {
336                return Ok(n as usize);
337            }
338            let err = io::Error::last_os_error();
339            match err.raw_os_error().unwrap_or(0) {
340                libc::EAGAIN => {
341                    unsafe { park_on_fd(self.fd, POLL_READ) };
342                }
343                _ => return Err(err),
344            }
345        }
346    }
347
348    /// Write `buf` to the stream.
349    ///
350    /// Parks the goroutine if the send buffer is full.  Returns the number of
351    /// bytes written (may be less than `buf.len()`).
352    pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
353        loop {
354            let n = unsafe {
355                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
356            };
357            if n >= 0 {
358                return Ok(n as usize);
359            }
360            let err = io::Error::last_os_error();
361            match err.raw_os_error().unwrap_or(0) {
362                libc::EAGAIN => {
363                    unsafe { park_on_fd(self.fd, POLL_WRITE) };
364                }
365                _ => return Err(err),
366            }
367        }
368    }
369
370    /// Duplicate this stream, creating a second `TcpStream` that refers to the
371    /// same underlying TCP connection.
372    ///
373    /// The duplicate is an independent `TcpStream` with its own fd (via
374    /// `dup(2)`).  Both streams share the same socket; reads and writes on
375    /// either half see the same data stream.  Closing one does not close the
376    /// other.
377    ///
378    /// The typical use-case is splitting a connection into a dedicated read
379    /// half and a dedicated write half for use in separate goroutines.
380    pub fn try_clone(&self) -> io::Result<TcpStream> {
381        let new_fd = unsafe { libc::dup(self.fd) };
382        if new_fd < 0 {
383            return Err(io::Error::last_os_error());
384        }
385        Ok(TcpStream { fd: new_fd })
386    }
387
388    /// Return the remote address of the peer this stream is connected to.
389    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
390        sockaddr_of(self.fd, /* peer = */ true)
391    }
392
393    /// Return the local address this stream is bound to.
394    pub fn local_addr(&self) -> io::Result<SocketAddr> {
395        sockaddr_of(self.fd, /* peer = */ false)
396    }
397
398    /// Return the underlying raw file descriptor.
399    pub fn as_raw_fd(&self) -> RawFd {
400        self.fd
401    }
402}
403
404impl Drop for TcpStream {
405    fn drop(&mut self) {
406        netpoll_unarm(self.fd);
407        unsafe { libc::close(self.fd) };
408    }
409}
410
411// ---------------------------------------------------------------------------
412// std::io trait implementations for TcpStream
413// ---------------------------------------------------------------------------
414
415/// Implements [`std::io::Read`] by delegating to [`TcpStream::read`].
416///
417/// This allows `TcpStream` to be used with any Rust I/O adapter that accepts
418/// `impl Read`, such as `BufReader`, `Read::read_to_string`, etc., without
419/// any unsafe wrapper or raw-fd manipulation.
420impl Read for TcpStream {
421    #[inline]
422    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
423        TcpStream::read(self, buf)
424    }
425}
426
427/// Implements [`std::io::Read`] on a shared reference by issuing a raw
428/// `libc::read` call.  The fd is non-blocking; EAGAIN causes the goroutine
429/// to park via netpoll exactly as the owned-`&mut self` path does.
430///
431/// This enables using the same `TcpStream` for both reading and writing from
432/// two separate code sites within the same goroutine (e.g. after splitting
433/// into read/write halves conceptually without calling `try_clone`).
434impl Read for &TcpStream {
435    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
436        loop {
437            let n = unsafe {
438                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
439            };
440            if n >= 0 {
441                return Ok(n as usize);
442            }
443            let err = io::Error::last_os_error();
444            match err.raw_os_error().unwrap_or(0) {
445                libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_READ) },
446                _ => return Err(err),
447            }
448        }
449    }
450}
451
452/// Implements [`std::io::Write`] by delegating to [`TcpStream::write`].
453/// `flush` is a no-op because the kernel TCP stack handles buffering.
454impl Write for TcpStream {
455    #[inline]
456    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
457        TcpStream::write(self, buf)
458    }
459
460    #[inline]
461    fn flush(&mut self) -> io::Result<()> {
462        Ok(())
463    }
464}
465
466/// Implements [`std::io::Write`] on a shared reference.
467impl Write for &TcpStream {
468    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
469        loop {
470            let n = unsafe {
471                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
472            };
473            if n >= 0 {
474                return Ok(n as usize);
475            }
476            let err = io::Error::last_os_error();
477            match err.raw_os_error().unwrap_or(0) {
478                libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_WRITE) },
479                _ => return Err(err),
480            }
481        }
482    }
483
484    #[inline]
485    fn flush(&mut self) -> io::Result<()> {
486        Ok(())
487    }
488}
489
490// ---------------------------------------------------------------------------
491// std::io trait implementations for TcpListener
492// ---------------------------------------------------------------------------
493
494impl TcpListener {
495    /// Return the local address the listener is bound to.
496    pub fn local_addr(&self) -> io::Result<SocketAddr> {
497        sockaddr_of(self.fd, /* peer = */ false)
498    }
499}
500
501// ---------------------------------------------------------------------------
502// Address helpers
503// ---------------------------------------------------------------------------
504
505/// Query the local or peer address of `fd`.
506fn sockaddr_of(fd: RawFd, peer: bool) -> io::Result<SocketAddr> {
507    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
508    let mut len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
509    let ret = unsafe {
510        if peer {
511            libc::getpeername(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
512        } else {
513            libc::getsockname(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
514        }
515    };
516    if ret < 0 {
517        return Err(io::Error::last_os_error());
518    }
519    match storage.ss_family as libc::c_int {
520        libc::AF_INET => {
521            let sa: &libc::sockaddr_in =
522                unsafe { &*(&storage as *const _ as *const libc::sockaddr_in) };
523            let ip = std::net::Ipv4Addr::from(u32::from_be(sa.sin_addr.s_addr));
524            let port = u16::from_be(sa.sin_port);
525            Ok(SocketAddr::from((ip, port)))
526        }
527        libc::AF_INET6 => {
528            let sa: &libc::sockaddr_in6 =
529                unsafe { &*(&storage as *const _ as *const libc::sockaddr_in6) };
530            let ip = std::net::Ipv6Addr::from(sa.sin6_addr.s6_addr);
531            let port = u16::from_be(sa.sin6_port);
532            Ok(SocketAddr::from((ip, port)))
533        }
534        family => Err(io::Error::new(
535            io::ErrorKind::Unsupported,
536            format!("unsupported address family: {family}"),
537        )),
538    }
539}