tokio 1.47.4

An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.
Documentation
#![allow(unknown_lints, unexpected_cfgs)]
#![cfg(tokio_unstable)]

use std::sync::{atomic::AtomicUsize, Arc, Mutex};

use tokio::task::yield_now;

#[cfg(not(target_os = "wasi"))]
#[test]
fn callbacks_fire_multi_thread() {
    let poll_start_counter = Arc::new(AtomicUsize::new(0));
    let poll_stop_counter = Arc::new(AtomicUsize::new(0));
    let poll_start = poll_start_counter.clone();
    let poll_stop = poll_stop_counter.clone();

    let before_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));
    let after_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));

    let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id);
    let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id);
    let rt = tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .on_before_task_poll(move |task_meta| {
            before_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .on_after_task_poll(move |task_meta| {
            after_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .build()
        .unwrap();
    let task = rt.spawn(async {
        yield_now().await;
        yield_now().await;
        yield_now().await;
    });

    let spawned_task_id = task.id();

    rt.block_on(task).expect("task should succeed");
    // We need to drop the runtime to guarantee the workers have exited (and thus called the callback)
    drop(rt);

    assert_eq!(
        before_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    assert_eq!(
        after_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    let actual_count = 4;
    assert_eq!(
        poll_start.load(std::sync::atomic::Ordering::Relaxed),
        actual_count,
        "unexpected number of poll starts"
    );
    assert_eq!(
        poll_stop.load(std::sync::atomic::Ordering::Relaxed),
        actual_count,
        "unexpected number of poll stops"
    );
}

#[test]
fn callbacks_fire_current_thread() {
    let poll_start_counter = Arc::new(AtomicUsize::new(0));
    let poll_stop_counter = Arc::new(AtomicUsize::new(0));
    let poll_start = poll_start_counter.clone();
    let poll_stop = poll_stop_counter.clone();

    let before_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));
    let after_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));

    let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id);
    let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id);
    let rt = tokio::runtime::Builder::new_current_thread()
        .enable_all()
        .on_before_task_poll(move |task_meta| {
            before_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .on_after_task_poll(move |task_meta| {
            after_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .build()
        .unwrap();

    let task = rt.spawn(async {
        yield_now().await;
        yield_now().await;
        yield_now().await;
    });

    let spawned_task_id = task.id();

    let _ = rt.block_on(task);
    drop(rt);

    assert_eq!(
        before_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    assert_eq!(
        after_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4);
    assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4);
}