1use std::future::Future;
8use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use crate::platform::sys::{set_nonblocking, Interest};
14use crate::reactor::source::{next_token, IoSource};
15
16use super::sockaddr::{reclaim_raw_sockaddr, sockaddr_to_socketaddr, socketaddr_to_raw};
17use super::{AsyncRead, AsyncWrite};
18
19pub struct TcpStream {
23 source: IoSource,
24}
25
26impl TcpStream {
27 pub fn connect(addr: SocketAddr) -> ConnectFuture {
32 ConnectFuture::new(addr)
33 }
34
35 pub(crate) fn from_raw_fd(fd: i32) -> io::Result<Self> {
42 let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
44 Ok(Self { source })
45 }
46
47 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
49 peer_addr(self.source.raw())
50 }
51
52 pub fn local_addr(&self) -> io::Result<SocketAddr> {
54 local_addr(self.source.raw())
55 }
56}
57
58impl Drop for TcpStream {
59 fn drop(&mut self) {
60 let fd = self.source.raw();
61 unsafe { libc::close(fd) };
64 }
65}
66
67#[cfg(unix)]
68impl std::os::unix::io::AsRawFd for TcpStream {
69 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
74 self.source.raw()
75 }
76}
77
78impl AsyncRead for TcpStream {
81 fn poll_read(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 buf: &mut [u8],
85 ) -> Poll<io::Result<usize>> {
86 let fd = self.source.raw();
87
88 let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
91 if n > 0 {
92 return Poll::Ready(Ok(n as usize));
93 }
94 if n == 0 {
95 return Poll::Ready(Ok(0)); }
97
98 let err = io::Error::last_os_error();
99 if err.kind() != io::ErrorKind::WouldBlock {
100 return Poll::Ready(Err(err));
101 }
102
103 match Pin::new(&mut self.source.readable()).poll(cx) {
105 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
106 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
107 }
108 }
109}
110
111impl AsyncWrite for TcpStream {
114 fn poll_write(
115 self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 buf: &[u8],
118 ) -> Poll<io::Result<usize>> {
119 let fd = self.source.raw();
120
121 let n = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) };
123 if n >= 0 {
124 return Poll::Ready(Ok(n as usize));
125 }
126
127 let err = io::Error::last_os_error();
128 if err.kind() != io::ErrorKind::WouldBlock {
129 return Poll::Ready(Err(err));
130 }
131
132 match Pin::new(&mut self.source.writable()).poll(cx) {
134 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
135 Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
136 }
137 }
138
139 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
140 Poll::Ready(Ok(()))
142 }
143
144 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145 let fd = self.source.raw();
146 let rc = unsafe { libc::shutdown(fd, libc::SHUT_WR) };
148 if rc == -1 {
149 Poll::Ready(Err(io::Error::last_os_error()))
150 } else {
151 Poll::Ready(Ok(()))
152 }
153 }
154}
155
156pub struct ConnectFuture {
163 state: ConnectState,
164}
165
166enum ConnectState {
167 Init(SocketAddr),
169 Connecting {
172 fd: i32,
173 token: usize,
174 registered: bool,
176 },
177 Done,
179}
180
181impl ConnectFuture {
182 fn new(addr: SocketAddr) -> Self {
183 Self {
184 state: ConnectState::Init(addr),
185 }
186 }
187}
188
189impl Future for ConnectFuture {
190 type Output = io::Result<TcpStream>;
191
192 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193 loop {
194 match &mut self.state {
195 ConnectState::Init(addr) => {
196 let addr = *addr;
197 match start_connect(addr) {
198 Err(e) => {
199 self.state = ConnectState::Done;
200 return Poll::Ready(Err(e));
201 }
202 Ok((fd, connected)) => {
203 if connected {
204 self.state = ConnectState::Done;
206 return Poll::Ready(TcpStream::from_raw_fd(fd));
207 }
208 let token = next_token();
211 if let Err(e) = crate::reactor::with_reactor(|r| {
212 r.register(fd, token, Interest::WRITABLE)
213 }) {
214 unsafe { libc::close(fd) };
215 self.state = ConnectState::Done;
216 return Poll::Ready(Err(e));
217 }
218 self.state = ConnectState::Connecting {
219 fd,
220 token,
221 registered: true,
222 };
223 }
225 }
226 }
227
228 ConnectState::Connecting { fd, token, .. } => {
229 let fd = *fd;
230 let token = *token;
231
232 crate::reactor::with_reactor_mut(|r| {
234 r.wakers.set_write_waker(token, cx.waker().clone());
235 });
236
237 match get_so_error(fd) {
239 Err(e) => {
240 let _ = crate::reactor::with_reactor_mut(|r| {
242 r.deregister_with_token(fd, token)
243 });
244 self.state = ConnectState::Done;
245 return Poll::Ready(Err(e));
246 }
247 Ok(Some(os_err)) => {
248 let _ = crate::reactor::with_reactor_mut(|r| {
249 r.deregister_with_token(fd, token)
250 });
251 unsafe { libc::close(fd) };
252 self.state = ConnectState::Done;
253 return Poll::Ready(Err(io::Error::from_raw_os_error(os_err)));
254 }
255 Ok(None) => {
256 if is_writable_now(fd) {
261 let _ = crate::reactor::with_reactor_mut(|r| {
263 r.deregister_with_token(fd, token)
264 });
265 self.state = ConnectState::Done;
266 return Poll::Ready(TcpStream::from_raw_fd(fd));
267 }
268 return Poll::Pending;
270 }
271 }
272 }
273
274 ConnectState::Done => {
275 return Poll::Ready(Err(io::Error::other(
276 "ConnectFuture polled after completion",
277 )));
278 }
279 }
280 }
281 }
282}
283
284impl Drop for ConnectFuture {
285 fn drop(&mut self) {
286 if let ConnectState::Connecting { fd, token, .. } = self.state {
287 let _ = crate::reactor::with_reactor_mut(|r| r.deregister_with_token(fd, token));
289 unsafe { libc::close(fd) };
291 }
292 }
293}
294
295fn is_writable_now(fd: i32) -> bool {
300 unsafe {
302 let mut pfd = libc::pollfd {
303 fd,
304 events: libc::POLLOUT,
305 revents: 0,
306 };
307 let n = libc::poll(&mut pfd, 1, 0);
308 n > 0 && (pfd.revents & libc::POLLOUT) != 0
309 }
310}
311
312fn start_connect(addr: SocketAddr) -> io::Result<(i32, bool)> {
319 let family = match addr {
320 SocketAddr::V4(_) => libc::AF_INET,
321 SocketAddr::V6(_) => libc::AF_INET6,
322 };
323 let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
325 if fd == -1 {
326 return Err(io::Error::last_os_error());
327 }
328 set_nonblocking(fd)?;
329
330 let (sa, sa_len) = socketaddr_to_raw(addr);
331 let rc = unsafe { libc::connect(fd, sa, sa_len) };
333 unsafe { reclaim_raw_sockaddr(sa, addr) };
335
336 if rc == 0 {
337 return Ok((fd, true)); }
339
340 let err = io::Error::last_os_error();
341 if err.raw_os_error() == Some(libc::EINPROGRESS) {
343 return Ok((fd, false));
344 }
345
346 unsafe { libc::close(fd) };
348 Err(err)
349}
350
351fn get_so_error(fd: i32) -> io::Result<Option<i32>> {
356 let mut val: libc::c_int = 0;
357 let mut len = std::mem::size_of_val(&val) as libc::socklen_t;
358 let rc = unsafe {
360 libc::getsockopt(
361 fd,
362 libc::SOL_SOCKET,
363 libc::SO_ERROR,
364 &mut val as *mut libc::c_int as *mut libc::c_void,
365 &mut len,
366 )
367 };
368 if rc == -1 {
369 return Err(io::Error::last_os_error());
370 }
371 Ok(if val == 0 { None } else { Some(val) })
372}
373
374fn peer_addr(fd: i32) -> io::Result<SocketAddr> {
376 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
377 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
378 let rc = unsafe { libc::getpeername(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
380 if rc == -1 {
381 return Err(io::Error::last_os_error());
382 }
383 sockaddr_to_socketaddr(&addr, len)
384}
385
386fn local_addr(fd: i32) -> io::Result<SocketAddr> {
388 let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
389 let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
390 let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
392 if rc == -1 {
393 return Err(io::Error::last_os_error());
394 }
395 sockaddr_to_socketaddr(&addr, len)
396}
397
398#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::executor::block_on_with_spawn;
404 use crate::net::TcpListener;
405
406 async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) {
408 use std::future::poll_fn;
409 let mut filled = 0;
410 while filled < buf.len() {
411 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_read(cx, &mut buf[filled..]))
412 .await
413 .expect("read_exact: io error");
414 if n == 0 {
415 break;
416 } filled += n;
418 }
419 }
420
421 async fn write_all(stream: &mut TcpStream, buf: &[u8]) {
423 use std::future::poll_fn;
424 let mut sent = 0;
425 while sent < buf.len() {
426 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, &buf[sent..]))
427 .await
428 .expect("write_all: io error");
429 sent += n;
430 }
431 }
432
433 #[test]
434 fn tcp_connect_and_echo() {
435 block_on_with_spawn(async {
436 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
438 let addr = listener.local_addr().unwrap();
439
440 let server = crate::spawn(async move {
442 let (mut stream, _peer) = listener.accept().await.unwrap();
443 let mut buf = [0u8; 5];
444 read_exact(&mut stream, &mut buf).await;
445 buf
446 });
447
448 let mut client = TcpStream::connect(addr).await.unwrap();
450 write_all(&mut client, b"hello").await;
451
452 use std::future::poll_fn;
454 poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
455 .await
456 .expect("shutdown failed");
457
458 let received = server.await.unwrap();
459 assert_eq!(&received, b"hello");
460 });
461 }
462
463 #[test]
466 fn tcp_stream_connect_and_write_read() {
467 block_on_with_spawn(async {
468 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
469 let addr = listener.local_addr().unwrap();
470 let jh = crate::spawn(async move {
471 let mut client = TcpStream::connect(addr).await.unwrap();
472 write_all(&mut client, b"hello").await;
473 });
474 let (mut server, _) = listener.accept().await.unwrap();
475 let mut buf = [0u8; 5];
476 read_exact(&mut server, &mut buf).await;
477 assert_eq!(&buf, b"hello");
478 jh.await.unwrap();
479 });
480 }
481
482 #[test]
483 fn tcp_stream_echo_roundtrip() {
484 block_on_with_spawn(async {
485 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
486 let addr = listener.local_addr().unwrap();
487 let jh = crate::spawn(async move {
489 let (mut conn, _) = listener.accept().await.unwrap();
490 let mut buf = [0u8; 4];
491 read_exact(&mut conn, &mut buf).await;
492 write_all(&mut conn, &buf).await;
493 });
494 let mut client = TcpStream::connect(addr).await.unwrap();
495 write_all(&mut client, b"ping").await;
496 let mut buf = [0u8; 4];
497 read_exact(&mut client, &mut buf).await;
498 assert_eq!(&buf, b"ping");
499 jh.await.unwrap();
500 });
501 }
502
503 #[test]
504 fn tcp_stream_connect_refused_returns_err() {
505 let result = block_on_with_spawn(async {
507 TcpStream::connect("127.0.0.1:1".parse().unwrap()).await
508 });
509 assert!(result.is_err());
510 }
511
512 #[test]
513 fn tcp_stream_local_and_peer_addr() {
514 block_on_with_spawn(async {
515 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
516 let server_addr = listener.local_addr().unwrap();
517 let jh = crate::spawn(async move { listener.accept().await.unwrap() });
518 let client = TcpStream::connect(server_addr).await.unwrap();
519 assert_eq!(client.peer_addr().unwrap(), server_addr);
520 assert_eq!(client.local_addr().unwrap().ip().to_string(), "127.0.0.1");
521 drop(jh);
522 });
523 }
524
525 #[test]
526 fn tcp_stream_large_payload() {
527 block_on_with_spawn(async {
528 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
529 let addr = listener.local_addr().unwrap();
530 let payload_size = 4096usize;
531 let jh = crate::spawn(async move {
532 let mut client = TcpStream::connect(addr).await.unwrap();
533 let data = vec![0xABu8; payload_size];
534 write_all(&mut client, &data).await;
535 });
536 let (mut server, _) = listener.accept().await.unwrap();
537 let mut buf = vec![0u8; payload_size];
538 read_exact(&mut server, &mut buf).await;
539 assert!(buf.iter().all(|&b| b == 0xAB));
540 jh.await.unwrap();
541 });
542 }
543
544 #[test]
545 fn tcp_stream_multiple_connections_sequential() {
546 block_on_with_spawn(async {
547 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
548 let addr = listener.local_addr().unwrap();
549 for i in 0u8..3 {
550 let a = addr;
551 let jh = crate::spawn(async move {
552 let mut client = TcpStream::connect(a).await.unwrap();
553 write_all(&mut client, &[i]).await;
554 });
555 let (mut server, _) = listener.accept().await.unwrap();
556 let mut buf = [0u8; 1];
557 read_exact(&mut server, &mut buf).await;
558 assert_eq!(buf[0], i);
559 jh.await.unwrap();
560 }
561 });
562 }
563
564 #[test]
565 fn tcp_stream_shutdown_write_half() {
566 use std::future::poll_fn;
567 block_on_with_spawn(async {
568 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
569 let addr = listener.local_addr().unwrap();
570 let jh = crate::spawn(async move {
571 let mut client = TcpStream::connect(addr).await.unwrap();
572 poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
574 .await
575 .unwrap();
576 });
577 let (_server, _) = listener.accept().await.unwrap();
578 jh.await.unwrap();
579 });
580 }
581}