use std::alloc::{GlobalAlloc, Layout, System};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use nexus_async_rt::CancellationToken;
struct CountingAllocator {
counting_active: AtomicBool,
allocs: AtomicUsize,
}
unsafe impl GlobalAlloc for CountingAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
if self.counting_active.load(Ordering::Relaxed) {
self.allocs.fetch_add(1, Ordering::Relaxed);
}
unsafe { System.alloc(layout) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { System.dealloc(ptr, layout) };
}
}
#[global_allocator]
static ALLOC: CountingAllocator = CountingAllocator {
counting_active: AtomicBool::new(false),
allocs: AtomicUsize::new(0),
};
fn start_counting() {
ALLOC.allocs.store(0, Ordering::Relaxed);
ALLOC.counting_active.store(true, Ordering::Relaxed);
}
fn stop_counting() -> usize {
ALLOC.counting_active.store(false, Ordering::Relaxed);
ALLOC.allocs.load(Ordering::Relaxed)
}
fn tracking_waker(flag: &std::cell::Cell<bool>) -> Waker {
let data = flag as *const std::cell::Cell<bool> as *const ();
static VTABLE: RawWakerVTable = RawWakerVTable::new(
|p| RawWaker::new(p, &VTABLE),
|p| {
let flag = unsafe { &*(p as *const std::cell::Cell<bool>) };
flag.set(true);
},
|p| {
let flag = unsafe { &*(p as *const std::cell::Cell<bool>) };
flag.set(true);
},
|_| {},
);
unsafe { Waker::from_raw(RawWaker::new(data, &VTABLE)) }
}
fn poll_with<F: std::future::Future>(f: Pin<&mut F>, w: &Waker) -> Poll<F::Output> {
let mut cx = Context::from_waker(w);
f.poll(&mut cx)
}
#[test]
fn no_allocation_on_repoll_across_wakers() {
let token = CancellationToken::new();
let flags: Vec<std::cell::Cell<bool>> = (0..5).map(|_| std::cell::Cell::new(false)).collect();
let wakers: Vec<Waker> = flags.iter().map(tracking_waker).collect();
let mut fut = Box::pin(token.cancelled());
assert!(matches!(poll_with(fut.as_mut(), &wakers[0]), Poll::Pending));
for _ in 0..2 {
for waker in wakers.iter() {
assert!(matches!(poll_with(fut.as_mut(), waker), Poll::Pending));
}
}
start_counting();
for i in 0..100 {
assert!(matches!(
poll_with(fut.as_mut(), &wakers[i % 5]),
Poll::Pending
));
}
let allocs = stop_counting();
assert_eq!(
allocs, 0,
"PR 3 (BUG-3 fix): 100 re-polls across cycling wakers must \
allocate 0 times. Pre-fix this was 100 (one Box<WaiterNode> \
per re-poll). Got {allocs}."
);
token.cancel();
let _ = poll_with(fut.as_mut(), &wakers[0]);
}