Skip to main content

nexus_async_rt/net/
tcp.rs

1//! Async TCP stream, listener, and pre-bind socket configuration.
2//!
3//! Wraps mio's TCP types with the runtime's IO driver for readiness-based
4//! async IO. Sockets register with mio lazily on first poll — the task
5//! pointer comes from the `Context`'s waker.
6//!
7//! # Split
8//!
9//! [`TcpStream::split`] borrows the stream into separate read/write halves
10//! for concurrent IO within a single task. [`TcpStream::into_split`]
11//! consumes the stream into owned halves that can be moved to different
12//! tasks.
13
14use std::io::{self, Read, Write};
15use std::net::SocketAddr;
16use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd};
17use std::pin::Pin;
18use std::task::{Context, Poll, Waker};
19use std::time::Duration;
20
21use mio::{Interest, Token};
22
23use super::{AsyncRead, AsyncWrite, waker_to_ptr};
24use crate::io::IoHandle;
25
26// =============================================================================
27// TcpStream
28// =============================================================================
29
30/// Async TCP stream backed by mio.
31///
32/// Created via [`TcpListener::accept`], [`TcpStream::connect`], or
33/// [`TcpSocket::connect`]. Implements [`AsyncRead`] and [`AsyncWrite`].
34///
35/// The stream registers with mio lazily on the first read or write.
36/// Uses edge-triggered epoll — registration happens once and persists.
37pub struct TcpStream {
38    inner: mio::net::TcpStream,
39    io: IoHandle,
40    token: Option<Token>,
41    /// Task pointer from the last registration. Used to detect when the
42    /// stream moves to a different task (e.g., via `into_split`) and
43    /// reregister with the IO driver to wake the correct task.
44    registered_task: *mut u8,
45}
46
47impl TcpStream {
48    /// Wrap a mio TcpStream. Registration deferred to first poll.
49    pub(crate) fn new(inner: mio::net::TcpStream, io: IoHandle) -> Self {
50        Self {
51            inner,
52            io,
53            token: None,
54            registered_task: std::ptr::null_mut(),
55        }
56    }
57
58    /// Initiate an async TCP connection to `addr`.
59    ///
60    /// The connection completes asynchronously. The first read or write
61    /// will register with mio and detect when the connection is
62    /// established.
63    ///
64    /// # Panics
65    ///
66    /// Panics if called outside a [`Runtime::block_on`](crate::Runtime::block_on)
67    /// context — fetches the runtime's [`IoHandle`] internally.
68    pub fn connect(addr: SocketAddr) -> io::Result<Self> {
69        let inner = mio::net::TcpStream::connect(addr)?;
70        Ok(Self::new(inner, IoHandle::current()))
71    }
72
73    /// Convert from a `std::net::TcpStream`.
74    ///
75    /// The stream must be set to non-blocking mode before calling this.
76    ///
77    /// # Panics
78    ///
79    /// Panics if called outside a runtime context.
80    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
81        let inner = mio::net::TcpStream::from_std(stream);
82        Ok(Self::new(inner, IoHandle::current()))
83    }
84
85    /// Convert into a `std::net::TcpStream`.
86    ///
87    /// Deregisters from mio. The returned stream is still non-blocking.
88    pub fn into_std(mut self) -> io::Result<std::net::TcpStream> {
89        if let Some(token) = self.token.take() {
90            // SAFETY: IoHandle valid (Runtime lifetime).
91            let _ = unsafe { self.io.deregister(&mut self.inner, token) };
92        }
93        let fd = self.inner.as_raw_fd();
94        std::mem::forget(self); // skip Drop (already deregistered)
95        // SAFETY: fd is valid, we own it.
96        Ok(unsafe { std::net::TcpStream::from_raw_fd(fd) })
97    }
98
99    // =========================================================================
100    // Address
101    // =========================================================================
102
103    /// Returns the local address of this stream.
104    pub fn local_addr(&self) -> io::Result<SocketAddr> {
105        self.inner.local_addr()
106    }
107
108    /// Returns the remote address of this stream.
109    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
110        self.inner.peer_addr()
111    }
112
113    // =========================================================================
114    // Socket options (via socket2)
115    // =========================================================================
116
117    /// Helper: get a socket2::Socket reference for option access.
118    fn socket_ref(&self) -> socket2::SockRef<'_> {
119        socket2::SockRef::from(&self.inner)
120    }
121
122    /// Get TCP_NODELAY.
123    pub fn nodelay(&self) -> io::Result<bool> {
124        self.inner.nodelay()
125    }
126
127    /// Set TCP_NODELAY (disable Nagle's algorithm).
128    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
129        self.inner.set_nodelay(nodelay)
130    }
131
132    /// Get IP_TTL.
133    pub fn ttl(&self) -> io::Result<u32> {
134        self.socket_ref().ttl()
135    }
136
137    /// Set IP_TTL.
138    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
139        self.socket_ref().set_ttl(ttl)
140    }
141
142    /// Get SO_LINGER.
143    pub fn linger(&self) -> io::Result<Option<Duration>> {
144        self.socket_ref().linger()
145    }
146
147    /// Set SO_LINGER.
148    pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
149        self.socket_ref().set_linger(duration)
150    }
151
152    /// Get SO_KEEPALIVE.
153    pub fn keepalive(&self) -> io::Result<bool> {
154        self.socket_ref().keepalive()
155    }
156
157    /// Set SO_KEEPALIVE.
158    pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
159        self.socket_ref().set_keepalive(keepalive)
160    }
161
162    /// Get SO_SNDBUF.
163    pub fn send_buffer_size(&self) -> io::Result<usize> {
164        self.socket_ref().send_buffer_size()
165    }
166
167    /// Set SO_SNDBUF.
168    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
169        self.socket_ref().set_send_buffer_size(size)
170    }
171
172    /// Get SO_RCVBUF.
173    pub fn recv_buffer_size(&self) -> io::Result<usize> {
174        self.socket_ref().recv_buffer_size()
175    }
176
177    /// Set SO_RCVBUF.
178    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
179        self.socket_ref().set_recv_buffer_size(size)
180    }
181
182    /// Get SO_ERROR and clear it.
183    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
184        self.socket_ref().take_error()
185    }
186
187    // =========================================================================
188    // Non-blocking try methods (no context needed)
189    // =========================================================================
190
191    /// Try to read without blocking. Returns `WouldBlock` if not ready.
192    pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
193        (&self.inner).read(buf)
194    }
195
196    /// Try to write without blocking. Returns `WouldBlock` if not ready.
197    pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
198        (&self.inner).write(buf)
199    }
200
201    /// Read without consuming from the buffer (MSG_PEEK).
202    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
203        // SAFETY: u8 and MaybeUninit<u8> have the same layout.
204        let buf = unsafe { &mut *(buf as *mut [u8] as *mut [std::mem::MaybeUninit<u8>]) };
205        self.socket_ref().peek(buf)
206    }
207
208    // =========================================================================
209    // Async convenience methods
210    // =========================================================================
211
212    /// Read bytes from the stream. Returns when at least 1 byte is read
213    /// or EOF (0 bytes).
214    pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
215        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await
216    }
217
218    /// Write bytes to the stream. Returns when at least 1 byte is written.
219    pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
220        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_write(cx, buf)).await
221    }
222
223    /// Write all bytes to the stream.
224    pub async fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
225        while !buf.is_empty() {
226            let n = self.write(buf).await?;
227            if n == 0 {
228                return Err(io::Error::new(
229                    io::ErrorKind::WriteZero,
230                    "failed to write whole buffer",
231                ));
232            }
233            buf = &buf[n..];
234        }
235        Ok(())
236    }
237
238    /// Poll for read readiness without performing IO.
239    ///
240    /// Returns `Ready(Ok(()))` if the socket has been reported readable
241    /// by epoll. Returns `Pending` if not yet ready. Use this for
242    /// sans-IO codecs that want to check readiness before feeding bytes.
243    pub fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
244        if let Err(e) = self.ensure_registered(cx) {
245            return Poll::Ready(Err(e));
246        }
247        if let Some(token) = self.token {
248            if self.io.readiness(token).readable {
249                return Poll::Ready(Ok(()));
250            }
251        }
252        Poll::Pending
253    }
254
255    /// Poll for write readiness without performing IO.
256    pub fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
257        if let Err(e) = self.ensure_registered(cx) {
258            return Poll::Ready(Err(e));
259        }
260        if let Some(token) = self.token {
261            if self.io.readiness(token).writable {
262                return Poll::Ready(Ok(()));
263            }
264        }
265        Poll::Pending
266    }
267
268    /// Wait until the stream is readable.
269    ///
270    /// Returns when epoll reports the socket as readable. After this
271    /// returns, [`try_read`](Self::try_read) should succeed.
272    pub async fn readable(&mut self) -> io::Result<()> {
273        std::future::poll_fn(|cx| self.poll_read_ready(cx)).await
274    }
275
276    /// Wait until the stream is writable.
277    pub async fn writable(&mut self) -> io::Result<()> {
278        std::future::poll_fn(|cx| self.poll_write_ready(cx)).await
279    }
280
281    // Note: after a successful read or WouldBlock, the readable flag is
282    // Correctly implementing them requires tracking readiness state from
283    // epoll events (like tokio's internal readiness tracking). Zero-length
284    // reads/writes don't reliably probe socket readiness on Linux.
285    // Use poll_read/poll_write or try_read/try_write instead.
286
287    // =========================================================================
288    // Split
289    // =========================================================================
290
291    /// Split into borrowed read and write halves.
292    ///
293    /// Both halves borrow the stream — they can be used concurrently
294    /// within a single task but cannot be moved to different tasks.
295    pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
296        let ptr = std::ptr::from_mut(self);
297        (
298            ReadHalf {
299                stream: ptr,
300                _marker: std::marker::PhantomData,
301            },
302            WriteHalf {
303                stream: ptr,
304                _marker: std::marker::PhantomData,
305            },
306        )
307    }
308
309    /// Split into owned read and write halves.
310    ///
311    /// The halves can be moved to different spawned tasks on the same
312    /// single-threaded runtime (`!Send` — not across threads). The IO
313    /// driver automatically updates the task pointer when a half is
314    /// polled from a different task. Use [`OwnedReadHalf::reunite`]
315    /// to reassemble the stream.
316    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
317        use std::rc::Rc;
318        let shared = Rc::new(std::cell::UnsafeCell::new(self));
319        (
320            OwnedReadHalf {
321                stream: Rc::clone(&shared),
322            },
323            OwnedWriteHalf { stream: shared },
324        )
325    }
326
327    // =========================================================================
328    // Registration (internal)
329    // =========================================================================
330
331    /// Ensure registered with mio and the correct task waker.
332    ///
333    /// First call: registers with mio. Subsequent calls: checks if the
334    /// task pointer changed (stream moved to a different task via
335    /// `into_split`). If so, updates the IO driver's waker.
336    #[inline(always)]
337    fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
338        let task_ptr = waker_to_ptr(cx);
339        if let Some(token) = self.token {
340            // Already registered — check if task changed.
341            if task_ptr != self.registered_task {
342                self.io.set_waker(token, cx.waker().clone());
343                self.registered_task = task_ptr;
344            }
345            return Ok(());
346        }
347        self.do_register(task_ptr, cx.waker().clone())
348    }
349
350    #[cold]
351    fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
352        let interest = Interest::READABLE | Interest::WRITABLE;
353        let token = self.io.register(&mut self.inner, interest, waker)?;
354        self.token = Some(token);
355        self.registered_task = task_ptr;
356        Ok(())
357    }
358}
359
360impl AsyncRead for TcpStream {
361    fn poll_read(
362        self: Pin<&mut Self>,
363        cx: &mut Context<'_>,
364        buf: &mut [u8],
365    ) -> Poll<io::Result<usize>> {
366        let this = self.get_mut();
367        if let Err(e) = this.ensure_registered(cx) {
368            return Poll::Ready(Err(e));
369        }
370        match this.inner.read(buf) {
371            Ok(n) => Poll::Ready(Ok(n)),
372            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
373                // Clear readable — wait for next epoll notification.
374                if let Some(token) = this.token {
375                    this.io.clear_readable(token);
376                }
377                Poll::Pending
378            }
379            Err(e) => Poll::Ready(Err(e)),
380        }
381    }
382}
383
384impl AsyncWrite for TcpStream {
385    fn poll_write(
386        self: Pin<&mut Self>,
387        cx: &mut Context<'_>,
388        buf: &[u8],
389    ) -> Poll<io::Result<usize>> {
390        let this = self.get_mut();
391        if let Err(e) = this.ensure_registered(cx) {
392            return Poll::Ready(Err(e));
393        }
394        match this.inner.write(buf) {
395            Ok(n) => Poll::Ready(Ok(n)),
396            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
397                if let Some(token) = this.token {
398                    this.io.clear_writable(token);
399                }
400                Poll::Pending
401            }
402            Err(e) => Poll::Ready(Err(e)),
403        }
404    }
405
406    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
407        let this = self.get_mut();
408        if let Err(e) = this.ensure_registered(cx) {
409            return Poll::Ready(Err(e));
410        }
411        match this.inner.flush() {
412            Ok(()) => Poll::Ready(Ok(())),
413            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
414                if let Some(token) = this.token {
415                    this.io.clear_writable(token);
416                }
417                Poll::Pending
418            }
419            Err(e) => Poll::Ready(Err(e)),
420        }
421    }
422
423    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
424        let this = self.get_mut();
425        match this.inner.shutdown(std::net::Shutdown::Write) {
426            Ok(()) => Poll::Ready(Ok(())),
427            Err(e) if e.kind() == io::ErrorKind::NotConnected => Poll::Ready(Ok(())),
428            Err(e) => Poll::Ready(Err(e)),
429        }
430    }
431}
432
433impl std::fmt::Debug for TcpStream {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_struct("TcpStream")
436            .field("fd", &self.inner.as_raw_fd())
437            .field("registered", &self.token.is_some())
438            .finish()
439    }
440}
441
442impl AsFd for TcpStream {
443    fn as_fd(&self) -> BorrowedFd<'_> {
444        self.inner.as_fd()
445    }
446}
447
448impl AsRawFd for TcpStream {
449    fn as_raw_fd(&self) -> RawFd {
450        self.inner.as_raw_fd()
451    }
452}
453
454impl Drop for TcpStream {
455    fn drop(&mut self) {
456        if let Some(token) = self.token {
457            // SAFETY: IoHandle valid (Runtime lifetime).
458            let _ = unsafe { self.io.deregister(&mut self.inner, token) };
459        }
460    }
461}
462
463// =============================================================================
464// ReadHalf / WriteHalf (borrowed split)
465// =============================================================================
466
467/// Borrowed read half of a [`TcpStream`].
468///
469/// Created by [`TcpStream::split`]. Borrows the stream — cannot be moved
470/// to a different task. Implements [`AsyncRead`].
471pub struct ReadHalf<'a> {
472    stream: *mut TcpStream,
473    // Tie lifetime to the borrow of the stream.
474    _marker: std::marker::PhantomData<&'a mut TcpStream>,
475}
476
477// The split constructor actually gives us two raw pointers to the same stream.
478// This is safe because ReadHalf only reads and WriteHalf only writes — no
479// aliased mutation of the same fields. Single-threaded.
480impl ReadHalf<'_> {
481    fn stream(&mut self) -> &mut TcpStream {
482        // SAFETY: Borrowed from split(), single-threaded, read-only side.
483        unsafe { &mut *self.stream }
484    }
485}
486
487impl AsyncRead for ReadHalf<'_> {
488    fn poll_read(
489        self: Pin<&mut Self>,
490        cx: &mut Context<'_>,
491        buf: &mut [u8],
492    ) -> Poll<io::Result<usize>> {
493        let this = self.get_mut();
494        Pin::new(this.stream()).poll_read(cx, buf)
495    }
496}
497
498/// Borrowed write half of a [`TcpStream`].
499///
500/// Created by [`TcpStream::split`]. Borrows the stream — cannot be moved
501/// to a different task. Implements [`AsyncWrite`].
502pub struct WriteHalf<'a> {
503    stream: *mut TcpStream,
504    _marker: std::marker::PhantomData<&'a mut TcpStream>,
505}
506
507impl WriteHalf<'_> {
508    fn stream(&mut self) -> &mut TcpStream {
509        // SAFETY: Borrowed from split(), single-threaded, write-only side.
510        unsafe { &mut *self.stream }
511    }
512}
513
514impl AsyncWrite for WriteHalf<'_> {
515    fn poll_write(
516        self: Pin<&mut Self>,
517        cx: &mut Context<'_>,
518        buf: &[u8],
519    ) -> Poll<io::Result<usize>> {
520        let this = self.get_mut();
521        Pin::new(this.stream()).poll_write(cx, buf)
522    }
523
524    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
525        let this = self.get_mut();
526        Pin::new(this.stream()).poll_flush(cx)
527    }
528
529    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
530        let this = self.get_mut();
531        Pin::new(this.stream()).poll_shutdown(cx)
532    }
533}
534
535// =============================================================================
536// OwnedReadHalf / OwnedWriteHalf (owned split)
537// =============================================================================
538
539/// Owned read half of a [`TcpStream`].
540///
541/// Created by [`TcpStream::into_split`]. Can be moved to a different task.
542pub struct OwnedReadHalf {
543    stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
544}
545
546impl OwnedReadHalf {
547    /// Reassemble the stream from its halves.
548    ///
549    /// Returns `Err` if the halves don't belong to the same stream.
550    pub fn reunite(self, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
551        if std::rc::Rc::ptr_eq(&self.stream, &write.stream) {
552            drop(write);
553            let cell = std::rc::Rc::try_unwrap(self.stream).map_err(|_| ReuniteError)?;
554            Ok(cell.into_inner())
555        } else {
556            Err(ReuniteError)
557        }
558    }
559
560    /// Returns the peer address.
561    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
562        // SAFETY: single-threaded, immutable field access.
563        unsafe { &*self.stream.get() }.peer_addr()
564    }
565
566    /// Returns the local address.
567    pub fn local_addr(&self) -> io::Result<SocketAddr> {
568        unsafe { &*self.stream.get() }.local_addr()
569    }
570}
571
572impl AsyncRead for OwnedReadHalf {
573    fn poll_read(
574        self: Pin<&mut Self>,
575        cx: &mut Context<'_>,
576        buf: &mut [u8],
577    ) -> Poll<io::Result<usize>> {
578        // SAFETY: single-threaded. Only the read half calls poll_read.
579        let stream = unsafe { &mut *self.stream.get() };
580        Pin::new(stream).poll_read(cx, buf)
581    }
582}
583
584/// Owned write half of a [`TcpStream`].
585///
586/// Created by [`TcpStream::into_split`]. Can be moved to a different task.
587pub struct OwnedWriteHalf {
588    stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
589}
590
591impl OwnedWriteHalf {
592    /// Reassemble the stream from its halves.
593    pub fn reunite(self, read: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
594        read.reunite(self)
595    }
596
597    /// Returns the peer address.
598    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
599        unsafe { &*self.stream.get() }.peer_addr()
600    }
601
602    /// Returns the local address.
603    pub fn local_addr(&self) -> io::Result<SocketAddr> {
604        unsafe { &*self.stream.get() }.local_addr()
605    }
606}
607
608impl AsyncWrite for OwnedWriteHalf {
609    fn poll_write(
610        self: Pin<&mut Self>,
611        cx: &mut Context<'_>,
612        buf: &[u8],
613    ) -> Poll<io::Result<usize>> {
614        // SAFETY: single-threaded. Only the write half calls poll_write.
615        let stream = unsafe { &mut *self.stream.get() };
616        Pin::new(stream).poll_write(cx, buf)
617    }
618
619    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
620        let stream = unsafe { &mut *self.stream.get() };
621        Pin::new(stream).poll_flush(cx)
622    }
623
624    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
625        let stream = unsafe { &mut *self.stream.get() };
626        Pin::new(stream).poll_shutdown(cx)
627    }
628}
629
630/// Error returned by [`OwnedReadHalf::reunite`] when the halves don't match.
631#[derive(Debug)]
632pub struct ReuniteError;
633
634impl std::fmt::Display for ReuniteError {
635    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        write!(f, "halves do not belong to the same TcpStream")
637    }
638}
639
640impl std::error::Error for ReuniteError {}
641
642// =============================================================================
643// TcpListener
644// =============================================================================
645
646/// Async TCP listener backed by mio.
647///
648/// Bind with [`TcpListener::bind`] or [`TcpSocket::listen`], then call
649/// [`accept`](Self::accept) to await incoming connections.
650pub struct TcpListener {
651    inner: mio::net::TcpListener,
652    io: IoHandle,
653    token: Option<Token>,
654    registered_task: *mut u8,
655}
656
657impl TcpListener {
658    /// Bind to `addr`. Registration deferred to first `accept` poll.
659    ///
660    /// # Panics
661    ///
662    /// Panics if called outside a [`Runtime::block_on`](crate::Runtime::block_on)
663    /// context — fetches the runtime's [`IoHandle`] internally.
664    pub fn bind(addr: SocketAddr) -> io::Result<Self> {
665        let inner = mio::net::TcpListener::bind(addr)?;
666        Ok(Self {
667            inner,
668            io: IoHandle::current(),
669            token: None,
670            registered_task: std::ptr::null_mut(),
671        })
672    }
673
674    /// Convert from a `std::net::TcpListener`.
675    ///
676    /// # Panics
677    ///
678    /// Panics if called outside a runtime context.
679    pub fn from_std(listener: std::net::TcpListener) -> io::Result<Self> {
680        let inner = mio::net::TcpListener::from_std(listener);
681        Ok(Self {
682            inner,
683            io: IoHandle::current(),
684            token: None,
685            registered_task: std::ptr::null_mut(),
686        })
687    }
688
689    /// Returns the local address this listener is bound to.
690    pub fn local_addr(&self) -> io::Result<SocketAddr> {
691        self.inner.local_addr()
692    }
693
694    /// Get IP_TTL.
695    pub fn ttl(&self) -> io::Result<u32> {
696        socket2::SockRef::from(&self.inner).ttl()
697    }
698
699    /// Set IP_TTL.
700    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
701        socket2::SockRef::from(&self.inner).set_ttl(ttl)
702    }
703
704    /// Accept a new TCP connection.
705    pub fn accept(&mut self) -> Accept<'_> {
706        Accept { listener: self }
707    }
708
709    /// Ensure registered with mio and the correct task waker.
710    #[inline(always)]
711    fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
712        let task_ptr = waker_to_ptr(cx);
713        if let Some(token) = self.token {
714            if task_ptr != self.registered_task {
715                self.io.set_waker(token, cx.waker().clone());
716                self.registered_task = task_ptr;
717            }
718            return Ok(());
719        }
720        self.do_register(task_ptr, cx.waker().clone())
721    }
722
723    #[cold]
724    fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
725        let token = self
726            .io
727            .register(&mut self.inner, Interest::READABLE, waker)?;
728        self.token = Some(token);
729        self.registered_task = task_ptr;
730        Ok(())
731    }
732}
733
734impl std::fmt::Debug for TcpListener {
735    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
736        f.debug_struct("TcpListener")
737            .field("fd", &self.inner.as_raw_fd())
738            .field("registered", &self.token.is_some())
739            .finish()
740    }
741}
742
743impl AsFd for TcpListener {
744    fn as_fd(&self) -> BorrowedFd<'_> {
745        self.inner.as_fd()
746    }
747}
748
749impl AsRawFd for TcpListener {
750    fn as_raw_fd(&self) -> RawFd {
751        self.inner.as_raw_fd()
752    }
753}
754
755impl Drop for TcpListener {
756    fn drop(&mut self) {
757        if let Some(token) = self.token {
758            let _ = unsafe { self.io.deregister(&mut self.inner, token) };
759        }
760    }
761}
762
763/// Future returned by [`TcpListener::accept`].
764pub struct Accept<'a> {
765    listener: &'a mut TcpListener,
766}
767
768impl std::future::Future for Accept<'_> {
769    type Output = io::Result<(TcpStream, SocketAddr)>;
770
771    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
772        let this = self.get_mut();
773        if let Err(e) = this.listener.ensure_registered(cx) {
774            return Poll::Ready(Err(e));
775        }
776        match this.listener.inner.accept() {
777            Ok((stream, addr)) => {
778                let tcp = TcpStream::new(stream, this.listener.io);
779                Poll::Ready(Ok((tcp, addr)))
780            }
781            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
782            Err(e) => Poll::Ready(Err(e)),
783        }
784    }
785}
786
787// =============================================================================
788// TcpSocket — pre-bind configuration
789// =============================================================================
790
791/// TCP socket builder for configuring options before bind/connect.
792///
793/// Wraps `socket2::Socket` to provide access to socket options that
794/// must be set before binding (SO_REUSEADDR, SO_REUSEPORT, buffer
795/// sizes, etc.).
796///
797/// # Examples
798///
799/// ```ignore
800/// let socket = TcpSocket::new_v4()?;
801/// socket.set_reuseaddr(true)?;
802/// socket.set_recv_buffer_size(1024 * 1024)?;
803/// let listener = socket.listen(1024, io)?;
804/// ```
805pub struct TcpSocket {
806    inner: socket2::Socket,
807}
808
809impl TcpSocket {
810    /// Create a new IPv4 TCP socket.
811    pub fn new_v4() -> io::Result<Self> {
812        let inner = socket2::Socket::new(
813            socket2::Domain::IPV4,
814            socket2::Type::STREAM,
815            Some(socket2::Protocol::TCP),
816        )?;
817        inner.set_nonblocking(true)?;
818        Ok(Self { inner })
819    }
820
821    /// Create a new IPv6 TCP socket.
822    pub fn new_v6() -> io::Result<Self> {
823        let inner = socket2::Socket::new(
824            socket2::Domain::IPV6,
825            socket2::Type::STREAM,
826            Some(socket2::Protocol::TCP),
827        )?;
828        inner.set_nonblocking(true)?;
829        Ok(Self { inner })
830    }
831
832    // -- Socket options --
833
834    /// Set SO_REUSEADDR.
835    pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
836        self.inner.set_reuse_address(reuseaddr)
837    }
838
839    /// Get SO_REUSEADDR.
840    pub fn reuseaddr(&self) -> io::Result<bool> {
841        self.inner.reuse_address()
842    }
843
844    /// Set SO_REUSEPORT (Unix only).
845    #[cfg(unix)]
846    pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> {
847        self.inner.set_reuse_port(reuseport)
848    }
849
850    /// Get SO_REUSEPORT (Unix only).
851    #[cfg(unix)]
852    pub fn reuseport(&self) -> io::Result<bool> {
853        self.inner.reuse_port()
854    }
855
856    /// Set SO_KEEPALIVE.
857    pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
858        self.inner.set_keepalive(keepalive)
859    }
860
861    /// Get SO_KEEPALIVE.
862    pub fn keepalive(&self) -> io::Result<bool> {
863        self.inner.keepalive()
864    }
865
866    /// Set TCP_NODELAY.
867    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
868        self.inner.set_nodelay(nodelay)
869    }
870
871    /// Get TCP_NODELAY.
872    pub fn nodelay(&self) -> io::Result<bool> {
873        self.inner.nodelay()
874    }
875
876    /// Set SO_LINGER.
877    pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
878        self.inner.set_linger(duration)
879    }
880
881    /// Get SO_LINGER.
882    pub fn linger(&self) -> io::Result<Option<Duration>> {
883        self.inner.linger()
884    }
885
886    /// Set SO_SNDBUF.
887    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
888        self.inner.set_send_buffer_size(size)
889    }
890
891    /// Get SO_SNDBUF.
892    pub fn send_buffer_size(&self) -> io::Result<usize> {
893        self.inner.send_buffer_size()
894    }
895
896    /// Set SO_RCVBUF.
897    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
898        self.inner.set_recv_buffer_size(size)
899    }
900
901    /// Get SO_RCVBUF.
902    pub fn recv_buffer_size(&self) -> io::Result<usize> {
903        self.inner.recv_buffer_size()
904    }
905
906    /// Set IP_TTL.
907    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
908        self.inner.set_ttl(ttl)
909    }
910
911    /// Get IP_TTL.
912    pub fn ttl(&self) -> io::Result<u32> {
913        self.inner.ttl()
914    }
915
916    // -- Bind, connect, listen --
917
918    /// Bind the socket to `addr`.
919    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
920        self.inner.bind(&addr.into())
921    }
922
923    /// Connect to `addr` and return a [`TcpStream`].
924    ///
925    /// The connection completes asynchronously (non-blocking socket).
926    /// The first read or write will detect when the connection is
927    /// established.
928    ///
929    /// # Panics
930    ///
931    /// Panics if called outside a [`Runtime::block_on`](crate::Runtime::block_on)
932    /// context — fetches the runtime's [`IoHandle`] internally.
933    pub fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
934        // Non-blocking connect returns EINPROGRESS/EALREADY — that's
935        // normal, not an error. Suppress these.
936        match self.inner.connect(&addr.into()) {
937            Ok(()) => {}
938            Err(e)
939                if e.raw_os_error() == Some(libc::EINPROGRESS)
940                    || e.raw_os_error() == Some(libc::EALREADY) => {}
941            Err(e) => return Err(e),
942        }
943        let std_stream: std::net::TcpStream = self.inner.into();
944        let mio_stream = mio::net::TcpStream::from_std(std_stream);
945        Ok(TcpStream::new(mio_stream, IoHandle::current()))
946    }
947
948    /// Start listening with the given backlog and return a [`TcpListener`].
949    ///
950    /// # Panics
951    ///
952    /// Panics if called outside a runtime context.
953    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
954        self.inner.listen(backlog)?;
955        let std_listener: std::net::TcpListener = self.inner.into();
956        let mio_listener = mio::net::TcpListener::from_std(std_listener);
957        Ok(TcpListener {
958            inner: mio_listener,
959            io: IoHandle::current(),
960            token: None,
961            registered_task: std::ptr::null_mut(),
962        })
963    }
964}
965
966impl std::fmt::Debug for TcpSocket {
967    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
968        f.debug_struct("TcpSocket")
969            .field("fd", &self.inner.as_raw_fd())
970            .finish()
971    }
972}
973
974impl AsFd for TcpSocket {
975    fn as_fd(&self) -> BorrowedFd<'_> {
976        self.inner.as_fd()
977    }
978}
979
980impl AsRawFd for TcpSocket {
981    fn as_raw_fd(&self) -> RawFd {
982        self.inner.as_raw_fd()
983    }
984}
985
986// =============================================================================
987// Tests
988// =============================================================================
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993    use crate::{Runtime, spawn_boxed};
994    use nexus_rt::WorldBuilder;
995    use std::cell::Cell;
996    use std::rc::Rc;
997
998    #[test]
999    #[cfg_attr(miri, ignore)] // Requires real TCP sockets — not miri-compatible.
1000    fn tcp_echo() {
1001        let wb = WorldBuilder::new();
1002        let mut world = wb.build();
1003        let mut rt = Runtime::new(&mut world);
1004
1005        let done = Rc::new(Cell::new(false));
1006        let done2 = done.clone();
1007
1008        rt.block_on(async move {
1009            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).expect("bind failed");
1010            let addr = listener.local_addr().unwrap();
1011            spawn_boxed(async move {
1012                let mut listener = listener;
1013                let (mut stream, _peer) = listener.accept().await.unwrap();
1014                let mut buf = [0u8; 64];
1015                let n = stream.read(&mut buf).await.unwrap();
1016                stream.write_all(&buf[..n]).await.unwrap();
1017            });
1018
1019            let flag = done2;
1020            spawn_boxed(async move {
1021                crate::context::sleep(std::time::Duration::from_millis(10)).await;
1022                let mut client = TcpStream::connect(addr).unwrap();
1023                client.write_all(b"hello").await.unwrap();
1024                let mut buf = [0u8; 64];
1025                let n = client.read(&mut buf).await.unwrap();
1026                assert_eq!(&buf[..n], b"hello");
1027                flag.set(true);
1028            });
1029
1030            crate::context::sleep(std::time::Duration::from_millis(500)).await;
1031        });
1032
1033        assert!(done.get(), "echo exchange never completed");
1034    }
1035
1036    #[test]
1037    #[cfg_attr(miri, ignore)] // Requires real TCP sockets — not miri-compatible.
1038    fn tcp_socket_builder() {
1039        let socket = TcpSocket::new_v4().unwrap();
1040        socket.set_reuseaddr(true).unwrap();
1041        assert!(socket.reuseaddr().unwrap());
1042        socket.set_nodelay(true).unwrap();
1043        assert!(socket.nodelay().unwrap());
1044        socket.set_send_buffer_size(65536).unwrap();
1045        // Buffer size may be rounded up by the kernel.
1046        assert!(socket.send_buffer_size().unwrap() >= 65536);
1047    }
1048}