go_lib/net.rs
1// SPDX-License-Identifier: Apache-2.0
2//! Goroutine-aware TCP networking.
3//!
4//! `TcpListener` and `TcpStream` wrap non-blocking OS sockets and integrate
5//! with the go-lib scheduler via the netpoll backend (`epoll` on Linux,
6//! `kqueue` on macOS). When a socket operation would block (`EAGAIN` /
7//! `EWOULDBLOCK`), the goroutine is parked via `gopark` and re-enqueued by
8//! the netpoll machinery when the socket becomes ready.
9//!
10//! ## Usage
11//!
12//! ```no_run
13//! use std::io::{Read, Write};
14//!
15//! #[go_lib::main]
16//! fn main() {
17//! let listener = go_lib::net::TcpListener::bind("127.0.0.1:8080").unwrap();
18//! loop {
19//! let mut stream = listener.accept().unwrap();
20//! go_lib::go!(move || {
21//! let mut buf = [0u8; 1024];
22//! let n = stream.read(&mut buf).unwrap();
23//! stream.write_all(&buf[..n]).unwrap();
24//! });
25//! }
26//! }
27//! ```
28//!
29//! ## `std::io` trait implementations
30//!
31//! `TcpStream` implements [`std::io::Read`] and [`std::io::Write`], so it
32//! works directly with any code that accepts `impl Read` or `impl Write` —
33//! including `BufReader`, `BufWriter`, and Rust's I/O adapters — without
34//! any unsafe wrapper or raw-fd manipulation.
35//!
36//! [`TcpStream::try_clone`] duplicates the underlying fd via `dup(2)`,
37//! yielding an independent stream that shares the same TCP connection. This
38//! is useful for splitting a connection into separate read and write halves:
39//!
40//! ```no_run
41//! use std::io::{Read, Write};
42//!
43//! #[go_lib::main]
44//! fn main() {
45//! let listener = go_lib::net::TcpListener::bind("127.0.0.1:9000").unwrap();
46//! let stream = listener.accept().unwrap();
47//! let mut writer = stream.try_clone().unwrap();
48//! go_lib::go!(move || {
49//! // `stream` is the read half; `writer` is the write half.
50//! let mut buf = [0u8; 512];
51//! let n = (&stream).read(&mut buf).unwrap(); // via &TcpStream impl
52//! writer.write_all(&buf[..n]).unwrap();
53//! });
54//! }
55//! ```
56//!
57//! ## Porting note
58//!
59//! Go's `net` package calls `runtime.poll.pollDesc.waitRead` / `waitWrite`
60//! which translate directly to `netpoll_arm(fd, POLL_READ/WRITE, gp)` +
61//! `gopark`. The same protocol is used here.
62
63use std::io::{self, Read, Write};
64use std::net::{SocketAddr, ToSocketAddrs};
65use std::os::unix::io::RawFd;
66
67use libc;
68
69use crate::runtime::g::WaitReason;
70use crate::runtime::netpoll::{netpoll_arm, netpoll_unarm, POLL_READ, POLL_WRITE};
71use crate::runtime::park::gopark;
72
73// ---------------------------------------------------------------------------
74// Helpers — non-blocking socket creation and address conversion
75// ---------------------------------------------------------------------------
76
77/// Create a non-blocking `SOCK_STREAM` socket for the given address family.
78///
79/// On Linux, `SOCK_NONBLOCK` is passed directly to `socket(2)`.
80/// On macOS (which lacks `SOCK_NONBLOCK`), `O_NONBLOCK` is set via `fcntl`.
81fn nonblocking_tcp_socket(family: libc::c_int) -> io::Result<RawFd> {
82 #[cfg(target_os = "linux")]
83 let fd = unsafe { libc::socket(family, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
84
85 #[cfg(not(target_os = "linux"))]
86 let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
87
88 if fd < 0 {
89 return Err(io::Error::last_os_error());
90 }
91
92 // On platforms where SOCK_NONBLOCK is not available, set O_NONBLOCK via fcntl.
93 #[cfg(not(target_os = "linux"))]
94 {
95 let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
96 if flags < 0
97 || unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0
98 {
99 unsafe { libc::close(fd) };
100 return Err(io::Error::last_os_error());
101 }
102 }
103
104 Ok(fd)
105}
106
107fn set_reuseaddr(fd: RawFd) -> io::Result<()> {
108 let one: libc::c_int = 1;
109 let ret = unsafe {
110 libc::setsockopt(
111 fd,
112 libc::SOL_SOCKET,
113 libc::SO_REUSEADDR,
114 &one as *const _ as *const libc::c_void,
115 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
116 )
117 };
118 if ret < 0 {
119 Err(io::Error::last_os_error())
120 } else {
121 Ok(())
122 }
123}
124
125/// Convert a `SocketAddr` to a `libc::sockaddr_storage` + length.
126fn to_sockaddr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
127 let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
128 match addr {
129 SocketAddr::V4(v4) => {
130 let sa: &mut libc::sockaddr_in =
131 unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in) };
132 sa.sin_family = libc::AF_INET as libc::sa_family_t;
133 sa.sin_port = v4.port().to_be();
134 sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
135 (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
136 }
137 SocketAddr::V6(v6) => {
138 let sa: &mut libc::sockaddr_in6 =
139 unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in6) };
140 sa.sin6_family = libc::AF_INET6 as libc::sa_family_t;
141 sa.sin6_port = v6.port().to_be();
142 sa.sin6_addr.s6_addr = v6.ip().octets();
143 (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
144 }
145 }
146}
147
148fn addr_family(addr: SocketAddr) -> libc::c_int {
149 match addr {
150 SocketAddr::V4(_) => libc::AF_INET,
151 SocketAddr::V6(_) => libc::AF_INET6,
152 }
153}
154
155/// Park the calling goroutine until `fd` is ready for `mode`
156/// (`POLL_READ` or `POLL_WRITE`).
157///
158/// # Safety
159/// Must be called from a live goroutine context.
160unsafe fn park_on_fd(fd: RawFd, mode: u32) {
161 let gp = crate::runtime::g::current_g();
162 debug_assert!(!gp.is_null(), "park_on_fd: not running on a goroutine");
163 unsafe {
164 netpoll_arm(fd, mode, gp);
165 gopark(WaitReason::IOWait);
166 // gopark suspends this goroutine; execution resumes after goready()
167 // is called by the netpoll machinery.
168 }
169}
170
171// ---------------------------------------------------------------------------
172// TcpListener
173// ---------------------------------------------------------------------------
174
175/// A goroutine-aware TCP server socket.
176///
177/// Calls to [`accept`][TcpListener::accept] park the current goroutine when no
178/// connection is immediately available and resume it when one arrives.
179pub struct TcpListener {
180 fd: RawFd,
181}
182
183impl TcpListener {
184 /// Bind a non-blocking TCP listener to `addr`.
185 ///
186 /// Equivalent to `net.Listen("tcp", addr)` in Go.
187 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
188 let addr = addr
189 .to_socket_addrs()?
190 .next()
191 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
192
193 let fd = nonblocking_tcp_socket(addr_family(addr))?;
194 set_reuseaddr(fd)?;
195
196 let (sa, sa_len) = to_sockaddr(addr);
197 let ret = unsafe {
198 libc::bind(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
199 };
200 if ret < 0 {
201 unsafe { libc::close(fd) };
202 return Err(io::Error::last_os_error());
203 }
204
205 let ret = unsafe { libc::listen(fd, 128) };
206 if ret < 0 {
207 unsafe { libc::close(fd) };
208 return Err(io::Error::last_os_error());
209 }
210
211 Ok(TcpListener { fd })
212 }
213
214 /// Accept the next incoming connection.
215 ///
216 /// Parks the goroutine if no connection is immediately available, resuming
217 /// it when the OS delivers one.
218 pub fn accept(&self) -> io::Result<TcpStream> {
219 loop {
220 let cfd = unsafe {
221 libc::accept(self.fd, std::ptr::null_mut(), std::ptr::null_mut())
222 };
223 if cfd >= 0 {
224 // Set O_NONBLOCK on the accepted socket.
225 let flags = unsafe { libc::fcntl(cfd, libc::F_GETFL) };
226 if flags >= 0 {
227 unsafe { libc::fcntl(cfd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
228 }
229 return Ok(TcpStream { fd: cfd });
230 }
231 let err = io::Error::last_os_error();
232 match err.raw_os_error().unwrap_or(0) {
233 libc::EAGAIN => {
234 // No connection yet — park until the listener fd is readable.
235 unsafe { park_on_fd(self.fd, POLL_READ) };
236 // After wakeup, retry accept().
237 }
238 _ => return Err(err),
239 }
240 }
241 }
242
243 /// Return the underlying raw file descriptor.
244 pub fn as_raw_fd(&self) -> RawFd {
245 self.fd
246 }
247}
248
249impl Drop for TcpListener {
250 fn drop(&mut self) {
251 netpoll_unarm(self.fd);
252 unsafe { libc::close(self.fd) };
253 }
254}
255
256// ---------------------------------------------------------------------------
257// TcpStream
258// ---------------------------------------------------------------------------
259
260/// A goroutine-aware TCP stream socket.
261///
262/// Blocking reads and writes park the calling goroutine (via the netpoll
263/// backend) when the operation would block, resuming it when data is
264/// available or the send buffer has space.
265///
266/// `TcpStream` implements [`std::io::Read`] and [`std::io::Write`] (for both
267/// `&mut TcpStream` and `&TcpStream`), so it works with any Rust I/O adapter
268/// without unsafe wrapper code. Use [`try_clone`][TcpStream::try_clone] to
269/// split a connection into independent read and write halves.
270pub struct TcpStream {
271 fd: RawFd,
272}
273
274impl TcpStream {
275 /// Connect to `addr`.
276 ///
277 /// Parks the goroutine until the connection completes if it does not
278 /// complete immediately (which is typical for non-blocking `connect`).
279 ///
280 /// Equivalent to `net.Dial("tcp", addr)` in Go.
281 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
282 let addr = addr
283 .to_socket_addrs()?
284 .next()
285 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
286
287 let fd = nonblocking_tcp_socket(addr_family(addr))?;
288 let (sa, sa_len) = to_sockaddr(addr);
289
290 let ret = unsafe {
291 libc::connect(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
292 };
293
294 if ret < 0 {
295 let err = io::Error::last_os_error();
296 match err.raw_os_error().unwrap_or(0) {
297 libc::EINPROGRESS | libc::EAGAIN => {
298 // Connection in progress — park until the socket is writable.
299 unsafe { park_on_fd(fd, POLL_WRITE) };
300 // Check for connect error via SO_ERROR.
301 let mut so_err: libc::c_int = 0;
302 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
303 unsafe {
304 libc::getsockopt(
305 fd,
306 libc::SOL_SOCKET,
307 libc::SO_ERROR,
308 &mut so_err as *mut _ as *mut libc::c_void,
309 &mut len,
310 )
311 };
312 if so_err != 0 {
313 unsafe { libc::close(fd) };
314 return Err(io::Error::from_raw_os_error(so_err));
315 }
316 }
317 _ => {
318 unsafe { libc::close(fd) };
319 return Err(err);
320 }
321 }
322 }
323
324 Ok(TcpStream { fd })
325 }
326
327 /// Read bytes from the stream into `buf`.
328 ///
329 /// Parks the goroutine if no data is immediately available.
330 pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
331 loop {
332 let n = unsafe {
333 libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
334 };
335 if n >= 0 {
336 return Ok(n as usize);
337 }
338 let err = io::Error::last_os_error();
339 match err.raw_os_error().unwrap_or(0) {
340 libc::EAGAIN => {
341 unsafe { park_on_fd(self.fd, POLL_READ) };
342 }
343 _ => return Err(err),
344 }
345 }
346 }
347
348 /// Write `buf` to the stream.
349 ///
350 /// Parks the goroutine if the send buffer is full. Returns the number of
351 /// bytes written (may be less than `buf.len()`).
352 pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
353 loop {
354 let n = unsafe {
355 libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
356 };
357 if n >= 0 {
358 return Ok(n as usize);
359 }
360 let err = io::Error::last_os_error();
361 match err.raw_os_error().unwrap_or(0) {
362 libc::EAGAIN => {
363 unsafe { park_on_fd(self.fd, POLL_WRITE) };
364 }
365 _ => return Err(err),
366 }
367 }
368 }
369
370 /// Duplicate this stream, creating a second `TcpStream` that refers to the
371 /// same underlying TCP connection.
372 ///
373 /// The duplicate is an independent `TcpStream` with its own fd (via
374 /// `dup(2)`). Both streams share the same socket; reads and writes on
375 /// either half see the same data stream. Closing one does not close the
376 /// other.
377 ///
378 /// The typical use-case is splitting a connection into a dedicated read
379 /// half and a dedicated write half for use in separate goroutines.
380 pub fn try_clone(&self) -> io::Result<TcpStream> {
381 let new_fd = unsafe { libc::dup(self.fd) };
382 if new_fd < 0 {
383 return Err(io::Error::last_os_error());
384 }
385 Ok(TcpStream { fd: new_fd })
386 }
387
388 /// Return the remote address of the peer this stream is connected to.
389 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
390 sockaddr_of(self.fd, /* peer = */ true)
391 }
392
393 /// Return the local address this stream is bound to.
394 pub fn local_addr(&self) -> io::Result<SocketAddr> {
395 sockaddr_of(self.fd, /* peer = */ false)
396 }
397
398 /// Return the underlying raw file descriptor.
399 pub fn as_raw_fd(&self) -> RawFd {
400 self.fd
401 }
402}
403
404impl Drop for TcpStream {
405 fn drop(&mut self) {
406 netpoll_unarm(self.fd);
407 unsafe { libc::close(self.fd) };
408 }
409}
410
411// ---------------------------------------------------------------------------
412// std::io trait implementations for TcpStream
413// ---------------------------------------------------------------------------
414
415/// Implements [`std::io::Read`] by delegating to [`TcpStream::read`].
416///
417/// This allows `TcpStream` to be used with any Rust I/O adapter that accepts
418/// `impl Read`, such as `BufReader`, `Read::read_to_string`, etc., without
419/// any unsafe wrapper or raw-fd manipulation.
420impl Read for TcpStream {
421 #[inline]
422 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
423 TcpStream::read(self, buf)
424 }
425}
426
427/// Implements [`std::io::Read`] on a shared reference by issuing a raw
428/// `libc::read` call. The fd is non-blocking; EAGAIN causes the goroutine
429/// to park via netpoll exactly as the owned-`&mut self` path does.
430///
431/// This enables using the same `TcpStream` for both reading and writing from
432/// two separate code sites within the same goroutine (e.g. after splitting
433/// into read/write halves conceptually without calling `try_clone`).
434impl Read for &TcpStream {
435 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
436 loop {
437 let n = unsafe {
438 libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
439 };
440 if n >= 0 {
441 return Ok(n as usize);
442 }
443 let err = io::Error::last_os_error();
444 match err.raw_os_error().unwrap_or(0) {
445 libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_READ) },
446 _ => return Err(err),
447 }
448 }
449 }
450}
451
452/// Implements [`std::io::Write`] by delegating to [`TcpStream::write`].
453/// `flush` is a no-op because the kernel TCP stack handles buffering.
454impl Write for TcpStream {
455 #[inline]
456 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
457 TcpStream::write(self, buf)
458 }
459
460 #[inline]
461 fn flush(&mut self) -> io::Result<()> {
462 Ok(())
463 }
464}
465
466/// Implements [`std::io::Write`] on a shared reference.
467impl Write for &TcpStream {
468 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
469 loop {
470 let n = unsafe {
471 libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
472 };
473 if n >= 0 {
474 return Ok(n as usize);
475 }
476 let err = io::Error::last_os_error();
477 match err.raw_os_error().unwrap_or(0) {
478 libc::EAGAIN => unsafe { park_on_fd(self.fd, POLL_WRITE) },
479 _ => return Err(err),
480 }
481 }
482 }
483
484 #[inline]
485 fn flush(&mut self) -> io::Result<()> {
486 Ok(())
487 }
488}
489
490// ---------------------------------------------------------------------------
491// std::io trait implementations for TcpListener
492// ---------------------------------------------------------------------------
493
494impl TcpListener {
495 /// Return the local address the listener is bound to.
496 pub fn local_addr(&self) -> io::Result<SocketAddr> {
497 sockaddr_of(self.fd, /* peer = */ false)
498 }
499}
500
501// ---------------------------------------------------------------------------
502// Address helpers
503// ---------------------------------------------------------------------------
504
505/// Query the local or peer address of `fd`.
506fn sockaddr_of(fd: RawFd, peer: bool) -> io::Result<SocketAddr> {
507 let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
508 let mut len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
509 let ret = unsafe {
510 if peer {
511 libc::getpeername(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
512 } else {
513 libc::getsockname(fd, &mut storage as *mut _ as *mut libc::sockaddr, &mut len)
514 }
515 };
516 if ret < 0 {
517 return Err(io::Error::last_os_error());
518 }
519 match storage.ss_family as libc::c_int {
520 libc::AF_INET => {
521 let sa: &libc::sockaddr_in =
522 unsafe { &*(&storage as *const _ as *const libc::sockaddr_in) };
523 let ip = std::net::Ipv4Addr::from(u32::from_be(sa.sin_addr.s_addr));
524 let port = u16::from_be(sa.sin_port);
525 Ok(SocketAddr::from((ip, port)))
526 }
527 libc::AF_INET6 => {
528 let sa: &libc::sockaddr_in6 =
529 unsafe { &*(&storage as *const _ as *const libc::sockaddr_in6) };
530 let ip = std::net::Ipv6Addr::from(sa.sin6_addr.s6_addr);
531 let port = u16::from_be(sa.sin6_port);
532 Ok(SocketAddr::from((ip, port)))
533 }
534 family => Err(io::Error::new(
535 io::ErrorKind::Unsupported,
536 format!("unsupported address family: {family}"),
537 )),
538 }
539}