1use core::{
2 alloc::Layout,
3 cell::{Cell, UnsafeCell},
4 future::Future,
5 pin::Pin,
6 sync::atomic::{AtomicU32, AtomicUsize, Ordering},
7 task::{Context, Poll},
8};
9
10#[cfg(feature = "alloc")]
11use alloc::{
12 alloc::{alloc, dealloc},
13 boxed::Box,
14 vec::Vec,
15};
16#[cfg(feature = "std")]
17use std::alloc::{alloc, dealloc};
18
19use crossbeam_utils::CachePadded;
20use thid::ThreadLocal;
21use waitq::{Fulfillment, IFulfillment, Waiter, WaiterQueue};
22
23use crate::{
24 buffer::{Buffer, BufferPtr},
25 free_stack::FreeStack,
26};
27
28pub(crate) struct LocalStock {
29 head: Cell<Option<BufferPtr>>,
30 watermark: Cell<Option<BufferPtr>>,
31 count: Cell<u32>,
32}
33
34unsafe impl Send for LocalStock {}
35
36impl LocalStock {
37 fn new() -> Self {
38 Self {
39 head: Cell::new(None),
40 watermark: Cell::new(None),
41 count: Cell::new(0),
42 }
43 }
44
45 fn try_acquire(&self) -> Option<BufferPtr> {
46 self.head.get().map(|head_ptr| {
47 debug_assert!(self.count.get() > 0);
49 self.count.set(self.count.get() - 1);
50 self.head.set(unsafe { head_ptr.get_next() });
51 unsafe {
52 head_ptr.set_next(None);
53 }
54
55 head_ptr
56 })
57 }
58}
59
60pub(crate) struct LocalState {
61 buffers_in_use: Cell<u32>,
62 local_buffer_state: *const [LocalBufferState],
63}
64
65unsafe impl Send for LocalState {}
66
67impl LocalState {
68 #[cfg(any(feature = "std", feature = "alloc"))]
69 fn new_heap(total_buffer_count: usize) -> Self {
70 let local_buffer_state = Box::into_raw(
71 (0..total_buffer_count)
72 .map(|_| LocalBufferState {
73 ref_count: Cell::new(0),
74 shared_rc_contribution: Cell::new(0),
75 })
76 .collect::<Vec<_>>()
77 .into_boxed_slice(),
78 );
79
80 Self {
81 buffers_in_use: Cell::new(0),
82 local_buffer_state,
83 }
84 }
85
86 #[inline]
87 pub(crate) fn local_buffer_state(&self, buffer_id: usize) -> &LocalBufferState {
88 unsafe { &*core::ptr::addr_of!((*self.local_buffer_state)[buffer_id]) }
89 }
90}
91
92pub(crate) struct LocalBufferState {
93 pub(crate) ref_count: Cell<u32>,
94 pub(crate) shared_rc_contribution: Cell<u32>,
95}
96
97pub struct BufferPool {
98 pub(crate) alloc: *mut Buffer,
99 alloc_layout: Layout,
100 buffer_padded_size: usize,
102 total_buffer_count: usize,
103 buffer_size: usize,
104 batch_size: u32,
105 free_stack: FreeStack,
106 waiter_queue: WaiterQueue<BufferPtr>,
107 local_stock: UnsafeCell<ThreadLocal<CachePadded<LocalStock>>>,
109 local_state: ThreadLocal<CachePadded<LocalState>>,
110 ref_count: AtomicUsize,
111 shutdown_released_buffers: AtomicU32,
112 handle_drop_fn: fn(*mut Self),
113}
114
115#[derive(Copy, Clone, Debug, Eq, PartialEq)]
116pub(crate) enum BufferPoolShutdownStatus {
117 NotShutdown,
118 ShutdownNow,
119 AlreadyShutdown,
120}
121
122pub struct BufferPoolThreadGuard<'a> {
124 buffer_pool: &'a BufferPool,
125}
126
127impl Drop for BufferPoolThreadGuard<'_> {
128 fn drop(&mut self) {
129 self.buffer_pool
130 .decrement_local_buffers_in_use(self.buffer_pool.local_state());
131 }
132}
133
134unsafe impl Send for BufferPool {}
135unsafe impl Sync for BufferPool {}
136
137impl BufferPool {
138 pub fn total_buffer_count(&self) -> usize {
140 self.total_buffer_count
141 }
142
143 pub fn buffer_size(&self) -> usize {
145 self.buffer_size
146 }
147
148 #[inline]
149 pub fn buffer_by_id(&self, id: u32) -> BufferPtr {
150 let buffer_raw = unsafe { self.alloc.byte_add(id as usize * self.buffer_padded_size) };
151 BufferPtr::from_ptr(buffer_raw).unwrap()
152 }
153
154 pub fn register_thread(&self) -> BufferPoolThreadGuard<'_> {
159 self.increment_local_buffers_in_use(self.local_state());
160 BufferPoolThreadGuard { buffer_pool: self }
161 }
162
163 #[inline]
164 pub(crate) fn local_stock(&self) -> &LocalStock {
165 let local_stock = unsafe { &*self.local_stock.get() };
166 local_stock.get_or(|| CachePadded::new(LocalStock::new()))
167 }
168
169 #[inline]
170 pub(crate) fn local_state(&self) -> &LocalState {
171 self.local_state
172 .get_or(|| CachePadded::new(LocalState::new_heap(self.total_buffer_count as usize)))
173 }
174
175 pub(crate) fn increment_local_buffers_in_use(&self, local_state: &LocalState) {
176 let prev = local_state
177 .buffers_in_use
178 .replace(local_state.buffers_in_use.get() + 1);
179 if prev == 0 {
180 let mut ref_count = self.ref_count.load(Ordering::Relaxed);
181 while ref_count > 0 {
183 match self.ref_count.compare_exchange(
184 ref_count,
185 ref_count + 1,
186 Ordering::AcqRel,
187 Ordering::Relaxed,
188 ) {
189 Ok(_) => break,
191 Err(new_ref_count) => {
192 ref_count = new_ref_count;
193 }
194 }
195 }
196 }
197 }
198
199 pub(crate) fn decrement_local_buffers_in_use(
201 &self,
202 local_state: &LocalState,
203 ) -> BufferPoolShutdownStatus {
204 let prev = local_state
205 .buffers_in_use
206 .replace(local_state.buffers_in_use.get() - 1);
207 if prev == 1 {
208 if !self.is_shutting_down() {
211 let prev_ref_count = self.ref_count.fetch_sub(1, Ordering::AcqRel);
212 if prev_ref_count == 1 {
213 BufferPoolShutdownStatus::ShutdownNow
223 } else {
224 BufferPoolShutdownStatus::NotShutdown
225 }
226 } else {
227 BufferPoolShutdownStatus::AlreadyShutdown
228 }
229 } else {
230 BufferPoolShutdownStatus::NotShutdown
231 }
232 }
233
234 fn is_shutting_down(&self) -> bool {
235 match self
236 .ref_count
237 .compare_exchange(0, 0, Ordering::Acquire, Ordering::Relaxed)
238 {
239 Ok(_) => true,
240 Err(_) => false,
241 }
242 }
243
244 pub async fn acquire(&self) -> BufferPtr {
245 if let Some(buffer) = self.try_acquire() {
246 return buffer;
247 }
248
249 let buffer = Acquire {
251 buffer_pool: self,
252 waiter: &Waiter::new(&self.waiter_queue),
253 }
254 .await;
255
256 buffer
257 }
258
259 pub fn try_acquire(&self) -> Option<BufferPtr> {
260 let local_stock = self.local_stock();
261 if let Some(local_buffer) = local_stock.try_acquire() {
262 return Some(local_buffer);
263 }
264
265 self.try_acquire_batch(local_stock)
266 }
267
268 fn try_acquire_batch(&self, local_stock: &LocalStock) -> Option<BufferPtr> {
269 debug_assert!(local_stock.head.get().is_none());
270 debug_assert_eq!(local_stock.count.get(), 0);
271
272 if let Some(batch_head) = self.try_take_batch(local_stock) {
273 local_stock.head.set(unsafe { batch_head.swap_next(None) });
274 local_stock.count.set(self.batch_size as u32 - 1);
275 Some(batch_head)
276 } else {
277 None
278 }
279 }
280
281 fn try_take_batch(&self, local_stock: &LocalStock) -> Option<BufferPtr> {
282 debug_assert!(local_stock.head.get().is_none());
283 debug_assert_eq!(local_stock.count.get(), 0);
284
285 self.free_stack.pop()
286 }
287
288 pub unsafe fn release(&self, buffer: BufferPtr) {
289 debug_assert_eq!(unsafe { buffer.get_next() }, None);
290
291 if self.waiter_queue.notify_one_local(buffer).is_none() {
292 return;
293 }
294
295 let local_stock = self.local_stock();
296
297 if local_stock.count.get() == self.batch_size {
299 local_stock.watermark.set(Some(buffer));
301 }
302 unsafe {
303 buffer.set_next(local_stock.head.get());
304 }
305 local_stock.head.set(Some(buffer));
306 local_stock.count.set(local_stock.count.get() + 1);
307
308 self.release_overflow(local_stock);
309 }
310
311 fn release_overflow(&self, local_stock: &LocalStock) {
312 if local_stock.count.get() < (self.batch_size as u32 * 3) / 2 {
313 return;
314 }
315
316 while let Some(watermark) = local_stock.watermark.take() {
319 let release_head = unsafe { watermark.swap_next(None) }.unwrap();
320 let release_count = self.batch_size;
321 local_stock
322 .count
323 .set(local_stock.count.get() - self.batch_size as u32);
324
325 let mut waiter_queue_guard = None;
326 self.free_stack.push_if(release_head, |free_count| {
327 if free_count == 0 {
328 let guard = waiter_queue_guard.get_or_insert_with(|| self.waiter_queue.lock());
330
331 if guard.waiter_count() > 0 {
332 waiter_queue_guard
333 .take()
334 .expect("bug: missing lock guard")
335 .notify(release_head, release_count as usize);
336 return false;
339 }
340 }
341
342 true
343 });
344 self.find_watermark(local_stock);
347 }
348 }
349
350 fn release_many(&self, release_head: BufferPtr, release_count: usize) {
351 let local_stock = self.local_stock();
352
353 let mut tail = local_stock.head.get();
355 while let Some(next) = tail {
356 let new_tail = unsafe { next.get_next() };
357 if new_tail.is_none() {
358 break;
359 }
360 tail = new_tail;
361 }
362
363 if let Some(tail) = tail {
364 unsafe {
366 tail.set_next(Some(release_head));
367 }
368 local_stock
369 .count
370 .set(local_stock.count.get() + release_count as u32);
371 } else {
372 debug_assert_eq!(local_stock.head.get(), None);
374 debug_assert_eq!(local_stock.count.get(), 0);
375 local_stock.head.set(Some(release_head));
376 local_stock.count.set(release_count as u32);
377 }
378
379 self.find_watermark(local_stock);
382 self.release_overflow(local_stock);
383 }
384
385 fn find_watermark(&self, local_stock: &LocalStock) {
386 if local_stock.count.get() > self.batch_size as u32 {
387 let mut watermark = local_stock.head.get().unwrap();
388 for _ in 0..local_stock.count.get() - self.batch_size as u32 - 1 {
389 watermark = unsafe { watermark.get_next() }.unwrap();
390 }
391 local_stock.watermark.set(Some(watermark));
392 }
393 }
394
395 pub(crate) fn shutdown_now_try_drop(buffer_pool: *mut BufferPool) {
396 let this = unsafe { &*buffer_pool };
397 let local_stock = unsafe { &mut *this.local_stock.get() };
400
401 let mut released_buffers = 0;
402 for local_stock in local_stock.iter_mut() {
403 released_buffers += local_stock.count.get();
404 }
405 while let Some(_) = this.free_stack.pop() {
406 released_buffers += this.batch_size;
407 }
408
409 let total_buffer_count = this.total_buffer_count as u32;
413
414 let prev_released_buffers = this
415 .shutdown_released_buffers
416 .fetch_add(released_buffers, Ordering::Release);
417
418 if prev_released_buffers + released_buffers == total_buffer_count {
419 this.shutdown_released_buffers.load(Ordering::Acquire);
423
424 let handle_drop_fn = unsafe { (*buffer_pool).handle_drop_fn };
426 handle_drop_fn(buffer_pool as *mut BufferPool);
427 }
428 }
429
430 pub(crate) fn already_shutdown_release_buffer(buffer_pool: *mut BufferPool) {
431 let this = unsafe { &*buffer_pool };
432
433 let total_buffer_count = this.total_buffer_count as u32;
437
438 let prev_released_buffers = this
439 .shutdown_released_buffers
440 .fetch_add(1, Ordering::Release);
441
442 if prev_released_buffers + 1 == total_buffer_count as u32 {
443 this.shutdown_released_buffers.load(Ordering::Acquire);
447
448 let handle_drop_fn = unsafe { (*buffer_pool).handle_drop_fn };
450 handle_drop_fn(buffer_pool as *mut BufferPool);
451 }
452 }
453}
454
455impl Drop for BufferPool {
456 fn drop(&mut self) {
457 for local_state in self.local_state.iter_mut() {
459 let _ =
460 unsafe { Box::from_raw(local_state.local_buffer_state as *mut [LocalBufferState]) };
461 }
462
463 let _ = unsafe { dealloc(self.alloc as *mut u8, self.alloc_layout) };
464 }
465}
466
467pub struct HeapBufferPool {
468 ptr: *const BufferPool,
469}
470
471impl HeapBufferPool {
472 pub fn new(buffer_size: usize, batch_count: usize, batch_size: usize) -> Self {
473 fn padding_needed_for_layout(layout: Layout) -> usize {
475 let len = layout.size();
476 let align = layout.align();
477
478 (len.wrapping_add(align).wrapping_sub(1) & !align.wrapping_sub(1)).wrapping_sub(len)
479 }
480 fn repeat_layout(layout: Layout, n: usize) -> (Layout, usize) {
481 let padded_size = layout.size() + padding_needed_for_layout(layout);
482 let alloc_size = padded_size.checked_mul(n).unwrap();
483
484 let layout = Layout::from_size_align(alloc_size, layout.align()).unwrap();
485 (layout, padded_size)
486 }
487
488 let total_buffer_count = batch_count * batch_size;
489 let buffer_layout = Buffer::layout_with_data(buffer_size);
490 let (alloc_layout, buffer_padded_size) = repeat_layout(buffer_layout, total_buffer_count);
491 let alloc = unsafe { alloc(alloc_layout) } as *mut Buffer;
492
493 let buffer_pool = Box::new(BufferPool {
494 alloc,
495 alloc_layout,
496 buffer_padded_size,
497 free_stack: FreeStack::new(batch_count),
498 waiter_queue: WaiterQueue::new(),
499 total_buffer_count,
500 buffer_size,
501 batch_size: batch_size as u32,
502 local_stock: UnsafeCell::new(ThreadLocal::new()),
503 local_state: ThreadLocal::new(),
504 ref_count: AtomicUsize::new(1),
505 shutdown_released_buffers: AtomicU32::new(0),
506 handle_drop_fn: |buffer_pool| {
507 let _ = unsafe { Box::from_raw(buffer_pool) };
508 },
509 });
510 let buffer_pool_ptr = Box::into_raw(buffer_pool);
511 let buffer_pool = unsafe { &*buffer_pool_ptr };
512
513 for id in 0..total_buffer_count {
515 let buffer = buffer_pool.buffer_by_id(id as u32);
516 unsafe {
517 Buffer::initialize(
518 buffer.as_ptr_mut(),
519 buffer_pool_ptr,
520 id as usize,
521 buffer_size,
522 );
523 }
524 }
525
526 let mut next_buffer_id = 0;
527 for _ in 0..batch_count {
528 let new_batch_head = buffer_pool.buffer_by_id(next_buffer_id);
529 next_buffer_id += 1;
530
531 let mut head = None;
532 for _ in 1..batch_size {
533 let next = head;
534 let new_head = buffer_pool.buffer_by_id(next_buffer_id);
535 head = Some(new_head);
536 next_buffer_id += 1;
537 unsafe {
538 new_head.set_next(next);
539 }
540 }
541 unsafe {
542 new_batch_head.set_next(head);
543 }
544
545 buffer_pool.free_stack.push_if(new_batch_head, |_| true);
546 }
547
548 Self {
549 ptr: buffer_pool_ptr,
550 }
551 }
552}
553
554unsafe impl Send for HeapBufferPool {}
555unsafe impl Sync for HeapBufferPool {}
556
557impl core::ops::Deref for HeapBufferPool {
558 type Target = BufferPool;
559
560 fn deref(&self) -> &Self::Target {
561 unsafe { &*self.ptr }
563 }
564}
565
566impl Clone for HeapBufferPool {
567 fn clone(&self) -> Self {
568 self.ref_count.fetch_add(1, Ordering::Relaxed);
569 Self { ptr: self.ptr }
570 }
571}
572
573impl Drop for HeapBufferPool {
574 fn drop(&mut self) {
575 let prev_rc = self.ref_count.fetch_sub(1, Ordering::Release);
576 if prev_rc == 1 {
577 self.ref_count.load(Ordering::Acquire);
581
582 BufferPool::shutdown_now_try_drop(self.ptr as *mut _);
584 }
585 }
586}
587
588impl IFulfillment for BufferPtr {
589 fn take_one(&mut self) -> Self {
590 let ptr = *self;
591 *self = unsafe { ptr.swap_next(None) }.unwrap();
592 ptr
593 }
594
595 fn append(&mut self, other: Self, _other_count: usize) {
596 let mut tail = *self;
597 while let Some(next) = unsafe { tail.get_next() } {
598 tail = next;
599 }
600
601 unsafe {
602 tail.set_next(Some(other));
603 }
604 }
605}
606
607pub struct Acquire<'a> {
608 buffer_pool: &'a BufferPool,
609 waiter: &'a Waiter<'a, BufferPtr>,
610}
611
612impl<'a> Acquire<'a> {
613 fn waiter(self: Pin<&'_ Self>) -> Pin<&'_ Waiter<'a, BufferPtr>> {
614 unsafe { self.map_unchecked(|s| s.waiter) }
616 }
617}
618
619impl Future for Acquire<'_> {
620 type Output = BufferPtr;
621
622 fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
623 let buffer_pool = self.buffer_pool;
624 let local_stock = self.buffer_pool.local_stock();
625 let Poll::Ready(fulfillment) = self.as_ref().waiter().poll_fulfillment(context, || {
626 if let Some(local_head) = local_stock.head.replace(None) {
627 Some(Fulfillment {
630 inner: local_head,
631 count: local_stock.count.replace(0) as usize,
632 })
633 } else {
634 buffer_pool
635 .try_take_batch(local_stock)
636 .map(|ptr| Fulfillment {
637 inner: ptr,
638 count: buffer_pool.batch_size as usize,
639 })
640 }
641 }) else {
642 return Poll::Pending;
643 };
644
645 let extra_head = unsafe { fulfillment.inner.swap_next(None) };
646 let extra_count = fulfillment.count as usize - 1;
647
648 if let Some(extra_head) = extra_head {
649 debug_assert!(extra_count > 0);
650 self.buffer_pool.release_many(extra_head, extra_count);
651 }
652
653 Poll::Ready(fulfillment.inner)
654 }
655}
656
657impl Drop for Acquire<'_> {
658 fn drop(&mut self) {
659 if let Some(fulfillment) = self.waiter.cancel() {
660 self.buffer_pool
661 .release_many(fulfillment.inner, fulfillment.count as usize);
662 }
663 }
664}
665
666#[cfg(test)]
667mod test {
668 use super::*;
669
670 #[test]
671 fn test_buffer_fulfillment_append_and_take_one() {
672 let batch_count = 16;
673 let batch_size = 16;
674 let buffer_pool = HeapBufferPool::new(16, batch_count, batch_size);
675
676 let a = buffer_pool.try_acquire().unwrap();
677 let b = buffer_pool.try_acquire().unwrap();
678 let c = buffer_pool.try_acquire().unwrap();
679 for buffer in [a, b, c] {
680 unsafe {
681 buffer.initialize_rc(1, 0, 0);
682 }
683 }
684
685 assert_eq!(a.count(), 1);
686 assert_eq!(b.count(), 1);
687 assert_eq!(c.count(), 1);
688
689 let mut f = Fulfillment { inner: a, count: 1 };
690
691 f.append(Fulfillment { inner: b, count: 1 });
692 assert_eq!((f.inner, f.count), (a, 2));
693 assert_eq!(a.count(), 2);
694 assert_eq!(b.count(), 1);
695 assert_eq!(c.count(), 1);
696 f.append(Fulfillment { inner: c, count: 1 });
697 assert_eq!((f.inner, f.count), (a, 3));
698 assert_eq!(a.count(), 3);
699 assert_eq!(b.count(), 2);
700 assert_eq!(c.count(), 1);
701
702 let taken = f.take_one();
703 assert_eq!(taken, a);
704 assert_eq!((f.inner, f.count), (b, 2));
705 assert_eq!(a.count(), 1);
706 assert_eq!(b.count(), 2);
707 assert_eq!(c.count(), 1);
708
709 let taken = f.take_one();
710 assert_eq!(taken, b);
711 assert_eq!((f.inner, f.count), (c, 1));
712 assert_eq!(a.count(), 1);
713 assert_eq!(b.count(), 1);
714 assert_eq!(c.count(), 1);
715
716 for buffer in [a, b, c] {
717 unsafe {
718 buffer.release_ref(1);
719 }
720 }
721 }
722
723 #[test]
724 fn test_buffer_pool_shutdown_send_packet() {
725 let batch_count = 16;
726 let batch_size = 16;
727 let buffer_pool = HeapBufferPool::new(16, batch_count, batch_size);
728
729 let a = buffer_pool.try_acquire().unwrap();
730 let b = buffer_pool.try_acquire().unwrap();
731
732 unsafe {
733 a.initialize_rc(1, 0, 0);
734 b.initialize_rc(1, 1, 1);
735 }
736
737 drop(buffer_pool);
738
739 unsafe {
740 a.release_ref(1);
741 assert_eq!(b.send_bulk(1), 1);
742 b.receive(1);
743 b.release_ref(1);
744 }
745 }
746
747 #[cfg(feature = "std")]
748 #[test]
749 fn test_buffer_pool_local_acquire_waiter() {
750 use std::rc::Rc;
751
752 let batch_count = 16;
753 let batch_size = 2;
754 let waiter_count = 8;
755 let buffer_pool = HeapBufferPool::new(64, batch_count, batch_size);
756
757 let ex = async_executor::LocalExecutor::new();
758
759 pollster::block_on(ex.run(async {
760 let channel = Rc::new(async_unsync::unbounded::channel());
762 let acquire_starts = Rc::new(async_unsync::semaphore::Semaphore::new(0));
763
764 for _ in 0..batch_count * batch_size {
765 let buf: BufferPtr = buffer_pool.acquire().await;
766 let data = unsafe {
767 core::slice::from_raw_parts_mut(buf.data(), buffer_pool.buffer_size())
768 };
769 unsafe {
770 buf.initialize_rc(1, 0, 0);
771 }
772 data[..4].copy_from_slice(&[1, 2, 3, 4]);
773 channel.send(buf).unwrap();
774 }
775
776 for _ in 0..waiter_count {
778 let buffer_pool = buffer_pool.clone();
779 let channel = channel.clone();
780 let acquire_starts = acquire_starts.clone();
781 ex.spawn(async move {
782 acquire_starts.add_permits(1);
783 let buf: BufferPtr = buffer_pool.acquire().await;
784 let data = unsafe {
785 core::slice::from_raw_parts_mut(buf.data(), buffer_pool.buffer_size())
786 };
787 unsafe {
788 buf.initialize_rc(1, 0, 0);
789 }
790 data[..4].copy_from_slice(&[1, 2, 3, 4]);
791 channel.send(buf).unwrap();
792 })
793 .detach();
794 }
795
796 for _ in 0..waiter_count {
797 acquire_starts.acquire().await.unwrap().forget();
798 }
799
800 for _ in 0..batch_count * batch_size + waiter_count {
802 let buf = channel.recv().await.unwrap();
803 let data = unsafe {
804 core::slice::from_raw_parts_mut(buf.data(), buffer_pool.buffer_size())
805 };
806 assert_eq!(&data[..4], &[1, 2, 3, 4]);
807 unsafe {
808 buf.release_ref(1);
809 }
810 }
811
812 assert!(channel.try_recv().is_err());
813 }));
814 }
815}