#![warn(rust_2018_idioms)]
#![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))]
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use tokio::runtime::Builder;
const TASKS: usize = 8;
const ITERATIONS: usize = 64;
#[test]
fn spawn_task_hook_fires() {
let count = Arc::new(AtomicUsize::new(0));
let count2 = Arc::clone(&count);
let ids = Arc::new(Mutex::new(HashSet::new()));
let ids2 = Arc::clone(&ids);
let runtime = Builder::new_current_thread()
.on_task_spawn(move |data| {
ids2.lock().unwrap().insert(data.id());
count2.fetch_add(1, Ordering::SeqCst);
})
.build()
.unwrap();
for _ in 0..TASKS {
runtime.spawn(std::future::pending::<()>());
}
let count_realized = count.load(Ordering::SeqCst);
assert_eq!(
TASKS, count_realized,
"Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}",
count_realized
);
let count_ids_realized = ids.lock().unwrap().len();
assert_eq!(
TASKS, count_ids_realized,
"Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}",
count_realized
);
}
#[test]
fn terminate_task_hook_fires() {
let count = Arc::new(AtomicUsize::new(0));
let count2 = Arc::clone(&count);
let runtime = Builder::new_current_thread()
.on_task_terminate(move |_data| {
count2.fetch_add(1, Ordering::SeqCst);
})
.build()
.unwrap();
for _ in 0..TASKS {
runtime.spawn(std::future::ready(()));
}
runtime.block_on(async {
for _ in 0..ITERATIONS {
tokio::task::yield_now().await;
}
});
assert_eq!(TASKS, count.load(Ordering::SeqCst));
}
#[test]
fn task_hook_spawn_location_current_thread() {
let spawns = Arc::new(AtomicUsize::new(0));
let poll_starts = Arc::new(AtomicUsize::new(0));
let poll_ends = Arc::new(AtomicUsize::new(0));
let runtime = Builder::new_current_thread()
.on_task_spawn(mk_spawn_location_hook(
"(current_thread) on_task_spawn",
&spawns,
))
.on_before_task_poll(mk_spawn_location_hook(
"(current_thread) on_before_task_poll",
&poll_starts,
))
.on_after_task_poll(mk_spawn_location_hook(
"(current_thread) on_after_task_poll",
&poll_ends,
))
.build()
.unwrap();
let task = runtime.spawn(async move { tokio::task::yield_now().await });
runtime.block_on(async move {
task.await.unwrap();
tokio::spawn(async move {}).await.unwrap();
for _ in 0..ITERATIONS {
tokio::task::yield_now().await;
}
});
assert_eq!(spawns.load(Ordering::SeqCst), 2);
let poll_starts = poll_starts.load(Ordering::SeqCst);
assert!(poll_starts > 2);
assert_eq!(poll_starts, poll_ends.load(Ordering::SeqCst));
}
#[cfg_attr(
target_os = "wasi",
ignore = "WASI does not support multi-threaded runtime"
)]
#[test]
fn task_hook_spawn_location_multi_thread() {
let spawns = Arc::new(AtomicUsize::new(0));
let poll_starts = Arc::new(AtomicUsize::new(0));
let poll_ends = Arc::new(AtomicUsize::new(0));
let runtime = Builder::new_multi_thread()
.on_task_spawn(mk_spawn_location_hook(
"(multi_thread) on_task_spawn",
&spawns,
))
.on_before_task_poll(mk_spawn_location_hook(
"(multi_thread) on_before_task_poll",
&poll_starts,
))
.on_after_task_poll(mk_spawn_location_hook(
"(multi_thread) on_after_task_poll",
&poll_ends,
))
.build()
.unwrap();
let task = runtime.spawn(async move { tokio::task::yield_now().await });
runtime.block_on(async move {
task.await.unwrap();
tokio::spawn(async move {}).await.unwrap();
for _ in 0..ITERATIONS {
tokio::task::yield_now().await;
}
});
runtime.shutdown_timeout(std::time::Duration::from_secs(60));
assert_eq!(spawns.fetch_add(0, Ordering::SeqCst), 2);
let poll_starts = poll_starts.fetch_add(0, Ordering::SeqCst);
assert!(poll_starts > 2);
assert_eq!(poll_starts, poll_ends.fetch_add(0, Ordering::SeqCst));
}
fn mk_spawn_location_hook(
event: &'static str,
count: &Arc<AtomicUsize>,
) -> impl Fn(&tokio::runtime::TaskMeta<'_>) {
let count = Arc::clone(&count);
move |data| {
eprintln!("{event} ({:?}): {:?}", data.id(), data.spawned_at());
assert_eq!(
data.spawned_at().file(),
file!(),
"incorrect spawn location in {event} hook",
);
count.fetch_add(1, Ordering::SeqCst);
}
}