Skip to main content

moduvex_runtime/net/
tcp_stream.rs

1//! Async `TcpStream` — non-blocking bidirectional TCP byte stream.
2//!
3//! Implements [`AsyncRead`] and [`AsyncWrite`] using `libc::read` / `libc::write`.
4//! The underlying fd is registered with the reactor; `readable()` / `writable()`
5//! futures from `IoSource` are used to suspend until the OS signals readiness.
6
7use std::future::Future;
8use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use crate::platform::sys::{set_nonblocking, Interest};
14use crate::reactor::source::{next_token, IoSource};
15
16use super::sockaddr::{reclaim_raw_sockaddr, sockaddr_to_socketaddr, socketaddr_to_raw};
17use super::{AsyncRead, AsyncWrite};
18
19// ── TcpStream ─────────────────────────────────────────────────────────────────
20
21/// Async TCP stream. Implements `AsyncRead` + `AsyncWrite`.
22pub struct TcpStream {
23    source: IoSource,
24}
25
26impl TcpStream {
27    /// Connect to `addr` asynchronously.
28    ///
29    /// Creates a non-blocking socket and starts a `connect()` call. Returns a
30    /// [`ConnectFuture`] that resolves once the TCP handshake completes.
31    pub fn connect(addr: SocketAddr) -> ConnectFuture {
32        ConnectFuture::new(addr)
33    }
34
35    /// Wrap an already-connected raw file descriptor in a `TcpStream`.
36    ///
37    /// `fd` must be a connected, non-blocking TCP socket.
38    ///
39    /// # Errors
40    /// Returns `Err` if reactor registration fails.
41    pub(crate) fn from_raw_fd(fd: i32) -> io::Result<Self> {
42        // Register for both directions so we can arm either on demand.
43        let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
44        Ok(Self { source })
45    }
46
47    /// Return the peer address of the connection.
48    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
49        peer_addr(self.source.raw())
50    }
51
52    /// Return the local address of the connection.
53    pub fn local_addr(&self) -> io::Result<SocketAddr> {
54        local_addr(self.source.raw())
55    }
56}
57
58impl Drop for TcpStream {
59    fn drop(&mut self) {
60        let fd = self.source.raw();
61        // IoSource Drop deregisters from reactor; close the fd here.
62        // SAFETY: we own `fd` exclusively; it is valid until this drop runs.
63        unsafe { libc::close(fd) };
64    }
65}
66
67#[cfg(unix)]
68impl std::os::unix::io::AsRawFd for TcpStream {
69    /// Return the underlying file descriptor.
70    ///
71    /// The fd remains valid for the lifetime of this `TcpStream`. Callers
72    /// must not close or duplicate it without coordinating with the owner.
73    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
74        self.source.raw()
75    }
76}
77
78// ── AsyncRead ─────────────────────────────────────────────────────────────────
79
80impl AsyncRead for TcpStream {
81    fn poll_read(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84        buf: &mut [u8],
85    ) -> Poll<io::Result<usize>> {
86        let fd = self.source.raw();
87
88        // Try the read immediately — may already have data in the kernel buffer.
89        // SAFETY: `fd` is a valid non-blocking socket; `buf` is a valid slice.
90        let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
91        if n > 0 {
92            return Poll::Ready(Ok(n as usize));
93        }
94        if n == 0 {
95            return Poll::Ready(Ok(0)); // EOF
96        }
97
98        let err = io::Error::last_os_error();
99        if err.kind() != io::ErrorKind::WouldBlock {
100            return Poll::Ready(Err(err));
101        }
102
103        // No data yet — register waker and wait for READABLE event.
104        match Pin::new(&mut self.source.readable()).poll(cx) {
105            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
106            Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
107        }
108    }
109}
110
111// ── AsyncWrite ────────────────────────────────────────────────────────────────
112
113impl AsyncWrite for TcpStream {
114    fn poll_write(
115        self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117        buf: &[u8],
118    ) -> Poll<io::Result<usize>> {
119        let fd = self.source.raw();
120
121        // SAFETY: `fd` is a valid non-blocking socket; `buf` is a valid slice.
122        let n = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) };
123        if n >= 0 {
124            return Poll::Ready(Ok(n as usize));
125        }
126
127        let err = io::Error::last_os_error();
128        if err.kind() != io::ErrorKind::WouldBlock {
129            return Poll::Ready(Err(err));
130        }
131
132        // Socket send buffer full — wait for WRITABLE event.
133        match Pin::new(&mut self.source.writable()).poll(cx) {
134            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
135            Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
136        }
137    }
138
139    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
140        // TCP sockets are kernel-buffered — flush is a no-op.
141        Poll::Ready(Ok(()))
142    }
143
144    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145        let fd = self.source.raw();
146        // SAFETY: `fd` is a valid socket; SHUT_WR is a documented constant.
147        let rc = unsafe { libc::shutdown(fd, libc::SHUT_WR) };
148        if rc == -1 {
149            Poll::Ready(Err(io::Error::last_os_error()))
150        } else {
151            Poll::Ready(Ok(()))
152        }
153    }
154}
155
156// ── ConnectFuture ─────────────────────────────────────────────────────────────
157
158/// Future returned by [`TcpStream::connect`].
159///
160/// Phase 1: creates the socket and calls `connect()` (returns EINPROGRESS).
161/// Phase 2: stores waker in reactor registry; on WRITABLE event, checks SO_ERROR.
162pub struct ConnectFuture {
163    state: ConnectState,
164}
165
166enum ConnectState {
167    /// Not yet started — stores the address for lazy socket creation.
168    Init(SocketAddr),
169    /// Socket created, connect() in progress; waiting for WRITABLE.
170    /// `waker_armed` tracks whether we already registered the waker this poll.
171    Connecting {
172        fd: i32,
173        token: usize,
174        /// True after initial `register()` — stays true across polls.
175        registered: bool,
176    },
177    /// Done (stream returned or error returned).
178    Done,
179}
180
181impl ConnectFuture {
182    fn new(addr: SocketAddr) -> Self {
183        Self {
184            state: ConnectState::Init(addr),
185        }
186    }
187}
188
189impl Future for ConnectFuture {
190    type Output = io::Result<TcpStream>;
191
192    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193        loop {
194            match &mut self.state {
195                ConnectState::Init(addr) => {
196                    let addr = *addr;
197                    match start_connect(addr) {
198                        Err(e) => {
199                            self.state = ConnectState::Done;
200                            return Poll::Ready(Err(e));
201                        }
202                        Ok((fd, connected)) => {
203                            if connected {
204                                // Instant connect — wrap fd directly.
205                                self.state = ConnectState::Done;
206                                return Poll::Ready(TcpStream::from_raw_fd(fd));
207                            }
208                            // Register fd for WRITABLE in the reactor so we
209                            // get woken when connect() completes.
210                            let token = next_token();
211                            if let Err(e) = crate::reactor::with_reactor(|r| {
212                                r.register(fd, token, Interest::WRITABLE)
213                            }) {
214                                unsafe { libc::close(fd) };
215                                self.state = ConnectState::Done;
216                                return Poll::Ready(Err(e));
217                            }
218                            self.state = ConnectState::Connecting {
219                                fd,
220                                token,
221                                registered: true,
222                            };
223                            // Fall through to Connecting arm.
224                        }
225                    }
226                }
227
228                ConnectState::Connecting { fd, token, .. } => {
229                    let fd = *fd;
230                    let token = *token;
231
232                    // Store waker so reactor wakes us on WRITABLE.
233                    crate::reactor::with_reactor_mut(|r| {
234                        r.wakers.set_write_waker(token, cx.waker().clone());
235                    });
236
237                    // Check if connect completed (may have raced since last poll).
238                    match get_so_error(fd) {
239                        Err(e) => {
240                            // Clean up reactor registration.
241                            let _ = crate::reactor::with_reactor_mut(|r| {
242                                r.deregister_with_token(fd, token)
243                            });
244                            self.state = ConnectState::Done;
245                            return Poll::Ready(Err(e));
246                        }
247                        Ok(Some(os_err)) => {
248                            let _ = crate::reactor::with_reactor_mut(|r| {
249                                r.deregister_with_token(fd, token)
250                            });
251                            unsafe { libc::close(fd) };
252                            self.state = ConnectState::Done;
253                            return Poll::Ready(Err(io::Error::from_raw_os_error(os_err)));
254                        }
255                        Ok(None) => {
256                            // SO_ERROR == 0 means connected. But we may be
257                            // polled here before the WRITABLE event fires on
258                            // the very first poll (connect still in progress).
259                            // Distinguish by checking if the socket is writable NOW.
260                            if is_writable_now(fd) {
261                                // Connect complete — deregister old token, wrap fd.
262                                let _ = crate::reactor::with_reactor_mut(|r| {
263                                    r.deregister_with_token(fd, token)
264                                });
265                                self.state = ConnectState::Done;
266                                return Poll::Ready(TcpStream::from_raw_fd(fd));
267                            }
268                            // Not writable yet — waker stored above, wait.
269                            return Poll::Pending;
270                        }
271                    }
272                }
273
274                ConnectState::Done => {
275                    return Poll::Ready(Err(io::Error::other(
276                        "ConnectFuture polled after completion",
277                    )));
278                }
279            }
280        }
281    }
282}
283
284impl Drop for ConnectFuture {
285    fn drop(&mut self) {
286        if let ConnectState::Connecting { fd, token, .. } = self.state {
287            // Clean up reactor and close fd if the future is dropped mid-connect.
288            let _ = crate::reactor::with_reactor_mut(|r| r.deregister_with_token(fd, token));
289            // SAFETY: fd is a valid socket we own; future is being dropped.
290            unsafe { libc::close(fd) };
291        }
292    }
293}
294
295/// Non-blocking poll: returns true if `fd` is writable right now.
296///
297/// Uses `poll(2)` with a zero timeout to probe write-readiness.
298/// Unlike `select(2)`, this has no FD_SETSIZE limit.
299fn is_writable_now(fd: i32) -> bool {
300    // SAFETY: pollfd is a plain C struct; poll is a documented POSIX syscall.
301    unsafe {
302        let mut pfd = libc::pollfd {
303            fd,
304            events: libc::POLLOUT,
305            revents: 0,
306        };
307        let n = libc::poll(&mut pfd, 1, 0);
308        n > 0 && (pfd.revents & libc::POLLOUT) != 0
309    }
310}
311
312// ── Unix helpers ──────────────────────────────────────────────────────────────
313
314/// Create a non-blocking TCP socket and call `connect()`.
315///
316/// Returns `(fd, connected)` where `connected` is `true` if the connection
317/// completed immediately (rare, e.g. loopback).
318fn start_connect(addr: SocketAddr) -> io::Result<(i32, bool)> {
319    let family = match addr {
320        SocketAddr::V4(_) => libc::AF_INET,
321        SocketAddr::V6(_) => libc::AF_INET6,
322    };
323    // SAFETY: documented syscall with valid AF_INET/AF_INET6 constants.
324    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
325    if fd == -1 {
326        return Err(io::Error::last_os_error());
327    }
328    set_nonblocking(fd)?;
329
330    let (sa, sa_len) = socketaddr_to_raw(addr);
331    // SAFETY: `fd` is a valid socket; `sa`/`sa_len` describe a valid sockaddr.
332    let rc = unsafe { libc::connect(fd, sa, sa_len) };
333    // SAFETY: we used Box::into_raw in socketaddr_to_raw; reclaim the Box now.
334    unsafe { reclaim_raw_sockaddr(sa, addr) };
335
336    if rc == 0 {
337        return Ok((fd, true)); // instant connect
338    }
339
340    let err = io::Error::last_os_error();
341    // EINPROGRESS (or EAGAIN on some platforms) means "in progress" — normal.
342    if err.raw_os_error() == Some(libc::EINPROGRESS) {
343        return Ok((fd, false));
344    }
345
346    // Real error — close and propagate.
347    unsafe { libc::close(fd) };
348    Err(err)
349}
350
351/// Read `SO_ERROR` on `fd` to check connect completion status.
352///
353/// Returns `Ok(None)` on success, `Ok(Some(errno))` on connect failure,
354/// `Err(...)` if getsockopt itself fails.
355fn get_so_error(fd: i32) -> io::Result<Option<i32>> {
356    let mut val: libc::c_int = 0;
357    let mut len = std::mem::size_of_val(&val) as libc::socklen_t;
358    // SAFETY: `fd` is a valid socket; `val`/`len` are correctly sized.
359    let rc = unsafe {
360        libc::getsockopt(
361            fd,
362            libc::SOL_SOCKET,
363            libc::SO_ERROR,
364            &mut val as *mut libc::c_int as *mut libc::c_void,
365            &mut len,
366        )
367    };
368    if rc == -1 {
369        return Err(io::Error::last_os_error());
370    }
371    Ok(if val == 0 { None } else { Some(val) })
372}
373
374/// Query the peer address of `fd`.
375fn peer_addr(fd: i32) -> io::Result<SocketAddr> {
376    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
377    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
378    // SAFETY: `fd` is a valid connected socket; `addr` is large enough.
379    let rc = unsafe { libc::getpeername(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
380    if rc == -1 {
381        return Err(io::Error::last_os_error());
382    }
383    sockaddr_to_socketaddr(&addr, len)
384}
385
386/// Query the local address of `fd`.
387fn local_addr(fd: i32) -> io::Result<SocketAddr> {
388    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
389    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
390    // SAFETY: `fd` is a valid socket; `addr` is large enough.
391    let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
392    if rc == -1 {
393        return Err(io::Error::last_os_error());
394    }
395    sockaddr_to_socketaddr(&addr, len)
396}
397
398// ── Tests ─────────────────────────────────────────────────────────────────────
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::executor::block_on_with_spawn;
404    use crate::net::TcpListener;
405
406    /// Poll-based async read: keeps polling until `n` bytes are gathered.
407    async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) {
408        use std::future::poll_fn;
409        let mut filled = 0;
410        while filled < buf.len() {
411            let n = poll_fn(|cx| Pin::new(&mut *stream).poll_read(cx, &mut buf[filled..]))
412                .await
413                .expect("read_exact: io error");
414            if n == 0 {
415                break;
416            } // EOF
417            filled += n;
418        }
419    }
420
421    /// Poll-based async write: keeps polling until all bytes are sent.
422    async fn write_all(stream: &mut TcpStream, buf: &[u8]) {
423        use std::future::poll_fn;
424        let mut sent = 0;
425        while sent < buf.len() {
426            let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, &buf[sent..]))
427                .await
428                .expect("write_all: io error");
429            sent += n;
430        }
431    }
432
433    #[test]
434    fn tcp_connect_and_echo() {
435        block_on_with_spawn(async {
436            // Bind a listener on a random port.
437            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
438            let addr = listener.local_addr().unwrap();
439
440            // Spawn a server task that accepts one connection and reads 5 bytes.
441            let server = crate::spawn(async move {
442                let (mut stream, _peer) = listener.accept().await.unwrap();
443                let mut buf = [0u8; 5];
444                read_exact(&mut stream, &mut buf).await;
445                buf
446            });
447
448            // Connect as client and send "hello".
449            let mut client = TcpStream::connect(addr).await.unwrap();
450            write_all(&mut client, b"hello").await;
451
452            // Shutdown write side so server's read returns EOF after 5 bytes.
453            use std::future::poll_fn;
454            poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
455                .await
456                .expect("shutdown failed");
457
458            let received = server.await.unwrap();
459            assert_eq!(&received, b"hello");
460        });
461    }
462
463    // ── Additional TCP stream tests ────────────────────────────────────────
464
465    #[test]
466    fn tcp_stream_connect_and_write_read() {
467        block_on_with_spawn(async {
468            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
469            let addr = listener.local_addr().unwrap();
470            let jh = crate::spawn(async move {
471                let mut client = TcpStream::connect(addr).await.unwrap();
472                write_all(&mut client, b"hello").await;
473            });
474            let (mut server, _) = listener.accept().await.unwrap();
475            let mut buf = [0u8; 5];
476            read_exact(&mut server, &mut buf).await;
477            assert_eq!(&buf, b"hello");
478            jh.await.unwrap();
479        });
480    }
481
482    #[test]
483    fn tcp_stream_echo_roundtrip() {
484        block_on_with_spawn(async {
485            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
486            let addr = listener.local_addr().unwrap();
487            // Server echoes back
488            let jh = crate::spawn(async move {
489                let (mut conn, _) = listener.accept().await.unwrap();
490                let mut buf = [0u8; 4];
491                read_exact(&mut conn, &mut buf).await;
492                write_all(&mut conn, &buf).await;
493            });
494            let mut client = TcpStream::connect(addr).await.unwrap();
495            write_all(&mut client, b"ping").await;
496            let mut buf = [0u8; 4];
497            read_exact(&mut client, &mut buf).await;
498            assert_eq!(&buf, b"ping");
499            jh.await.unwrap();
500        });
501    }
502
503    #[test]
504    fn tcp_stream_connect_refused_returns_err() {
505        // Nothing listening on port 1 — connection should be refused.
506        let result = block_on_with_spawn(async {
507            TcpStream::connect("127.0.0.1:1".parse().unwrap()).await
508        });
509        assert!(result.is_err());
510    }
511
512    #[test]
513    fn tcp_stream_local_and_peer_addr() {
514        block_on_with_spawn(async {
515            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
516            let server_addr = listener.local_addr().unwrap();
517            let jh = crate::spawn(async move { listener.accept().await.unwrap() });
518            let client = TcpStream::connect(server_addr).await.unwrap();
519            assert_eq!(client.peer_addr().unwrap(), server_addr);
520            assert_eq!(client.local_addr().unwrap().ip().to_string(), "127.0.0.1");
521            drop(jh);
522        });
523    }
524
525    #[test]
526    fn tcp_stream_large_payload() {
527        block_on_with_spawn(async {
528            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
529            let addr = listener.local_addr().unwrap();
530            let payload_size = 4096usize;
531            let jh = crate::spawn(async move {
532                let mut client = TcpStream::connect(addr).await.unwrap();
533                let data = vec![0xABu8; payload_size];
534                write_all(&mut client, &data).await;
535            });
536            let (mut server, _) = listener.accept().await.unwrap();
537            let mut buf = vec![0u8; payload_size];
538            read_exact(&mut server, &mut buf).await;
539            assert!(buf.iter().all(|&b| b == 0xAB));
540            jh.await.unwrap();
541        });
542    }
543
544    #[test]
545    fn tcp_stream_multiple_connections_sequential() {
546        block_on_with_spawn(async {
547            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
548            let addr = listener.local_addr().unwrap();
549            for i in 0u8..3 {
550                let a = addr;
551                let jh = crate::spawn(async move {
552                    let mut client = TcpStream::connect(a).await.unwrap();
553                    write_all(&mut client, &[i]).await;
554                });
555                let (mut server, _) = listener.accept().await.unwrap();
556                let mut buf = [0u8; 1];
557                read_exact(&mut server, &mut buf).await;
558                assert_eq!(buf[0], i);
559                jh.await.unwrap();
560            }
561        });
562    }
563
564    #[test]
565    fn tcp_stream_shutdown_write_half() {
566        use std::future::poll_fn;
567        block_on_with_spawn(async {
568            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
569            let addr = listener.local_addr().unwrap();
570            let jh = crate::spawn(async move {
571                let mut client = TcpStream::connect(addr).await.unwrap();
572                // Shutdown write half
573                poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
574                    .await
575                    .unwrap();
576            });
577            let (_server, _) = listener.accept().await.unwrap();
578            jh.await.unwrap();
579        });
580    }
581}