1use std::io::{self, Read, Write};
15use std::net::SocketAddr;
16use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd};
17use std::pin::Pin;
18use std::task::{Context, Poll, Waker};
19use std::time::Duration;
20
21use mio::{Interest, Token};
22
23use super::{AsyncRead, AsyncWrite, waker_to_ptr};
24use crate::io::IoHandle;
25
26pub struct TcpStream {
38 inner: mio::net::TcpStream,
39 io: IoHandle,
40 token: Option<Token>,
41 registered_task: *mut u8,
45}
46
47impl TcpStream {
48 pub(crate) fn new(inner: mio::net::TcpStream, io: IoHandle) -> Self {
50 Self {
51 inner,
52 io,
53 token: None,
54 registered_task: std::ptr::null_mut(),
55 }
56 }
57
58 pub fn connect(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
64 let inner = mio::net::TcpStream::connect(addr)?;
65 Ok(Self::new(inner, io))
66 }
67
68 pub fn from_std(stream: std::net::TcpStream, io: IoHandle) -> io::Result<Self> {
72 let inner = mio::net::TcpStream::from_std(stream);
73 Ok(Self::new(inner, io))
74 }
75
76 pub fn into_std(mut self) -> io::Result<std::net::TcpStream> {
80 if let Some(token) = self.token.take() {
81 let _ = unsafe { self.io.deregister(&mut self.inner, token) };
83 }
84 let fd = self.inner.as_raw_fd();
85 std::mem::forget(self); Ok(unsafe { std::net::TcpStream::from_raw_fd(fd) })
88 }
89
90 pub fn local_addr(&self) -> io::Result<SocketAddr> {
96 self.inner.local_addr()
97 }
98
99 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
101 self.inner.peer_addr()
102 }
103
104 fn socket_ref(&self) -> socket2::SockRef<'_> {
110 socket2::SockRef::from(&self.inner)
111 }
112
113 pub fn nodelay(&self) -> io::Result<bool> {
115 self.inner.nodelay()
116 }
117
118 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
120 self.inner.set_nodelay(nodelay)
121 }
122
123 pub fn ttl(&self) -> io::Result<u32> {
125 self.socket_ref().ttl()
126 }
127
128 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
130 self.socket_ref().set_ttl(ttl)
131 }
132
133 pub fn linger(&self) -> io::Result<Option<Duration>> {
135 self.socket_ref().linger()
136 }
137
138 pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
140 self.socket_ref().set_linger(duration)
141 }
142
143 pub fn keepalive(&self) -> io::Result<bool> {
145 self.socket_ref().keepalive()
146 }
147
148 pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
150 self.socket_ref().set_keepalive(keepalive)
151 }
152
153 pub fn send_buffer_size(&self) -> io::Result<usize> {
155 self.socket_ref().send_buffer_size()
156 }
157
158 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
160 self.socket_ref().set_send_buffer_size(size)
161 }
162
163 pub fn recv_buffer_size(&self) -> io::Result<usize> {
165 self.socket_ref().recv_buffer_size()
166 }
167
168 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
170 self.socket_ref().set_recv_buffer_size(size)
171 }
172
173 pub fn take_error(&self) -> io::Result<Option<io::Error>> {
175 self.socket_ref().take_error()
176 }
177
178 pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
184 (&self.inner).read(buf)
185 }
186
187 pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
189 (&self.inner).write(buf)
190 }
191
192 pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
194 let buf = unsafe { &mut *(buf as *mut [u8] as *mut [std::mem::MaybeUninit<u8>]) };
196 self.socket_ref().peek(buf)
197 }
198
199 pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
206 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await
207 }
208
209 pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
211 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_write(cx, buf)).await
212 }
213
214 pub async fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
216 while !buf.is_empty() {
217 let n = self.write(buf).await?;
218 if n == 0 {
219 return Err(io::Error::new(
220 io::ErrorKind::WriteZero,
221 "failed to write whole buffer",
222 ));
223 }
224 buf = &buf[n..];
225 }
226 Ok(())
227 }
228
229 pub fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
235 if let Err(e) = self.ensure_registered(cx) {
236 return Poll::Ready(Err(e));
237 }
238 if let Some(token) = self.token {
239 if self.io.readiness(token).readable {
240 return Poll::Ready(Ok(()));
241 }
242 }
243 Poll::Pending
244 }
245
246 pub fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
248 if let Err(e) = self.ensure_registered(cx) {
249 return Poll::Ready(Err(e));
250 }
251 if let Some(token) = self.token {
252 if self.io.readiness(token).writable {
253 return Poll::Ready(Ok(()));
254 }
255 }
256 Poll::Pending
257 }
258
259 pub async fn readable(&mut self) -> io::Result<()> {
264 std::future::poll_fn(|cx| self.poll_read_ready(cx)).await
265 }
266
267 pub async fn writable(&mut self) -> io::Result<()> {
269 std::future::poll_fn(|cx| self.poll_write_ready(cx)).await
270 }
271
272 pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
287 let ptr = std::ptr::from_mut(self);
288 (
289 ReadHalf {
290 stream: ptr,
291 _marker: std::marker::PhantomData,
292 },
293 WriteHalf {
294 stream: ptr,
295 _marker: std::marker::PhantomData,
296 },
297 )
298 }
299
300 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
308 use std::rc::Rc;
309 let shared = Rc::new(std::cell::UnsafeCell::new(self));
310 (
311 OwnedReadHalf {
312 stream: Rc::clone(&shared),
313 },
314 OwnedWriteHalf { stream: shared },
315 )
316 }
317
318 #[inline(always)]
328 fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
329 let task_ptr = waker_to_ptr(cx);
330 if let Some(token) = self.token {
331 if task_ptr != self.registered_task {
333 self.io.set_waker(token, cx.waker().clone());
334 self.registered_task = task_ptr;
335 }
336 return Ok(());
337 }
338 self.do_register(task_ptr, cx.waker().clone())
339 }
340
341 #[cold]
342 fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
343 let interest = Interest::READABLE | Interest::WRITABLE;
344 let token = self.io.register(&mut self.inner, interest, waker)?;
345 self.token = Some(token);
346 self.registered_task = task_ptr;
347 Ok(())
348 }
349}
350
351impl AsyncRead for TcpStream {
352 fn poll_read(
353 self: Pin<&mut Self>,
354 cx: &mut Context<'_>,
355 buf: &mut [u8],
356 ) -> Poll<io::Result<usize>> {
357 let this = self.get_mut();
358 if let Err(e) = this.ensure_registered(cx) {
359 return Poll::Ready(Err(e));
360 }
361 match this.inner.read(buf) {
362 Ok(n) => Poll::Ready(Ok(n)),
363 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
364 if let Some(token) = this.token {
366 this.io.clear_readable(token);
367 }
368 Poll::Pending
369 }
370 Err(e) => Poll::Ready(Err(e)),
371 }
372 }
373}
374
375impl AsyncWrite for TcpStream {
376 fn poll_write(
377 self: Pin<&mut Self>,
378 cx: &mut Context<'_>,
379 buf: &[u8],
380 ) -> Poll<io::Result<usize>> {
381 let this = self.get_mut();
382 if let Err(e) = this.ensure_registered(cx) {
383 return Poll::Ready(Err(e));
384 }
385 match this.inner.write(buf) {
386 Ok(n) => Poll::Ready(Ok(n)),
387 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
388 if let Some(token) = this.token {
389 this.io.clear_writable(token);
390 }
391 Poll::Pending
392 }
393 Err(e) => Poll::Ready(Err(e)),
394 }
395 }
396
397 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
398 let this = self.get_mut();
399 if let Err(e) = this.ensure_registered(cx) {
400 return Poll::Ready(Err(e));
401 }
402 match this.inner.flush() {
403 Ok(()) => Poll::Ready(Ok(())),
404 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
405 if let Some(token) = this.token {
406 this.io.clear_writable(token);
407 }
408 Poll::Pending
409 }
410 Err(e) => Poll::Ready(Err(e)),
411 }
412 }
413
414 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
415 let this = self.get_mut();
416 match this.inner.shutdown(std::net::Shutdown::Write) {
417 Ok(()) => Poll::Ready(Ok(())),
418 Err(e) if e.kind() == io::ErrorKind::NotConnected => Poll::Ready(Ok(())),
419 Err(e) => Poll::Ready(Err(e)),
420 }
421 }
422}
423
424impl std::fmt::Debug for TcpStream {
425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 f.debug_struct("TcpStream")
427 .field("fd", &self.inner.as_raw_fd())
428 .field("registered", &self.token.is_some())
429 .finish()
430 }
431}
432
433impl AsFd for TcpStream {
434 fn as_fd(&self) -> BorrowedFd<'_> {
435 self.inner.as_fd()
436 }
437}
438
439impl AsRawFd for TcpStream {
440 fn as_raw_fd(&self) -> RawFd {
441 self.inner.as_raw_fd()
442 }
443}
444
445impl Drop for TcpStream {
446 fn drop(&mut self) {
447 if let Some(token) = self.token {
448 let _ = unsafe { self.io.deregister(&mut self.inner, token) };
450 }
451 }
452}
453
454pub struct ReadHalf<'a> {
463 stream: *mut TcpStream,
464 _marker: std::marker::PhantomData<&'a mut TcpStream>,
466}
467
468impl ReadHalf<'_> {
472 fn stream(&mut self) -> &mut TcpStream {
473 unsafe { &mut *self.stream }
475 }
476}
477
478impl AsyncRead for ReadHalf<'_> {
479 fn poll_read(
480 self: Pin<&mut Self>,
481 cx: &mut Context<'_>,
482 buf: &mut [u8],
483 ) -> Poll<io::Result<usize>> {
484 let this = self.get_mut();
485 Pin::new(this.stream()).poll_read(cx, buf)
486 }
487}
488
489pub struct WriteHalf<'a> {
494 stream: *mut TcpStream,
495 _marker: std::marker::PhantomData<&'a mut TcpStream>,
496}
497
498impl WriteHalf<'_> {
499 fn stream(&mut self) -> &mut TcpStream {
500 unsafe { &mut *self.stream }
502 }
503}
504
505impl AsyncWrite for WriteHalf<'_> {
506 fn poll_write(
507 self: Pin<&mut Self>,
508 cx: &mut Context<'_>,
509 buf: &[u8],
510 ) -> Poll<io::Result<usize>> {
511 let this = self.get_mut();
512 Pin::new(this.stream()).poll_write(cx, buf)
513 }
514
515 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
516 let this = self.get_mut();
517 Pin::new(this.stream()).poll_flush(cx)
518 }
519
520 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
521 let this = self.get_mut();
522 Pin::new(this.stream()).poll_shutdown(cx)
523 }
524}
525
526pub struct OwnedReadHalf {
534 stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
535}
536
537impl OwnedReadHalf {
538 pub fn reunite(self, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
542 if std::rc::Rc::ptr_eq(&self.stream, &write.stream) {
543 drop(write);
544 let cell = std::rc::Rc::try_unwrap(self.stream).map_err(|_| ReuniteError)?;
545 Ok(cell.into_inner())
546 } else {
547 Err(ReuniteError)
548 }
549 }
550
551 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
553 unsafe { &*self.stream.get() }.peer_addr()
555 }
556
557 pub fn local_addr(&self) -> io::Result<SocketAddr> {
559 unsafe { &*self.stream.get() }.local_addr()
560 }
561}
562
563impl AsyncRead for OwnedReadHalf {
564 fn poll_read(
565 self: Pin<&mut Self>,
566 cx: &mut Context<'_>,
567 buf: &mut [u8],
568 ) -> Poll<io::Result<usize>> {
569 let stream = unsafe { &mut *self.stream.get() };
571 Pin::new(stream).poll_read(cx, buf)
572 }
573}
574
575pub struct OwnedWriteHalf {
579 stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
580}
581
582impl OwnedWriteHalf {
583 pub fn reunite(self, read: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
585 read.reunite(self)
586 }
587
588 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
590 unsafe { &*self.stream.get() }.peer_addr()
591 }
592
593 pub fn local_addr(&self) -> io::Result<SocketAddr> {
595 unsafe { &*self.stream.get() }.local_addr()
596 }
597}
598
599impl AsyncWrite for OwnedWriteHalf {
600 fn poll_write(
601 self: Pin<&mut Self>,
602 cx: &mut Context<'_>,
603 buf: &[u8],
604 ) -> Poll<io::Result<usize>> {
605 let stream = unsafe { &mut *self.stream.get() };
607 Pin::new(stream).poll_write(cx, buf)
608 }
609
610 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
611 let stream = unsafe { &mut *self.stream.get() };
612 Pin::new(stream).poll_flush(cx)
613 }
614
615 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
616 let stream = unsafe { &mut *self.stream.get() };
617 Pin::new(stream).poll_shutdown(cx)
618 }
619}
620
621#[derive(Debug)]
623pub struct ReuniteError;
624
625impl std::fmt::Display for ReuniteError {
626 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
627 write!(f, "halves do not belong to the same TcpStream")
628 }
629}
630
631impl std::error::Error for ReuniteError {}
632
633pub struct TcpListener {
642 inner: mio::net::TcpListener,
643 io: IoHandle,
644 token: Option<Token>,
645 registered_task: *mut u8,
646}
647
648impl TcpListener {
649 pub fn bind(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
651 let inner = mio::net::TcpListener::bind(addr)?;
652 Ok(Self {
653 inner,
654 io,
655 token: None,
656 registered_task: std::ptr::null_mut(),
657 })
658 }
659
660 pub fn from_std(listener: std::net::TcpListener, io: IoHandle) -> io::Result<Self> {
662 let inner = mio::net::TcpListener::from_std(listener);
663 Ok(Self {
664 inner,
665 io,
666 token: None,
667 registered_task: std::ptr::null_mut(),
668 })
669 }
670
671 pub fn local_addr(&self) -> io::Result<SocketAddr> {
673 self.inner.local_addr()
674 }
675
676 pub fn ttl(&self) -> io::Result<u32> {
678 socket2::SockRef::from(&self.inner).ttl()
679 }
680
681 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
683 socket2::SockRef::from(&self.inner).set_ttl(ttl)
684 }
685
686 pub fn accept(&mut self) -> Accept<'_> {
688 Accept { listener: self }
689 }
690
691 #[inline(always)]
693 fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
694 let task_ptr = waker_to_ptr(cx);
695 if let Some(token) = self.token {
696 if task_ptr != self.registered_task {
697 self.io.set_waker(token, cx.waker().clone());
698 self.registered_task = task_ptr;
699 }
700 return Ok(());
701 }
702 self.do_register(task_ptr, cx.waker().clone())
703 }
704
705 #[cold]
706 fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
707 let token = self
708 .io
709 .register(&mut self.inner, Interest::READABLE, waker)?;
710 self.token = Some(token);
711 self.registered_task = task_ptr;
712 Ok(())
713 }
714}
715
716impl std::fmt::Debug for TcpListener {
717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718 f.debug_struct("TcpListener")
719 .field("fd", &self.inner.as_raw_fd())
720 .field("registered", &self.token.is_some())
721 .finish()
722 }
723}
724
725impl AsFd for TcpListener {
726 fn as_fd(&self) -> BorrowedFd<'_> {
727 self.inner.as_fd()
728 }
729}
730
731impl AsRawFd for TcpListener {
732 fn as_raw_fd(&self) -> RawFd {
733 self.inner.as_raw_fd()
734 }
735}
736
737impl Drop for TcpListener {
738 fn drop(&mut self) {
739 if let Some(token) = self.token {
740 let _ = unsafe { self.io.deregister(&mut self.inner, token) };
741 }
742 }
743}
744
745pub struct Accept<'a> {
747 listener: &'a mut TcpListener,
748}
749
750impl std::future::Future for Accept<'_> {
751 type Output = io::Result<(TcpStream, SocketAddr)>;
752
753 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
754 let this = self.get_mut();
755 if let Err(e) = this.listener.ensure_registered(cx) {
756 return Poll::Ready(Err(e));
757 }
758 match this.listener.inner.accept() {
759 Ok((stream, addr)) => {
760 let tcp = TcpStream::new(stream, this.listener.io);
761 Poll::Ready(Ok((tcp, addr)))
762 }
763 Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
764 Err(e) => Poll::Ready(Err(e)),
765 }
766 }
767}
768
769pub struct TcpSocket {
788 inner: socket2::Socket,
789}
790
791impl TcpSocket {
792 pub fn new_v4() -> io::Result<Self> {
794 let inner = socket2::Socket::new(
795 socket2::Domain::IPV4,
796 socket2::Type::STREAM,
797 Some(socket2::Protocol::TCP),
798 )?;
799 inner.set_nonblocking(true)?;
800 Ok(Self { inner })
801 }
802
803 pub fn new_v6() -> io::Result<Self> {
805 let inner = socket2::Socket::new(
806 socket2::Domain::IPV6,
807 socket2::Type::STREAM,
808 Some(socket2::Protocol::TCP),
809 )?;
810 inner.set_nonblocking(true)?;
811 Ok(Self { inner })
812 }
813
814 pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
818 self.inner.set_reuse_address(reuseaddr)
819 }
820
821 pub fn reuseaddr(&self) -> io::Result<bool> {
823 self.inner.reuse_address()
824 }
825
826 #[cfg(unix)]
828 pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> {
829 self.inner.set_reuse_port(reuseport)
830 }
831
832 #[cfg(unix)]
834 pub fn reuseport(&self) -> io::Result<bool> {
835 self.inner.reuse_port()
836 }
837
838 pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
840 self.inner.set_keepalive(keepalive)
841 }
842
843 pub fn keepalive(&self) -> io::Result<bool> {
845 self.inner.keepalive()
846 }
847
848 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
850 self.inner.set_nodelay(nodelay)
851 }
852
853 pub fn nodelay(&self) -> io::Result<bool> {
855 self.inner.nodelay()
856 }
857
858 pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
860 self.inner.set_linger(duration)
861 }
862
863 pub fn linger(&self) -> io::Result<Option<Duration>> {
865 self.inner.linger()
866 }
867
868 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
870 self.inner.set_send_buffer_size(size)
871 }
872
873 pub fn send_buffer_size(&self) -> io::Result<usize> {
875 self.inner.send_buffer_size()
876 }
877
878 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
880 self.inner.set_recv_buffer_size(size)
881 }
882
883 pub fn recv_buffer_size(&self) -> io::Result<usize> {
885 self.inner.recv_buffer_size()
886 }
887
888 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
890 self.inner.set_ttl(ttl)
891 }
892
893 pub fn ttl(&self) -> io::Result<u32> {
895 self.inner.ttl()
896 }
897
898 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
902 self.inner.bind(&addr.into())
903 }
904
905 pub fn connect(self, addr: SocketAddr, io: IoHandle) -> io::Result<TcpStream> {
911 match self.inner.connect(&addr.into()) {
914 Ok(()) => {}
915 Err(e)
916 if e.raw_os_error() == Some(libc::EINPROGRESS)
917 || e.raw_os_error() == Some(libc::EALREADY) => {}
918 Err(e) => return Err(e),
919 }
920 let std_stream: std::net::TcpStream = self.inner.into();
921 let mio_stream = mio::net::TcpStream::from_std(std_stream);
922 Ok(TcpStream::new(mio_stream, io))
923 }
924
925 pub fn listen(self, backlog: i32, io: IoHandle) -> io::Result<TcpListener> {
927 self.inner.listen(backlog)?;
928 let std_listener: std::net::TcpListener = self.inner.into();
929 let mio_listener = mio::net::TcpListener::from_std(std_listener);
930 Ok(TcpListener {
931 inner: mio_listener,
932 io,
933 token: None,
934 registered_task: std::ptr::null_mut(),
935 })
936 }
937}
938
939impl std::fmt::Debug for TcpSocket {
940 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941 f.debug_struct("TcpSocket")
942 .field("fd", &self.inner.as_raw_fd())
943 .finish()
944 }
945}
946
947impl AsFd for TcpSocket {
948 fn as_fd(&self) -> BorrowedFd<'_> {
949 self.inner.as_fd()
950 }
951}
952
953impl AsRawFd for TcpSocket {
954 fn as_raw_fd(&self) -> RawFd {
955 self.inner.as_raw_fd()
956 }
957}
958
959#[cfg(test)]
964mod tests {
965 use super::*;
966 use crate::{Runtime, spawn_boxed};
967 use nexus_rt::WorldBuilder;
968 use std::cell::Cell;
969 use std::rc::Rc;
970
971 #[test]
972 fn tcp_echo() {
973 let wb = WorldBuilder::new();
974 let mut world = wb.build();
975 let mut rt = Runtime::new(&mut world);
976
977 let done = Rc::new(Cell::new(false));
978 let done2 = done.clone();
979
980 rt.block_on(async move {
981 let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap(), crate::context::io())
982 .expect("bind failed");
983 let addr = listener.local_addr().unwrap();
984 spawn_boxed(async move {
985 let mut listener = listener;
986 let (mut stream, _peer) = listener.accept().await.unwrap();
987 let mut buf = [0u8; 64];
988 let n = stream.read(&mut buf).await.unwrap();
989 stream.write_all(&buf[..n]).await.unwrap();
990 });
991
992 let io = crate::context::io();
993 let flag = done2;
994 spawn_boxed(async move {
995 crate::context::sleep(std::time::Duration::from_millis(10)).await;
996 let mut client = TcpStream::connect(addr, io).unwrap();
997 client.write_all(b"hello").await.unwrap();
998 let mut buf = [0u8; 64];
999 let n = client.read(&mut buf).await.unwrap();
1000 assert_eq!(&buf[..n], b"hello");
1001 flag.set(true);
1002 });
1003
1004 crate::context::sleep(std::time::Duration::from_millis(500)).await;
1005 });
1006
1007 assert!(done.get(), "echo exchange never completed");
1008 }
1009
1010 #[test]
1011 fn tcp_socket_builder() {
1012 let socket = TcpSocket::new_v4().unwrap();
1013 socket.set_reuseaddr(true).unwrap();
1014 assert!(socket.reuseaddr().unwrap());
1015 socket.set_nodelay(true).unwrap();
1016 assert!(socket.nodelay().unwrap());
1017 socket.set_send_buffer_size(65536).unwrap();
1018 assert!(socket.send_buffer_size().unwrap() >= 65536);
1020 }
1021}