#![warn(rust_2018_idioms)]
#![cfg(not(target_os = "wasi"))]
use std::rc::Rc;
use std::sync::Arc;
use tokio::sync::Barrier;
use tokio_util::task;
#[tokio::test]
async fn can_spawn_not_send_future() {
let pool = task::LocalPoolHandle::new(1);
let output = pool
.spawn_pinned(|| {
let local_data = Rc::new("test");
async move { local_data.to_string() }
})
.await
.unwrap();
assert_eq!(output, "test");
}
#[test]
fn can_drop_future_and_still_get_output() {
let pool = task::LocalPoolHandle::new(1);
let (sender, receiver) = std::sync::mpsc::channel();
let _ = pool.spawn_pinned(move || {
let local_data = Rc::new("test");
async move {
let _ = sender.send(local_data.to_string());
}
});
assert_eq!(receiver.recv(), Ok("test".to_string()));
}
#[test]
#[should_panic(expected = "assertion failed: pool_size > 0")]
fn cannot_create_zero_sized_pool() {
let _pool = task::LocalPoolHandle::new(0);
}
#[tokio::test]
async fn can_spawn_multiple_futures() {
let pool = task::LocalPoolHandle::new(2);
let join_handle1 = pool.spawn_pinned(|| {
let local_data = Rc::new("test1");
async move { local_data.to_string() }
});
let join_handle2 = pool.spawn_pinned(|| {
let local_data = Rc::new("test2");
async move { local_data.to_string() }
});
assert_eq!(join_handle1.await.unwrap(), "test1");
assert_eq!(join_handle2.await.unwrap(), "test2");
}
#[tokio::test]
async fn task_panic_propagates() {
let pool = task::LocalPoolHandle::new(1);
let join_handle = pool.spawn_pinned(|| async {
panic!("Test panic");
});
let result = join_handle.await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.is_panic());
let panic_str: &str = *error.into_panic().downcast().unwrap();
assert_eq!(panic_str, "Test panic");
let join_handle = pool.spawn_pinned(|| async { "test" });
let result = join_handle.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test");
}
#[tokio::test]
async fn callback_panic_does_not_kill_worker() {
let pool = task::LocalPoolHandle::new(1);
let join_handle = pool.spawn_pinned(|| {
panic!("Test panic");
#[allow(unreachable_code)]
async {}
});
let result = join_handle.await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.is_panic());
let panic_str: &str = *error.into_panic().downcast().unwrap();
assert_eq!(panic_str, "Test panic");
let join_handle = pool.spawn_pinned(|| async { "test" });
let result = join_handle.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test");
}
#[tokio::test]
async fn task_cancellation_propagates() {
let pool = task::LocalPoolHandle::new(1);
let notify_dropped = Arc::new(());
let weak_notify_dropped = Arc::downgrade(¬ify_dropped);
let (start_sender, start_receiver) = tokio::sync::oneshot::channel();
let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>();
let join_handle = pool.spawn_pinned(|| async move {
let _drop_sender = drop_sender;
let _notify_dropped = notify_dropped;
let _ = start_sender.send(());
futures::future::pending::<()>().await;
});
let _ = start_receiver.await;
join_handle.abort();
let _ = drop_receiver.await;
assert!(weak_notify_dropped.upgrade().is_none());
}
#[tokio::test]
async fn tasks_are_balanced() {
let pool = task::LocalPoolHandle::new(2);
let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel();
let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel();
let join_handle1 = pool.spawn_pinned(|| async move {
let _ = start_sender1.send(());
let _ = end_receiver1.await;
std::thread::current().id()
});
let _ = start_receiver1.await;
let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel();
let join_handle2 = pool.spawn_pinned(|| async move {
let _ = start_sender2.send(());
std::thread::current().id()
});
let _ = start_receiver2.await;
let _ = end_sender1.send(());
let thread_id1 = join_handle1.await.unwrap();
let thread_id2 = join_handle2.await.unwrap();
assert_ne!(thread_id1, thread_id2);
}
#[tokio::test]
async fn spawn_by_idx() {
let pool = task::LocalPoolHandle::new(3);
let barrier = Arc::new(Barrier::new(4));
let barrier1 = barrier.clone();
let barrier2 = barrier.clone();
let barrier3 = barrier.clone();
let handle1 = pool.spawn_pinned_by_idx(
|| async move {
barrier1.wait().await;
std::thread::current().id()
},
0,
);
let _ = pool.spawn_pinned_by_idx(
|| async move {
barrier2.wait().await;
std::thread::current().id()
},
0,
);
let handle2 = pool.spawn_pinned_by_idx(
|| async move {
barrier3.wait().await;
std::thread::current().id()
},
1,
);
let loads = pool.get_task_loads_for_each_worker();
barrier.wait().await;
assert_eq!(loads[0], 2);
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 0);
let thread_id1 = handle1.await.unwrap();
let thread_id2 = handle2.await.unwrap();
assert_ne!(thread_id1, thread_id2);
}