musli_allocator/
stack.rs

1#[cfg(test)]
2mod tests;
3
4use core::cell::{Cell, UnsafeCell};
5use core::fmt::{self, Arguments};
6use core::marker::PhantomData;
7use core::mem::{align_of, forget, replace, size_of, MaybeUninit};
8use core::num::NonZeroU8;
9use core::ops::{Deref, DerefMut};
10use core::ptr;
11use core::slice;
12
13use musli::buf::Error;
14use musli::{Allocator, Buf};
15
16use crate::DEFAULT_STACK_BUFFER;
17
18/// Required alignment.
19const ALIGNMENT: usize = 8;
20/// The size of a header.
21const HEADER_U32: u32 = size_of::<Header>() as u32;
22// We keep max bytes to 2^31, since that ensures that addition between two
23// magnitutes never overflow.
24const MAX_BYTES: u32 = i32::MAX as u32;
25
26const _: () = {
27    if ALIGNMENT % align_of::<Header>() != 0 {
28        panic!("Header is not aligned by 8");
29    }
30};
31
32/// A buffer that can be used to store data on the stack.
33///
34/// See the [module level documentation][super] for more information.
35#[repr(align(8))]
36pub struct StackBuffer<const N: usize = DEFAULT_STACK_BUFFER> {
37    data: [MaybeUninit<u8>; N],
38}
39
40impl<const C: usize> StackBuffer<C> {
41    /// Construct a new buffer.
42    pub const fn new() -> Self {
43        Self {
44            // SAFETY: This is safe to initialize, since it's just an array of
45            // contiguous uninitialized memory.
46            data: unsafe { MaybeUninit::uninit().assume_init() },
47        }
48    }
49}
50
51impl<const C: usize> Default for StackBuffer<C> {
52    #[inline]
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl<const C: usize> Deref for StackBuffer<C> {
59    type Target = [MaybeUninit<u8>];
60
61    #[inline]
62    fn deref(&self) -> &Self::Target {
63        &self.data
64    }
65}
66
67impl<const C: usize> DerefMut for StackBuffer<C> {
68    #[inline]
69    fn deref_mut(&mut self) -> &mut Self::Target {
70        &mut self.data
71    }
72}
73
74/// A no-std compatible fixed-memory allocator that can be used with the `musli`
75/// crate.
76///
77/// It is geared towards handling few allocations, but they can be arbitrarily
78/// large. It is optimized to work best when allocations are short lived and
79/// "merged back" into one previously allocated region through
80/// `Buffer::write_buffer`.
81///
82/// It's also optimized to write to one allocation "at a time". So once an
83/// allocation has been grown once, it will be put in a region where it is
84/// unlikely to need to be moved again, usually the last region which has access
85/// to the remainder of the provided buffer.
86///
87/// For the moment, this allocator only supports 255 unique allocations, which
88/// is fine for use with the `musli` crate, but might be a limitation for other
89/// use-cases.
90///
91/// # Design
92///
93/// The allocator takes a buffer of contiguous memory. This is dynamically
94/// diviced into two parts:
95///
96/// * One part which grows upwards from the base, constituting the memory being
97///   allocated.
98/// * Its metadata growing downward from the end of the buffer, containing
99///   headers for all allocated region.
100///
101/// By designing the allocator so that the memory allocated and its metadata is
102/// separate, neighbouring regions can efficiently be merged as they are written
103/// or freed.
104///
105/// Each allocation is sparse, meaning it does not try to over-allocate memory.
106/// This ensures that subsequent regions with initialized memory can be merged
107/// efficiently, but degrades performance for many small writes performed across
108/// multiple allocations concurrently.
109///
110/// Below is an illustration of this, where `a` and `b` are two allocations
111/// where we write one byte at a time to each. Here `x` below indicates an
112/// occupied `gap` in memory regions.
113///
114/// ```text
115/// a
116/// ab
117/// # a moved to end
118/// xbaa
119/// # b moved to 0
120/// bbaa
121/// # aa not moved
122/// bbaaa
123/// # bb moved to end
124/// xxaaabbb
125/// # aaa moved to 0
126/// aaaaxbbb
127/// # bbb not moved
128/// aaaaxbbbb
129/// # aaaa not moved
130/// aaaaabbbb
131/// # bbbbb not moved
132/// aaaaabbbbb
133/// # aaaaa moved to end
134/// xxxxxbbbbbaaaaaa
135/// # bbbbb moved to 0
136/// bbbbbbxxxxaaaaaa
137/// ```
138pub struct Stack<'a> {
139    // This must be an unsafe cell, since it's mutably accessed through an
140    // immutable pointers. We simply make sure that those accesses do not
141    // clobber each other, which we can do since the API is restricted through
142    // the `Buf` trait.
143    internal: UnsafeCell<Internal>,
144    // The underlying vector being borrowed.
145    _marker: PhantomData<&'a mut [MaybeUninit<u8>]>,
146}
147
148impl<'a> Stack<'a> {
149    /// Build a new no-std allocator.
150    ///
151    /// The buffer must be aligned by 8 bytes, and should be a multiple of 8 bytes.
152    ///
153    /// See [type-level documentation][Stack] for more information.
154    ///
155    /// # Panics
156    ///
157    /// This method panics if called with a buffer larger than 2**31 or is
158    /// provided a buffer which is not aligned by 8.
159    ///
160    /// An easy way to align a buffer is to use [`StackBuffer`] when
161    /// constructing it.
162    pub fn new(buffer: &'a mut [MaybeUninit<u8>]) -> Self {
163        assert!(
164            buffer.len() <= MAX_BYTES as usize,
165            "Buffer too large 0-{}",
166            MAX_BYTES
167        );
168
169        assert!(
170            buffer.as_ptr() as usize % ALIGNMENT == 0,
171            "Provided buffer at {:08x} is not aligned by 8",
172            buffer.as_ptr() as usize
173        );
174
175        let size = buffer.len() as u32;
176
177        // Ensure the buffer is aligned for headers.
178        let size = size - size % (ALIGNMENT as u32);
179
180        Self {
181            internal: UnsafeCell::new(Internal {
182                free: None,
183                head: None,
184                tail: None,
185                bytes: 0,
186                headers: 0,
187                occupied: 0,
188                size,
189                data: buffer.as_mut_ptr(),
190            }),
191            _marker: PhantomData,
192        }
193    }
194}
195
196impl Allocator for Stack<'_> {
197    type Buf<'this> = StackBuf<'this> where Self: 'this;
198
199    #[inline(always)]
200    fn alloc(&self) -> Option<Self::Buf<'_>> {
201        // SAFETY: We have exclusive access to the internal state, and it's only
202        // held for the duration of this call.
203        let region = unsafe { (*self.internal.get()).alloc(0)? };
204
205        Some(StackBuf {
206            region: Cell::new(region.id),
207            internal: &self.internal,
208        })
209    }
210}
211
212/// A no-std allocated buffer.
213pub struct StackBuf<'a> {
214    region: Cell<HeaderId>,
215    internal: &'a UnsafeCell<Internal>,
216}
217
218impl<'a> Buf for StackBuf<'a> {
219    #[inline]
220    fn write(&mut self, bytes: &[u8]) -> bool {
221        if bytes.is_empty() {
222            return true;
223        }
224
225        if bytes.len() > MAX_BYTES as usize {
226            return false;
227        }
228
229        let bytes_len = bytes.len() as u32;
230
231        // SAFETY: Due to invariants in the Buffer trait we know that these
232        // cannot be used incorrectly.
233        unsafe {
234            let i = &mut *self.internal.get();
235
236            let region = i.region(self.region.get());
237            let len = region.len;
238
239            // Region can fit the bytes available.
240            let mut region = 'out: {
241                // Region can already fit in the requested bytes.
242                if region.cap - len >= bytes_len {
243                    break 'out region;
244                };
245
246                let requested = len + bytes_len;
247
248                let Some(region) = i.realloc(self.region.get(), len, requested) else {
249                    return false;
250                };
251
252                self.region.set(region.id);
253                region
254            };
255
256            let dst = i.data.wrapping_add((region.start + len) as usize).cast();
257
258            ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len());
259            region.len += bytes.len() as u32;
260            true
261        }
262    }
263
264    #[inline]
265    fn write_buffer<B>(&mut self, buf: B) -> bool
266    where
267        B: Buf,
268    {
269        'out: {
270            // NB: Placing this here to make miri happy, since accessing the
271            // slice will mean mutably accessing the internal state.
272            let other_ptr = buf.as_slice().as_ptr().cast();
273
274            unsafe {
275                let i = &mut *self.internal.get();
276                let mut this = i.region(self.region.get());
277
278                debug_assert!(this.cap >= this.len);
279
280                let data_cap_ptr = this.data_cap_ptr(i.data);
281
282                // If this region immediately follows the other region, we can
283                // optimize the write by simply growing the current region and
284                // de-allocating the second since the only conclusion is that
285                // they share the same allocator.
286                if !ptr::eq(data_cap_ptr.cast_const(), other_ptr) {
287                    break 'out;
288                }
289
290                let Some(next) = this.next else {
291                    break 'out;
292                };
293
294                // Prevent the other buffer from being dropped, since we're
295                // taking care of the allocation in here directly instead.
296                forget(buf);
297
298                let next = i.region(next);
299
300                let diff = this.cap - this.len;
301
302                // Data needs to be shuffle back to the end of the initialized
303                // region.
304                if diff > 0 {
305                    let to_ptr = data_cap_ptr.wrapping_sub(diff as usize);
306                    ptr::copy(data_cap_ptr, to_ptr, next.len as usize);
307                }
308
309                let old = i.free_region(next);
310                this.cap += old.cap;
311                this.len += old.len;
312                return true;
313            }
314        }
315
316        self.write(buf.as_slice())
317    }
318
319    #[inline(always)]
320    fn len(&self) -> usize {
321        unsafe {
322            let i = &*self.internal.get();
323            i.header(self.region.get()).len as usize
324        }
325    }
326
327    #[inline(always)]
328    fn as_slice(&self) -> &[u8] {
329        unsafe {
330            let i = &*self.internal.get();
331            let this = i.header(self.region.get());
332            let ptr = i.data.wrapping_add(this.start as usize).cast();
333            slice::from_raw_parts(ptr, this.len as usize)
334        }
335    }
336
337    #[inline(always)]
338    fn write_fmt(&mut self, arguments: Arguments<'_>) -> Result<(), Error> {
339        fmt::write(self, arguments).map_err(|_| Error)
340    }
341}
342
343impl fmt::Write for StackBuf<'_> {
344    #[inline]
345    fn write_str(&mut self, s: &str) -> fmt::Result {
346        if !self.write(s.as_bytes()) {
347            return Err(fmt::Error);
348        }
349
350        Ok(())
351    }
352}
353
354impl Drop for StackBuf<'_> {
355    fn drop(&mut self) {
356        // SAFETY: We have exclusive access to the internal state.
357        unsafe {
358            (*self.internal.get()).free(self.region.get());
359        }
360    }
361}
362
363struct Region {
364    id: HeaderId,
365    ptr: *mut Header,
366}
367
368impl Region {
369    #[inline]
370    unsafe fn data_cap_ptr(&self, data: *mut MaybeUninit<u8>) -> *mut MaybeUninit<u8> {
371        data.wrapping_add((self.start + self.cap) as usize)
372    }
373
374    #[inline]
375    unsafe fn data_base_ptr(&self, data: *mut MaybeUninit<u8>) -> *mut MaybeUninit<u8> {
376        data.wrapping_add(self.start as usize)
377    }
378}
379
380impl Deref for Region {
381    type Target = Header;
382
383    #[inline]
384    fn deref(&self) -> &Self::Target {
385        // SAFETY: Construction of the region is unsafe, so the caller must
386        // ensure that it's used correctly after that.
387        unsafe { &*self.ptr }
388    }
389}
390
391impl DerefMut for Region {
392    #[inline]
393    fn deref_mut(&mut self) -> &mut Self::Target {
394        // SAFETY: Construction of the region is unsafe, so the caller must
395        // ensure that it's used correctly after that.
396        unsafe { &mut *self.ptr }
397    }
398}
399
400/// The identifier of a region.
401#[derive(Debug, Clone, Copy, PartialEq, Eq)]
402#[cfg_attr(test, derive(PartialOrd, Ord, Hash))]
403#[repr(transparent)]
404struct HeaderId(NonZeroU8);
405
406impl HeaderId {
407    /// Create a new region identifier.
408    ///
409    /// # Safety
410    ///
411    /// The given value must be non-zero.
412    #[inline]
413    const unsafe fn new_unchecked(value: u8) -> Self {
414        Self(NonZeroU8::new_unchecked(value))
415    }
416
417    /// Get the value of the region identifier.
418    #[inline]
419    fn get(self) -> u8 {
420        self.0.get()
421    }
422}
423
424struct Internal {
425    // The first free region.
426    free: Option<HeaderId>,
427    // Pointer to the head region.
428    head: Option<HeaderId>,
429    // Pointer to the tail region.
430    tail: Option<HeaderId>,
431    // Size of allocation in the bytes region.
432    bytes: u32,
433    // The number of headers in use.
434    headers: u8,
435    /// The number of occupied regions.
436    occupied: u8,
437    /// The size of the buffer being wrapped.
438    size: u32,
439    // The slab of regions and allocations.
440    //
441    // Allocated memory grows from the bottom upwards, because this allows
442    // copying writes to be optimized.
443    //
444    // Region metadata is written to the end growing downwards.
445    data: *mut MaybeUninit<u8>,
446}
447
448impl Internal {
449    /// Get the header pointer corresponding to the given id.
450    #[inline]
451    fn header(&self, at: HeaderId) -> &Header {
452        // SAFETY: Once we've coerced to `&self`, then we guarantee that we can
453        // get a header immutably.
454        unsafe {
455            &*self
456                .data
457                .wrapping_add(self.region_to_addr(at))
458                .cast::<Header>()
459        }
460    }
461
462    /// Get the mutable header pointer corresponding to the given id.
463    #[inline]
464    fn header_mut(&mut self, at: HeaderId) -> *mut Header {
465        self.data
466            .wrapping_add(self.region_to_addr(at))
467            .cast::<Header>()
468    }
469
470    /// Get the mutable region corresponding to the given id.
471    #[inline]
472    fn region(&mut self, id: HeaderId) -> Region {
473        Region {
474            id,
475            ptr: self.header_mut(id),
476        }
477    }
478
479    unsafe fn unlink(&mut self, header: &Header) {
480        if let Some(next) = header.next {
481            (*self.header_mut(next)).prev = header.prev;
482        } else {
483            self.tail = header.prev;
484        }
485
486        if let Some(prev) = header.prev {
487            (*self.header_mut(prev)).next = header.next;
488        } else {
489            self.head = header.next;
490        }
491    }
492
493    unsafe fn replace_back(&mut self, region: &mut Region) {
494        let prev = region.prev.take();
495        let next = region.next.take();
496
497        if let Some(prev) = prev {
498            (*self.header_mut(prev)).next = next;
499        }
500
501        if let Some(next) = next {
502            (*self.header_mut(next)).prev = prev;
503        }
504
505        if self.head == Some(region.id) {
506            self.head = next;
507        }
508
509        self.push_back(region);
510    }
511
512    unsafe fn push_back(&mut self, region: &mut Region) {
513        if self.head.is_none() {
514            self.head = Some(region.id);
515        }
516
517        if let Some(tail) = self.tail.replace(region.id) {
518            region.prev = Some(tail);
519            (*self.region(tail).ptr).next = Some(region.id);
520        }
521    }
522
523    /// Free a region.
524    unsafe fn free_region(&mut self, region: Region) -> Header {
525        let old = region.ptr.replace(Header {
526            start: 0,
527            len: 0,
528            cap: 0,
529            state: State::Free,
530            next_free: self.free.replace(region.id),
531            prev: None,
532            next: None,
533        });
534
535        self.unlink(&old);
536        old
537    }
538
539    /// Allocate a region.
540    ///
541    /// # Safety
542    ///
543    /// The caller must ensure that `this` is exclusively available.
544    unsafe fn alloc(&mut self, requested: u32) -> Option<Region> {
545        if self.occupied > 0 {
546            if let Some(mut region) =
547                self.find_region(|h| h.state == State::Occupy && h.cap >= requested)
548            {
549                self.occupied -= 1;
550                region.state = State::Used;
551                return Some(region);
552            }
553        }
554
555        let mut region = 'out: {
556            if let Some(mut region) = self.pop_free() {
557                let bytes = self.bytes + requested;
558
559                if bytes > self.size {
560                    return None;
561                }
562
563                region.start = self.bytes;
564                region.state = State::Used;
565                region.cap = requested;
566
567                self.bytes = bytes;
568                break 'out region;
569            }
570
571            let bytes = self.bytes + requested;
572            let headers = self.headers.checked_add(1)?;
573            let size = self.size.checked_sub(HEADER_U32)?;
574
575            if bytes > size {
576                return None;
577            }
578
579            let start = replace(&mut self.bytes, bytes);
580            self.headers = headers;
581            self.size = size;
582
583            let region = self.region(HeaderId::new_unchecked(headers));
584
585            // We need to write a full header, since we're allocating a new one.
586            region.ptr.write(Header {
587                start,
588                len: 0,
589                cap: requested,
590                state: State::Used,
591                next_free: None,
592                prev: None,
593                next: None,
594            });
595
596            region
597        };
598
599        self.push_back(&mut region);
600        Some(region)
601    }
602
603    unsafe fn free(&mut self, region: HeaderId) {
604        let mut region = self.region(region);
605
606        debug_assert_eq!(region.state, State::Used);
607        debug_assert_eq!(region.next_free, None);
608
609        // Just free up the last region in the slab.
610        if region.next.is_none() {
611            self.free_tail(region);
612            return;
613        }
614
615        // If there is no previous region, then mark this region as occupy.
616        let Some(prev) = region.prev else {
617            self.occupied += 1;
618            region.state = State::Occupy;
619            region.len = 0;
620            return;
621        };
622
623        let mut prev = self.region(prev);
624        debug_assert!(matches!(prev.state, State::Occupy | State::Used));
625
626        // Move allocation to the previous region.
627        let region = self.free_region(region);
628
629        prev.cap += region.cap;
630
631        // The current header being freed is the last in the list.
632        if region.next.is_none() {
633            self.bytes = region.start;
634        }
635    }
636
637    /// Free the tail starting at the `current` region.
638    unsafe fn free_tail(&mut self, current: Region) {
639        debug_assert_eq!(self.tail, Some(current.id));
640
641        let current = self.free_region(current);
642        debug_assert_eq!(current.next, None);
643        self.bytes -= current.cap;
644
645        let Some(prev) = current.prev else {
646            return;
647        };
648
649        let prev = self.region(prev);
650
651        // The prior region is occupied, so we can free that as well.
652        if prev.state == State::Occupy {
653            let prev = self.free_region(prev);
654            self.bytes -= prev.cap;
655            self.occupied -= 1;
656        }
657    }
658
659    unsafe fn realloc(&mut self, from: HeaderId, len: u32, requested: u32) -> Option<Region> {
660        let mut from = self.region(from);
661
662        // This is the last region in the slab, so we can just expand it.
663        if from.next.is_none() {
664            let additional = requested - from.cap;
665
666            if self.bytes + additional > self.size {
667                return None;
668            }
669
670            from.cap += additional;
671            self.bytes += additional;
672            return Some(from);
673        }
674
675        // Try to merge with a preceeding region, if the requested memory can
676        // fit in it.
677        'bail: {
678            // Check if the immediate prior region can fit the requested allocation.
679            let Some(prev) = from.prev else {
680                break 'bail;
681            };
682
683            let mut prev = self.region(prev);
684
685            if prev.state != State::Occupy || prev.cap + len < requested {
686                break 'bail;
687            }
688
689            let prev_ptr = prev.data_base_ptr(self.data);
690            let from_ptr = from.data_base_ptr(self.data);
691
692            let from = self.free_region(from);
693
694            ptr::copy(from_ptr, prev_ptr, from.len as usize);
695
696            prev.state = State::Used;
697            prev.cap += from.cap;
698            prev.len = from.len;
699            return Some(prev);
700        }
701
702        // There is no data allocated in the current region, so we can simply
703        // re-link it to the end of the chain of allocation.
704        if from.cap == 0 {
705            let bytes = self.bytes + requested;
706
707            if bytes > self.size {
708                return None;
709            }
710
711            from.start = self.bytes;
712            from.cap = requested;
713
714            self.replace_back(&mut from);
715            self.bytes = bytes;
716            return Some(from);
717        }
718
719        let mut to = self.alloc(requested)?;
720
721        let from_data = self
722            .data
723            .wrapping_add(from.start as usize)
724            .cast::<u8>()
725            .cast_const();
726
727        let to_data = self.data.wrapping_add(to.start as usize).cast::<u8>();
728
729        ptr::copy_nonoverlapping(from_data, to_data, len as usize);
730        to.len = len;
731        self.free(from.id);
732        Some(to)
733    }
734
735    unsafe fn find_region<T>(&mut self, mut condition: T) -> Option<Region>
736    where
737        T: FnMut(&Header) -> bool,
738    {
739        let mut next = self.head;
740
741        while let Some(id) = next {
742            let ptr = self.header_mut(id);
743
744            if condition(&*ptr) {
745                return Some(Region { id, ptr });
746            }
747
748            next = (*ptr).next;
749        }
750
751        None
752    }
753
754    unsafe fn pop_free(&mut self) -> Option<Region> {
755        let id = self.free.take()?;
756        let ptr = self.header_mut(id);
757        self.free = (*ptr).next_free.take();
758        Some(Region { id, ptr })
759    }
760
761    #[inline]
762    fn region_to_addr(&self, at: HeaderId) -> usize {
763        region_to_addr(self.size, self.headers, at)
764    }
765}
766
767#[inline]
768fn region_to_addr(size: u32, headers: u8, at: HeaderId) -> usize {
769    (size + u32::from(headers - at.get()) * HEADER_U32) as usize
770}
771
772/// The state of an allocated region.
773#[derive(Debug, Clone, Copy, PartialEq, Eq)]
774#[repr(u8)]
775enum State {
776    /// The region is fully free and doesn't occupy any memory.
777    ///
778    /// # Requirements
779    ///
780    /// - The range must be zero-sized at offset 0.
781    /// - The region must not be linked.
782    /// - The region must be in the free list.
783    Free = 0,
784    /// The region is occupied.
785    ///
786    /// # Requirements
787    ///
788    /// - The range must point to a non-zero slice of memory.,
789    /// - The region must be linked.
790    /// - The region must be in the occupied list.
791    Occupy,
792    /// The region is used by an active allocation.
793    Used,
794}
795
796/// The header of a region.
797#[derive(Debug, Clone, Copy, PartialEq, Eq)]
798#[repr(align(8))]
799struct Header {
800    // Start of the allocated region as a multiple of 8.
801    start: u32,
802    // The length of the region.
803    len: u32,
804    // The capacity of the region.
805    cap: u32,
806    // The state of the region.
807    state: State,
808    // Link to the next free region.
809    next_free: Option<HeaderId>,
810    // The previous neighbouring region.
811    prev: Option<HeaderId>,
812    // The next neighbouring region.
813    next: Option<HeaderId>,
814}