#![cfg_attr(feature = "fail-on-warnings", deny(warnings))]
#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use std::{
cell::RefCell,
future::Future,
sync::{LazyLock, RwLock, atomic::AtomicU64},
};
use switchy::unsync::util::CancellationToken;
static WORKER_THREAD_ID_COUNTER: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(1));
thread_local! {
static WORKER_THREAD_ID: RefCell<u64> = RefCell::new(WORKER_THREAD_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst));
}
#[must_use]
pub fn worker_thread_id() -> u64 {
WORKER_THREAD_ID.with_borrow(|x| *x)
}
thread_local! {
static SIMULATOR_CANCELLATION_TOKEN: RefCell<RwLock<CancellationToken>> =
RefCell::new(RwLock::new(CancellationToken::new()));
}
pub fn reset_simulator_cancellation_token() {
SIMULATOR_CANCELLATION_TOKEN
.with_borrow_mut(|x| *x.write().unwrap() = CancellationToken::new());
}
#[must_use]
pub fn is_simulator_cancelled() -> bool {
is_global_simulator_cancelled()
|| SIMULATOR_CANCELLATION_TOKEN.with_borrow(|x| x.read().unwrap().is_cancelled())
}
pub fn cancel_simulation() {
SIMULATOR_CANCELLATION_TOKEN.with_borrow(|x| x.read().unwrap().cancel());
}
static GLOBAL_SIMULATOR_CANCELLATION_TOKEN: LazyLock<RwLock<CancellationToken>> =
LazyLock::new(|| RwLock::new(CancellationToken::new()));
pub fn reset_global_simulator_cancellation_token() {
*GLOBAL_SIMULATOR_CANCELLATION_TOKEN.write().unwrap() = CancellationToken::new();
}
#[must_use]
pub fn is_global_simulator_cancelled() -> bool {
GLOBAL_SIMULATOR_CANCELLATION_TOKEN
.read()
.unwrap()
.is_cancelled()
}
pub fn cancel_global_simulation() {
GLOBAL_SIMULATOR_CANCELLATION_TOKEN.read().unwrap().cancel();
}
pub async fn run_until_simulation_cancelled<F>(fut: F) -> Option<F::Output>
where
F: Future,
{
let global_token = GLOBAL_SIMULATOR_CANCELLATION_TOKEN.read().unwrap().clone();
let local_token = SIMULATOR_CANCELLATION_TOKEN.with_borrow(|x| x.read().unwrap().clone());
switchy::unsync::select! {
resp = fut => Some(resp),
() = global_token.cancelled() => None,
() = local_token.cancelled() => None,
}
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use super::*;
#[test_log::test]
#[serial]
fn test_worker_thread_id_returns_unique_ids() {
let id1 = worker_thread_id();
let id2 = worker_thread_id();
assert_eq!(id1, id2);
}
#[test_log::test]
#[serial]
fn test_worker_thread_id_uniqueness_across_threads() {
let id1 = worker_thread_id();
let handle = std::thread::spawn(worker_thread_id);
let id2 = handle.join().unwrap();
assert_ne!(id1, id2);
}
#[test_log::test]
#[serial]
fn test_local_cancellation_isolated_between_threads() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
cancel_simulation();
assert!(is_simulator_cancelled());
let handle = std::thread::spawn(|| {
reset_simulator_cancellation_token();
is_simulator_cancelled()
});
let other_thread_cancelled = handle.join().unwrap();
assert!(
!other_thread_cancelled,
"Local cancellation should not affect other threads"
);
}
#[test_log::test]
#[serial]
fn test_reset_simulator_cancellation_token() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
cancel_simulation();
assert!(is_simulator_cancelled());
reset_simulator_cancellation_token();
assert!(!is_simulator_cancelled());
}
#[test_log::test]
#[serial]
fn test_cancel_simulation_sets_cancelled_state() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
assert!(!is_simulator_cancelled());
cancel_simulation();
assert!(is_simulator_cancelled());
}
#[test_log::test]
#[serial]
fn test_is_simulator_cancelled_respects_global_cancellation() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
assert!(!is_simulator_cancelled());
cancel_global_simulation();
assert!(is_simulator_cancelled());
}
#[test_log::test]
#[serial]
fn test_global_cancellation_independent_from_local() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
cancel_simulation();
assert!(!is_global_simulator_cancelled());
assert!(is_simulator_cancelled());
}
#[test_log::test]
#[serial]
fn test_reset_global_simulator_cancellation_token() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
cancel_global_simulation();
assert!(is_global_simulator_cancelled());
reset_global_simulator_cancellation_token();
assert!(!is_global_simulator_cancelled());
}
#[test_log::test(switchy_async::test)]
#[serial]
async fn test_run_until_simulation_cancelled_completes_normally() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
let result = run_until_simulation_cancelled(async { 42 }).await;
assert_eq!(result, Some(42));
}
#[test_log::test(switchy_async::test)]
#[serial]
async fn test_run_until_simulation_cancelled_with_local_cancellation() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
let cancel_task = async {
cancel_simulation();
};
let work_task = async {
std::future::pending::<()>().await;
42
};
cancel_task.await;
let result = run_until_simulation_cancelled(work_task).await;
assert_eq!(result, None);
}
#[test_log::test(switchy_async::test)]
#[serial]
async fn test_run_until_simulation_cancelled_with_global_cancellation() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
let cancel_task = async {
cancel_global_simulation();
};
let work_task = async {
std::future::pending::<()>().await;
42
};
cancel_task.await;
let result = run_until_simulation_cancelled(work_task).await;
assert_eq!(result, None);
}
#[test_log::test]
#[serial]
fn test_global_cancellation_affects_other_threads() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
assert!(!is_global_simulator_cancelled());
cancel_global_simulation();
let handle = std::thread::spawn(|| {
reset_simulator_cancellation_token();
is_simulator_cancelled()
});
let other_thread_sees_cancellation = handle.join().unwrap();
assert!(
other_thread_sees_cancellation,
"Global cancellation should be visible to all threads"
);
}
#[test_log::test]
#[serial]
fn test_worker_thread_ids_are_monotonically_increasing() {
let mut handles = Vec::new();
for _ in 0..5 {
handles.push(std::thread::spawn(worker_thread_id));
}
let mut ids: Vec<u64> = handles.into_iter().map(|h| h.join().unwrap()).collect();
ids.sort_unstable();
let original_len = ids.len();
ids.dedup();
assert_eq!(ids.len(), original_len, "All thread IDs should be unique");
assert!(ids.iter().all(|&id| id >= 1), "All IDs should be >= 1");
}
#[test_log::test]
#[serial]
fn test_is_simulator_cancelled_with_both_local_and_global_cancelled() {
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
cancel_simulation();
cancel_global_simulation();
assert!(is_simulator_cancelled());
assert!(is_global_simulator_cancelled());
reset_global_simulator_cancellation_token();
assert!(is_simulator_cancelled());
assert!(!is_global_simulator_cancelled());
reset_simulator_cancellation_token();
assert!(!is_simulator_cancelled());
}
#[test_log::test]
#[serial]
fn test_global_cancellation_from_multiple_threads_is_thread_safe() {
reset_global_simulator_cancellation_token();
let mut handles = Vec::new();
for _ in 0..10 {
handles.push(std::thread::spawn(|| {
cancel_global_simulation();
is_global_simulator_cancelled()
}));
}
for handle in handles {
let result = handle.join().unwrap();
assert!(result, "All threads should see global cancellation");
}
assert!(is_global_simulator_cancelled());
}
#[test_log::test(switchy_async::test)]
#[serial]
async fn test_run_until_simulation_cancelled_with_concurrent_cancellation() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
reset_global_simulator_cancellation_token();
reset_simulator_cancellation_token();
let work_started = Arc::new(AtomicBool::new(false));
let work_started_clone = Arc::clone(&work_started);
let work_task = async move {
work_started_clone.store(true, Ordering::SeqCst);
std::future::pending::<()>().await;
42
};
let result = switchy::unsync::select! {
result = run_until_simulation_cancelled(work_task) => result,
() = async {
while !work_started.load(Ordering::SeqCst) {
switchy::unsync::task::yield_now().await;
}
cancel_simulation();
} => None,
};
assert_eq!(result, None, "Task should be cancelled");
}
}