use std::fmt;
use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::task::{Context, Poll};
use tokio::sync::futures::OwnedNotified;
use crate::buffer::Buffer;
use crate::{BuddyArena, FixedArena};
pub enum AsyncPolicy {
Notify,
TreiberWaiters,
}
pub trait Waiter: Send + Sync + 'static {
type Registration: WaitRegistration;
fn register(&self) -> Self::Registration;
fn wake_one(&self);
}
pub trait WaitRegistration: Future<Output = ()> {
fn prepare(self: Pin<&mut Self>);
fn revoke(self: Pin<&mut Self>);
}
pub(crate) trait WakeOne: Send + Sync {
fn wake_one(&self);
}
impl<W: Waiter> WakeOne for W {
fn wake_one(&self) {
Waiter::wake_one(self);
}
}
pub(crate) struct WakeHandle {
inner: Arc<dyn WakeOne>,
}
impl WakeHandle {
pub(crate) fn new<W: Waiter>(waiters: Arc<W>) -> Self {
let inner: Arc<dyn WakeOne> = waiters;
Self { inner }
}
pub(crate) fn wake(&self) {
self.inner.wake_one();
}
}
#[derive(Clone, Default)]
pub struct NotifyWaiters {
notify: Arc<tokio::sync::Notify>,
}
impl NotifyWaiters {
pub fn new() -> Self {
Self {
notify: Arc::new(tokio::sync::Notify::new()),
}
}
}
impl Waiter for NotifyWaiters {
type Registration = NotifyRegistration;
fn register(&self) -> Self::Registration {
NotifyRegistration {
future: self.notify.clone().notified_owned(),
}
}
fn wake_one(&self) {
self.notify.notify_one();
}
}
pub struct NotifyRegistration {
future: OwnedNotified,
}
impl WaitRegistration for NotifyRegistration {
fn prepare(self: Pin<&mut Self>) {
let _ = self.project_future().enable();
}
fn revoke(self: Pin<&mut Self>) {}
}
impl Future for NotifyRegistration {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project_future().poll(cx)
}
}
impl NotifyRegistration {
fn project_future(self: Pin<&mut Self>) -> Pin<&mut OwnedNotified> {
unsafe { self.map_unchecked_mut(|this| &mut this.future) }
}
}
struct WaiterNode {
next: AtomicPtr<WaiterNode>,
notify: Arc<tokio::sync::Notify>,
revoked: AtomicBool,
}
struct TreiberStack {
head: AtomicPtr<WaiterNode>,
}
unsafe impl Send for TreiberStack {}
unsafe impl Sync for TreiberStack {}
impl TreiberStack {
fn new() -> Self {
Self {
head: AtomicPtr::new(ptr::null_mut()),
}
}
fn push(&self, raw: *const WaiterNode) {
let node = raw as *mut WaiterNode;
loop {
let head = self.head.load(Ordering::Relaxed);
unsafe {
(*node).next.store(head, Ordering::Relaxed);
}
if self
.head
.compare_exchange_weak(head, node, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
break;
}
}
}
fn wake_one(&self) {
loop {
let head = self.head.load(Ordering::Acquire);
if head.is_null() {
return;
}
let next = unsafe { (*head).next.load(Ordering::Relaxed) };
if self
.head
.compare_exchange_weak(head, next, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
let node = unsafe { Arc::from_raw(head as *const WaiterNode) };
if node.revoked.load(Ordering::Acquire) {
continue;
}
node.notify.notify_one();
return;
}
}
}
}
impl Drop for TreiberStack {
fn drop(&mut self) {
let mut current = *self.head.get_mut();
while !current.is_null() {
let node = unsafe { Arc::from_raw(current as *const WaiterNode) };
current = node.next.load(Ordering::Relaxed);
}
}
}
pub struct TreiberWaiters {
stack: Arc<TreiberStack>,
}
impl TreiberWaiters {
pub fn new() -> Self {
Self {
stack: Arc::new(TreiberStack::new()),
}
}
}
impl Default for TreiberWaiters {
fn default() -> Self {
Self::new()
}
}
impl Waiter for TreiberWaiters {
type Registration = TreiberRegistration;
fn register(&self) -> Self::Registration {
let node = Arc::new(WaiterNode {
next: AtomicPtr::new(ptr::null_mut()),
notify: Arc::new(tokio::sync::Notify::new()),
revoked: AtomicBool::new(false),
});
TreiberRegistration {
node: Arc::clone(&node),
stack: Arc::clone(&self.stack),
future: node.notify.clone().notified_owned(),
published: false,
}
}
fn wake_one(&self) {
self.stack.wake_one();
}
}
pub struct TreiberRegistration {
node: Arc<WaiterNode>,
stack: Arc<TreiberStack>,
future: OwnedNotified,
published: bool,
}
impl WaitRegistration for TreiberRegistration {
fn prepare(mut self: Pin<&mut Self>) {
let _ = self.as_mut().project_future().enable();
let this = unsafe { self.as_mut().get_unchecked_mut() };
if !this.published {
this.stack.push(Arc::into_raw(Arc::clone(&this.node)));
this.published = true;
}
}
fn revoke(self: Pin<&mut Self>) {
self.node.revoked.store(true, Ordering::Release);
}
}
impl Future for TreiberRegistration {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project_future().poll(cx)
}
}
impl TreiberRegistration {
fn project_future(self: Pin<&mut Self>) -> Pin<&mut OwnedNotified> {
unsafe { self.map_unchecked_mut(|this| &mut this.future) }
}
}
#[doc(hidden)]
pub enum BuiltInWaiters {
Notify(NotifyWaiters),
Treiber(TreiberWaiters),
}
impl Waiter for BuiltInWaiters {
type Registration = BuiltInRegistration;
fn register(&self) -> Self::Registration {
match self {
Self::Notify(waiters) => BuiltInRegistration::Notify(waiters.register()),
Self::Treiber(waiters) => BuiltInRegistration::Treiber(waiters.register()),
}
}
fn wake_one(&self) {
match self {
Self::Notify(waiters) => Waiter::wake_one(waiters),
Self::Treiber(waiters) => Waiter::wake_one(waiters),
}
}
}
#[doc(hidden)]
pub enum BuiltInRegistration {
Notify(NotifyRegistration),
Treiber(TreiberRegistration),
}
impl WaitRegistration for BuiltInRegistration {
fn prepare(self: Pin<&mut Self>) {
unsafe {
match self.get_unchecked_mut() {
Self::Notify(registration) => Pin::new_unchecked(registration).prepare(),
Self::Treiber(registration) => Pin::new_unchecked(registration).prepare(),
}
}
}
fn revoke(self: Pin<&mut Self>) {
unsafe {
match self.get_unchecked_mut() {
Self::Notify(registration) => Pin::new_unchecked(registration).revoke(),
Self::Treiber(registration) => Pin::new_unchecked(registration).revoke(),
}
}
}
}
impl Future for BuiltInRegistration {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
match self.get_unchecked_mut() {
Self::Notify(registration) => Pin::new_unchecked(registration).poll(cx),
Self::Treiber(registration) => Pin::new_unchecked(registration).poll(cx),
}
}
}
}
async fn allocate_with_waiter<W, T, F>(waiters: &W, mut try_allocate: F) -> T
where
W: Waiter,
F: FnMut() -> Option<T>,
{
loop {
let registration = waiters.register();
tokio::pin!(registration);
registration.as_mut().prepare();
if let Some(value) = try_allocate() {
registration.as_mut().revoke();
return value;
}
registration.await;
}
}
#[derive(Clone)]
pub struct AsyncFixedArena<W = BuiltInWaiters> {
inner: FixedArena,
waiters: Arc<W>,
}
impl<W> AsyncFixedArena<W> {
pub(crate) fn new(inner: FixedArena, waiters: Arc<W>) -> Self {
Self { inner, waiters }
}
}
impl<W: Waiter> AsyncFixedArena<W> {
pub async fn allocate_async(&self) -> Buffer {
allocate_with_waiter(self.waiters.as_ref(), || self.inner.allocate().ok()).await
}
}
impl<W> Deref for AsyncFixedArena<W> {
type Target = FixedArena;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<W> fmt::Debug for AsyncFixedArena<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncFixedArena")
.field("inner", &self.inner)
.finish()
}
}
#[derive(Clone)]
pub struct AsyncBuddyArena<W = NotifyWaiters> {
inner: BuddyArena,
waiters: Arc<W>,
}
impl<W> AsyncBuddyArena<W> {
pub(crate) fn new(inner: BuddyArena, waiters: Arc<W>) -> Self {
Self { inner, waiters }
}
}
impl<W: Waiter> AsyncBuddyArena<W> {
pub async fn allocate_async(&self, len: std::num::NonZeroUsize) -> Buffer {
allocate_with_waiter(self.waiters.as_ref(), || self.inner.allocate(len).ok()).await
}
}
impl<W> Deref for AsyncBuddyArena<W> {
type Target = BuddyArena;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<W> fmt::Debug for AsyncBuddyArena<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncBuddyArena")
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use bytes::BufMut;
use tokio::time::{Duration, timeout};
use crate::BuddyArena;
use crate::FixedArena;
use super::*;
fn nz(n: usize) -> NonZeroUsize {
NonZeroUsize::new(n).unwrap()
}
#[derive(Clone)]
struct CountingWaiters {
inner: NotifyWaiters,
registrations: Arc<AtomicUsize>,
wakes: Arc<AtomicUsize>,
}
impl CountingWaiters {
fn new() -> Self {
Self {
inner: NotifyWaiters::new(),
registrations: Arc::new(AtomicUsize::new(0)),
wakes: Arc::new(AtomicUsize::new(0)),
}
}
fn registrations(&self) -> usize {
self.registrations.load(AtomicOrdering::Relaxed)
}
fn wakes(&self) -> usize {
self.wakes.load(AtomicOrdering::Relaxed)
}
}
struct CountingRegistration {
inner: NotifyRegistration,
}
impl Waiter for CountingWaiters {
type Registration = CountingRegistration;
fn register(&self) -> Self::Registration {
self.registrations.fetch_add(1, AtomicOrdering::Relaxed);
CountingRegistration {
inner: self.inner.register(),
}
}
fn wake_one(&self) {
self.wakes.fetch_add(1, AtomicOrdering::Relaxed);
Waiter::wake_one(&self.inner);
}
}
impl WaitRegistration for CountingRegistration {
fn prepare(self: Pin<&mut Self>) {
self.project_inner().prepare();
}
fn revoke(self: Pin<&mut Self>) {
self.project_inner().revoke();
}
}
impl Future for CountingRegistration {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project_inner().poll(cx)
}
}
impl CountingRegistration {
fn project_inner(self: Pin<&mut Self>) -> Pin<&mut NotifyRegistration> {
unsafe { self.map_unchecked_mut(|this| &mut this.inner) }
}
}
#[tokio::test]
async fn allocate_async_basic() {
let arena = FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::Notify)
.unwrap();
let mut buf = arena.allocate_async().await;
buf.put_slice(b"data");
let bytes = buf.freeze();
drop(bytes);
let _buf2 = arena.allocate_async().await;
}
#[tokio::test]
async fn allocate_async_waits_then_succeeds() {
let arena = Arc::new(
FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::Notify)
.unwrap(),
);
let mut buf = arena.allocate_async().await;
buf.put_slice(b"blocking");
let bytes = buf.freeze();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move {
let buf = arena2.allocate_async().await;
buf.capacity()
});
tokio::time::sleep(Duration::from_millis(50)).await;
drop(bytes);
let cap = timeout(Duration::from_secs(2), handle)
.await
.expect("should not timeout")
.expect("task should not panic");
assert_eq!(cap, 32);
}
#[tokio::test]
async fn sync_allocate_still_fast_fails() {
let arena = FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::Notify)
.unwrap();
let _buf = arena.allocate().unwrap();
let err = arena.allocate().unwrap_err();
assert_eq!(err, crate::AllocError::ArenaFull);
}
#[tokio::test]
async fn multiple_waiters_all_served() {
let arena = Arc::new(
FixedArena::builder(nz(2), nz(32))
.build_async(AsyncPolicy::Notify)
.unwrap(),
);
let buf1 = arena.allocate().unwrap();
let buf2 = arena.allocate().unwrap();
let a1 = Arc::clone(&arena);
let h1 = tokio::spawn(async move { a1.allocate_async().await.capacity() });
let a2 = Arc::clone(&arena);
let h2 = tokio::spawn(async move { a2.allocate_async().await.capacity() });
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf1);
drop(buf2);
let (r1, r2) = tokio::join!(
timeout(Duration::from_secs(2), h1),
timeout(Duration::from_secs(2), h2),
);
assert_eq!(r1.unwrap().unwrap(), 32);
assert_eq!(r2.unwrap().unwrap(), 32);
}
#[tokio::test]
async fn deref_exposes_sync_methods() {
let arena = FixedArena::builder(nz(4), nz(64))
.build_async(AsyncPolicy::Notify)
.unwrap();
assert_eq!(arena.slot_count(), 4);
assert_eq!(arena.slot_capacity(), 64);
}
#[tokio::test]
async fn treiber_allocate_async_basic() {
let arena = FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::TreiberWaiters)
.unwrap();
let mut buf = arena.allocate_async().await;
buf.put_slice(b"data");
let bytes = buf.freeze();
drop(bytes);
let _buf2 = arena.allocate_async().await;
}
#[tokio::test]
async fn treiber_waits_then_succeeds() {
let arena = Arc::new(
FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::TreiberWaiters)
.unwrap(),
);
let mut buf = arena.allocate_async().await;
buf.put_slice(b"blocking");
let bytes = buf.freeze();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move {
let buf = arena2.allocate_async().await;
buf.capacity()
});
tokio::time::sleep(Duration::from_millis(50)).await;
drop(bytes);
let cap = timeout(Duration::from_secs(2), handle)
.await
.expect("should not timeout")
.expect("task should not panic");
assert_eq!(cap, 32);
}
#[tokio::test]
async fn treiber_multiple_waiters() {
let arena = Arc::new(
FixedArena::builder(nz(2), nz(32))
.build_async(AsyncPolicy::TreiberWaiters)
.unwrap(),
);
let buf1 = arena.allocate().unwrap();
let buf2 = arena.allocate().unwrap();
let a1 = Arc::clone(&arena);
let h1 = tokio::spawn(async move { a1.allocate_async().await.capacity() });
let a2 = Arc::clone(&arena);
let h2 = tokio::spawn(async move { a2.allocate_async().await.capacity() });
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf1);
drop(buf2);
let (r1, r2) = tokio::join!(
timeout(Duration::from_secs(2), h1),
timeout(Duration::from_secs(2), h2),
);
assert_eq!(r1.unwrap().unwrap(), 32);
assert_eq!(r2.unwrap().unwrap(), 32);
}
#[tokio::test]
async fn treiber_cancellation_no_leak() {
let arena = Arc::new(
FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::TreiberWaiters)
.unwrap(),
);
let buf = arena.allocate().unwrap();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move { arena2.allocate_async().await });
tokio::time::sleep(Duration::from_millis(50)).await;
handle.abort();
let _ = handle.await;
drop(buf);
let _buf2 = arena.allocate().unwrap();
}
#[tokio::test]
async fn treiber_sync_still_fast_fails() {
let arena = FixedArena::builder(nz(1), nz(32))
.build_async(AsyncPolicy::TreiberWaiters)
.unwrap();
let _buf = arena.allocate().unwrap();
let err = arena.allocate().unwrap_err();
assert_eq!(err, crate::AllocError::ArenaFull);
}
#[tokio::test]
async fn buddy_allocate_async_waits_then_succeeds() {
let arena = Arc::new(
BuddyArena::builder(nz(4096), nz(512))
.build_async()
.unwrap(),
);
let buf = arena.allocate(nz(2048)).unwrap();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move {
let buf = arena2.allocate_async(nz(2048)).await;
buf.capacity()
});
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf);
let cap = timeout(Duration::from_secs(2), handle)
.await
.expect("should not timeout")
.expect("task should not panic");
assert_eq!(cap, 2048);
}
#[tokio::test]
async fn buddy_multiple_waiters_all_served() {
let arena = Arc::new(
BuddyArena::builder(nz(4096), nz(512))
.build_async()
.unwrap(),
);
let buf1 = arena.allocate(nz(2048)).unwrap();
let buf2 = arena.allocate(nz(2048)).unwrap();
let a1 = Arc::clone(&arena);
let h1 = tokio::spawn(async move { a1.allocate_async(nz(2048)).await.capacity() });
let a2 = Arc::clone(&arena);
let h2 = tokio::spawn(async move { a2.allocate_async(nz(2048)).await.capacity() });
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf1);
drop(buf2);
let (r1, r2) = tokio::join!(
timeout(Duration::from_secs(2), h1),
timeout(Duration::from_secs(2), h2),
);
assert_eq!(r1.unwrap().unwrap(), 2048);
assert_eq!(r2.unwrap().unwrap(), 2048);
}
#[tokio::test]
async fn buddy_large_request_unblocks_after_coalesce() {
let arena = Arc::new(
BuddyArena::builder(nz(4096), nz(512))
.build_async()
.unwrap(),
);
let buf1 = arena.allocate(nz(2048)).unwrap();
let buf2 = arena.allocate(nz(2048)).unwrap();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move {
let buf = arena2.allocate_async(nz(4096)).await;
buf.capacity()
});
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf1);
tokio::time::sleep(Duration::from_millis(25)).await;
assert!(!handle.is_finished());
drop(buf2);
let cap = timeout(Duration::from_secs(2), handle)
.await
.expect("should not timeout")
.expect("task should not panic");
assert_eq!(cap, 4096);
}
#[tokio::test]
async fn buddy_cancellation_does_not_leak() {
let arena = Arc::new(
BuddyArena::builder(nz(4096), nz(512))
.build_async()
.unwrap(),
);
let buf = arena.allocate(nz(4096)).unwrap();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move { arena2.allocate_async(nz(512)).await });
tokio::time::sleep(Duration::from_millis(50)).await;
handle.abort();
let _ = handle.await;
drop(buf);
let _buf2 = arena.allocate(nz(4096)).unwrap();
}
#[tokio::test]
async fn fixed_custom_waiter_supported() {
let waiters = CountingWaiters::new();
let arena = Arc::new(
FixedArena::builder(nz(1), nz(32))
.build_async_with(waiters.clone())
.unwrap(),
);
let buf = arena.allocate().unwrap();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move { arena2.allocate_async().await.capacity() });
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf);
let cap = timeout(Duration::from_secs(2), handle)
.await
.expect("should not timeout")
.expect("task should not panic");
assert_eq!(cap, 32);
assert!(waiters.registrations() >= 1);
assert!(waiters.wakes() >= 1);
}
#[tokio::test]
async fn buddy_custom_waiter_supported() {
let waiters = CountingWaiters::new();
let arena = Arc::new(
BuddyArena::builder(nz(4096), nz(512))
.build_async_with(waiters.clone())
.unwrap(),
);
let buf = arena.allocate(nz(2048)).unwrap();
let arena2 = Arc::clone(&arena);
let handle = tokio::spawn(async move { arena2.allocate_async(nz(2048)).await.capacity() });
tokio::time::sleep(Duration::from_millis(50)).await;
drop(buf);
let cap = timeout(Duration::from_secs(2), handle)
.await
.expect("should not timeout")
.expect("task should not panic");
assert_eq!(cap, 2048);
assert!(waiters.registrations() >= 1);
assert!(waiters.wakes() >= 1);
}
}