1use std::future::Future;
8use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use crate::platform::sys::{set_nonblocking, Interest};
14use crate::reactor::source::{next_token, IoSource};
15
16use super::{AsyncRead, AsyncWrite};
17
18pub struct TcpStream {
22 source: IoSource,
23}
24
25impl TcpStream {
26 pub fn connect(addr: SocketAddr) -> ConnectFuture {
31 ConnectFuture::new(addr)
32 }
33
34 pub(crate) fn from_raw_fd(fd: i32) -> io::Result<Self> {
41 let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
43 Ok(Self { source })
44 }
45
46 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
48 peer_addr(self.source.raw())
49 }
50
51 pub fn local_addr(&self) -> io::Result<SocketAddr> {
53 local_addr(self.source.raw())
54 }
55}
56
57impl Drop for TcpStream {
58 fn drop(&mut self) {
59 let fd = self.source.raw();
60 unsafe { libc::close(fd) };
63 }
64}
65
66impl AsyncRead for TcpStream {
69 fn poll_read(
70 self: Pin<&mut Self>,
71 cx: &mut Context<'_>,
72 buf: &mut [u8],
73 ) -> Poll<io::Result<usize>> {
74 let fd = self.source.raw();
75
76 let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
79 if n > 0 {
80 return Poll::Ready(Ok(n as usize));
81 }
82 if n == 0 {
83 return Poll::Ready(Ok(0)); }
85
86 let err = io::Error::last_os_error();
87 if err.kind() != io::ErrorKind::WouldBlock {
88 return Poll::Ready(Err(err));
89 }
90
91 match Pin::new(&mut self.source.readable()).poll(cx) {
93 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
94 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
95 }
96 }
97}
98
99impl AsyncWrite for TcpStream {
102 fn poll_write(
103 self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 buf: &[u8],
106 ) -> Poll<io::Result<usize>> {
107 let fd = self.source.raw();
108
109 let n = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) };
111 if n >= 0 {
112 return Poll::Ready(Ok(n as usize));
113 }
114
115 let err = io::Error::last_os_error();
116 if err.kind() != io::ErrorKind::WouldBlock {
117 return Poll::Ready(Err(err));
118 }
119
120 match Pin::new(&mut self.source.writable()).poll(cx) {
122 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
123 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
124 }
125 }
126
127 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128 Poll::Ready(Ok(()))
130 }
131
132 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
133 let fd = self.source.raw();
134 let rc = unsafe { libc::shutdown(fd, libc::SHUT_WR) };
136 if rc == -1 {
137 Poll::Ready(Err(io::Error::last_os_error()))
138 } else {
139 Poll::Ready(Ok(()))
140 }
141 }
142}
143
144pub struct ConnectFuture {
151 state: ConnectState,
152}
153
154enum ConnectState {
155 Init(SocketAddr),
157 Connecting {
160 fd: i32,
161 token: usize,
162 registered: bool,
164 },
165 Done,
167}
168
169impl ConnectFuture {
170 fn new(addr: SocketAddr) -> Self {
171 Self {
172 state: ConnectState::Init(addr),
173 }
174 }
175}
176
177impl Future for ConnectFuture {
178 type Output = io::Result<TcpStream>;
179
180 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181 loop {
182 match &mut self.state {
183 ConnectState::Init(addr) => {
184 let addr = *addr;
185 match start_connect(addr) {
186 Err(e) => {
187 self.state = ConnectState::Done;
188 return Poll::Ready(Err(e));
189 }
190 Ok((fd, connected)) => {
191 if connected {
192 self.state = ConnectState::Done;
194 return Poll::Ready(TcpStream::from_raw_fd(fd));
195 }
196 let token = next_token();
199 if let Err(e) = crate::reactor::with_reactor(|r| {
200 r.register(fd, token, Interest::WRITABLE)
201 }) {
202 unsafe { libc::close(fd) };
203 self.state = ConnectState::Done;
204 return Poll::Ready(Err(e));
205 }
206 self.state = ConnectState::Connecting {
207 fd,
208 token,
209 registered: true,
210 };
211 }
213 }
214 }
215
216 ConnectState::Connecting { fd, token, .. } => {
217 let fd = *fd;
218 let token = *token;
219
220 crate::reactor::with_reactor_mut(|r| {
222 r.wakers.set_write_waker(token, cx.waker().clone());
223 });
224
225 match get_so_error(fd) {
227 Err(e) => {
228 let _ = crate::reactor::with_reactor_mut(|r| {
230 r.deregister_with_token(fd, token)
231 });
232 self.state = ConnectState::Done;
233 return Poll::Ready(Err(e));
234 }
235 Ok(Some(os_err)) => {
236 let _ = crate::reactor::with_reactor_mut(|r| {
237 r.deregister_with_token(fd, token)
238 });
239 unsafe { libc::close(fd) };
240 self.state = ConnectState::Done;
241 return Poll::Ready(Err(io::Error::from_raw_os_error(os_err)));
242 }
243 Ok(None) => {
244 if is_writable_now(fd) {
249 let _ = crate::reactor::with_reactor_mut(|r| {
251 r.deregister_with_token(fd, token)
252 });
253 self.state = ConnectState::Done;
254 return Poll::Ready(TcpStream::from_raw_fd(fd));
255 }
256 return Poll::Pending;
258 }
259 }
260 }
261
262 ConnectState::Done => {
263 return Poll::Ready(Err(io::Error::other(
264 "ConnectFuture polled after completion",
265 )));
266 }
267 }
268 }
269 }
270}
271
272impl Drop for ConnectFuture {
273 fn drop(&mut self) {
274 if let ConnectState::Connecting { fd, token, .. } = self.state {
275 let _ = crate::reactor::with_reactor_mut(|r| r.deregister_with_token(fd, token));
277 unsafe { libc::close(fd) };
279 }
280 }
281}
282
283fn is_writable_now(fd: i32) -> bool {
287 unsafe {
289 let mut write_set: libc::fd_set = std::mem::zeroed();
290 libc::FD_ZERO(&mut write_set);
291 libc::FD_SET(fd, &mut write_set);
292 let mut tv = libc::timeval {
293 tv_sec: 0,
294 tv_usec: 0,
295 };
296 let n = libc::select(
297 fd + 1,
298 std::ptr::null_mut(),
299 &mut write_set,
300 std::ptr::null_mut(),
301 &mut tv,
302 );
303 n > 0 && libc::FD_ISSET(fd, &write_set)
304 }
305}
306
307fn start_connect(addr: SocketAddr) -> io::Result<(i32, bool)> {
314 let family = match addr {
315 SocketAddr::V4(_) => libc::AF_INET,
316 SocketAddr::V6(_) => libc::AF_INET6,
317 };
318 let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
320 if fd == -1 {
321 return Err(io::Error::last_os_error());
322 }
323 set_nonblocking(fd)?;
324
325 let (sa, sa_len) = socketaddr_to_raw(addr);
326 let rc = unsafe { libc::connect(fd, sa, sa_len) };
328 unsafe { reclaim_raw_sockaddr(sa, addr) };
330
331 if rc == 0 {
332 return Ok((fd, true)); }
334
335 let err = io::Error::last_os_error();
336 if err.raw_os_error() == Some(libc::EINPROGRESS) {
338 return Ok((fd, false));
339 }
340
341 unsafe { libc::close(fd) };
343 Err(err)
344}
345
346fn get_so_error(fd: i32) -> io::Result<Option<i32>> {
351 let mut val: libc::c_int = 0;
352 let mut len = std::mem::size_of_val(&val) as libc::socklen_t;
353 let rc = unsafe {
355 libc::getsockopt(
356 fd,
357 libc::SOL_SOCKET,
358 libc::SO_ERROR,
359 &mut val as *mut libc::c_int as *mut libc::c_void,
360 &mut len,
361 )
362 };
363 if rc == -1 {
364 return Err(io::Error::last_os_error());
365 }
366 Ok(if val == 0 { None } else { Some(val) })
367}
368
369fn peer_addr(fd: i32) -> io::Result<SocketAddr> {
371 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
372 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
373 let rc = unsafe { libc::getpeername(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
375 if rc == -1 {
376 return Err(io::Error::last_os_error());
377 }
378 sockaddr_to_socketaddr(&addr, len)
379}
380
381fn local_addr(fd: i32) -> io::Result<SocketAddr> {
383 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
384 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
385 let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
387 if rc == -1 {
388 return Err(io::Error::last_os_error());
389 }
390 sockaddr_to_socketaddr(&addr, len)
391}
392
393fn socketaddr_to_raw(addr: SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
397 match addr {
398 SocketAddr::V4(v4) => {
399 let octets = v4.ip().octets();
400 let mut sin: libc::sockaddr_in = unsafe { std::mem::zeroed() };
402 sin.sin_family = libc::AF_INET as libc::sa_family_t;
403 sin.sin_port = v4.port().to_be();
404 sin.sin_addr = libc::in_addr {
405 s_addr: u32::from_be_bytes(octets).to_be(),
406 };
407 let boxed = Box::new(sin);
408 let ptr = Box::into_raw(boxed) as *const libc::sockaddr;
409 let len = std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t;
410 (ptr, len)
411 }
412 SocketAddr::V6(v6) => {
413 let mut sin6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
415 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
416 sin6.sin6_port = v6.port().to_be();
417 sin6.sin6_flowinfo = v6.flowinfo();
418 sin6.sin6_addr = libc::in6_addr {
419 s6_addr: v6.ip().octets(),
420 };
421 sin6.sin6_scope_id = v6.scope_id();
422 let boxed = Box::new(sin6);
423 let ptr = Box::into_raw(boxed) as *const libc::sockaddr;
424 let len = std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t;
425 (ptr, len)
426 }
427 }
428}
429
430unsafe fn reclaim_raw_sockaddr(ptr: *const libc::sockaddr, addr: SocketAddr) {
433 match addr {
434 SocketAddr::V4(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in)),
435 SocketAddr::V6(_) => drop(Box::from_raw(ptr as *mut libc::sockaddr_in6)),
436 }
437}
438
439fn sockaddr_to_socketaddr(
441 addr: &libc::sockaddr_in6,
442 len: libc::socklen_t,
443) -> io::Result<SocketAddr> {
444 let family = addr.sin6_family as libc::c_int;
445 match family {
446 libc::AF_INET if len >= std::mem::size_of::<libc::sockaddr_in>() as u32 => {
447 let v4: &libc::sockaddr_in =
449 unsafe { &*(addr as *const _ as *const libc::sockaddr_in) };
450 let ip = std::net::Ipv4Addr::from(u32::from_be(v4.sin_addr.s_addr));
451 let port = u16::from_be(v4.sin_port);
452 Ok(SocketAddr::V4(std::net::SocketAddrV4::new(ip, port)))
453 }
454 libc::AF_INET6 if len >= std::mem::size_of::<libc::sockaddr_in6>() as u32 => {
455 let ip = std::net::Ipv6Addr::from(addr.sin6_addr.s6_addr);
456 let port = u16::from_be(addr.sin6_port);
457 Ok(SocketAddr::V6(std::net::SocketAddrV6::new(
458 ip,
459 port,
460 addr.sin6_flowinfo,
461 addr.sin6_scope_id,
462 )))
463 }
464 _ => Err(io::Error::new(
465 io::ErrorKind::InvalidData,
466 format!("unsupported address family: {family}"),
467 )),
468 }
469}
470
471#[cfg(test)]
474mod tests {
475 use super::*;
476 use crate::executor::block_on_with_spawn;
477 use crate::net::TcpListener;
478
479 async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) {
481 use std::future::poll_fn;
482 let mut filled = 0;
483 while filled < buf.len() {
484 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_read(cx, &mut buf[filled..]))
485 .await
486 .expect("read_exact: io error");
487 if n == 0 {
488 break;
489 } filled += n;
491 }
492 }
493
494 async fn write_all(stream: &mut TcpStream, buf: &[u8]) {
496 use std::future::poll_fn;
497 let mut sent = 0;
498 while sent < buf.len() {
499 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, &buf[sent..]))
500 .await
501 .expect("write_all: io error");
502 sent += n;
503 }
504 }
505
506 #[test]
507 #[ignore = "requires IoSource ReadableFuture to resolve Ready — net integration pending"]
508 fn tcp_connect_and_echo() {
509 block_on_with_spawn(async {
510 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
512 let addr = listener.local_addr().unwrap();
513
514 let server = crate::spawn(async move {
516 let (mut stream, _peer) = listener.accept().await.unwrap();
517 let mut buf = [0u8; 5];
518 read_exact(&mut stream, &mut buf).await;
519 buf
520 });
521
522 let mut client = TcpStream::connect(addr).await.unwrap();
524 write_all(&mut client, b"hello").await;
525
526 use std::future::poll_fn;
528 poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
529 .await
530 .expect("shutdown failed");
531
532 let received = server.await.unwrap();
533 assert_eq!(&received, b"hello");
534 });
535 }
536}