Skip to main content

go_lib/
net.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Goroutine-aware TCP networking.
3//!
4//! `TcpListener` and `TcpStream` wrap non-blocking OS sockets and integrate
5//! with the go-lib scheduler via the netpoll backend (`epoll` on Linux,
6//! `kqueue` on macOS).  When a socket operation would block (`EAGAIN` /
7//! `EWOULDBLOCK`), the goroutine is parked via `gopark` and re-enqueued by
8//! the netpoll machinery when the socket becomes ready.
9//!
10//! ## Usage
11//!
12//! ```no_run
13//! use std::io::{Read, Write};
14//!
15//! go_lib::run(|| {
16//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:8080").unwrap();
17//!     loop {
18//!         let mut stream = listener.accept().unwrap();
19//!         go_lib::go!(move || {
20//!             let mut buf = [0u8; 1024];
21//!             let n = stream.read(&mut buf).unwrap();
22//!             stream.write_all(&buf[..n]).unwrap();
23//!         });
24//!     }
25//! });
26//! ```
27//!
28//! ## `std::io` trait implementations
29//!
30//! `TcpStream` implements [`std::io::Read`] and [`std::io::Write`], so it
31//! works directly with any code that accepts `impl Read` or `impl Write` —
32//! including `BufReader`, `BufWriter`, and Rust's I/O adapters — without
33//! any unsafe wrapper or raw-fd manipulation.
34//!
35//! [`TcpStream::try_clone`] duplicates the underlying fd via `dup(2)`,
36//! yielding an independent stream that shares the same TCP connection.  This
37//! is useful for splitting a connection into separate read and write halves:
38//!
39//! ```no_run
40//! use std::io::{Read, Write};
41//!
42//! go_lib::run(|| {
43//!     let listener = go_lib::net::TcpListener::bind("127.0.0.1:9000").unwrap();
44//!     let stream = listener.accept().unwrap();
45//!     let mut writer = stream.try_clone().unwrap();
46//!     go_lib::go!(move || {
47//!         // `stream` is the read half; `writer` is the write half.
48//!         let mut buf = [0u8; 512];
49//!         let n = (&stream).read(&mut buf).unwrap();   // via &TcpStream impl
50//!         writer.write_all(&buf[..n]).unwrap();
51//!     });
52//! });
53//! ```
54//!
55//! ## Porting note
56//!
57//! Go's `net` package calls `runtime.poll.pollDesc.waitRead` / `waitWrite`
58//! which translate directly to `netpoll_arm(fd, POLL_READ/WRITE, gp)` +
59//! `gopark`.  The same protocol is used here.
60
61use std::io::{self, Read, Write};
62use std::net::{SocketAddr, ToSocketAddrs};
63use std::os::unix::io::RawFd;
64
65use libc;
66
67use crate::runtime::g::WaitReason;
68use crate::runtime::netpoll::{netpoll_arm, netpoll_unarm, POLL_READ, POLL_WRITE};
69use crate::runtime::park::gopark;
70
71// ---------------------------------------------------------------------------
72// Helpers — non-blocking socket creation and address conversion
73// ---------------------------------------------------------------------------
74
75/// Create a non-blocking `SOCK_STREAM` socket for the given address family.
76///
77/// On Linux, `SOCK_NONBLOCK` is passed directly to `socket(2)`.
78/// On macOS (which lacks `SOCK_NONBLOCK`), `O_NONBLOCK` is set via `fcntl`.
79fn nonblocking_tcp_socket(family: libc::c_int) -> io::Result<RawFd> {
80    #[cfg(target_os = "linux")]
81    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
82
83    #[cfg(not(target_os = "linux"))]
84    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
85
86    if fd < 0 {
87        return Err(io::Error::last_os_error());
88    }
89
90    // On platforms where SOCK_NONBLOCK is not available, set O_NONBLOCK via fcntl.
91    #[cfg(not(target_os = "linux"))]
92    {
93        let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
94        if flags < 0
95            || unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0
96        {
97            unsafe { libc::close(fd) };
98            return Err(io::Error::last_os_error());
99        }
100    }
101
102    Ok(fd)
103}
104
105fn set_reuseaddr(fd: RawFd) -> io::Result<()> {
106    let one: libc::c_int = 1;
107    let ret = unsafe {
108        libc::setsockopt(
109            fd,
110            libc::SOL_SOCKET,
111            libc::SO_REUSEADDR,
112            &one as *const _ as *const libc::c_void,
113            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
114        )
115    };
116    if ret < 0 {
117        Err(io::Error::last_os_error())
118    } else {
119        Ok(())
120    }
121}
122
123/// Convert a `SocketAddr` to a `libc::sockaddr_storage` + length.
124fn to_sockaddr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
125    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
126    match addr {
127        SocketAddr::V4(v4) => {
128            let sa: &mut libc::sockaddr_in =
129                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in) };
130            sa.sin_family = libc::AF_INET as libc::sa_family_t;
131            sa.sin_port   = v4.port().to_be();
132            sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
133            (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
134        }
135        SocketAddr::V6(v6) => {
136            let sa: &mut libc::sockaddr_in6 =
137                unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in6) };
138            sa.sin6_family   = libc::AF_INET6 as libc::sa_family_t;
139            sa.sin6_port     = v6.port().to_be();
140            sa.sin6_addr.s6_addr = v6.ip().octets();
141            (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
142        }
143    }
144}
145
146fn addr_family(addr: SocketAddr) -> libc::c_int {
147    match addr {
148        SocketAddr::V4(_) => libc::AF_INET,
149        SocketAddr::V6(_) => libc::AF_INET6,
150    }
151}
152
153/// Park the calling goroutine until `fd` is ready for `mode`
154/// (`POLL_READ` or `POLL_WRITE`).
155///
156/// # Safety
157/// Must be called from a live goroutine context.
158unsafe fn park_on_fd(fd: RawFd, mode: u32) {
159    let gp = crate::runtime::g::current_g();
160    debug_assert!(!gp.is_null(), "park_on_fd: not running on a goroutine");
161    unsafe {
162        netpoll_arm(fd, mode, gp);
163        gopark(WaitReason::IOWait);
164        // gopark suspends this goroutine; execution resumes after goready()
165        // is called by the netpoll machinery.
166    }
167}
168
169// ---------------------------------------------------------------------------
170// TcpListener
171// ---------------------------------------------------------------------------
172
173/// A goroutine-aware TCP server socket.
174///
175/// Calls to [`accept`][TcpListener::accept] park the current goroutine when no
176/// connection is immediately available and resume it when one arrives.
177pub struct TcpListener {
178    fd: RawFd,
179}
180
181impl TcpListener {
182    /// Bind a non-blocking TCP listener to `addr`.
183    ///
184    /// Equivalent to `net.Listen("tcp", addr)` in Go.
185    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
186        let addr = addr
187            .to_socket_addrs()?
188            .next()
189            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
190
191        let fd = nonblocking_tcp_socket(addr_family(addr))?;
192        set_reuseaddr(fd)?;
193
194        let (sa, sa_len) = to_sockaddr(addr);
195        let ret = unsafe {
196            libc::bind(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
197        };
198        if ret < 0 {
199            unsafe { libc::close(fd) };
200            return Err(io::Error::last_os_error());
201        }
202
203        let ret = unsafe { libc::listen(fd, 128) };
204        if ret < 0 {
205            unsafe { libc::close(fd) };
206            return Err(io::Error::last_os_error());
207        }
208
209        Ok(TcpListener { fd })
210    }
211
212    /// Accept the next incoming connection.
213    ///
214    /// Parks the goroutine if no connection is immediately available, resuming
215    /// it when the OS delivers one.
216    pub fn accept(&self) -> io::Result<TcpStream> {
217        loop {
218            let cfd = unsafe {
219                libc::accept(self.fd, std::ptr::null_mut(), std::ptr::null_mut())
220            };
221            if cfd >= 0 {
222                // Set O_NONBLOCK on the accepted socket.
223                let flags = unsafe { libc::fcntl(cfd, libc::F_GETFL) };
224                if flags >= 0 {
225                    unsafe { libc::fcntl(cfd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
226                }
227                return Ok(TcpStream { fd: cfd });
228            }
229            let err = io::Error::last_os_error();
230            match err.raw_os_error().unwrap_or(0) {
231                libc::EAGAIN => {
232                    // No connection yet — park until the listener fd is readable.
233                    unsafe { park_on_fd(self.fd, POLL_READ) };
234                    // After wakeup, retry accept().
235                }
236                _ => return Err(err),
237            }
238        }
239    }
240
241    /// Return the underlying raw file descriptor.
242    pub fn as_raw_fd(&self) -> RawFd {
243        self.fd
244    }
245}
246
247impl Drop for TcpListener {
248    fn drop(&mut self) {
249        netpoll_unarm(self.fd);
250        unsafe { libc::close(self.fd) };
251    }
252}
253
254// ---------------------------------------------------------------------------
255// TcpStream
256// ---------------------------------------------------------------------------
257
258/// A goroutine-aware TCP stream socket.
259///
260/// Blocking reads and writes park the calling goroutine (via the netpoll
261/// backend) when the operation would block, resuming it when data is
262/// available or the send buffer has space.
263///
264/// `TcpStream` implements [`std::io::Read`] and [`std::io::Write`] (for both
265/// `&mut TcpStream` and `&TcpStream`), so it works with any Rust I/O adapter
266/// without unsafe wrapper code.  Use [`try_clone`][TcpStream::try_clone] to
267/// split a connection into independent read and write halves.
268pub struct TcpStream {
269    fd: RawFd,
270}
271
272impl TcpStream {
273    /// Connect to `addr`.
274    ///
275    /// Parks the goroutine until the connection completes if it does not
276    /// complete immediately (which is typical for non-blocking `connect`).
277    ///
278    /// Equivalent to `net.Dial("tcp", addr)` in Go.
279    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
280        let addr = addr
281            .to_socket_addrs()?
282            .next()
283            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
284
285        let fd = nonblocking_tcp_socket(addr_family(addr))?;
286        let (sa, sa_len) = to_sockaddr(addr);
287
288        let ret = unsafe {
289            libc::connect(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
290        };
291
292        if ret < 0 {
293            let err = io::Error::last_os_error();
294            match err.raw_os_error().unwrap_or(0) {
295                libc::EINPROGRESS | libc::EAGAIN => {
296                    // Connection in progress — park until the socket is writable.
297                    unsafe { park_on_fd(fd, POLL_WRITE) };
298                    // Check for connect error via SO_ERROR.
299                    let mut so_err: libc::c_int = 0;
300                    let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
301                    unsafe {
302                        libc::getsockopt(
303                            fd,
304                            libc::SOL_SOCKET,
305                            libc::SO_ERROR,
306                            &mut so_err as *mut _ as *mut libc::c_void,
307                            &mut len,
308                        )
309                    };
310                    if so_err != 0 {
311                        unsafe { libc::close(fd) };
312                        return Err(io::Error::from_raw_os_error(so_err));
313                    }
314                }
315                _ => {
316                    unsafe { libc::close(fd) };
317                    return Err(err);
318                }
319            }
320        }
321
322        Ok(TcpStream { fd })
323    }
324
325    /// Read bytes from the stream into `buf`.
326    ///
327    /// Parks the goroutine if no data is immediately available.
328    pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
329        loop {
330            let n = unsafe {
331                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
332            };
333            if n >= 0 {
334                return Ok(n as usize);
335            }
336            let err = io::Error::last_os_error();
337            match err.raw_os_error().unwrap_or(0) {
338                libc::EAGAIN => {
339                    unsafe { park_on_fd(self.fd, POLL_READ) };
340                }
341                _ => return Err(err),
342            }
343        }
344    }
345
346    /// Write `buf` to the stream.
347    ///
348    /// Parks the goroutine if the send buffer is full.  Returns the number of
349    /// bytes written (may be less than `buf.len()`).
350    pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
351        loop {
352            let n = unsafe {
353                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
354            };
355            if n >= 0 {
356                return Ok(n as usize);
357            }
358            let err = io::Error::last_os_error();
359            match err.raw_os_error().unwrap_or(0) {
360                libc::EAGAIN => {
361                    unsafe { park_on_fd(self.fd, POLL_WRITE) };
362                }
363                _ => return Err(err),
364            }
365        }
366    }
367
368    /// Duplicate this stream, creating a second `TcpStream` that refers to the
369    /// same underlying TCP connection.
370    ///
371    /// The duplicate is an independent `TcpStream` with its own fd (via
372    /// `dup(2)`).  Both streams share the same socket; reads and writes on
373    /// either half see the same data stream.  Closing one does not close the
374    /// other.
375    ///
376    /// The typical use-case is splitting a connection into a dedicated read
377    /// half and a dedicated write half for use in separate goroutines.
378    pub fn try_clone(&self) -> io::Result<TcpStream> {
379        let new_fd = unsafe { libc::dup(self.fd) };
380        if new_fd < 0 {
381            return Err(io::Error::last_os_error());
382        }
383        Ok(TcpStream { fd: new_fd })
384    }
385
386    /// Return the remote address of the peer this stream is connected to.
387    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
388        sockaddr_of(self.fd, /* peer = */ true)
389    }
390
391    /// Return the local address this stream is bound to.
392    pub fn local_addr(&self) -> io::Result<SocketAddr> {
393        sockaddr_of(self.fd, /* peer = */ false)
394    }
395
396    /// Return the underlying raw file descriptor.
397    pub fn as_raw_fd(&self) -> RawFd {
398        self.fd
399    }
400}
401
402impl Drop for TcpStream {
403    fn drop(&mut self) {
404        netpoll_unarm(self.fd);
405        unsafe { libc::close(self.fd) };
406    }
407}
408
409// ---------------------------------------------------------------------------
410// std::io trait implementations for TcpStream
411// ---------------------------------------------------------------------------
412
413/// Implements [`std::io::Read`] by delegating to [`TcpStream::read`].
414///
415/// This allows `TcpStream` to be used with any Rust I/O adapter that accepts
416/// `impl Read`, such as `BufReader`, `Read::read_to_string`, etc., without
417/// any unsafe wrapper or raw-fd manipulation.
418impl Read for TcpStream {
419    #[inline]
420    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
421        TcpStream::read(self, buf)
422    }
423}
424
425/// Implements [`std::io::Read`] on a shared reference by issuing a raw
426/// `libc::read` call.  The fd is non-blocking; EAGAIN causes the goroutine
427/// to park via netpoll exactly as the owned-`&mut self` path does.
428///
429/// This enables using the same `TcpStream` for both reading and writing from
430/// two separate code sites within the same goroutine (e.g. after splitting
431/// into read/write halves conceptually without calling `try_clone`).
432impl Read for &TcpStream {
433    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
434        loop {
435            let n = unsafe {
436                libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
437            };
438            if n >= 0 {
439                return Ok(n as usize);
440            }
441            let err = io::Error::last_os_error();
442            match err.raw_os_error().unwrap_or(0) {
443                libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_READ) },
444                _ => return Err(err),
445            }
446        }
447    }
448}
449
450/// Implements [`std::io::Write`] by delegating to [`TcpStream::write`].
451/// `flush` is a no-op because the kernel TCP stack handles buffering.
452impl Write for TcpStream {
453    #[inline]
454    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
455        TcpStream::write(self, buf)
456    }
457
458    #[inline]
459    fn flush(&mut self) -> io::Result<()> {
460        Ok(())
461    }
462}
463
464/// Implements [`std::io::Write`] on a shared reference.
465impl Write for &TcpStream {
466    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
467        loop {
468            let n = unsafe {
469                libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
470            };
471            if n >= 0 {
472                return Ok(n as usize);
473            }
474            let err = io::Error::last_os_error();
475            match err.raw_os_error().unwrap_or(0) {
476                libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_WRITE) },
477                _ => return Err(err),
478            }
479        }
480    }
481
482    #[inline]
483    fn flush(&mut self) -> io::Result<()> {
484        Ok(())
485    }
486}
487
488// ---------------------------------------------------------------------------
489// std::io trait implementations for TcpListener
490// ---------------------------------------------------------------------------
491
492impl TcpListener {
493    /// Return the local address the listener is bound to.
494    pub fn local_addr(&self) -> io::Result<SocketAddr> {
495        sockaddr_of(self.fd, /* peer = */ false)
496    }
497}
498
499// ---------------------------------------------------------------------------
500// Address helpers
501// ---------------------------------------------------------------------------
502
503/// Query the local or peer address of `fd`.
504fn sockaddr_of(fd: RawFd, peer: bool) -> io::Result<SocketAddr> {
505    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
506    let mut len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
507    let ret = unsafe {
508        if peer {
509            libc::getpeername(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
510        } else {
511            libc::getsockname(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
512        }
513    };
514    if ret < 0 {
515        return Err(io::Error::last_os_error());
516    }
517    match storage.ss_family as libc::c_int {
518        libc::AF_INET => {
519            let sa: &libc::sockaddr_in =
520                unsafe { &*(&storage as *const _ as *const libc::sockaddr_in) };
521            let ip = std::net::Ipv4Addr::from(u32::from_be(sa.sin_addr.s_addr));
522            let port = u16::from_be(sa.sin_port);
523            Ok(SocketAddr::from((ip, port)))
524        }
525        libc::AF_INET6 => {
526            let sa: &libc::sockaddr_in6 =
527                unsafe { &*(&storage as *const _ as *const libc::sockaddr_in6) };
528            let ip = std::net::Ipv6Addr::from(sa.sin6_addr.s6_addr);
529            let port = u16::from_be(sa.sin6_port);
530            Ok(SocketAddr::from((ip, port)))
531        }
532        family => Err(io::Error::new(
533            io::ErrorKind::Unsupported,
534            format!("unsupported address family: {family}"),
535        )),
536    }
537}