Skip to main content

nexus_async_rt/net/
tcp.rs

1//! Async TCP stream, listener, and pre-bind socket configuration.
2//!
3//! Wraps mio's TCP types with the runtime's IO driver for readiness-based
4//! async IO. Sockets register with mio lazily on first poll — the task
5//! pointer comes from the `Context`'s waker.
6//!
7//! # Split
8//!
9//! [`TcpStream::split`] borrows the stream into separate read/write halves
10//! for concurrent IO within a single task. [`TcpStream::into_split`]
11//! consumes the stream into owned halves that can be moved to different
12//! tasks.
13
14use std::io::{self, Read, Write};
15use std::net::SocketAddr;
16use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd};
17use std::pin::Pin;
18use std::task::{Context, Poll, Waker};
19use std::time::Duration;
20
21use mio::{Interest, Token};
22
23use super::{AsyncRead, AsyncWrite, waker_to_ptr};
24use crate::io::IoHandle;
25
26// =============================================================================
27// TcpStream
28// =============================================================================
29
30/// Async TCP stream backed by mio.
31///
32/// Created via [`TcpListener::accept`], [`TcpStream::connect`], or
33/// [`TcpSocket::connect`]. Implements [`AsyncRead`] and [`AsyncWrite`].
34///
35/// The stream registers with mio lazily on the first read or write.
36/// Uses edge-triggered epoll — registration happens once and persists.
37pub struct TcpStream {
38    inner: mio::net::TcpStream,
39    io: IoHandle,
40    token: Option<Token>,
41    /// Task pointer from the last registration. Used to detect when the
42    /// stream moves to a different task (e.g., via `into_split`) and
43    /// reregister with the IO driver to wake the correct task.
44    registered_task: *mut u8,
45}
46
47impl TcpStream {
48    /// Wrap a mio TcpStream. Registration deferred to first poll.
49    pub(crate) fn new(inner: mio::net::TcpStream, io: IoHandle) -> Self {
50        Self {
51            inner,
52            io,
53            token: None,
54            registered_task: std::ptr::null_mut(),
55        }
56    }
57
58    /// Initiate an async TCP connection to `addr`.
59    ///
60    /// The connection completes asynchronously. The first read or write
61    /// will register with mio and detect when the connection is
62    /// established.
63    pub fn connect(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
64        let inner = mio::net::TcpStream::connect(addr)?;
65        Ok(Self::new(inner, io))
66    }
67
68    /// Convert from a `std::net::TcpStream`.
69    ///
70    /// The stream must be set to non-blocking mode before calling this.
71    pub fn from_std(stream: std::net::TcpStream, io: IoHandle) -> io::Result<Self> {
72        let inner = mio::net::TcpStream::from_std(stream);
73        Ok(Self::new(inner, io))
74    }
75
76    /// Convert into a `std::net::TcpStream`.
77    ///
78    /// Deregisters from mio. The returned stream is still non-blocking.
79    pub fn into_std(mut self) -> io::Result<std::net::TcpStream> {
80        if let Some(token) = self.token.take() {
81            // SAFETY: IoHandle valid (Runtime lifetime).
82            let _ = unsafe { self.io.deregister(&mut self.inner, token) };
83        }
84        let fd = self.inner.as_raw_fd();
85        std::mem::forget(self); // skip Drop (already deregistered)
86        // SAFETY: fd is valid, we own it.
87        Ok(unsafe { std::net::TcpStream::from_raw_fd(fd) })
88    }
89
90    // =========================================================================
91    // Address
92    // =========================================================================
93
94    /// Returns the local address of this stream.
95    pub fn local_addr(&self) -> io::Result<SocketAddr> {
96        self.inner.local_addr()
97    }
98
99    /// Returns the remote address of this stream.
100    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
101        self.inner.peer_addr()
102    }
103
104    // =========================================================================
105    // Socket options (via socket2)
106    // =========================================================================
107
108    /// Helper: get a socket2::Socket reference for option access.
109    fn socket_ref(&self) -> socket2::SockRef<'_> {
110        socket2::SockRef::from(&self.inner)
111    }
112
113    /// Get TCP_NODELAY.
114    pub fn nodelay(&self) -> io::Result<bool> {
115        self.inner.nodelay()
116    }
117
118    /// Set TCP_NODELAY (disable Nagle's algorithm).
119    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
120        self.inner.set_nodelay(nodelay)
121    }
122
123    /// Get IP_TTL.
124    pub fn ttl(&self) -> io::Result<u32> {
125        self.socket_ref().ttl()
126    }
127
128    /// Set IP_TTL.
129    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
130        self.socket_ref().set_ttl(ttl)
131    }
132
133    /// Get SO_LINGER.
134    pub fn linger(&self) -> io::Result<Option<Duration>> {
135        self.socket_ref().linger()
136    }
137
138    /// Set SO_LINGER.
139    pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
140        self.socket_ref().set_linger(duration)
141    }
142
143    /// Get SO_KEEPALIVE.
144    pub fn keepalive(&self) -> io::Result<bool> {
145        self.socket_ref().keepalive()
146    }
147
148    /// Set SO_KEEPALIVE.
149    pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
150        self.socket_ref().set_keepalive(keepalive)
151    }
152
153    /// Get SO_SNDBUF.
154    pub fn send_buffer_size(&self) -> io::Result<usize> {
155        self.socket_ref().send_buffer_size()
156    }
157
158    /// Set SO_SNDBUF.
159    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
160        self.socket_ref().set_send_buffer_size(size)
161    }
162
163    /// Get SO_RCVBUF.
164    pub fn recv_buffer_size(&self) -> io::Result<usize> {
165        self.socket_ref().recv_buffer_size()
166    }
167
168    /// Set SO_RCVBUF.
169    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
170        self.socket_ref().set_recv_buffer_size(size)
171    }
172
173    /// Get SO_ERROR and clear it.
174    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
175        self.socket_ref().take_error()
176    }
177
178    // =========================================================================
179    // Non-blocking try methods (no context needed)
180    // =========================================================================
181
182    /// Try to read without blocking. Returns `WouldBlock` if not ready.
183    pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
184        (&self.inner).read(buf)
185    }
186
187    /// Try to write without blocking. Returns `WouldBlock` if not ready.
188    pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
189        (&self.inner).write(buf)
190    }
191
192    /// Read without consuming from the buffer (MSG_PEEK).
193    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
194        // SAFETY: u8 and MaybeUninit<u8> have the same layout.
195        let buf = unsafe { &mut *(buf as *mut [u8] as *mut [std::mem::MaybeUninit<u8>]) };
196        self.socket_ref().peek(buf)
197    }
198
199    // =========================================================================
200    // Async convenience methods
201    // =========================================================================
202
203    /// Read bytes from the stream. Returns when at least 1 byte is read
204    /// or EOF (0 bytes).
205    pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
206        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await
207    }
208
209    /// Write bytes to the stream. Returns when at least 1 byte is written.
210    pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
211        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_write(cx, buf)).await
212    }
213
214    /// Write all bytes to the stream.
215    pub async fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
216        while !buf.is_empty() {
217            let n = self.write(buf).await?;
218            if n == 0 {
219                return Err(io::Error::new(
220                    io::ErrorKind::WriteZero,
221                    "failed to write whole buffer",
222                ));
223            }
224            buf = &buf[n..];
225        }
226        Ok(())
227    }
228
229    /// Poll for read readiness without performing IO.
230    ///
231    /// Returns `Ready(Ok(()))` if the socket has been reported readable
232    /// by epoll. Returns `Pending` if not yet ready. Use this for
233    /// sans-IO codecs that want to check readiness before feeding bytes.
234    pub fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
235        if let Err(e) = self.ensure_registered(cx) {
236            return Poll::Ready(Err(e));
237        }
238        if let Some(token) = self.token {
239            if self.io.readiness(token).readable {
240                return Poll::Ready(Ok(()));
241            }
242        }
243        Poll::Pending
244    }
245
246    /// Poll for write readiness without performing IO.
247    pub fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
248        if let Err(e) = self.ensure_registered(cx) {
249            return Poll::Ready(Err(e));
250        }
251        if let Some(token) = self.token {
252            if self.io.readiness(token).writable {
253                return Poll::Ready(Ok(()));
254            }
255        }
256        Poll::Pending
257    }
258
259    /// Wait until the stream is readable.
260    ///
261    /// Returns when epoll reports the socket as readable. After this
262    /// returns, [`try_read`](Self::try_read) should succeed.
263    pub async fn readable(&mut self) -> io::Result<()> {
264        std::future::poll_fn(|cx| self.poll_read_ready(cx)).await
265    }
266
267    /// Wait until the stream is writable.
268    pub async fn writable(&mut self) -> io::Result<()> {
269        std::future::poll_fn(|cx| self.poll_write_ready(cx)).await
270    }
271
272    // Note: after a successful read or WouldBlock, the readable flag is
273    // Correctly implementing them requires tracking readiness state from
274    // epoll events (like tokio's internal readiness tracking). Zero-length
275    // reads/writes don't reliably probe socket readiness on Linux.
276    // Use poll_read/poll_write or try_read/try_write instead.
277
278    // =========================================================================
279    // Split
280    // =========================================================================
281
282    /// Split into borrowed read and write halves.
283    ///
284    /// Both halves borrow the stream — they can be used concurrently
285    /// within a single task but cannot be moved to different tasks.
286    pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
287        let ptr = std::ptr::from_mut(self);
288        (
289            ReadHalf {
290                stream: ptr,
291                _marker: std::marker::PhantomData,
292            },
293            WriteHalf {
294                stream: ptr,
295                _marker: std::marker::PhantomData,
296            },
297        )
298    }
299
300    /// Split into owned read and write halves.
301    ///
302    /// The halves can be moved to different spawned tasks on the same
303    /// single-threaded runtime (`!Send` — not across threads). The IO
304    /// driver automatically updates the task pointer when a half is
305    /// polled from a different task. Use [`OwnedReadHalf::reunite`]
306    /// to reassemble the stream.
307    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
308        use std::rc::Rc;
309        let shared = Rc::new(std::cell::UnsafeCell::new(self));
310        (
311            OwnedReadHalf {
312                stream: Rc::clone(&shared),
313            },
314            OwnedWriteHalf { stream: shared },
315        )
316    }
317
318    // =========================================================================
319    // Registration (internal)
320    // =========================================================================
321
322    /// Ensure registered with mio and the correct task waker.
323    ///
324    /// First call: registers with mio. Subsequent calls: checks if the
325    /// task pointer changed (stream moved to a different task via
326    /// `into_split`). If so, updates the IO driver's waker.
327    #[inline(always)]
328    fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
329        let task_ptr = waker_to_ptr(cx);
330        if let Some(token) = self.token {
331            // Already registered — check if task changed.
332            if task_ptr != self.registered_task {
333                self.io.set_waker(token, cx.waker().clone());
334                self.registered_task = task_ptr;
335            }
336            return Ok(());
337        }
338        self.do_register(task_ptr, cx.waker().clone())
339    }
340
341    #[cold]
342    fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
343        let interest = Interest::READABLE | Interest::WRITABLE;
344        let token = self.io.register(&mut self.inner, interest, waker)?;
345        self.token = Some(token);
346        self.registered_task = task_ptr;
347        Ok(())
348    }
349}
350
351impl AsyncRead for TcpStream {
352    fn poll_read(
353        self: Pin<&mut Self>,
354        cx: &mut Context<'_>,
355        buf: &mut [u8],
356    ) -> Poll<io::Result<usize>> {
357        let this = self.get_mut();
358        if let Err(e) = this.ensure_registered(cx) {
359            return Poll::Ready(Err(e));
360        }
361        match this.inner.read(buf) {
362            Ok(n) => Poll::Ready(Ok(n)),
363            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
364                // Clear readable — wait for next epoll notification.
365                if let Some(token) = this.token {
366                    this.io.clear_readable(token);
367                }
368                Poll::Pending
369            }
370            Err(e) => Poll::Ready(Err(e)),
371        }
372    }
373}
374
375impl AsyncWrite for TcpStream {
376    fn poll_write(
377        self: Pin<&mut Self>,
378        cx: &mut Context<'_>,
379        buf: &[u8],
380    ) -> Poll<io::Result<usize>> {
381        let this = self.get_mut();
382        if let Err(e) = this.ensure_registered(cx) {
383            return Poll::Ready(Err(e));
384        }
385        match this.inner.write(buf) {
386            Ok(n) => Poll::Ready(Ok(n)),
387            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
388                if let Some(token) = this.token {
389                    this.io.clear_writable(token);
390                }
391                Poll::Pending
392            }
393            Err(e) => Poll::Ready(Err(e)),
394        }
395    }
396
397    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
398        let this = self.get_mut();
399        if let Err(e) = this.ensure_registered(cx) {
400            return Poll::Ready(Err(e));
401        }
402        match this.inner.flush() {
403            Ok(()) => Poll::Ready(Ok(())),
404            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
405                if let Some(token) = this.token {
406                    this.io.clear_writable(token);
407                }
408                Poll::Pending
409            }
410            Err(e) => Poll::Ready(Err(e)),
411        }
412    }
413
414    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
415        let this = self.get_mut();
416        match this.inner.shutdown(std::net::Shutdown::Write) {
417            Ok(()) => Poll::Ready(Ok(())),
418            Err(e) if e.kind() == io::ErrorKind::NotConnected => Poll::Ready(Ok(())),
419            Err(e) => Poll::Ready(Err(e)),
420        }
421    }
422}
423
424impl std::fmt::Debug for TcpStream {
425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        f.debug_struct("TcpStream")
427            .field("fd", &self.inner.as_raw_fd())
428            .field("registered", &self.token.is_some())
429            .finish()
430    }
431}
432
433impl AsFd for TcpStream {
434    fn as_fd(&self) -> BorrowedFd<'_> {
435        self.inner.as_fd()
436    }
437}
438
439impl AsRawFd for TcpStream {
440    fn as_raw_fd(&self) -> RawFd {
441        self.inner.as_raw_fd()
442    }
443}
444
445impl Drop for TcpStream {
446    fn drop(&mut self) {
447        if let Some(token) = self.token {
448            // SAFETY: IoHandle valid (Runtime lifetime).
449            let _ = unsafe { self.io.deregister(&mut self.inner, token) };
450        }
451    }
452}
453
454// =============================================================================
455// ReadHalf / WriteHalf (borrowed split)
456// =============================================================================
457
458/// Borrowed read half of a [`TcpStream`].
459///
460/// Created by [`TcpStream::split`]. Borrows the stream — cannot be moved
461/// to a different task. Implements [`AsyncRead`].
462pub struct ReadHalf<'a> {
463    stream: *mut TcpStream,
464    // Tie lifetime to the borrow of the stream.
465    _marker: std::marker::PhantomData<&'a mut TcpStream>,
466}
467
468// The split constructor actually gives us two raw pointers to the same stream.
469// This is safe because ReadHalf only reads and WriteHalf only writes — no
470// aliased mutation of the same fields. Single-threaded.
471impl ReadHalf<'_> {
472    fn stream(&mut self) -> &mut TcpStream {
473        // SAFETY: Borrowed from split(), single-threaded, read-only side.
474        unsafe { &mut *self.stream }
475    }
476}
477
478impl AsyncRead for ReadHalf<'_> {
479    fn poll_read(
480        self: Pin<&mut Self>,
481        cx: &mut Context<'_>,
482        buf: &mut [u8],
483    ) -> Poll<io::Result<usize>> {
484        let this = self.get_mut();
485        Pin::new(this.stream()).poll_read(cx, buf)
486    }
487}
488
489/// Borrowed write half of a [`TcpStream`].
490///
491/// Created by [`TcpStream::split`]. Borrows the stream — cannot be moved
492/// to a different task. Implements [`AsyncWrite`].
493pub struct WriteHalf<'a> {
494    stream: *mut TcpStream,
495    _marker: std::marker::PhantomData<&'a mut TcpStream>,
496}
497
498impl WriteHalf<'_> {
499    fn stream(&mut self) -> &mut TcpStream {
500        // SAFETY: Borrowed from split(), single-threaded, write-only side.
501        unsafe { &mut *self.stream }
502    }
503}
504
505impl AsyncWrite for WriteHalf<'_> {
506    fn poll_write(
507        self: Pin<&mut Self>,
508        cx: &mut Context<'_>,
509        buf: &[u8],
510    ) -> Poll<io::Result<usize>> {
511        let this = self.get_mut();
512        Pin::new(this.stream()).poll_write(cx, buf)
513    }
514
515    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
516        let this = self.get_mut();
517        Pin::new(this.stream()).poll_flush(cx)
518    }
519
520    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
521        let this = self.get_mut();
522        Pin::new(this.stream()).poll_shutdown(cx)
523    }
524}
525
526// =============================================================================
527// OwnedReadHalf / OwnedWriteHalf (owned split)
528// =============================================================================
529
530/// Owned read half of a [`TcpStream`].
531///
532/// Created by [`TcpStream::into_split`]. Can be moved to a different task.
533pub struct OwnedReadHalf {
534    stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
535}
536
537impl OwnedReadHalf {
538    /// Reassemble the stream from its halves.
539    ///
540    /// Returns `Err` if the halves don't belong to the same stream.
541    pub fn reunite(self, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
542        if std::rc::Rc::ptr_eq(&self.stream, &write.stream) {
543            drop(write);
544            let cell = std::rc::Rc::try_unwrap(self.stream).map_err(|_| ReuniteError)?;
545            Ok(cell.into_inner())
546        } else {
547            Err(ReuniteError)
548        }
549    }
550
551    /// Returns the peer address.
552    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
553        // SAFETY: single-threaded, immutable field access.
554        unsafe { &*self.stream.get() }.peer_addr()
555    }
556
557    /// Returns the local address.
558    pub fn local_addr(&self) -> io::Result<SocketAddr> {
559        unsafe { &*self.stream.get() }.local_addr()
560    }
561}
562
563impl AsyncRead for OwnedReadHalf {
564    fn poll_read(
565        self: Pin<&mut Self>,
566        cx: &mut Context<'_>,
567        buf: &mut [u8],
568    ) -> Poll<io::Result<usize>> {
569        // SAFETY: single-threaded. Only the read half calls poll_read.
570        let stream = unsafe { &mut *self.stream.get() };
571        Pin::new(stream).poll_read(cx, buf)
572    }
573}
574
575/// Owned write half of a [`TcpStream`].
576///
577/// Created by [`TcpStream::into_split`]. Can be moved to a different task.
578pub struct OwnedWriteHalf {
579    stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
580}
581
582impl OwnedWriteHalf {
583    /// Reassemble the stream from its halves.
584    pub fn reunite(self, read: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
585        read.reunite(self)
586    }
587
588    /// Returns the peer address.
589    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
590        unsafe { &*self.stream.get() }.peer_addr()
591    }
592
593    /// Returns the local address.
594    pub fn local_addr(&self) -> io::Result<SocketAddr> {
595        unsafe { &*self.stream.get() }.local_addr()
596    }
597}
598
599impl AsyncWrite for OwnedWriteHalf {
600    fn poll_write(
601        self: Pin<&mut Self>,
602        cx: &mut Context<'_>,
603        buf: &[u8],
604    ) -> Poll<io::Result<usize>> {
605        // SAFETY: single-threaded. Only the write half calls poll_write.
606        let stream = unsafe { &mut *self.stream.get() };
607        Pin::new(stream).poll_write(cx, buf)
608    }
609
610    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
611        let stream = unsafe { &mut *self.stream.get() };
612        Pin::new(stream).poll_flush(cx)
613    }
614
615    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
616        let stream = unsafe { &mut *self.stream.get() };
617        Pin::new(stream).poll_shutdown(cx)
618    }
619}
620
621/// Error returned by [`OwnedReadHalf::reunite`] when the halves don't match.
622#[derive(Debug)]
623pub struct ReuniteError;
624
625impl std::fmt::Display for ReuniteError {
626    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
627        write!(f, "halves do not belong to the same TcpStream")
628    }
629}
630
631impl std::error::Error for ReuniteError {}
632
633// =============================================================================
634// TcpListener
635// =============================================================================
636
637/// Async TCP listener backed by mio.
638///
639/// Bind with [`TcpListener::bind`] or [`TcpSocket::listen`], then call
640/// [`accept`](Self::accept) to await incoming connections.
641pub struct TcpListener {
642    inner: mio::net::TcpListener,
643    io: IoHandle,
644    token: Option<Token>,
645    registered_task: *mut u8,
646}
647
648impl TcpListener {
649    /// Bind to `addr`. Registration deferred to first `accept` poll.
650    pub fn bind(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
651        let inner = mio::net::TcpListener::bind(addr)?;
652        Ok(Self {
653            inner,
654            io,
655            token: None,
656            registered_task: std::ptr::null_mut(),
657        })
658    }
659
660    /// Convert from a `std::net::TcpListener`.
661    pub fn from_std(listener: std::net::TcpListener, io: IoHandle) -> io::Result<Self> {
662        let inner = mio::net::TcpListener::from_std(listener);
663        Ok(Self {
664            inner,
665            io,
666            token: None,
667            registered_task: std::ptr::null_mut(),
668        })
669    }
670
671    /// Returns the local address this listener is bound to.
672    pub fn local_addr(&self) -> io::Result<SocketAddr> {
673        self.inner.local_addr()
674    }
675
676    /// Get IP_TTL.
677    pub fn ttl(&self) -> io::Result<u32> {
678        socket2::SockRef::from(&self.inner).ttl()
679    }
680
681    /// Set IP_TTL.
682    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
683        socket2::SockRef::from(&self.inner).set_ttl(ttl)
684    }
685
686    /// Accept a new TCP connection.
687    pub fn accept(&mut self) -> Accept<'_> {
688        Accept { listener: self }
689    }
690
691    /// Ensure registered with mio and the correct task waker.
692    #[inline(always)]
693    fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
694        let task_ptr = waker_to_ptr(cx);
695        if let Some(token) = self.token {
696            if task_ptr != self.registered_task {
697                self.io.set_waker(token, cx.waker().clone());
698                self.registered_task = task_ptr;
699            }
700            return Ok(());
701        }
702        self.do_register(task_ptr, cx.waker().clone())
703    }
704
705    #[cold]
706    fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
707        let token = self
708            .io
709            .register(&mut self.inner, Interest::READABLE, waker)?;
710        self.token = Some(token);
711        self.registered_task = task_ptr;
712        Ok(())
713    }
714}
715
716impl std::fmt::Debug for TcpListener {
717    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718        f.debug_struct("TcpListener")
719            .field("fd", &self.inner.as_raw_fd())
720            .field("registered", &self.token.is_some())
721            .finish()
722    }
723}
724
725impl AsFd for TcpListener {
726    fn as_fd(&self) -> BorrowedFd<'_> {
727        self.inner.as_fd()
728    }
729}
730
731impl AsRawFd for TcpListener {
732    fn as_raw_fd(&self) -> RawFd {
733        self.inner.as_raw_fd()
734    }
735}
736
737impl Drop for TcpListener {
738    fn drop(&mut self) {
739        if let Some(token) = self.token {
740            let _ = unsafe { self.io.deregister(&mut self.inner, token) };
741        }
742    }
743}
744
745/// Future returned by [`TcpListener::accept`].
746pub struct Accept<'a> {
747    listener: &'a mut TcpListener,
748}
749
750impl std::future::Future for Accept<'_> {
751    type Output = io::Result<(TcpStream, SocketAddr)>;
752
753    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
754        let this = self.get_mut();
755        if let Err(e) = this.listener.ensure_registered(cx) {
756            return Poll::Ready(Err(e));
757        }
758        match this.listener.inner.accept() {
759            Ok((stream, addr)) => {
760                let tcp = TcpStream::new(stream, this.listener.io);
761                Poll::Ready(Ok((tcp, addr)))
762            }
763            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
764            Err(e) => Poll::Ready(Err(e)),
765        }
766    }
767}
768
769// =============================================================================
770// TcpSocket — pre-bind configuration
771// =============================================================================
772
773/// TCP socket builder for configuring options before bind/connect.
774///
775/// Wraps `socket2::Socket` to provide access to socket options that
776/// must be set before binding (SO_REUSEADDR, SO_REUSEPORT, buffer
777/// sizes, etc.).
778///
779/// # Examples
780///
781/// ```ignore
782/// let socket = TcpSocket::new_v4()?;
783/// socket.set_reuseaddr(true)?;
784/// socket.set_recv_buffer_size(1024 * 1024)?;
785/// let listener = socket.listen(1024, io)?;
786/// ```
787pub struct TcpSocket {
788    inner: socket2::Socket,
789}
790
791impl TcpSocket {
792    /// Create a new IPv4 TCP socket.
793    pub fn new_v4() -> io::Result<Self> {
794        let inner = socket2::Socket::new(
795            socket2::Domain::IPV4,
796            socket2::Type::STREAM,
797            Some(socket2::Protocol::TCP),
798        )?;
799        inner.set_nonblocking(true)?;
800        Ok(Self { inner })
801    }
802
803    /// Create a new IPv6 TCP socket.
804    pub fn new_v6() -> io::Result<Self> {
805        let inner = socket2::Socket::new(
806            socket2::Domain::IPV6,
807            socket2::Type::STREAM,
808            Some(socket2::Protocol::TCP),
809        )?;
810        inner.set_nonblocking(true)?;
811        Ok(Self { inner })
812    }
813
814    // -- Socket options --
815
816    /// Set SO_REUSEADDR.
817    pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
818        self.inner.set_reuse_address(reuseaddr)
819    }
820
821    /// Get SO_REUSEADDR.
822    pub fn reuseaddr(&self) -> io::Result<bool> {
823        self.inner.reuse_address()
824    }
825
826    /// Set SO_REUSEPORT (Unix only).
827    #[cfg(unix)]
828    pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> {
829        self.inner.set_reuse_port(reuseport)
830    }
831
832    /// Get SO_REUSEPORT (Unix only).
833    #[cfg(unix)]
834    pub fn reuseport(&self) -> io::Result<bool> {
835        self.inner.reuse_port()
836    }
837
838    /// Set SO_KEEPALIVE.
839    pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
840        self.inner.set_keepalive(keepalive)
841    }
842
843    /// Get SO_KEEPALIVE.
844    pub fn keepalive(&self) -> io::Result<bool> {
845        self.inner.keepalive()
846    }
847
848    /// Set TCP_NODELAY.
849    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
850        self.inner.set_nodelay(nodelay)
851    }
852
853    /// Get TCP_NODELAY.
854    pub fn nodelay(&self) -> io::Result<bool> {
855        self.inner.nodelay()
856    }
857
858    /// Set SO_LINGER.
859    pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
860        self.inner.set_linger(duration)
861    }
862
863    /// Get SO_LINGER.
864    pub fn linger(&self) -> io::Result<Option<Duration>> {
865        self.inner.linger()
866    }
867
868    /// Set SO_SNDBUF.
869    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
870        self.inner.set_send_buffer_size(size)
871    }
872
873    /// Get SO_SNDBUF.
874    pub fn send_buffer_size(&self) -> io::Result<usize> {
875        self.inner.send_buffer_size()
876    }
877
878    /// Set SO_RCVBUF.
879    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
880        self.inner.set_recv_buffer_size(size)
881    }
882
883    /// Get SO_RCVBUF.
884    pub fn recv_buffer_size(&self) -> io::Result<usize> {
885        self.inner.recv_buffer_size()
886    }
887
888    /// Set IP_TTL.
889    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
890        self.inner.set_ttl(ttl)
891    }
892
893    /// Get IP_TTL.
894    pub fn ttl(&self) -> io::Result<u32> {
895        self.inner.ttl()
896    }
897
898    // -- Bind, connect, listen --
899
900    /// Bind the socket to `addr`.
901    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
902        self.inner.bind(&addr.into())
903    }
904
905    /// Connect to `addr` and return a [`TcpStream`].
906    ///
907    /// The connection completes asynchronously (non-blocking socket).
908    /// The first read or write will detect when the connection is
909    /// established.
910    pub fn connect(self, addr: SocketAddr, io: IoHandle) -> io::Result<TcpStream> {
911        // Non-blocking connect returns EINPROGRESS/EALREADY — that's
912        // normal, not an error. Suppress these.
913        match self.inner.connect(&addr.into()) {
914            Ok(()) => {}
915            Err(e)
916                if e.raw_os_error() == Some(libc::EINPROGRESS)
917                    || e.raw_os_error() == Some(libc::EALREADY) => {}
918            Err(e) => return Err(e),
919        }
920        let std_stream: std::net::TcpStream = self.inner.into();
921        let mio_stream = mio::net::TcpStream::from_std(std_stream);
922        Ok(TcpStream::new(mio_stream, io))
923    }
924
925    /// Start listening with the given backlog and return a [`TcpListener`].
926    pub fn listen(self, backlog: i32, io: IoHandle) -> io::Result<TcpListener> {
927        self.inner.listen(backlog)?;
928        let std_listener: std::net::TcpListener = self.inner.into();
929        let mio_listener = mio::net::TcpListener::from_std(std_listener);
930        Ok(TcpListener {
931            inner: mio_listener,
932            io,
933            token: None,
934            registered_task: std::ptr::null_mut(),
935        })
936    }
937}
938
939impl std::fmt::Debug for TcpSocket {
940    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941        f.debug_struct("TcpSocket")
942            .field("fd", &self.inner.as_raw_fd())
943            .finish()
944    }
945}
946
947impl AsFd for TcpSocket {
948    fn as_fd(&self) -> BorrowedFd<'_> {
949        self.inner.as_fd()
950    }
951}
952
953impl AsRawFd for TcpSocket {
954    fn as_raw_fd(&self) -> RawFd {
955        self.inner.as_raw_fd()
956    }
957}
958
959// =============================================================================
960// Tests
961// =============================================================================
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use crate::{Runtime, spawn_boxed};
967    use nexus_rt::WorldBuilder;
968    use std::cell::Cell;
969    use std::rc::Rc;
970
971    #[test]
972    fn tcp_echo() {
973        let wb = WorldBuilder::new();
974        let mut world = wb.build();
975        let mut rt = Runtime::new(&mut world);
976
977        let done = Rc::new(Cell::new(false));
978        let done2 = done.clone();
979
980        rt.block_on(async move {
981            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap(), crate::context::io())
982                .expect("bind failed");
983            let addr = listener.local_addr().unwrap();
984            spawn_boxed(async move {
985                let mut listener = listener;
986                let (mut stream, _peer) = listener.accept().await.unwrap();
987                let mut buf = [0u8; 64];
988                let n = stream.read(&mut buf).await.unwrap();
989                stream.write_all(&buf[..n]).await.unwrap();
990            });
991
992            let io = crate::context::io();
993            let flag = done2;
994            spawn_boxed(async move {
995                crate::context::sleep(std::time::Duration::from_millis(10)).await;
996                let mut client = TcpStream::connect(addr, io).unwrap();
997                client.write_all(b"hello").await.unwrap();
998                let mut buf = [0u8; 64];
999                let n = client.read(&mut buf).await.unwrap();
1000                assert_eq!(&buf[..n], b"hello");
1001                flag.set(true);
1002            });
1003
1004            crate::context::sleep(std::time::Duration::from_millis(500)).await;
1005        });
1006
1007        assert!(done.get(), "echo exchange never completed");
1008    }
1009
1010    #[test]
1011    fn tcp_socket_builder() {
1012        let socket = TcpSocket::new_v4().unwrap();
1013        socket.set_reuseaddr(true).unwrap();
1014        assert!(socket.reuseaddr().unwrap());
1015        socket.set_nodelay(true).unwrap();
1016        assert!(socket.nodelay().unwrap());
1017        socket.set_send_buffer_size(65536).unwrap();
1018        // Buffer size may be rounded up by the kernel.
1019        assert!(socket.send_buffer_size().unwrap() >= 65536);
1020    }
1021}