piece/
chain_allocator.rs

1use crate::alloc::{AllocError, Allocator};
2
3#[cfg(not(feature = "boxed"))]
4use alloc_crate::boxed::Box;
5
6#[cfg(feature = "boxed")]
7use crate::boxed::Box;
8
9use core::{alloc::Layout, cell::Cell, marker::PhantomData, ptr::NonNull};
10
11struct AllocatorNode<A> {
12    allocator: A,
13    next_allocator: AllocatorRef<A>,
14    _owns: PhantomData<AllocatorNode<A>>,
15}
16
17impl<A: Allocator> AllocatorNode<A> {
18    fn new(allocator: A) -> Self {
19        Self {
20            allocator,
21            next_allocator: AllocatorRef::new(None),
22            _owns: PhantomData,
23        }
24    }
25
26    fn with_next(allocator: A, next: NonNull<Self>) -> Self {
27        Self {
28            allocator,
29            next_allocator: AllocatorRef::new(Some(next)),
30            _owns: PhantomData,
31        }
32    }
33
34    fn allocate_and_track(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
35        let (layout_with_footer, footer_offset) =
36            layout.extend(Layout::new::<AllocationFooter<A>>()).unwrap();
37
38        let buf_ptr = self.allocator.allocate(layout_with_footer)?.cast::<u8>();
39
40        // SAFETY: buf_ptr is valid because allocation succeeded
41        unsafe {
42            core::ptr::write(
43                buf_ptr.as_ptr().add(footer_offset).cast(),
44                AllocationFooter::<A> {
45                    allocator_node: NonNull::from(self),
46                },
47            );
48        };
49
50        Ok(NonNull::slice_from_raw_parts(buf_ptr, layout.size()))
51    }
52
53    unsafe fn grow_and_track(
54        &self,
55        ptr: NonNull<u8>,
56        old_layout: Layout,
57        new_layout: Layout,
58    ) -> Result<NonNull<[u8]>, AllocError> {
59        let (old_layout_with_footer, _) = old_layout
60            .extend(Layout::new::<AllocationFooter<A>>())
61            .unwrap();
62
63        let (new_layout_with_footer, footer_offset) = new_layout
64            .extend(Layout::new::<AllocationFooter<A>>())
65            .unwrap();
66
67        let buf_ptr = unsafe {
68            self.allocator
69                .grow(ptr, old_layout_with_footer, new_layout_with_footer)
70        }?
71        .cast::<u8>();
72
73        // SAFETY: buf_ptr is valid because allocation succeeded
74        unsafe {
75            core::ptr::write(
76                buf_ptr.as_ptr().add(footer_offset).cast(),
77                AllocationFooter::<A> {
78                    allocator_node: NonNull::from(self),
79                },
80            );
81        };
82
83        Ok(NonNull::slice_from_raw_parts(buf_ptr, new_layout.size()))
84    }
85
86    unsafe fn grow_zeroed_and_track(
87        &self,
88        ptr: NonNull<u8>,
89        old_layout: Layout,
90        new_layout: Layout,
91    ) -> Result<NonNull<[u8]>, AllocError> {
92        let (old_layout_with_footer, old_footer_offset) = old_layout
93            .extend(Layout::new::<AllocationFooter<A>>())
94            .unwrap();
95
96        let (new_layout_with_footer, new_footer_offset) = new_layout
97            .extend(Layout::new::<AllocationFooter<A>>())
98            .unwrap();
99
100        let buf_ptr = unsafe {
101            self.allocator
102                .grow_zeroed(ptr, old_layout_with_footer, new_layout_with_footer)
103        }?
104        .cast::<u8>();
105
106        // Zero the previous pointer location
107        // SAFETY: buffer returned should be bigger than the previous one.
108        unsafe {
109            core::ptr::write_bytes(
110                buf_ptr.as_ptr().add(old_footer_offset),
111                0,
112                core::mem::size_of::<AllocationFooter<A>>(),
113            );
114        };
115
116        // SAFETY: buf_ptr is valid because allocation succeeded
117        unsafe {
118            core::ptr::write(
119                buf_ptr.as_ptr().add(new_footer_offset).cast(),
120                AllocationFooter::<A> {
121                    allocator_node: NonNull::from(self),
122                },
123            );
124        };
125
126        Ok(NonNull::slice_from_raw_parts(buf_ptr, new_layout.size()))
127    }
128
129    unsafe fn shrink_and_track(
130        &self,
131        ptr: NonNull<u8>,
132        old_layout: Layout,
133        new_layout: Layout,
134    ) -> Result<NonNull<[u8]>, AllocError> {
135        let (old_layout_with_footer, _) = old_layout
136            .extend(Layout::new::<AllocationFooter<A>>())
137            .unwrap();
138
139        let (new_layout_with_footer, footer_offset) = new_layout
140            .extend(Layout::new::<AllocationFooter<A>>())
141            .unwrap();
142
143        let buf_ptr = unsafe {
144            self.allocator
145                .shrink(ptr, old_layout_with_footer, new_layout_with_footer)
146        }?
147        .cast::<u8>();
148
149        // SAFETY: buf_ptr is valid because allocation succeeded
150        unsafe {
151            core::ptr::write(
152                buf_ptr.as_ptr().add(footer_offset).cast(),
153                AllocationFooter::<A> {
154                    allocator_node: NonNull::from(self),
155                },
156            );
157        };
158
159        Ok(NonNull::slice_from_raw_parts(buf_ptr, new_layout.size()))
160    }
161
162    // SAFETY: layout and ptr have been used before to make an allocation
163    unsafe fn ref_from_allocation<'a>(layout: Layout, ptr: NonNull<u8>) -> (Layout, &'a Self) {
164        let (layout_with_footer, footer_offset) = layout.extend(Layout::new::<Self>()).unwrap();
165
166        let footer_ptr: *mut AllocationFooter<A> = ptr.as_ptr().add(footer_offset).cast();
167
168        let allocator_node = NonNull::new_unchecked(footer_ptr)
169            .as_ref()
170            .allocator_node
171            .as_ref();
172
173        (layout_with_footer, allocator_node)
174    }
175}
176
177type AllocatorRef<A> = Cell<Option<NonNull<AllocatorNode<A>>>>;
178
179/// A [`ChainAllocator<A>`] create a new allocator of type `A` when the existing allocators of this
180/// type are exausted.
181///
182/// It can be useful when used with a [`LinearAllocator`] for example. When
183/// all of its memory is used, the [`ChainAllocator`] will create a new one. This is useful when
184/// you want to use fixed-sized allocators but you're worried that your program will run out of
185/// memory.
186///
187/// There's some overhead when using the [`ChainAllocator`]. Currently, every allocation has an
188/// extra pointer that refers to the allocator, to make deallocation possible. Also the allocators
189/// themselves are allocated using the [`Box`].
190///
191/// # Usage:
192/// ```
193/// #![cfg_attr(not(feature = "stable"), feature(allocator_api))]
194/// #[cfg(feature="vec")]
195/// {
196///     use core::mem::size_of;
197///
198///     use piece::vec::Vec;
199///     use piece::LinearAllocator;
200///     use piece::ChainAllocator;
201///
202///     // Make room for the allocator pointer
203///     let chain_allocator = ChainAllocator::new(|| {
204///         LinearAllocator::with_capacity(32 * size_of::<i32>() + size_of::<*const ()>())
205///     });
206///
207///     // Create two vectors that fills the whole `LinearAllocator` so
208///     // each `Vec` creates a new allocator
209///     let mut vec1 = Vec::with_capacity_in(32, &chain_allocator);
210///     let mut vec2 = Vec::with_capacity_in(32, &chain_allocator);
211///
212///     vec1.extend_from_slice(&[1, 2, 3, 4, 5]);
213///     vec2.extend_from_slice(&[6, 7, 8, 9, 10]);
214///
215///     assert_eq!(vec1, &[1, 2, 3, 4, 5]);
216///     assert_eq!(vec2, &[6, 7, 8, 9, 10]);
217///
218///     assert_eq!(2, chain_allocator.allocator_count());
219/// }
220/// ```
221///
222/// [`LinearAllocator`]: crate::LinearAllocator
223pub struct ChainAllocator<A, F> {
224    next_allocator: AllocatorRef<A>,
225    _owns: PhantomData<AllocatorNode<A>>,
226    allocator_factory: F,
227}
228
229// SAFETY: It's safe to send them across threads because there's no way to get a references to
230// allocation nodes, so no alias happens
231unsafe impl<A: Send, F> Send for ChainAllocator<A, F> {}
232
233impl<A, F> Drop for ChainAllocator<A, F> {
234    fn drop(&mut self) {
235        while let Some(alloc_node_ptr) = self.next_allocator.get() {
236            // SAFETY: alloc_node_ptr was allocated using `Box` and it's never dereferenced again
237            let alloc_node = unsafe { Box::from_raw(alloc_node_ptr.as_ptr()) };
238            self.next_allocator.set(alloc_node.next_allocator.get());
239        }
240    }
241}
242
243unsafe impl<A, F> Allocator for ChainAllocator<A, F>
244where
245    A: Allocator,
246    F: Fn() -> A,
247{
248    #[inline]
249    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
250        // No need to track zero size allocators(like Global), they are already free to create and
251        // all instances should be the same
252        if core::mem::size_of::<A>() == 0 {
253            let zero_sized_allocator = (self.allocator_factory)();
254            return zero_sized_allocator.allocate(layout);
255        }
256
257        match self.next_allocator.get() {
258            Some(next_allocator_node_ptr) => {
259                // SAFETY: Should be safe because ChainAllocator is not `Sync` and `Send`
260                let next_allocator_node = unsafe { next_allocator_node_ptr.as_ref() };
261
262                match next_allocator_node.allocate_and_track(layout) {
263                    Ok(buf_ptr) => Ok(buf_ptr),
264                    Err(_) => {
265                        let allocator = (self.allocator_factory)();
266                        let allocator_node =
267                            AllocatorNode::with_next(allocator, next_allocator_node_ptr);
268
269                        self.allocate_and_track_node(allocator_node, layout)
270                    }
271                }
272            }
273            None => {
274                let allocator = (self.allocator_factory)();
275                let allocator_node = AllocatorNode::new(allocator);
276
277                self.allocate_and_track_node(allocator_node, layout)
278            }
279        }
280    }
281
282    #[inline]
283    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
284        // No need to track zero size allocators(like Global), they are already free to create and
285        // all instances should be the same
286
287        if core::mem::size_of::<A>() == 0 {
288            let zero_sized_allocator = (self.allocator_factory)();
289            return zero_sized_allocator.deallocate(ptr, layout);
290        }
291
292        let (layout_with_footer, allocator_node) =
293            AllocatorNode::<A>::ref_from_allocation(layout, ptr);
294
295        allocator_node.allocator.deallocate(ptr, layout_with_footer);
296    }
297
298    unsafe fn grow(
299        &self,
300        ptr: NonNull<u8>,
301        old_layout: Layout,
302        new_layout: Layout,
303    ) -> Result<NonNull<[u8]>, AllocError> {
304        debug_assert!(
305            new_layout.size() >= old_layout.size(),
306            "`new_layout.size()` must be greater than or equal to `old_layout.size()`"
307        );
308
309        if core::mem::size_of::<A>() == 0 {
310            let zero_sized_allocator = (self.allocator_factory)();
311            return zero_sized_allocator.grow(ptr, old_layout, new_layout);
312        }
313
314        let (_, allocator_node) = AllocatorNode::<A>::ref_from_allocation(old_layout, ptr);
315
316        if let Ok(ptr) = allocator_node.grow_and_track(ptr, old_layout, new_layout) {
317            return Ok(ptr);
318        }
319
320        let new_ptr = self.allocate(new_layout)?;
321
322        // SAFETY: because `new_layout.size()` must be greater than or equal to
323        // `old_layout.size()`, both the old and new memory allocation are valid for reads and
324        // writes for `old_layout.size()` bytes. Also, because the old allocation wasn't yet
325        // deallocated, it cannot overlap `new_ptr`. Thus, the call to `copy_nonoverlapping` is
326        // safe. The safety contract for `dealloc` must be upheld by the caller.
327        unsafe {
328            core::ptr::copy_nonoverlapping(
329                ptr.as_ptr(),
330                new_ptr.cast().as_ptr(),
331                old_layout.size(),
332            );
333            self.deallocate(ptr, old_layout);
334        }
335
336        Ok(new_ptr)
337    }
338
339    unsafe fn grow_zeroed(
340        &self,
341        ptr: NonNull<u8>,
342        old_layout: Layout,
343        new_layout: Layout,
344    ) -> Result<NonNull<[u8]>, AllocError> {
345        debug_assert!(
346            new_layout.size() >= old_layout.size(),
347            "`new_layout.size()` must be greater than or equal to `old_layout.size()`"
348        );
349
350        if core::mem::size_of::<A>() == 0 {
351            let zero_sized_allocator = (self.allocator_factory)();
352            return zero_sized_allocator.grow_zeroed(ptr, old_layout, new_layout);
353        }
354
355        let (_, allocator_node) = AllocatorNode::<A>::ref_from_allocation(old_layout, ptr);
356
357        if let Ok(ptr) = allocator_node.grow_zeroed_and_track(ptr, old_layout, new_layout) {
358            return Ok(ptr);
359        }
360
361        let new_ptr = self.allocate_zeroed(new_layout)?;
362
363        // SAFETY: because `new_layout.size()` must be greater than or equal to
364        // `old_layout.size()`, both the old and new memory allocation are valid for reads and
365        // writes for `old_layout.size()` bytes. Also, because the old allocation wasn't yet
366        // deallocated, it cannot overlap `new_ptr`. Thus, the call to `copy_nonoverlapping` is
367        // safe. The safety contract for `dealloc` must be upheld by the caller.
368        unsafe {
369            core::ptr::copy_nonoverlapping(
370                ptr.as_ptr(),
371                new_ptr.cast().as_ptr(),
372                old_layout.size(),
373            );
374            self.deallocate(ptr, old_layout);
375        }
376
377        Ok(new_ptr)
378    }
379
380    unsafe fn shrink(
381        &self,
382        ptr: NonNull<u8>,
383        old_layout: Layout,
384        new_layout: Layout,
385    ) -> Result<NonNull<[u8]>, AllocError> {
386        debug_assert!(
387            new_layout.size() <= old_layout.size(),
388            "`new_layout.size()` must be smaller than or equal to `old_layout.size()`"
389        );
390
391        if core::mem::size_of::<A>() == 0 {
392            let zero_sized_allocator = (self.allocator_factory)();
393            return zero_sized_allocator.grow_zeroed(ptr, old_layout, new_layout);
394        }
395
396        let (_, allocator_node) = AllocatorNode::<A>::ref_from_allocation(old_layout, ptr);
397
398        if let Ok(ptr) = allocator_node.shrink_and_track(ptr, old_layout, new_layout) {
399            return Ok(ptr);
400        }
401
402        let new_ptr = self.allocate(new_layout)?;
403
404        // SAFETY: because `new_layout.size()` must be lower than or equal to
405        // `old_layout.size()`, both the old and new memory allocation are valid for reads and
406        // writes for `new_layout.size()` bytes. Also, because the old allocation wasn't yet
407        // deallocated, it cannot overlap `new_ptr`. Thus, the call to `copy_nonoverlapping` is
408        // safe. The safety contract for `dealloc` must be upheld by the caller.
409        unsafe {
410            core::ptr::copy_nonoverlapping(
411                ptr.as_ptr(),
412                new_ptr.cast().as_ptr(),
413                new_layout.size(),
414            );
415            self.deallocate(ptr, old_layout);
416        }
417
418        Ok(new_ptr)
419    }
420}
421
422#[repr(transparent)]
423struct AllocationFooter<A> {
424    allocator_node: NonNull<AllocatorNode<A>>,
425}
426
427impl<A: Allocator, F> ChainAllocator<A, F> {
428    fn allocate_and_track_node(
429        &self,
430        allocator_node: AllocatorNode<A>,
431        layout: Layout,
432    ) -> Result<NonNull<[u8]>, AllocError>
433    where
434        A: Allocator,
435    {
436        #[cfg(feature = "boxed")]
437        let allocator_node = Box::try_new(allocator_node)?;
438
439        #[cfg(not(feature = "boxed"))]
440        let allocator_node = Box::new(allocator_node);
441
442        let allocation = allocator_node.allocate_and_track(layout)?;
443
444        // SAFETY: pointers from `Box` are always valid
445        self.next_allocator.set(Some(unsafe {
446            NonNull::new_unchecked(Box::into_raw(allocator_node))
447        }));
448
449        Ok(allocation)
450    }
451}
452
453impl<A, F> ChainAllocator<A, F>
454where
455    F: Fn() -> A,
456{
457    /// Creates a empty [`ChainAllocator<A>`]. `allocator_factory` should create a fresh allocator.
458    #[inline]
459    #[must_use]
460    pub const fn new(allocator_factory: F) -> Self {
461        Self {
462            next_allocator: AllocatorRef::new(None),
463            allocator_factory,
464            _owns: PhantomData,
465        }
466    }
467
468    /// Returns the number of allocators created by this [`ChainAllocator<A>`].
469    pub fn allocator_count(&self) -> usize {
470        let mut next_allocator = self.next_allocator.get();
471
472        let mut count = 0;
473        while let Some(alloc_node_ptr) = next_allocator {
474            // SAFETY: it's not possible to get a reference to an allocation node outside the
475            // crate
476            next_allocator = unsafe { alloc_node_ptr.as_ref() }.next_allocator.get();
477            count += 1;
478        }
479        count
480    }
481}
482
483#[cfg(test)]
484mod test {
485    use super::*;
486    use crate::linear_allocator::LinearAllocator;
487    use core::mem::size_of;
488
489    #[test]
490    fn should_alloc_zeroed() {
491        let chain_allocator =
492            ChainAllocator::new(|| LinearAllocator::with_capacity(32 + size_of::<*const ()>()));
493
494        let layout = Layout::array::<u8>(32).unwrap();
495        let allocation = chain_allocator.allocate_zeroed(layout).unwrap();
496
497        let slice = unsafe { allocation.as_ref() };
498        assert_eq!(slice.len(), 32);
499        assert_eq!(slice, [0; 32]);
500        assert_eq!(chain_allocator.allocator_count(), 1);
501
502        unsafe { chain_allocator.deallocate(allocation.cast(), layout) };
503    }
504
505    #[test]
506    fn should_grow_allocation() {
507        let chain_allocator =
508            ChainAllocator::new(|| LinearAllocator::with_capacity(128 + size_of::<*const ()>()));
509
510        let old_layout = Layout::array::<u8>(32).unwrap();
511        let old_allocation = chain_allocator.allocate(old_layout).unwrap();
512
513        let new_layout = Layout::array::<u8>(64).unwrap();
514
515        let new_allocation = unsafe {
516            chain_allocator
517                .grow(old_allocation.cast(), old_layout, new_layout)
518                .unwrap()
519        };
520
521        let slice = unsafe { new_allocation.as_ref() };
522        assert_eq!(slice.len(), 64);
523
524        unsafe { chain_allocator.deallocate(new_allocation.cast(), new_layout) };
525    }
526
527    #[test]
528    fn should_grow_zeroed_allocation() {
529        let chain_allocator =
530            ChainAllocator::new(|| LinearAllocator::with_capacity(128 + size_of::<*const ()>()));
531
532        let old_layout = Layout::array::<u8>(32).unwrap();
533        let mut old_allocation = chain_allocator.allocate(old_layout).unwrap();
534
535        {
536            let slice = unsafe { old_allocation.as_mut() };
537            slice.fill(1);
538            assert_eq!(slice, [1; 32]);
539        }
540
541        let new_layout = Layout::array::<u8>(64).unwrap();
542        let new_allocation = unsafe {
543            chain_allocator
544                .grow_zeroed(old_allocation.cast(), old_layout, new_layout)
545                .unwrap()
546        };
547
548        let slice = unsafe { new_allocation.as_ref() };
549
550        assert_eq!(slice.len(), 64);
551        assert_eq!(slice[..32], [1; 32]);
552        assert_eq!(slice[32..], [0; 32]);
553
554        unsafe { chain_allocator.deallocate(new_allocation.cast(), new_layout) };
555    }
556
557    #[test]
558    fn should_shrink_allocation() {
559        let chain_allocator =
560            ChainAllocator::new(|| LinearAllocator::with_capacity(128 + size_of::<*const ()>()));
561
562        let old_layout = Layout::array::<u8>(64).unwrap();
563        let mut old_allocation = chain_allocator.allocate(old_layout).unwrap();
564
565        {
566            let slice = unsafe { old_allocation.as_mut() };
567            slice.fill(1);
568            assert_eq!(slice, [1; 64]);
569        }
570
571        let new_layout = Layout::array::<u8>(32).unwrap();
572        let new_allocation = unsafe {
573            chain_allocator
574                .shrink(old_allocation.cast(), old_layout, new_layout)
575                .unwrap()
576        };
577
578        let slice = unsafe { new_allocation.as_ref() };
579
580        assert_eq!(slice.len(), 32);
581        assert_eq!(slice, [1; 32]);
582
583        unsafe { chain_allocator.deallocate(new_allocation.cast(), new_layout) };
584    }
585}