1use crate::alloc::{AllocError, Allocator};
2
3#[cfg(not(feature = "boxed"))]
4use alloc_crate::boxed::Box;
5
6#[cfg(feature = "boxed")]
7use crate::boxed::Box;
8
9use core::{alloc::Layout, cell::Cell, marker::PhantomData, ptr::NonNull};
10
11struct AllocatorNode<A> {
12 allocator: A,
13 next_allocator: AllocatorRef<A>,
14 _owns: PhantomData<AllocatorNode<A>>,
15}
16
17impl<A: Allocator> AllocatorNode<A> {
18 fn new(allocator: A) -> Self {
19 Self {
20 allocator,
21 next_allocator: AllocatorRef::new(None),
22 _owns: PhantomData,
23 }
24 }
25
26 fn with_next(allocator: A, next: NonNull<Self>) -> Self {
27 Self {
28 allocator,
29 next_allocator: AllocatorRef::new(Some(next)),
30 _owns: PhantomData,
31 }
32 }
33
34 fn allocate_and_track(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
35 let (layout_with_footer, footer_offset) =
36 layout.extend(Layout::new::<AllocationFooter<A>>()).unwrap();
37
38 let buf_ptr = self.allocator.allocate(layout_with_footer)?.cast::<u8>();
39
40 unsafe {
42 core::ptr::write(
43 buf_ptr.as_ptr().add(footer_offset).cast(),
44 AllocationFooter::<A> {
45 allocator_node: NonNull::from(self),
46 },
47 );
48 };
49
50 Ok(NonNull::slice_from_raw_parts(buf_ptr, layout.size()))
51 }
52
53 unsafe fn grow_and_track(
54 &self,
55 ptr: NonNull<u8>,
56 old_layout: Layout,
57 new_layout: Layout,
58 ) -> Result<NonNull<[u8]>, AllocError> {
59 let (old_layout_with_footer, _) = old_layout
60 .extend(Layout::new::<AllocationFooter<A>>())
61 .unwrap();
62
63 let (new_layout_with_footer, footer_offset) = new_layout
64 .extend(Layout::new::<AllocationFooter<A>>())
65 .unwrap();
66
67 let buf_ptr = unsafe {
68 self.allocator
69 .grow(ptr, old_layout_with_footer, new_layout_with_footer)
70 }?
71 .cast::<u8>();
72
73 unsafe {
75 core::ptr::write(
76 buf_ptr.as_ptr().add(footer_offset).cast(),
77 AllocationFooter::<A> {
78 allocator_node: NonNull::from(self),
79 },
80 );
81 };
82
83 Ok(NonNull::slice_from_raw_parts(buf_ptr, new_layout.size()))
84 }
85
86 unsafe fn grow_zeroed_and_track(
87 &self,
88 ptr: NonNull<u8>,
89 old_layout: Layout,
90 new_layout: Layout,
91 ) -> Result<NonNull<[u8]>, AllocError> {
92 let (old_layout_with_footer, old_footer_offset) = old_layout
93 .extend(Layout::new::<AllocationFooter<A>>())
94 .unwrap();
95
96 let (new_layout_with_footer, new_footer_offset) = new_layout
97 .extend(Layout::new::<AllocationFooter<A>>())
98 .unwrap();
99
100 let buf_ptr = unsafe {
101 self.allocator
102 .grow_zeroed(ptr, old_layout_with_footer, new_layout_with_footer)
103 }?
104 .cast::<u8>();
105
106 unsafe {
109 core::ptr::write_bytes(
110 buf_ptr.as_ptr().add(old_footer_offset),
111 0,
112 core::mem::size_of::<AllocationFooter<A>>(),
113 );
114 };
115
116 unsafe {
118 core::ptr::write(
119 buf_ptr.as_ptr().add(new_footer_offset).cast(),
120 AllocationFooter::<A> {
121 allocator_node: NonNull::from(self),
122 },
123 );
124 };
125
126 Ok(NonNull::slice_from_raw_parts(buf_ptr, new_layout.size()))
127 }
128
129 unsafe fn shrink_and_track(
130 &self,
131 ptr: NonNull<u8>,
132 old_layout: Layout,
133 new_layout: Layout,
134 ) -> Result<NonNull<[u8]>, AllocError> {
135 let (old_layout_with_footer, _) = old_layout
136 .extend(Layout::new::<AllocationFooter<A>>())
137 .unwrap();
138
139 let (new_layout_with_footer, footer_offset) = new_layout
140 .extend(Layout::new::<AllocationFooter<A>>())
141 .unwrap();
142
143 let buf_ptr = unsafe {
144 self.allocator
145 .shrink(ptr, old_layout_with_footer, new_layout_with_footer)
146 }?
147 .cast::<u8>();
148
149 unsafe {
151 core::ptr::write(
152 buf_ptr.as_ptr().add(footer_offset).cast(),
153 AllocationFooter::<A> {
154 allocator_node: NonNull::from(self),
155 },
156 );
157 };
158
159 Ok(NonNull::slice_from_raw_parts(buf_ptr, new_layout.size()))
160 }
161
162 unsafe fn ref_from_allocation<'a>(layout: Layout, ptr: NonNull<u8>) -> (Layout, &'a Self) {
164 let (layout_with_footer, footer_offset) = layout.extend(Layout::new::<Self>()).unwrap();
165
166 let footer_ptr: *mut AllocationFooter<A> = ptr.as_ptr().add(footer_offset).cast();
167
168 let allocator_node = NonNull::new_unchecked(footer_ptr)
169 .as_ref()
170 .allocator_node
171 .as_ref();
172
173 (layout_with_footer, allocator_node)
174 }
175}
176
177type AllocatorRef<A> = Cell<Option<NonNull<AllocatorNode<A>>>>;
178
179pub struct ChainAllocator<A, F> {
224 next_allocator: AllocatorRef<A>,
225 _owns: PhantomData<AllocatorNode<A>>,
226 allocator_factory: F,
227}
228
229unsafe impl<A: Send, F> Send for ChainAllocator<A, F> {}
232
233impl<A, F> Drop for ChainAllocator<A, F> {
234 fn drop(&mut self) {
235 while let Some(alloc_node_ptr) = self.next_allocator.get() {
236 let alloc_node = unsafe { Box::from_raw(alloc_node_ptr.as_ptr()) };
238 self.next_allocator.set(alloc_node.next_allocator.get());
239 }
240 }
241}
242
243unsafe impl<A, F> Allocator for ChainAllocator<A, F>
244where
245 A: Allocator,
246 F: Fn() -> A,
247{
248 #[inline]
249 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
250 if core::mem::size_of::<A>() == 0 {
253 let zero_sized_allocator = (self.allocator_factory)();
254 return zero_sized_allocator.allocate(layout);
255 }
256
257 match self.next_allocator.get() {
258 Some(next_allocator_node_ptr) => {
259 let next_allocator_node = unsafe { next_allocator_node_ptr.as_ref() };
261
262 match next_allocator_node.allocate_and_track(layout) {
263 Ok(buf_ptr) => Ok(buf_ptr),
264 Err(_) => {
265 let allocator = (self.allocator_factory)();
266 let allocator_node =
267 AllocatorNode::with_next(allocator, next_allocator_node_ptr);
268
269 self.allocate_and_track_node(allocator_node, layout)
270 }
271 }
272 }
273 None => {
274 let allocator = (self.allocator_factory)();
275 let allocator_node = AllocatorNode::new(allocator);
276
277 self.allocate_and_track_node(allocator_node, layout)
278 }
279 }
280 }
281
282 #[inline]
283 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
284 if core::mem::size_of::<A>() == 0 {
288 let zero_sized_allocator = (self.allocator_factory)();
289 return zero_sized_allocator.deallocate(ptr, layout);
290 }
291
292 let (layout_with_footer, allocator_node) =
293 AllocatorNode::<A>::ref_from_allocation(layout, ptr);
294
295 allocator_node.allocator.deallocate(ptr, layout_with_footer);
296 }
297
298 unsafe fn grow(
299 &self,
300 ptr: NonNull<u8>,
301 old_layout: Layout,
302 new_layout: Layout,
303 ) -> Result<NonNull<[u8]>, AllocError> {
304 debug_assert!(
305 new_layout.size() >= old_layout.size(),
306 "`new_layout.size()` must be greater than or equal to `old_layout.size()`"
307 );
308
309 if core::mem::size_of::<A>() == 0 {
310 let zero_sized_allocator = (self.allocator_factory)();
311 return zero_sized_allocator.grow(ptr, old_layout, new_layout);
312 }
313
314 let (_, allocator_node) = AllocatorNode::<A>::ref_from_allocation(old_layout, ptr);
315
316 if let Ok(ptr) = allocator_node.grow_and_track(ptr, old_layout, new_layout) {
317 return Ok(ptr);
318 }
319
320 let new_ptr = self.allocate(new_layout)?;
321
322 unsafe {
328 core::ptr::copy_nonoverlapping(
329 ptr.as_ptr(),
330 new_ptr.cast().as_ptr(),
331 old_layout.size(),
332 );
333 self.deallocate(ptr, old_layout);
334 }
335
336 Ok(new_ptr)
337 }
338
339 unsafe fn grow_zeroed(
340 &self,
341 ptr: NonNull<u8>,
342 old_layout: Layout,
343 new_layout: Layout,
344 ) -> Result<NonNull<[u8]>, AllocError> {
345 debug_assert!(
346 new_layout.size() >= old_layout.size(),
347 "`new_layout.size()` must be greater than or equal to `old_layout.size()`"
348 );
349
350 if core::mem::size_of::<A>() == 0 {
351 let zero_sized_allocator = (self.allocator_factory)();
352 return zero_sized_allocator.grow_zeroed(ptr, old_layout, new_layout);
353 }
354
355 let (_, allocator_node) = AllocatorNode::<A>::ref_from_allocation(old_layout, ptr);
356
357 if let Ok(ptr) = allocator_node.grow_zeroed_and_track(ptr, old_layout, new_layout) {
358 return Ok(ptr);
359 }
360
361 let new_ptr = self.allocate_zeroed(new_layout)?;
362
363 unsafe {
369 core::ptr::copy_nonoverlapping(
370 ptr.as_ptr(),
371 new_ptr.cast().as_ptr(),
372 old_layout.size(),
373 );
374 self.deallocate(ptr, old_layout);
375 }
376
377 Ok(new_ptr)
378 }
379
380 unsafe fn shrink(
381 &self,
382 ptr: NonNull<u8>,
383 old_layout: Layout,
384 new_layout: Layout,
385 ) -> Result<NonNull<[u8]>, AllocError> {
386 debug_assert!(
387 new_layout.size() <= old_layout.size(),
388 "`new_layout.size()` must be smaller than or equal to `old_layout.size()`"
389 );
390
391 if core::mem::size_of::<A>() == 0 {
392 let zero_sized_allocator = (self.allocator_factory)();
393 return zero_sized_allocator.grow_zeroed(ptr, old_layout, new_layout);
394 }
395
396 let (_, allocator_node) = AllocatorNode::<A>::ref_from_allocation(old_layout, ptr);
397
398 if let Ok(ptr) = allocator_node.shrink_and_track(ptr, old_layout, new_layout) {
399 return Ok(ptr);
400 }
401
402 let new_ptr = self.allocate(new_layout)?;
403
404 unsafe {
410 core::ptr::copy_nonoverlapping(
411 ptr.as_ptr(),
412 new_ptr.cast().as_ptr(),
413 new_layout.size(),
414 );
415 self.deallocate(ptr, old_layout);
416 }
417
418 Ok(new_ptr)
419 }
420}
421
422#[repr(transparent)]
423struct AllocationFooter<A> {
424 allocator_node: NonNull<AllocatorNode<A>>,
425}
426
427impl<A: Allocator, F> ChainAllocator<A, F> {
428 fn allocate_and_track_node(
429 &self,
430 allocator_node: AllocatorNode<A>,
431 layout: Layout,
432 ) -> Result<NonNull<[u8]>, AllocError>
433 where
434 A: Allocator,
435 {
436 #[cfg(feature = "boxed")]
437 let allocator_node = Box::try_new(allocator_node)?;
438
439 #[cfg(not(feature = "boxed"))]
440 let allocator_node = Box::new(allocator_node);
441
442 let allocation = allocator_node.allocate_and_track(layout)?;
443
444 self.next_allocator.set(Some(unsafe {
446 NonNull::new_unchecked(Box::into_raw(allocator_node))
447 }));
448
449 Ok(allocation)
450 }
451}
452
453impl<A, F> ChainAllocator<A, F>
454where
455 F: Fn() -> A,
456{
457 #[inline]
459 #[must_use]
460 pub const fn new(allocator_factory: F) -> Self {
461 Self {
462 next_allocator: AllocatorRef::new(None),
463 allocator_factory,
464 _owns: PhantomData,
465 }
466 }
467
468 pub fn allocator_count(&self) -> usize {
470 let mut next_allocator = self.next_allocator.get();
471
472 let mut count = 0;
473 while let Some(alloc_node_ptr) = next_allocator {
474 next_allocator = unsafe { alloc_node_ptr.as_ref() }.next_allocator.get();
477 count += 1;
478 }
479 count
480 }
481}
482
483#[cfg(test)]
484mod test {
485 use super::*;
486 use crate::linear_allocator::LinearAllocator;
487 use core::mem::size_of;
488
489 #[test]
490 fn should_alloc_zeroed() {
491 let chain_allocator =
492 ChainAllocator::new(|| LinearAllocator::with_capacity(32 + size_of::<*const ()>()));
493
494 let layout = Layout::array::<u8>(32).unwrap();
495 let allocation = chain_allocator.allocate_zeroed(layout).unwrap();
496
497 let slice = unsafe { allocation.as_ref() };
498 assert_eq!(slice.len(), 32);
499 assert_eq!(slice, [0; 32]);
500 assert_eq!(chain_allocator.allocator_count(), 1);
501
502 unsafe { chain_allocator.deallocate(allocation.cast(), layout) };
503 }
504
505 #[test]
506 fn should_grow_allocation() {
507 let chain_allocator =
508 ChainAllocator::new(|| LinearAllocator::with_capacity(128 + size_of::<*const ()>()));
509
510 let old_layout = Layout::array::<u8>(32).unwrap();
511 let old_allocation = chain_allocator.allocate(old_layout).unwrap();
512
513 let new_layout = Layout::array::<u8>(64).unwrap();
514
515 let new_allocation = unsafe {
516 chain_allocator
517 .grow(old_allocation.cast(), old_layout, new_layout)
518 .unwrap()
519 };
520
521 let slice = unsafe { new_allocation.as_ref() };
522 assert_eq!(slice.len(), 64);
523
524 unsafe { chain_allocator.deallocate(new_allocation.cast(), new_layout) };
525 }
526
527 #[test]
528 fn should_grow_zeroed_allocation() {
529 let chain_allocator =
530 ChainAllocator::new(|| LinearAllocator::with_capacity(128 + size_of::<*const ()>()));
531
532 let old_layout = Layout::array::<u8>(32).unwrap();
533 let mut old_allocation = chain_allocator.allocate(old_layout).unwrap();
534
535 {
536 let slice = unsafe { old_allocation.as_mut() };
537 slice.fill(1);
538 assert_eq!(slice, [1; 32]);
539 }
540
541 let new_layout = Layout::array::<u8>(64).unwrap();
542 let new_allocation = unsafe {
543 chain_allocator
544 .grow_zeroed(old_allocation.cast(), old_layout, new_layout)
545 .unwrap()
546 };
547
548 let slice = unsafe { new_allocation.as_ref() };
549
550 assert_eq!(slice.len(), 64);
551 assert_eq!(slice[..32], [1; 32]);
552 assert_eq!(slice[32..], [0; 32]);
553
554 unsafe { chain_allocator.deallocate(new_allocation.cast(), new_layout) };
555 }
556
557 #[test]
558 fn should_shrink_allocation() {
559 let chain_allocator =
560 ChainAllocator::new(|| LinearAllocator::with_capacity(128 + size_of::<*const ()>()));
561
562 let old_layout = Layout::array::<u8>(64).unwrap();
563 let mut old_allocation = chain_allocator.allocate(old_layout).unwrap();
564
565 {
566 let slice = unsafe { old_allocation.as_mut() };
567 slice.fill(1);
568 assert_eq!(slice, [1; 64]);
569 }
570
571 let new_layout = Layout::array::<u8>(32).unwrap();
572 let new_allocation = unsafe {
573 chain_allocator
574 .shrink(old_allocation.cast(), old_layout, new_layout)
575 .unwrap()
576 };
577
578 let slice = unsafe { new_allocation.as_ref() };
579
580 assert_eq!(slice.len(), 32);
581 assert_eq!(slice, [1; 32]);
582
583 unsafe { chain_allocator.deallocate(new_allocation.cast(), new_layout) };
584 }
585}