tokio 1.37.0

An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.
Documentation
#![warn(rust_2018_idioms)]
#![cfg(all(feature = "full", not(target_os = "wasi")))] // Wasi doesn't support bind

use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio_test::assert_ok;

use std::io;
use std::net::{IpAddr, SocketAddr};

macro_rules! test_accept {
    ($(($ident:ident, $target:expr),)*) => {
        $(
            #[tokio::test]
            async fn $ident() {
                let listener = assert_ok!(TcpListener::bind($target).await);
                let addr = listener.local_addr().unwrap();

                let (tx, rx) = oneshot::channel();

                tokio::spawn(async move {
                    let (socket, _) = assert_ok!(listener.accept().await);
                    assert_ok!(tx.send(socket));
                });

                let cli = assert_ok!(TcpStream::connect(&addr).await);
                let srv = assert_ok!(rx.await);

                assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
            }
        )*
    }
}

test_accept! {
    (ip_str, "127.0.0.1:0"),
    (host_str, "localhost:0"),
    (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
    (str_port_tuple, ("127.0.0.1", 0)),
    (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
}

use std::pin::Pin;
use std::sync::{
    atomic::{AtomicUsize, Ordering::SeqCst},
    Arc,
};
use std::task::{Context, Poll};
use tokio_stream::{Stream, StreamExt};

struct TrackPolls<'a> {
    npolls: Arc<AtomicUsize>,
    listener: &'a mut TcpListener,
}

impl<'a> Stream for TrackPolls<'a> {
    type Item = io::Result<(TcpStream, SocketAddr)>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.npolls.fetch_add(1, SeqCst);
        self.listener.poll_accept(cx).map(Some)
    }
}

#[tokio::test]
async fn no_extra_poll() {
    let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
    let addr = listener.local_addr().unwrap();

    let (tx, rx) = oneshot::channel();
    let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();

    tokio::spawn(async move {
        let mut incoming = TrackPolls {
            npolls: Arc::new(AtomicUsize::new(0)),
            listener: &mut listener,
        };
        assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
        while incoming.next().await.is_some() {
            accepted_tx.send(()).unwrap();
        }
    });

    let npolls = assert_ok!(rx.await);
    tokio::task::yield_now().await;

    // should have been polled exactly once: the initial poll
    assert_eq!(npolls.load(SeqCst), 1);

    let _ = assert_ok!(TcpStream::connect(&addr).await);
    accepted_rx.recv().await.unwrap();

    // should have been polled twice more: once to yield Some(), then once to yield Pending
    assert_eq!(npolls.load(SeqCst), 1 + 2);
}

#[tokio::test]
async fn accept_many() {
    use futures::future::poll_fn;
    use std::future::Future;
    use std::sync::atomic::AtomicBool;

    const N: usize = 50;

    let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
    let listener = Arc::new(listener);
    let addr = listener.local_addr().unwrap();
    let connected = Arc::new(AtomicBool::new(false));

    let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
    let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();

    for _ in 0..N {
        let listener = listener.clone();
        let connected = connected.clone();
        let pending_tx = pending_tx.clone();
        let notified_tx = notified_tx.clone();

        tokio::spawn(async move {
            let accept = listener.accept();
            tokio::pin!(accept);

            let mut polled = false;

            poll_fn(|cx| {
                if !polled {
                    polled = true;
                    assert!(Pin::new(&mut accept).poll(cx).is_pending());
                    pending_tx.send(()).unwrap();
                    Poll::Pending
                } else if connected.load(SeqCst) {
                    notified_tx.send(()).unwrap();
                    Poll::Ready(())
                } else {
                    Poll::Pending
                }
            })
            .await;

            pending_tx.send(()).unwrap();
        });
    }

    // Wait for all tasks to have polled at least once
    for _ in 0..N {
        pending_rx.recv().await.unwrap();
    }

    // Establish a TCP connection
    connected.store(true, SeqCst);
    let _sock = TcpStream::connect(addr).await.unwrap();

    // Wait for all notifications
    for _ in 0..N {
        notified_rx.recv().await.unwrap();
    }
}