Skip to main content

moduvex_runtime/net/
tcp_stream.rs

1//! Async `TcpStream` — non-blocking bidirectional TCP byte stream.
2//!
3//! Implements [`AsyncRead`] and [`AsyncWrite`] using `libc::read` / `libc::write`.
4//! The underlying fd is registered with the reactor; `readable()` / `writable()`
5//! futures from `IoSource` are used to suspend until the OS signals readiness.
6
7use std::future::Future;
8use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use crate::platform::sys::{set_nonblocking, Interest};
14use crate::reactor::source::{next_token, IoSource};
15
16use super::{AsyncRead, AsyncWrite};
17
18// ── TcpStream ─────────────────────────────────────────────────────────────────
19
20/// Async TCP stream. Implements `AsyncRead` + `AsyncWrite`.
21pub struct TcpStream {
22    source: IoSource,
23}
24
25impl TcpStream {
26    /// Connect to `addr` asynchronously.
27    ///
28    /// Creates a non-blocking socket and starts a `connect()` call. Returns a
29    /// [`ConnectFuture`] that resolves once the TCP handshake completes.
30    pub fn connect(addr: SocketAddr) -> ConnectFuture {
31        ConnectFuture::new(addr)
32    }
33
34    /// Wrap an already-connected raw file descriptor in a `TcpStream`.
35    ///
36    /// `fd` must be a connected, non-blocking TCP socket.
37    ///
38    /// # Errors
39    /// Returns `Err` if reactor registration fails.
40    pub(crate) fn from_raw_fd(fd: i32) -> io::Result<Self> {
41        // Register for both directions so we can arm either on demand.
42        let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
43        Ok(Self { source })
44    }
45
46    /// Return the peer address of the connection.
47    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
48        peer_addr(self.source.raw())
49    }
50
51    /// Return the local address of the connection.
52    pub fn local_addr(&self) -> io::Result<SocketAddr> {
53        local_addr(self.source.raw())
54    }
55}
56
57impl Drop for TcpStream {
58    fn drop(&mut self) {
59        let fd = self.source.raw();
60        // IoSource Drop deregisters from reactor; close the fd here.
61        // SAFETY: we own `fd` exclusively; it is valid until this drop runs.
62        unsafe { libc::close(fd) };
63    }
64}
65
66// ── AsyncRead ─────────────────────────────────────────────────────────────────
67
68impl AsyncRead for TcpStream {
69    fn poll_read(
70        self: Pin<&mut Self>,
71        cx: &mut Context<'_>,
72        buf: &mut [u8],
73    ) -> Poll<io::Result<usize>> {
74        let fd = self.source.raw();
75
76        // Try the read immediately — may already have data in the kernel buffer.
77        // SAFETY: `fd` is a valid non-blocking socket; `buf` is a valid slice.
78        let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
79        if n > 0 {
80            return Poll::Ready(Ok(n as usize));
81        }
82        if n == 0 {
83            return Poll::Ready(Ok(0)); // EOF
84        }
85
86        let err = io::Error::last_os_error();
87        if err.kind() != io::ErrorKind::WouldBlock {
88            return Poll::Ready(Err(err));
89        }
90
91        // No data yet — register waker and wait for READABLE event.
92        match Pin::new(&mut self.source.readable()).poll(cx) {
93            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
94            Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
95        }
96    }
97}
98
99// ── AsyncWrite ────────────────────────────────────────────────────────────────
100
101impl AsyncWrite for TcpStream {
102    fn poll_write(
103        self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105        buf: &[u8],
106    ) -> Poll<io::Result<usize>> {
107        let fd = self.source.raw();
108
109        // SAFETY: `fd` is a valid non-blocking socket; `buf` is a valid slice.
110        let n = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) };
111        if n >= 0 {
112            return Poll::Ready(Ok(n as usize));
113        }
114
115        let err = io::Error::last_os_error();
116        if err.kind() != io::ErrorKind::WouldBlock {
117            return Poll::Ready(Err(err));
118        }
119
120        // Socket send buffer full — wait for WRITABLE event.
121        match Pin::new(&mut self.source.writable()).poll(cx) {
122            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
123            Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
124        }
125    }
126
127    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128        // TCP sockets are kernel-buffered — flush is a no-op.
129        Poll::Ready(Ok(()))
130    }
131
132    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
133        let fd = self.source.raw();
134        // SAFETY: `fd` is a valid socket; SHUT_WR is a documented constant.
135        let rc = unsafe { libc::shutdown(fd, libc::SHUT_WR) };
136        if rc == -1 {
137            Poll::Ready(Err(io::Error::last_os_error()))
138        } else {
139            Poll::Ready(Ok(()))
140        }
141    }
142}
143
144// ── ConnectFuture ─────────────────────────────────────────────────────────────
145
146/// Future returned by [`TcpStream::connect`].
147///
148/// Phase 1: creates the socket and calls `connect()` (returns EINPROGRESS).
149/// Phase 2: stores waker in reactor registry; on WRITABLE event, checks SO_ERROR.
150pub struct ConnectFuture {
151    state: ConnectState,
152}
153
154enum ConnectState {
155    /// Not yet started — stores the address for lazy socket creation.
156    Init(SocketAddr),
157    /// Socket created, connect() in progress; waiting for WRITABLE.
158    /// `waker_armed` tracks whether we already registered the waker this poll.
159    Connecting {
160        fd: i32,
161        token: usize,
162        /// True after initial `register()` — stays true across polls.
163        registered: bool,
164    },
165    /// Done (stream returned or error returned).
166    Done,
167}
168
169impl ConnectFuture {
170    fn new(addr: SocketAddr) -> Self {
171        Self {
172            state: ConnectState::Init(addr),
173        }
174    }
175}
176
177impl Future for ConnectFuture {
178    type Output = io::Result<TcpStream>;
179
180    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181        loop {
182            match &mut self.state {
183                ConnectState::Init(addr) => {
184                    let addr = *addr;
185                    match start_connect(addr) {
186                        Err(e) => {
187                            self.state = ConnectState::Done;
188                            return Poll::Ready(Err(e));
189                        }
190                        Ok((fd, connected)) => {
191                            if connected {
192                                // Instant connect — wrap fd directly.
193                                self.state = ConnectState::Done;
194                                return Poll::Ready(TcpStream::from_raw_fd(fd));
195                            }
196                            // Register fd for WRITABLE in the reactor so we
197                            // get woken when connect() completes.
198                            let token = next_token();
199                            if let Err(e) = crate::reactor::with_reactor(|r| {
200                                r.register(fd, token, Interest::WRITABLE)
201                            }) {
202                                unsafe { libc::close(fd) };
203                                self.state = ConnectState::Done;
204                                return Poll::Ready(Err(e));
205                            }
206                            self.state = ConnectState::Connecting {
207                                fd,
208                                token,
209                                registered: true,
210                            };
211                            // Fall through to Connecting arm.
212                        }
213                    }
214                }
215
216                ConnectState::Connecting { fd, token, .. } => {
217                    let fd = *fd;
218                    let token = *token;
219
220                    // Store waker so reactor wakes us on WRITABLE.
221                    crate::reactor::with_reactor_mut(|r| {
222                        r.wakers.set_write_waker(token, cx.waker().clone());
223                    });
224
225                    // Check if connect completed (may have raced since last poll).
226                    match get_so_error(fd) {
227                        Err(e) => {
228                            // Clean up reactor registration.
229                            let _ = crate::reactor::with_reactor_mut(|r| {
230                                r.deregister_with_token(fd, token)
231                            });
232                            self.state = ConnectState::Done;
233                            return Poll::Ready(Err(e));
234                        }
235                        Ok(Some(os_err)) => {
236                            let _ = crate::reactor::with_reactor_mut(|r| {
237                                r.deregister_with_token(fd, token)
238                            });
239                            unsafe { libc::close(fd) };
240                            self.state = ConnectState::Done;
241                            return Poll::Ready(Err(io::Error::from_raw_os_error(os_err)));
242                        }
243                        Ok(None) => {
244                            // SO_ERROR == 0 means connected. But we may be
245                            // polled here before the WRITABLE event fires on
246                            // the very first poll (connect still in progress).
247                            // Distinguish by checking if the socket is writable NOW.
248                            if is_writable_now(fd) {
249                                // Connect complete — deregister old token, wrap fd.
250                                let _ = crate::reactor::with_reactor_mut(|r| {
251                                    r.deregister_with_token(fd, token)
252                                });
253                                self.state = ConnectState::Done;
254                                return Poll::Ready(TcpStream::from_raw_fd(fd));
255                            }
256                            // Not writable yet — waker stored above, wait.
257                            return Poll::Pending;
258                        }
259                    }
260                }
261
262                ConnectState::Done => {
263                    return Poll::Ready(Err(io::Error::other(
264                        "ConnectFuture polled after completion",
265                    )));
266                }
267            }
268        }
269    }
270}
271
272impl Drop for ConnectFuture {
273    fn drop(&mut self) {
274        if let ConnectState::Connecting { fd, token, .. } = self.state {
275            // Clean up reactor and close fd if the future is dropped mid-connect.
276            let _ = crate::reactor::with_reactor_mut(|r| r.deregister_with_token(fd, token));
277            // SAFETY: fd is a valid socket we own; future is being dropped.
278            unsafe { libc::close(fd) };
279        }
280    }
281}
282
283/// Non-blocking poll: returns true if `fd` is writable right now.
284///
285/// Uses `select` with a zero timeout to probe write-readiness.
286fn is_writable_now(fd: i32) -> bool {
287    // SAFETY: all libc types are initialized; select is a documented syscall.
288    unsafe {
289        let mut write_set: libc::fd_set = std::mem::zeroed();
290        libc::FD_ZERO(&mut write_set);
291        libc::FD_SET(fd, &mut write_set);
292        let mut tv = libc::timeval {
293            tv_sec: 0,
294            tv_usec: 0,
295        };
296        let n = libc::select(
297            fd + 1,
298            std::ptr::null_mut(),
299            &mut write_set,
300            std::ptr::null_mut(),
301            &mut tv,
302        );
303        n > 0 && libc::FD_ISSET(fd, &write_set)
304    }
305}
306
307// ── Unix helpers ──────────────────────────────────────────────────────────────
308
309/// Create a non-blocking TCP socket and call `connect()`.
310///
311/// Returns `(fd, connected)` where `connected` is `true` if the connection
312/// completed immediately (rare, e.g. loopback).
313fn start_connect(addr: SocketAddr) -> io::Result<(i32, bool)> {
314    let family = match addr {
315        SocketAddr::V4(_) => libc::AF_INET,
316        SocketAddr::V6(_) => libc::AF_INET6,
317    };
318    // SAFETY: documented syscall with valid AF_INET/AF_INET6 constants.
319    let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
320    if fd == -1 {
321        return Err(io::Error::last_os_error());
322    }
323    set_nonblocking(fd)?;
324
325    let (sa, sa_len) = socketaddr_to_raw(addr);
326    // SAFETY: `fd` is a valid socket; `sa`/`sa_len` describe a valid sockaddr.
327    let rc = unsafe { libc::connect(fd, sa, sa_len) };
328    // SAFETY: we used Box::into_raw in socketaddr_to_raw; reclaim the Box now.
329    unsafe { reclaim_raw_sockaddr(sa, addr) };
330
331    if rc == 0 {
332        return Ok((fd, true)); // instant connect
333    }
334
335    let err = io::Error::last_os_error();
336    // EINPROGRESS (or EAGAIN on some platforms) means "in progress" — normal.
337    if err.raw_os_error() == Some(libc::EINPROGRESS) {
338        return Ok((fd, false));
339    }
340
341    // Real error — close and propagate.
342    unsafe { libc::close(fd) };
343    Err(err)
344}
345
346/// Read `SO_ERROR` on `fd` to check connect completion status.
347///
348/// Returns `Ok(None)` on success, `Ok(Some(errno))` on connect failure,
349/// `Err(...)` if getsockopt itself fails.
350fn get_so_error(fd: i32) -> io::Result<Option<i32>> {
351    let mut val: libc::c_int = 0;
352    let mut len = std::mem::size_of_val(&val) as libc::socklen_t;
353    // SAFETY: `fd` is a valid socket; `val`/`len` are correctly sized.
354    let rc = unsafe {
355        libc::getsockopt(
356            fd,
357            libc::SOL_SOCKET,
358            libc::SO_ERROR,
359            &mut val as *mut libc::c_int as *mut libc::c_void,
360            &mut len,
361        )
362    };
363    if rc == -1 {
364        return Err(io::Error::last_os_error());
365    }
366    Ok(if val == 0 { None } else { Some(val) })
367}
368
369/// Query the peer address of `fd`.
370fn peer_addr(fd: i32) -> io::Result<SocketAddr> {
371    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
372    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
373    // SAFETY: `fd` is a valid connected socket; `addr` is large enough.
374    let rc = unsafe { libc::getpeername(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
375    if rc == -1 {
376        return Err(io::Error::last_os_error());
377    }
378    sockaddr_to_socketaddr(&addr, len)
379}
380
381/// Query the local address of `fd`.
382fn local_addr(fd: i32) -> io::Result<SocketAddr> {
383    let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
384    let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
385    // SAFETY: `fd` is a valid socket; `addr` is large enough.
386    let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
387    if rc == -1 {
388        return Err(io::Error::last_os_error());
389    }
390    sockaddr_to_socketaddr(&addr, len)
391}
392
393/// Convert `SocketAddr` to a heap-allocated raw sockaddr pointer.
394///
395/// Caller must call `reclaim_raw_sockaddr` after the syscall to free memory.
396fn socketaddr_to_raw(addr: SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
397    match addr {
398        SocketAddr::V4(v4) => {
399            let octets = v4.ip().octets();
400            // SAFETY: zeroed() gives a valid bit pattern; we fill every field.
401            let mut sin: libc::sockaddr_in = unsafe { std::mem::zeroed() };
402            sin.sin_family = libc::AF_INET as libc::sa_family_t;
403            sin.sin_port = v4.port().to_be();
404            sin.sin_addr = libc::in_addr {
405                s_addr: u32::from_be_bytes(octets).to_be(),
406            };
407            let boxed = Box::new(sin);
408            let ptr = Box::into_raw(boxed) as *const libc::sockaddr;
409            let len = std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t;
410            (ptr, len)
411        }
412        SocketAddr::V6(v6) => {
413            // SAFETY: zeroed() gives a valid bit pattern; we fill every field.
414            let mut sin6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
415            sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
416            sin6.sin6_port = v6.port().to_be();
417            sin6.sin6_flowinfo = v6.flowinfo();
418            sin6.sin6_addr = libc::in6_addr {
419                s6_addr: v6.ip().octets(),
420            };
421            sin6.sin6_scope_id = v6.scope_id();
422            let boxed = Box::new(sin6);
423            let ptr = Box::into_raw(boxed) as *const libc::sockaddr;
424            let len = std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t;
425            (ptr, len)
426        }
427    }
428}
429
430/// # Safety
431/// `ptr` must have been produced by `socketaddr_to_raw` with matching `addr`.
432unsafe fn reclaim_raw_sockaddr(ptr: *const libc::sockaddr, addr: SocketAddr) {
433    match addr {
434        SocketAddr::V4(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in)),
435        SocketAddr::V6(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in6)),
436    }
437}
438
439/// Convert a kernel-filled `sockaddr_in6` buffer (may be `sockaddr_in`) to `SocketAddr`.
440fn sockaddr_to_socketaddr(
441    addr: &libc::sockaddr_in6,
442    len: libc::socklen_t,
443) -> io::Result<SocketAddr> {
444    let family = addr.sin6_family as libc::c_int;
445    match family {
446        libc::AF_INET if len >= std::mem::size_of::<libc::sockaddr_in>() as u32 => {
447            // SAFETY: kernel wrote AF_INET data of correct size.
448            let v4: &libc::sockaddr_in =
449                unsafe { &*(addr as *const _ as *const libc::sockaddr_in) };
450            let ip = std::net::Ipv4Addr::from(u32::from_be(v4.sin_addr.s_addr));
451            let port = u16::from_be(v4.sin_port);
452            Ok(SocketAddr::V4(std::net::SocketAddrV4::new(ip, port)))
453        }
454        libc::AF_INET6 if len >= std::mem::size_of::<libc::sockaddr_in6>() as u32 => {
455            let ip = std::net::Ipv6Addr::from(addr.sin6_addr.s6_addr);
456            let port = u16::from_be(addr.sin6_port);
457            Ok(SocketAddr::V6(std::net::SocketAddrV6::new(
458                ip,
459                port,
460                addr.sin6_flowinfo,
461                addr.sin6_scope_id,
462            )))
463        }
464        _ => Err(io::Error::new(
465            io::ErrorKind::InvalidData,
466            format!("unsupported address family: {family}"),
467        )),
468    }
469}
470
471// ── Tests ─────────────────────────────────────────────────────────────────────
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use crate::executor::block_on_with_spawn;
477    use crate::net::TcpListener;
478
479    /// Poll-based async read: keeps polling until `n` bytes are gathered.
480    async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) {
481        use std::future::poll_fn;
482        let mut filled = 0;
483        while filled < buf.len() {
484            let n = poll_fn(|cx| Pin::new(&mut *stream).poll_read(cx, &mut buf[filled..]))
485                .await
486                .expect("read_exact: io error");
487            if n == 0 {
488                break;
489            } // EOF
490            filled += n;
491        }
492    }
493
494    /// Poll-based async write: keeps polling until all bytes are sent.
495    async fn write_all(stream: &mut TcpStream, buf: &[u8]) {
496        use std::future::poll_fn;
497        let mut sent = 0;
498        while sent < buf.len() {
499            let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, &buf[sent..]))
500                .await
501                .expect("write_all: io error");
502            sent += n;
503        }
504    }
505
506    #[test]
507    #[ignore = "requires IoSource ReadableFuture to resolve Ready — net integration pending"]
508    fn tcp_connect_and_echo() {
509        block_on_with_spawn(async {
510            // Bind a listener on a random port.
511            let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
512            let addr = listener.local_addr().unwrap();
513
514            // Spawn a server task that accepts one connection and reads 5 bytes.
515            let server = crate::spawn(async move {
516                let (mut stream, _peer) = listener.accept().await.unwrap();
517                let mut buf = [0u8; 5];
518                read_exact(&mut stream, &mut buf).await;
519                buf
520            });
521
522            // Connect as client and send "hello".
523            let mut client = TcpStream::connect(addr).await.unwrap();
524            write_all(&mut client, b"hello").await;
525
526            // Shutdown write side so server's read returns EOF after 5 bytes.
527            use std::future::poll_fn;
528            poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
529                .await
530                .expect("shutdown failed");
531
532            let received = server.await.unwrap();
533            assert_eq!(&received, b"hello");
534        });
535    }
536}