bab/
buffer_pool.rs

1use core::{
2    alloc::Layout,
3    cell::{Cell, UnsafeCell},
4    future::Future,
5    pin::Pin,
6    sync::atomic::{AtomicU32, AtomicUsize, Ordering},
7    task::{Context, Poll},
8};
9
10#[cfg(feature = "alloc")]
11use alloc::{
12    alloc::{alloc, dealloc},
13    boxed::Box,
14    vec::Vec,
15};
16#[cfg(feature = "std")]
17use std::alloc::{alloc, dealloc};
18
19use crossbeam_utils::CachePadded;
20use thid::ThreadLocal;
21use waitq::{Fulfillment, IFulfillment, Waiter, WaiterQueue};
22
23use crate::{
24    buffer::{Buffer, BufferPtr},
25    free_stack::FreeStack,
26};
27
28pub(crate) struct LocalStock {
29    head: Cell<Option<BufferPtr>>,
30    watermark: Cell<Option<BufferPtr>>,
31    count: Cell<u32>,
32}
33
34unsafe impl Send for LocalStock {}
35
36impl LocalStock {
37    fn new() -> Self {
38        Self {
39            head: Cell::new(None),
40            watermark: Cell::new(None),
41            count: Cell::new(0),
42        }
43    }
44
45    fn try_acquire(&self) -> Option<BufferPtr> {
46        self.head.get().map(|head_ptr| {
47            // We can take a buffer from the local batch - advance the local head and return.
48            debug_assert!(self.count.get() > 0);
49            self.count.set(self.count.get() - 1);
50            self.head.set(unsafe { head_ptr.get_next() });
51            unsafe {
52                head_ptr.set_next(None);
53            }
54
55            head_ptr
56        })
57    }
58}
59
60pub(crate) struct LocalState {
61    buffers_in_use: Cell<u32>,
62    local_buffer_state: *const [LocalBufferState],
63}
64
65unsafe impl Send for LocalState {}
66
67impl LocalState {
68    #[cfg(any(feature = "std", feature = "alloc"))]
69    fn new_heap(total_buffer_count: usize) -> Self {
70        let local_buffer_state = Box::into_raw(
71            (0..total_buffer_count)
72                .map(|_| LocalBufferState {
73                    ref_count: Cell::new(0),
74                    shared_rc_contribution: Cell::new(0),
75                })
76                .collect::<Vec<_>>()
77                .into_boxed_slice(),
78        );
79
80        Self {
81            buffers_in_use: Cell::new(0),
82            local_buffer_state,
83        }
84    }
85
86    #[inline]
87    pub(crate) fn local_buffer_state(&self, buffer_id: usize) -> &LocalBufferState {
88        unsafe { &*core::ptr::addr_of!((*self.local_buffer_state)[buffer_id]) }
89    }
90}
91
92pub(crate) struct LocalBufferState {
93    pub(crate) ref_count: Cell<u32>,
94    pub(crate) shared_rc_contribution: Cell<u32>,
95}
96
97pub struct BufferPool {
98    pub(crate) alloc: *mut Buffer,
99    alloc_layout: Layout,
100    // Size of a single buffer, including padding for alignment, in self.alloc
101    buffer_padded_size: usize,
102    total_buffer_count: usize,
103    buffer_size: usize,
104    batch_size: u32,
105    free_stack: FreeStack,
106    waiter_queue: WaiterQueue<BufferPtr>,
107    // Unfortunately needs to be in an UnsafeCell so that we can mutably access it during shutdown.
108    local_stock: UnsafeCell<ThreadLocal<CachePadded<LocalStock>>>,
109    local_state: ThreadLocal<CachePadded<LocalState>>,
110    ref_count: AtomicUsize,
111    shutdown_released_buffers: AtomicU32,
112    handle_drop_fn: fn(*mut Self),
113}
114
115#[derive(Copy, Clone, Debug, Eq, PartialEq)]
116pub(crate) enum BufferPoolShutdownStatus {
117    NotShutdown,
118    ShutdownNow,
119    AlreadyShutdown,
120}
121
122/// Guard object returned by BufferPool::register_thread - see there for more details.
123pub struct BufferPoolThreadGuard<'a> {
124    buffer_pool: &'a BufferPool,
125}
126
127impl Drop for BufferPoolThreadGuard<'_> {
128    fn drop(&mut self) {
129        self.buffer_pool
130            .decrement_local_buffers_in_use(self.buffer_pool.local_state());
131    }
132}
133
134unsafe impl Send for BufferPool {}
135unsafe impl Sync for BufferPool {}
136
137impl BufferPool {
138    /// Get the total number of buffers from this pool that are in circulation.
139    pub fn total_buffer_count(&self) -> usize {
140        self.total_buffer_count
141    }
142
143    /// Get the size in bytes of each buffer in this pool.
144    pub fn buffer_size(&self) -> usize {
145        self.buffer_size
146    }
147
148    #[inline]
149    pub fn buffer_by_id(&self, id: u32) -> BufferPtr {
150        let buffer_raw = unsafe { self.alloc.byte_add(id as usize * self.buffer_padded_size) };
151        BufferPtr::from_ptr(buffer_raw).unwrap()
152    }
153
154    /// Register the current thread as a known user of buffers until the returned guard object is
155    /// dropped. This is purely an optimization and is optional to do. It prevents an atomic
156    /// reference count from being unnecessarily incremented and decremented when buffers are
157    /// received and released by this thread.
158    pub fn register_thread(&self) -> BufferPoolThreadGuard<'_> {
159        self.increment_local_buffers_in_use(self.local_state());
160        BufferPoolThreadGuard { buffer_pool: self }
161    }
162
163    #[inline]
164    pub(crate) fn local_stock(&self) -> &LocalStock {
165        let local_stock = unsafe { &*self.local_stock.get() };
166        local_stock.get_or(|| CachePadded::new(LocalStock::new()))
167    }
168
169    #[inline]
170    pub(crate) fn local_state(&self) -> &LocalState {
171        self.local_state
172            .get_or(|| CachePadded::new(LocalState::new_heap(self.total_buffer_count as usize)))
173    }
174
175    pub(crate) fn increment_local_buffers_in_use(&self, local_state: &LocalState) {
176        let prev = local_state
177            .buffers_in_use
178            .replace(local_state.buffers_in_use.get() + 1);
179        if prev == 0 {
180            let mut ref_count = self.ref_count.load(Ordering::Relaxed);
181            // Only attempt to acquire a reference if the BufferPool isn't already shutting down.
182            while ref_count > 0 {
183                match self.ref_count.compare_exchange(
184                    ref_count,
185                    ref_count + 1,
186                    Ordering::AcqRel,
187                    Ordering::Relaxed,
188                ) {
189                    // Successfully acquired reference to BufferPool.
190                    Ok(_) => break,
191                    Err(new_ref_count) => {
192                        ref_count = new_ref_count;
193                    }
194                }
195            }
196        }
197    }
198
199    /// Returns true if the buffer pool is shutting down.
200    pub(crate) fn decrement_local_buffers_in_use(
201        &self,
202        local_state: &LocalState,
203    ) -> BufferPoolShutdownStatus {
204        let prev = local_state
205            .buffers_in_use
206            .replace(local_state.buffers_in_use.get() - 1);
207        if prev == 1 {
208            // This thread no longer has any buffers from this pool in circulation.
209
210            if !self.is_shutting_down() {
211                let prev_ref_count = self.ref_count.fetch_sub(1, Ordering::AcqRel);
212                if prev_ref_count == 1 {
213                    // This is was the last active buffer on the last active thread. There may still
214                    // be some unreleased buffers, but they will all be buffers sent from one thread
215                    // and not yet received by their destination thread.
216                    //
217                    // It's important that after the fetch_sub above, `self.local_stock` won't be
218                    // written to ever again by any thread, and any thread that reads the
219                    // ref_count == 0 value sees the latest self.local_stock.count values. This is
220                    // why we use AcqRel ordering in the fetch_sub above, and Acquire ordering in
221                    // BufferPool::is_shutting_down.
222                    BufferPoolShutdownStatus::ShutdownNow
223                } else {
224                    BufferPoolShutdownStatus::NotShutdown
225                }
226            } else {
227                BufferPoolShutdownStatus::AlreadyShutdown
228            }
229        } else {
230            BufferPoolShutdownStatus::NotShutdown
231        }
232    }
233
234    fn is_shutting_down(&self) -> bool {
235        match self
236            .ref_count
237            .compare_exchange(0, 0, Ordering::Acquire, Ordering::Relaxed)
238        {
239            Ok(_) => true,
240            Err(_) => false,
241        }
242    }
243
244    pub async fn acquire(&self) -> BufferPtr {
245        if let Some(buffer) = self.try_acquire() {
246            return buffer;
247        }
248
249        // Need to wait for a buffer to become available.
250        let buffer = Acquire {
251            buffer_pool: self,
252            waiter: &Waiter::new(&self.waiter_queue),
253        }
254        .await;
255
256        buffer
257    }
258
259    pub fn try_acquire(&self) -> Option<BufferPtr> {
260        let local_stock = self.local_stock();
261        if let Some(local_buffer) = local_stock.try_acquire() {
262            return Some(local_buffer);
263        }
264
265        self.try_acquire_batch(local_stock)
266    }
267
268    fn try_acquire_batch(&self, local_stock: &LocalStock) -> Option<BufferPtr> {
269        debug_assert!(local_stock.head.get().is_none());
270        debug_assert_eq!(local_stock.count.get(), 0);
271
272        if let Some(batch_head) = self.try_take_batch(local_stock) {
273            local_stock.head.set(unsafe { batch_head.swap_next(None) });
274            local_stock.count.set(self.batch_size as u32 - 1);
275            Some(batch_head)
276        } else {
277            None
278        }
279    }
280
281    fn try_take_batch(&self, local_stock: &LocalStock) -> Option<BufferPtr> {
282        debug_assert!(local_stock.head.get().is_none());
283        debug_assert_eq!(local_stock.count.get(), 0);
284
285        self.free_stack.pop()
286    }
287
288    pub unsafe fn release(&self, buffer: BufferPtr) {
289        debug_assert_eq!(unsafe { buffer.get_next() }, None);
290
291        if self.waiter_queue.notify_one_local(buffer).is_none() {
292            return;
293        }
294
295        let local_stock = self.local_stock();
296
297        // Release the buffer into local stock
298        if local_stock.count.get() == self.batch_size {
299            // Store the local batch watermark
300            local_stock.watermark.set(Some(buffer));
301        }
302        unsafe {
303            buffer.set_next(local_stock.head.get());
304        }
305        local_stock.head.set(Some(buffer));
306        local_stock.count.set(local_stock.count.get() + 1);
307
308        self.release_overflow(local_stock);
309    }
310
311    fn release_overflow(&self, local_stock: &LocalStock) {
312        if local_stock.count.get() < (self.batch_size as u32 * 3) / 2 {
313            return;
314        }
315
316        // Release batch
317
318        while let Some(watermark) = local_stock.watermark.take() {
319            let release_head = unsafe { watermark.swap_next(None) }.unwrap();
320            let release_count = self.batch_size;
321            local_stock
322                .count
323                .set(local_stock.count.get() - self.batch_size as u32);
324
325            let mut waiter_queue_guard = None;
326            self.free_stack.push_if(release_head, |free_count| {
327                if free_count == 0 {
328                    // Only need to try to notify a waiter if the stock was empty.
329                    let guard = waiter_queue_guard.get_or_insert_with(|| self.waiter_queue.lock());
330
331                    if guard.waiter_count() > 0 {
332                        waiter_queue_guard
333                            .take()
334                            .expect("bug: missing lock guard")
335                            .notify(release_head, release_count as usize);
336                        // Don't push onto free stack since we've used the released buffers to
337                        // fulfill waiters.
338                        return false;
339                    }
340                }
341
342                true
343            });
344            // TODO there's a more efficient way to do this than re-chasing all the pointers every
345            // time around.
346            self.find_watermark(local_stock);
347        }
348    }
349
350    fn release_many(&self, release_head: BufferPtr, release_count: usize) {
351        let local_stock = self.local_stock();
352
353        // Find the local tail so we can add the extra buffers.
354        let mut tail = local_stock.head.get();
355        while let Some(next) = tail {
356            let new_tail = unsafe { next.get_next() };
357            if new_tail.is_none() {
358                break;
359            }
360            tail = new_tail;
361        }
362
363        if let Some(tail) = tail {
364            // Append newly released buffers to end of local stockpile.
365            unsafe {
366                tail.set_next(Some(release_head));
367            }
368            local_stock
369                .count
370                .set(local_stock.count.get() + release_count as u32);
371        } else {
372            // Local stockpile is empty.
373            debug_assert_eq!(local_stock.head.get(), None);
374            debug_assert_eq!(local_stock.count.get(), 0);
375            local_stock.head.set(Some(release_head));
376            local_stock.count.set(release_count as u32);
377        }
378
379        // Always need to re-find the watermark since we appended the new buffers to the tail of
380        // the local stockpile.
381        self.find_watermark(local_stock);
382        self.release_overflow(local_stock);
383    }
384
385    fn find_watermark(&self, local_stock: &LocalStock) {
386        if local_stock.count.get() > self.batch_size as u32 {
387            let mut watermark = local_stock.head.get().unwrap();
388            for _ in 0..local_stock.count.get() - self.batch_size as u32 - 1 {
389                watermark = unsafe { watermark.get_next() }.unwrap();
390            }
391            local_stock.watermark.set(Some(watermark));
392        }
393    }
394
395    pub(crate) fn shutdown_now_try_drop(buffer_pool: *mut BufferPool) {
396        let this = unsafe { &*buffer_pool };
397        // SAFETY: Once buffer_pool.ref_count == 0, only the site that last decremented ref_count
398        // will call `shutdown_now_try_drop`, and no other code will access buffer_pool.local_stock.
399        let local_stock = unsafe { &mut *this.local_stock.get() };
400
401        let mut released_buffers = 0;
402        for local_stock in local_stock.iter_mut() {
403            released_buffers += local_stock.count.get();
404        }
405        while let Some(_) = this.free_stack.pop() {
406            released_buffers += this.batch_size;
407        }
408
409        // Stash total_buffer_count so we can use it later - once we add the released_buffers to
410        // this.shutdown_released_buffers, we aren't allowed to access `this` since another
411        // thread can drop it (until we confirm that didn't happen).
412        let total_buffer_count = this.total_buffer_count as u32;
413
414        let prev_released_buffers = this
415            .shutdown_released_buffers
416            .fetch_add(released_buffers, Ordering::Release);
417
418        if prev_released_buffers + released_buffers == total_buffer_count {
419            // Enforce that that any previous access to self from another thread *happens before*
420            // deleting the object on this thread.
421            // See comment in source of `Arc::drop`.
422            this.shutdown_released_buffers.load(Ordering::Acquire);
423
424            // All buffers have been released - time to drop
425            let handle_drop_fn = unsafe { (*buffer_pool).handle_drop_fn };
426            handle_drop_fn(buffer_pool as *mut BufferPool);
427        }
428    }
429
430    pub(crate) fn already_shutdown_release_buffer(buffer_pool: *mut BufferPool) {
431        let this = unsafe { &*buffer_pool };
432
433        // Stash total_buffer_count so we can use it later - once we increment
434        // this.shutdown_released_buffers, we aren't allowed to access `this` since another
435        // thread can drop it (until we confirm that didn't happen).
436        let total_buffer_count = this.total_buffer_count as u32;
437
438        let prev_released_buffers = this
439            .shutdown_released_buffers
440            .fetch_add(1, Ordering::Release);
441
442        if prev_released_buffers + 1 == total_buffer_count as u32 {
443            // Enforce that that any previous access to self from another thread *happens before*
444            // deleting the object on this thread.
445            // See comment in source of `Arc::drop`.
446            this.shutdown_released_buffers.load(Ordering::Acquire);
447
448            // All buffers have been released - time to drop
449            let handle_drop_fn = unsafe { (*buffer_pool).handle_drop_fn };
450            handle_drop_fn(buffer_pool as *mut BufferPool);
451        }
452    }
453}
454
455impl Drop for BufferPool {
456    fn drop(&mut self) {
457        // Drop all local buffer state arrays
458        for local_state in self.local_state.iter_mut() {
459            let _ =
460                unsafe { Box::from_raw(local_state.local_buffer_state as *mut [LocalBufferState]) };
461        }
462
463        let _ = unsafe { dealloc(self.alloc as *mut u8, self.alloc_layout) };
464    }
465}
466
467pub struct HeapBufferPool {
468    ptr: *const BufferPool,
469}
470
471impl HeapBufferPool {
472    pub fn new(buffer_size: usize, batch_count: usize, batch_size: usize) -> Self {
473        // Helpers adapted from core::alloc::Layout till they're stable.
474        fn padding_needed_for_layout(layout: Layout) -> usize {
475            let len = layout.size();
476            let align = layout.align();
477
478            (len.wrapping_add(align).wrapping_sub(1) & !align.wrapping_sub(1)).wrapping_sub(len)
479        }
480        fn repeat_layout(layout: Layout, n: usize) -> (Layout, usize) {
481            let padded_size = layout.size() + padding_needed_for_layout(layout);
482            let alloc_size = padded_size.checked_mul(n).unwrap();
483
484            let layout = Layout::from_size_align(alloc_size, layout.align()).unwrap();
485            (layout, padded_size)
486        }
487
488        let total_buffer_count = batch_count * batch_size;
489        let buffer_layout = Buffer::layout_with_data(buffer_size);
490        let (alloc_layout, buffer_padded_size) = repeat_layout(buffer_layout, total_buffer_count);
491        let alloc = unsafe { alloc(alloc_layout) } as *mut Buffer;
492
493        let buffer_pool = Box::new(BufferPool {
494            alloc,
495            alloc_layout,
496            buffer_padded_size,
497            free_stack: FreeStack::new(batch_count),
498            waiter_queue: WaiterQueue::new(),
499            total_buffer_count,
500            buffer_size,
501            batch_size: batch_size as u32,
502            local_stock: UnsafeCell::new(ThreadLocal::new()),
503            local_state: ThreadLocal::new(),
504            ref_count: AtomicUsize::new(1),
505            shutdown_released_buffers: AtomicU32::new(0),
506            handle_drop_fn: |buffer_pool| {
507                let _ = unsafe { Box::from_raw(buffer_pool) };
508            },
509        });
510        let buffer_pool_ptr = Box::into_raw(buffer_pool);
511        let buffer_pool = unsafe { &*buffer_pool_ptr };
512
513        // Initialize buffer descriptors
514        for id in 0..total_buffer_count {
515            let buffer = buffer_pool.buffer_by_id(id as u32);
516            unsafe {
517                Buffer::initialize(
518                    buffer.as_ptr_mut(),
519                    buffer_pool_ptr,
520                    id as usize,
521                    buffer_size,
522                );
523            }
524        }
525
526        let mut next_buffer_id = 0;
527        for _ in 0..batch_count {
528            let new_batch_head = buffer_pool.buffer_by_id(next_buffer_id);
529            next_buffer_id += 1;
530
531            let mut head = None;
532            for _ in 1..batch_size {
533                let next = head;
534                let new_head = buffer_pool.buffer_by_id(next_buffer_id);
535                head = Some(new_head);
536                next_buffer_id += 1;
537                unsafe {
538                    new_head.set_next(next);
539                }
540            }
541            unsafe {
542                new_batch_head.set_next(head);
543            }
544
545            buffer_pool.free_stack.push_if(new_batch_head, |_| true);
546        }
547
548        Self {
549            ptr: buffer_pool_ptr,
550        }
551    }
552}
553
554unsafe impl Send for HeapBufferPool {}
555unsafe impl Sync for HeapBufferPool {}
556
557impl core::ops::Deref for HeapBufferPool {
558    type Target = BufferPool;
559
560    fn deref(&self) -> &Self::Target {
561        // SAFETY: BufferPool is never dereferenced mutably except on drop.
562        unsafe { &*self.ptr }
563    }
564}
565
566impl Clone for HeapBufferPool {
567    fn clone(&self) -> Self {
568        self.ref_count.fetch_add(1, Ordering::Relaxed);
569        Self { ptr: self.ptr }
570    }
571}
572
573impl Drop for HeapBufferPool {
574    fn drop(&mut self) {
575        let prev_rc = self.ref_count.fetch_sub(1, Ordering::Release);
576        if prev_rc == 1 {
577            // Enforce that that any previous access to self from another thread *happens before*
578            // deleting the object on this thread.
579            // See comment in source of `Arc::drop`.
580            self.ref_count.load(Ordering::Acquire);
581
582            // All HeapBufferPool handles have been dropped.
583            BufferPool::shutdown_now_try_drop(self.ptr as *mut _);
584        }
585    }
586}
587
588impl IFulfillment for BufferPtr {
589    fn take_one(&mut self) -> Self {
590        let ptr = *self;
591        *self = unsafe { ptr.swap_next(None) }.unwrap();
592        ptr
593    }
594
595    fn append(&mut self, other: Self, _other_count: usize) {
596        let mut tail = *self;
597        while let Some(next) = unsafe { tail.get_next() } {
598            tail = next;
599        }
600
601        unsafe {
602            tail.set_next(Some(other));
603        }
604    }
605}
606
607pub struct Acquire<'a> {
608    buffer_pool: &'a BufferPool,
609    waiter: &'a Waiter<'a, BufferPtr>,
610}
611
612impl<'a> Acquire<'a> {
613    fn waiter(self: Pin<&'_ Self>) -> Pin<&'_ Waiter<'a, BufferPtr>> {
614        // SAFETY: `waiter` is pinned when `self` is.
615        unsafe { self.map_unchecked(|s| s.waiter) }
616    }
617}
618
619impl Future for Acquire<'_> {
620    type Output = BufferPtr;
621
622    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
623        let buffer_pool = self.buffer_pool;
624        let local_stock = self.buffer_pool.local_stock();
625        let Poll::Ready(fulfillment) = self.as_ref().waiter().poll_fulfillment(context, || {
626            if let Some(local_head) = local_stock.head.replace(None) {
627                // This case can happen if waitq notify_one_local fails to notify the local
628                // head.
629                Some(Fulfillment {
630                    inner: local_head,
631                    count: local_stock.count.replace(0) as usize,
632                })
633            } else {
634                buffer_pool
635                    .try_take_batch(local_stock)
636                    .map(|ptr| Fulfillment {
637                        inner: ptr,
638                        count: buffer_pool.batch_size as usize,
639                    })
640            }
641        }) else {
642            return Poll::Pending;
643        };
644
645        let extra_head = unsafe { fulfillment.inner.swap_next(None) };
646        let extra_count = fulfillment.count as usize - 1;
647
648        if let Some(extra_head) = extra_head {
649            debug_assert!(extra_count > 0);
650            self.buffer_pool.release_many(extra_head, extra_count);
651        }
652
653        Poll::Ready(fulfillment.inner)
654    }
655}
656
657impl Drop for Acquire<'_> {
658    fn drop(&mut self) {
659        if let Some(fulfillment) = self.waiter.cancel() {
660            self.buffer_pool
661                .release_many(fulfillment.inner, fulfillment.count as usize);
662        }
663    }
664}
665
666#[cfg(test)]
667mod test {
668    use super::*;
669
670    #[test]
671    fn test_buffer_fulfillment_append_and_take_one() {
672        let batch_count = 16;
673        let batch_size = 16;
674        let buffer_pool = HeapBufferPool::new(16, batch_count, batch_size);
675
676        let a = buffer_pool.try_acquire().unwrap();
677        let b = buffer_pool.try_acquire().unwrap();
678        let c = buffer_pool.try_acquire().unwrap();
679        for buffer in [a, b, c] {
680            unsafe {
681                buffer.initialize_rc(1, 0, 0);
682            }
683        }
684
685        assert_eq!(a.count(), 1);
686        assert_eq!(b.count(), 1);
687        assert_eq!(c.count(), 1);
688
689        let mut f = Fulfillment { inner: a, count: 1 };
690
691        f.append(Fulfillment { inner: b, count: 1 });
692        assert_eq!((f.inner, f.count), (a, 2));
693        assert_eq!(a.count(), 2);
694        assert_eq!(b.count(), 1);
695        assert_eq!(c.count(), 1);
696        f.append(Fulfillment { inner: c, count: 1 });
697        assert_eq!((f.inner, f.count), (a, 3));
698        assert_eq!(a.count(), 3);
699        assert_eq!(b.count(), 2);
700        assert_eq!(c.count(), 1);
701
702        let taken = f.take_one();
703        assert_eq!(taken, a);
704        assert_eq!((f.inner, f.count), (b, 2));
705        assert_eq!(a.count(), 1);
706        assert_eq!(b.count(), 2);
707        assert_eq!(c.count(), 1);
708
709        let taken = f.take_one();
710        assert_eq!(taken, b);
711        assert_eq!((f.inner, f.count), (c, 1));
712        assert_eq!(a.count(), 1);
713        assert_eq!(b.count(), 1);
714        assert_eq!(c.count(), 1);
715
716        for buffer in [a, b, c] {
717            unsafe {
718                buffer.release_ref(1);
719            }
720        }
721    }
722
723    #[test]
724    fn test_buffer_pool_shutdown_send_packet() {
725        let batch_count = 16;
726        let batch_size = 16;
727        let buffer_pool = HeapBufferPool::new(16, batch_count, batch_size);
728
729        let a = buffer_pool.try_acquire().unwrap();
730        let b = buffer_pool.try_acquire().unwrap();
731
732        unsafe {
733            a.initialize_rc(1, 0, 0);
734            b.initialize_rc(1, 1, 1);
735        }
736
737        drop(buffer_pool);
738
739        unsafe {
740            a.release_ref(1);
741            assert_eq!(b.send_bulk(1), 1);
742            b.receive(1);
743            b.release_ref(1);
744        }
745    }
746
747    #[cfg(feature = "std")]
748    #[test]
749    fn test_buffer_pool_local_acquire_waiter() {
750        use std::rc::Rc;
751
752        let batch_count = 16;
753        let batch_size = 2;
754        let waiter_count = 8;
755        let buffer_pool = HeapBufferPool::new(64, batch_count, batch_size);
756
757        let ex = async_executor::LocalExecutor::new();
758
759        pollster::block_on(ex.run(async {
760            // Acquire all the buffers
761            let channel = Rc::new(async_unsync::unbounded::channel());
762            let acquire_starts = Rc::new(async_unsync::semaphore::Semaphore::new(0));
763
764            for _ in 0..batch_count * batch_size {
765                let buf: BufferPtr = buffer_pool.acquire().await;
766                let data = unsafe {
767                    core::slice::from_raw_parts_mut(buf.data(), buffer_pool.buffer_size())
768                };
769                unsafe {
770                    buf.initialize_rc(1, 0, 0);
771                }
772                data[..4].copy_from_slice(&[1, 2, 3, 4]);
773                channel.send(buf).unwrap();
774            }
775
776            // Spawn some acquires to force local acquire waiters to get created.
777            for _ in 0..waiter_count {
778                let buffer_pool = buffer_pool.clone();
779                let channel = channel.clone();
780                let acquire_starts = acquire_starts.clone();
781                ex.spawn(async move {
782                    acquire_starts.add_permits(1);
783                    let buf: BufferPtr = buffer_pool.acquire().await;
784                    let data = unsafe {
785                        core::slice::from_raw_parts_mut(buf.data(), buffer_pool.buffer_size())
786                    };
787                    unsafe {
788                        buf.initialize_rc(1, 0, 0);
789                    }
790                    data[..4].copy_from_slice(&[1, 2, 3, 4]);
791                    channel.send(buf).unwrap();
792                })
793                .detach();
794            }
795
796            for _ in 0..waiter_count {
797                acquire_starts.acquire().await.unwrap().forget();
798            }
799
800            // Release all the buffers
801            for _ in 0..batch_count * batch_size + waiter_count {
802                let buf = channel.recv().await.unwrap();
803                let data = unsafe {
804                    core::slice::from_raw_parts_mut(buf.data(), buffer_pool.buffer_size())
805                };
806                assert_eq!(&data[..4], &[1, 2, 3, 4]);
807                unsafe {
808                    buf.release_ref(1);
809                }
810            }
811
812            assert!(channel.try_recv().is_err());
813        }));
814    }
815}