1use aligned_buffer::{
2 alloc::{AllocError, Allocator, BufferAllocator, Global, RawBuffer},
3 SharedAlignedBuffer, UniqueAlignedBuffer, DEFAULT_BUFFER_ALIGNMENT,
4};
5use crossbeam_queue::ArrayQueue;
6use std::{
7 alloc::Layout,
8 mem::ManuallyDrop,
9 ptr::NonNull,
10 sync::{Arc, Weak},
11};
12
13pub type UniquePooledAlignedBuffer<
14 P = RetainAllRetentionPolicy,
15 const ALIGNMENT: usize = DEFAULT_BUFFER_ALIGNMENT,
16 A = Global,
17> = UniqueAlignedBuffer<
18 ALIGNMENT,
19 BufferPoolAllocator<P, ALIGNMENT, WeakAlignedBufferPool<P, ALIGNMENT, A>, A>,
20>;
21
22pub type SharedPooledAlignedBuffer<
23 P = RetainAllRetentionPolicy,
24 const ALIGNMENT: usize = DEFAULT_BUFFER_ALIGNMENT,
25 A = Global,
26> = SharedAlignedBuffer<
27 ALIGNMENT,
28 BufferPoolAllocator<P, ALIGNMENT, WeakAlignedBufferPool<P, ALIGNMENT, A>, A>,
29>;
30
31pub trait BufferRetentionPolicy: Clone {
33 fn should_retain(&self, capaicty: usize) -> bool;
34}
35
36#[derive(Default, Clone, Copy)]
38pub struct RetainAllRetentionPolicy;
39
40impl BufferRetentionPolicy for RetainAllRetentionPolicy {
41 #[inline(always)]
42 fn should_retain(&self, _: usize) -> bool {
43 true
44 }
45}
46
47#[derive(Default, Clone, Copy)]
49pub struct ConstMaxSizeRetentionPolicy<const SIZE: usize>;
50
51impl<const SIZE: usize> BufferRetentionPolicy for ConstMaxSizeRetentionPolicy<SIZE> {
52 #[inline(always)]
53 fn should_retain(&self, capacity: usize) -> bool {
54 capacity <= SIZE
55 }
56}
57
58pub(crate) trait WeakAlignedBufferPoolRef<
59 P: BufferRetentionPolicy,
60 const ALIGNMENT: usize,
61 A: Allocator + Clone,
62>: Clone
63{
64 fn with<F>(&self, f: F)
65 where
66 F: FnOnce(&AlignedBufferPoolInner<P, ALIGNMENT, Self, A>);
67}
68
69pub struct BufferPoolAllocator<
70 P: BufferRetentionPolicy,
71 const ALIGNMENT: usize,
72 R,
73 A: Allocator + Clone,
74> {
75 policy: P,
76 alloc: A,
77 pool_ref: R,
78}
79
80unsafe impl<
81 P: BufferRetentionPolicy,
82 const ALIGNMENT: usize,
83 R: WeakAlignedBufferPoolRef<P, ALIGNMENT, A>,
84 A: Allocator + Clone,
85 > Allocator for BufferPoolAllocator<P, ALIGNMENT, R, A>
86{
87 #[inline(always)]
88 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
89 self.alloc.allocate(layout)
90 }
91
92 #[inline(always)]
93 fn allocate_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
94 self.alloc.allocate_zeroed(layout)
95 }
96
97 #[inline(always)]
98 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
99 self.alloc.deallocate(ptr, layout)
100 }
101
102 #[inline(always)]
103 unsafe fn grow(
104 &self,
105 ptr: NonNull<u8>,
106 old_layout: Layout,
107 new_layout: Layout,
108 ) -> Result<NonNull<[u8]>, AllocError> {
109 self.alloc.grow(ptr, old_layout, new_layout)
110 }
111
112 #[inline(always)]
113 unsafe fn grow_zeroed(
114 &self,
115 ptr: NonNull<u8>,
116 old_layout: Layout,
117 new_layout: Layout,
118 ) -> Result<NonNull<[u8]>, AllocError> {
119 self.alloc.grow_zeroed(ptr, old_layout, new_layout)
120 }
121
122 #[inline(always)]
123 unsafe fn shrink(
124 &self,
125 ptr: NonNull<u8>,
126 old_layout: Layout,
127 new_layout: Layout,
128 ) -> Result<NonNull<[u8]>, AllocError> {
129 self.alloc.shrink(ptr, old_layout, new_layout)
130 }
131}
132
133impl<
134 P: BufferRetentionPolicy,
135 const ALIGNMENT: usize,
136 R: WeakAlignedBufferPoolRef<P, ALIGNMENT, A>,
137 A: Allocator + Clone,
138 > BufferAllocator<ALIGNMENT> for BufferPoolAllocator<P, ALIGNMENT, R, A>
139{
140 unsafe fn deallocate_buffer(&self, raw: RawBuffer<ALIGNMENT>) {
141 struct DeallocOnDrop<'a, const ALIGNMENT: usize, A: Allocator> {
142 allocator: &'a A,
143 raw: RawBuffer<ALIGNMENT>,
144 }
145
146 impl<'a, const ALIGNMENT: usize, A: Allocator> Drop for DeallocOnDrop<'a, ALIGNMENT, A> {
147 fn drop(&mut self) {
148 let (ptr, layout) = self.raw.alloc_info();
149 unsafe { self.allocator.deallocate(ptr, layout) }
150 }
151 }
152
153 let guard = DeallocOnDrop {
154 allocator: &self.alloc,
155 raw,
156 };
157
158 if self.policy.should_retain(raw.capacity().size()) {
159 let alloc = self.clone();
160 self.pool_ref.with(move |pool| {
161 let unguard = ManuallyDrop::new(guard);
162 let buf = UniqueAlignedBuffer::from_raw_parts_in(raw.buf_ptr(), 0, raw.capacity(), alloc);
163 if let Err(pool) = pool.pool.push(buf) {
164 std::mem::forget(pool);
166
167 drop(ManuallyDrop::into_inner(unguard));
169 }
170 });
171 }
172 }
173}
174
175impl<
176 P: BufferRetentionPolicy,
177 const ALIGNMENT: usize,
178 R: WeakAlignedBufferPoolRef<P, ALIGNMENT, A>,
179 A: Allocator + Clone,
180 > Clone for BufferPoolAllocator<P, ALIGNMENT, R, A>
181{
182 fn clone(&self) -> Self {
183 Self {
184 policy: self.policy.clone(),
185 alloc: self.alloc.clone(),
186 pool_ref: self.pool_ref.clone(),
187 }
188 }
189}
190
191pub(crate) struct AlignedBufferPoolInner<
192 P: BufferRetentionPolicy,
193 const ALIGNMENT: usize,
194 R: WeakAlignedBufferPoolRef<P, ALIGNMENT, A>,
195 A: Allocator + Clone,
196> {
197 pool: ArrayQueue<UniqueAlignedBuffer<ALIGNMENT, BufferPoolAllocator<P, ALIGNMENT, R, A>>>,
198 alloc: BufferPoolAllocator<P, ALIGNMENT, R, A>,
199}
200
201impl<
202 P: BufferRetentionPolicy,
203 const ALIGNMENT: usize,
204 R: WeakAlignedBufferPoolRef<P, ALIGNMENT, A>,
205 A: Allocator + Clone,
206 > AlignedBufferPoolInner<P, ALIGNMENT, R, A>
207{
208 pub fn new(policy: P, alloc: A, self_ref: R, capacity: usize) -> Self {
209 Self {
210 pool: ArrayQueue::new(capacity),
211 alloc: BufferPoolAllocator {
212 policy,
213 alloc,
214 pool_ref: self_ref.clone(),
215 },
216 }
217 }
218
219 #[inline]
222 pub fn get(&self) -> UniqueAlignedBuffer<ALIGNMENT, BufferPoolAllocator<P, ALIGNMENT, R, A>> {
223 if let Some(buf) = self.pool.pop() {
224 buf
225 } else {
226 let alloc = self.alloc.clone();
227 UniqueAlignedBuffer::new_in(alloc.clone())
228 }
229 }
230}
231
232pub struct AlignedBufferPool<
234 P: BufferRetentionPolicy = RetainAllRetentionPolicy,
235 const ALIGNMENT: usize = DEFAULT_BUFFER_ALIGNMENT,
236 A: Allocator + Clone = Global,
237> {
238 inner: Arc<AlignedBufferPoolInner<P, ALIGNMENT, WeakAlignedBufferPool<P, ALIGNMENT, A>, A>>,
239}
240
241impl<P: BufferRetentionPolicy, const ALIGNMENT: usize> AlignedBufferPool<P, ALIGNMENT, Global>
242where
243 P: Default,
244{
245 pub fn new(capacity: usize) -> Self {
246 Self::with_policy(P::default(), capacity)
247 }
248
249 pub fn with_capacity(capacity: usize) -> Self {
250 Self::with_capacity_in(capacity, Global)
251 }
252}
253
254impl<P: BufferRetentionPolicy, const ALIGNMENT: usize> AlignedBufferPool<P, ALIGNMENT, Global> {
255 pub fn with_policy(policy: P, capacity: usize) -> Self {
256 Self::new_in(policy, capacity, Global)
257 }
258}
259
260impl<P: BufferRetentionPolicy, const ALIGNMENT: usize, A: Allocator + Clone>
261 AlignedBufferPool<P, ALIGNMENT, A>
262where
263 P: Default,
264{
265 pub fn with_capacity_in(capacity: usize, alloc: A) -> Self {
266 Self::new_in(P::default(), capacity, alloc)
267 }
268}
269
270impl<P: BufferRetentionPolicy, const ALIGNMENT: usize, A: Allocator + Clone>
271 AlignedBufferPool<P, ALIGNMENT, A>
272{
273 pub fn new_in(policy: P, capacity: usize, alloc: A) -> Self {
274 Self {
275 inner: Arc::new_cyclic(|weak| {
276 let weak = WeakAlignedBufferPool {
277 inner: weak.clone(),
278 };
279 AlignedBufferPoolInner::new(policy, alloc, weak, capacity)
280 }),
281 }
282 }
283
284 #[inline]
287 pub fn get(&self) -> UniquePooledAlignedBuffer<P, ALIGNMENT, A> {
288 self.inner.get()
289 }
290
291 pub fn weak(&self) -> WeakAlignedBufferPool<P, ALIGNMENT, A> {
292 WeakAlignedBufferPool {
293 inner: Arc::downgrade(&self.inner),
294 }
295 }
296}
297
298pub struct WeakAlignedBufferPool<
299 P: BufferRetentionPolicy = RetainAllRetentionPolicy,
300 const ALIGNMENT: usize = DEFAULT_BUFFER_ALIGNMENT,
301 A: Allocator + Clone = Global,
302> {
303 inner: Weak<AlignedBufferPoolInner<P, ALIGNMENT, WeakAlignedBufferPool<P, ALIGNMENT, A>, A>>,
304}
305
306impl<P: BufferRetentionPolicy, const ALIGNMENT: usize, A: Allocator + Clone> Clone
307 for WeakAlignedBufferPool<P, ALIGNMENT, A>
308{
309 #[inline]
310 fn clone(&self) -> Self {
311 Self {
312 inner: self.inner.clone(),
313 }
314 }
315}
316
317impl<P: BufferRetentionPolicy, const ALIGNMENT: usize, A: Allocator + Clone>
318 WeakAlignedBufferPoolRef<P, ALIGNMENT, A> for WeakAlignedBufferPool<P, ALIGNMENT, A>
319{
320 fn with<F>(&self, f: F)
321 where
322 F: FnOnce(&AlignedBufferPoolInner<P, ALIGNMENT, Self, A>),
323 {
324 if let Some(inner) = self.inner.upgrade() {
325 f(&inner);
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn empty_pool_reuses_buffers() {
336 let pool = AlignedBufferPool::<RetainAllRetentionPolicy, 64>::with_capacity(2);
337 let mut buf = pool.get();
338 buf.extend([1, 2, 3]);
339 drop(buf);
340
341 let buf = pool.get();
342 assert!(buf.capacity() >= 3);
343 }
344}