1#[cfg(test)]
2mod tests;
3
4use core::cell::{Cell, UnsafeCell};
5use core::fmt::{self, Arguments};
6use core::marker::PhantomData;
7use core::mem::{align_of, forget, replace, size_of, MaybeUninit};
8use core::num::NonZeroU8;
9use core::ops::{Deref, DerefMut};
10use core::ptr;
11use core::slice;
12
13use musli::buf::Error;
14use musli::{Allocator, Buf};
15
16use crate::DEFAULT_STACK_BUFFER;
17
18const ALIGNMENT: usize = 8;
20const HEADER_U32: u32 = size_of::<Header>() as u32;
22const MAX_BYTES: u32 = i32::MAX as u32;
25
26const _: () = {
27 if ALIGNMENT % align_of::<Header>() != 0 {
28 panic!("Header is not aligned by 8");
29 }
30};
31
32#[repr(align(8))]
36pub struct StackBuffer<const N: usize = DEFAULT_STACK_BUFFER> {
37 data: [MaybeUninit<u8>; N],
38}
39
40impl<const C: usize> StackBuffer<C> {
41 pub const fn new() -> Self {
43 Self {
44 data: unsafe { MaybeUninit::uninit().assume_init() },
47 }
48 }
49}
50
51impl<const C: usize> Default for StackBuffer<C> {
52 #[inline]
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl<const C: usize> Deref for StackBuffer<C> {
59 type Target = [MaybeUninit<u8>];
60
61 #[inline]
62 fn deref(&self) -> &Self::Target {
63 &self.data
64 }
65}
66
67impl<const C: usize> DerefMut for StackBuffer<C> {
68 #[inline]
69 fn deref_mut(&mut self) -> &mut Self::Target {
70 &mut self.data
71 }
72}
73
74pub struct Stack<'a> {
139 internal: UnsafeCell<Internal>,
144 _marker: PhantomData<&'a mut [MaybeUninit<u8>]>,
146}
147
148impl<'a> Stack<'a> {
149 pub fn new(buffer: &'a mut [MaybeUninit<u8>]) -> Self {
163 assert!(
164 buffer.len() <= MAX_BYTES as usize,
165 "Buffer too large 0-{}",
166 MAX_BYTES
167 );
168
169 assert!(
170 buffer.as_ptr() as usize % ALIGNMENT == 0,
171 "Provided buffer at {:08x} is not aligned by 8",
172 buffer.as_ptr() as usize
173 );
174
175 let size = buffer.len() as u32;
176
177 let size = size - size % (ALIGNMENT as u32);
179
180 Self {
181 internal: UnsafeCell::new(Internal {
182 free: None,
183 head: None,
184 tail: None,
185 bytes: 0,
186 headers: 0,
187 occupied: 0,
188 size,
189 data: buffer.as_mut_ptr(),
190 }),
191 _marker: PhantomData,
192 }
193 }
194}
195
196impl Allocator for Stack<'_> {
197 type Buf<'this> = StackBuf<'this> where Self: 'this;
198
199 #[inline(always)]
200 fn alloc(&self) -> Option<Self::Buf<'_>> {
201 let region = unsafe { (*self.internal.get()).alloc(0)? };
204
205 Some(StackBuf {
206 region: Cell::new(region.id),
207 internal: &self.internal,
208 })
209 }
210}
211
212pub struct StackBuf<'a> {
214 region: Cell<HeaderId>,
215 internal: &'a UnsafeCell<Internal>,
216}
217
218impl<'a> Buf for StackBuf<'a> {
219 #[inline]
220 fn write(&mut self, bytes: &[u8]) -> bool {
221 if bytes.is_empty() {
222 return true;
223 }
224
225 if bytes.len() > MAX_BYTES as usize {
226 return false;
227 }
228
229 let bytes_len = bytes.len() as u32;
230
231 unsafe {
234 let i = &mut *self.internal.get();
235
236 let region = i.region(self.region.get());
237 let len = region.len;
238
239 let mut region = 'out: {
241 if region.cap - len >= bytes_len {
243 break 'out region;
244 };
245
246 let requested = len + bytes_len;
247
248 let Some(region) = i.realloc(self.region.get(), len, requested) else {
249 return false;
250 };
251
252 self.region.set(region.id);
253 region
254 };
255
256 let dst = i.data.wrapping_add((region.start + len) as usize).cast();
257
258 ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len());
259 region.len += bytes.len() as u32;
260 true
261 }
262 }
263
264 #[inline]
265 fn write_buffer<B>(&mut self, buf: B) -> bool
266 where
267 B: Buf,
268 {
269 'out: {
270 let other_ptr = buf.as_slice().as_ptr().cast();
273
274 unsafe {
275 let i = &mut *self.internal.get();
276 let mut this = i.region(self.region.get());
277
278 debug_assert!(this.cap >= this.len);
279
280 let data_cap_ptr = this.data_cap_ptr(i.data);
281
282 if !ptr::eq(data_cap_ptr.cast_const(), other_ptr) {
287 break 'out;
288 }
289
290 let Some(next) = this.next else {
291 break 'out;
292 };
293
294 forget(buf);
297
298 let next = i.region(next);
299
300 let diff = this.cap - this.len;
301
302 if diff > 0 {
305 let to_ptr = data_cap_ptr.wrapping_sub(diff as usize);
306 ptr::copy(data_cap_ptr, to_ptr, next.len as usize);
307 }
308
309 let old = i.free_region(next);
310 this.cap += old.cap;
311 this.len += old.len;
312 return true;
313 }
314 }
315
316 self.write(buf.as_slice())
317 }
318
319 #[inline(always)]
320 fn len(&self) -> usize {
321 unsafe {
322 let i = &*self.internal.get();
323 i.header(self.region.get()).len as usize
324 }
325 }
326
327 #[inline(always)]
328 fn as_slice(&self) -> &[u8] {
329 unsafe {
330 let i = &*self.internal.get();
331 let this = i.header(self.region.get());
332 let ptr = i.data.wrapping_add(this.start as usize).cast();
333 slice::from_raw_parts(ptr, this.len as usize)
334 }
335 }
336
337 #[inline(always)]
338 fn write_fmt(&mut self, arguments: Arguments<'_>) -> Result<(), Error> {
339 fmt::write(self, arguments).map_err(|_| Error)
340 }
341}
342
343impl fmt::Write for StackBuf<'_> {
344 #[inline]
345 fn write_str(&mut self, s: &str) -> fmt::Result {
346 if !self.write(s.as_bytes()) {
347 return Err(fmt::Error);
348 }
349
350 Ok(())
351 }
352}
353
354impl Drop for StackBuf<'_> {
355 fn drop(&mut self) {
356 unsafe {
358 (*self.internal.get()).free(self.region.get());
359 }
360 }
361}
362
363struct Region {
364 id: HeaderId,
365 ptr: *mut Header,
366}
367
368impl Region {
369 #[inline]
370 unsafe fn data_cap_ptr(&self, data: *mut MaybeUninit<u8>) -> *mut MaybeUninit<u8> {
371 data.wrapping_add((self.start + self.cap) as usize)
372 }
373
374 #[inline]
375 unsafe fn data_base_ptr(&self, data: *mut MaybeUninit<u8>) -> *mut MaybeUninit<u8> {
376 data.wrapping_add(self.start as usize)
377 }
378}
379
380impl Deref for Region {
381 type Target = Header;
382
383 #[inline]
384 fn deref(&self) -> &Self::Target {
385 unsafe { &*self.ptr }
388 }
389}
390
391impl DerefMut for Region {
392 #[inline]
393 fn deref_mut(&mut self) -> &mut Self::Target {
394 unsafe { &mut *self.ptr }
397 }
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
402#[cfg_attr(test, derive(PartialOrd, Ord, Hash))]
403#[repr(transparent)]
404struct HeaderId(NonZeroU8);
405
406impl HeaderId {
407 #[inline]
413 const unsafe fn new_unchecked(value: u8) -> Self {
414 Self(NonZeroU8::new_unchecked(value))
415 }
416
417 #[inline]
419 fn get(self) -> u8 {
420 self.0.get()
421 }
422}
423
424struct Internal {
425 free: Option<HeaderId>,
427 head: Option<HeaderId>,
429 tail: Option<HeaderId>,
431 bytes: u32,
433 headers: u8,
435 occupied: u8,
437 size: u32,
439 data: *mut MaybeUninit<u8>,
446}
447
448impl Internal {
449 #[inline]
451 fn header(&self, at: HeaderId) -> &Header {
452 unsafe {
455 &*self
456 .data
457 .wrapping_add(self.region_to_addr(at))
458 .cast::<Header>()
459 }
460 }
461
462 #[inline]
464 fn header_mut(&mut self, at: HeaderId) -> *mut Header {
465 self.data
466 .wrapping_add(self.region_to_addr(at))
467 .cast::<Header>()
468 }
469
470 #[inline]
472 fn region(&mut self, id: HeaderId) -> Region {
473 Region {
474 id,
475 ptr: self.header_mut(id),
476 }
477 }
478
479 unsafe fn unlink(&mut self, header: &Header) {
480 if let Some(next) = header.next {
481 (*self.header_mut(next)).prev = header.prev;
482 } else {
483 self.tail = header.prev;
484 }
485
486 if let Some(prev) = header.prev {
487 (*self.header_mut(prev)).next = header.next;
488 } else {
489 self.head = header.next;
490 }
491 }
492
493 unsafe fn replace_back(&mut self, region: &mut Region) {
494 let prev = region.prev.take();
495 let next = region.next.take();
496
497 if let Some(prev) = prev {
498 (*self.header_mut(prev)).next = next;
499 }
500
501 if let Some(next) = next {
502 (*self.header_mut(next)).prev = prev;
503 }
504
505 if self.head == Some(region.id) {
506 self.head = next;
507 }
508
509 self.push_back(region);
510 }
511
512 unsafe fn push_back(&mut self, region: &mut Region) {
513 if self.head.is_none() {
514 self.head = Some(region.id);
515 }
516
517 if let Some(tail) = self.tail.replace(region.id) {
518 region.prev = Some(tail);
519 (*self.region(tail).ptr).next = Some(region.id);
520 }
521 }
522
523 unsafe fn free_region(&mut self, region: Region) -> Header {
525 let old = region.ptr.replace(Header {
526 start: 0,
527 len: 0,
528 cap: 0,
529 state: State::Free,
530 next_free: self.free.replace(region.id),
531 prev: None,
532 next: None,
533 });
534
535 self.unlink(&old);
536 old
537 }
538
539 unsafe fn alloc(&mut self, requested: u32) -> Option<Region> {
545 if self.occupied > 0 {
546 if let Some(mut region) =
547 self.find_region(|h| h.state == State::Occupy && h.cap >= requested)
548 {
549 self.occupied -= 1;
550 region.state = State::Used;
551 return Some(region);
552 }
553 }
554
555 let mut region = 'out: {
556 if let Some(mut region) = self.pop_free() {
557 let bytes = self.bytes + requested;
558
559 if bytes > self.size {
560 return None;
561 }
562
563 region.start = self.bytes;
564 region.state = State::Used;
565 region.cap = requested;
566
567 self.bytes = bytes;
568 break 'out region;
569 }
570
571 let bytes = self.bytes + requested;
572 let headers = self.headers.checked_add(1)?;
573 let size = self.size.checked_sub(HEADER_U32)?;
574
575 if bytes > size {
576 return None;
577 }
578
579 let start = replace(&mut self.bytes, bytes);
580 self.headers = headers;
581 self.size = size;
582
583 let region = self.region(HeaderId::new_unchecked(headers));
584
585 region.ptr.write(Header {
587 start,
588 len: 0,
589 cap: requested,
590 state: State::Used,
591 next_free: None,
592 prev: None,
593 next: None,
594 });
595
596 region
597 };
598
599 self.push_back(&mut region);
600 Some(region)
601 }
602
603 unsafe fn free(&mut self, region: HeaderId) {
604 let mut region = self.region(region);
605
606 debug_assert_eq!(region.state, State::Used);
607 debug_assert_eq!(region.next_free, None);
608
609 if region.next.is_none() {
611 self.free_tail(region);
612 return;
613 }
614
615 let Some(prev) = region.prev else {
617 self.occupied += 1;
618 region.state = State::Occupy;
619 region.len = 0;
620 return;
621 };
622
623 let mut prev = self.region(prev);
624 debug_assert!(matches!(prev.state, State::Occupy | State::Used));
625
626 let region = self.free_region(region);
628
629 prev.cap += region.cap;
630
631 if region.next.is_none() {
633 self.bytes = region.start;
634 }
635 }
636
637 unsafe fn free_tail(&mut self, current: Region) {
639 debug_assert_eq!(self.tail, Some(current.id));
640
641 let current = self.free_region(current);
642 debug_assert_eq!(current.next, None);
643 self.bytes -= current.cap;
644
645 let Some(prev) = current.prev else {
646 return;
647 };
648
649 let prev = self.region(prev);
650
651 if prev.state == State::Occupy {
653 let prev = self.free_region(prev);
654 self.bytes -= prev.cap;
655 self.occupied -= 1;
656 }
657 }
658
659 unsafe fn realloc(&mut self, from: HeaderId, len: u32, requested: u32) -> Option<Region> {
660 let mut from = self.region(from);
661
662 if from.next.is_none() {
664 let additional = requested - from.cap;
665
666 if self.bytes + additional > self.size {
667 return None;
668 }
669
670 from.cap += additional;
671 self.bytes += additional;
672 return Some(from);
673 }
674
675 'bail: {
678 let Some(prev) = from.prev else {
680 break 'bail;
681 };
682
683 let mut prev = self.region(prev);
684
685 if prev.state != State::Occupy || prev.cap + len < requested {
686 break 'bail;
687 }
688
689 let prev_ptr = prev.data_base_ptr(self.data);
690 let from_ptr = from.data_base_ptr(self.data);
691
692 let from = self.free_region(from);
693
694 ptr::copy(from_ptr, prev_ptr, from.len as usize);
695
696 prev.state = State::Used;
697 prev.cap += from.cap;
698 prev.len = from.len;
699 return Some(prev);
700 }
701
702 if from.cap == 0 {
705 let bytes = self.bytes + requested;
706
707 if bytes > self.size {
708 return None;
709 }
710
711 from.start = self.bytes;
712 from.cap = requested;
713
714 self.replace_back(&mut from);
715 self.bytes = bytes;
716 return Some(from);
717 }
718
719 let mut to = self.alloc(requested)?;
720
721 let from_data = self
722 .data
723 .wrapping_add(from.start as usize)
724 .cast::<u8>()
725 .cast_const();
726
727 let to_data = self.data.wrapping_add(to.start as usize).cast::<u8>();
728
729 ptr::copy_nonoverlapping(from_data, to_data, len as usize);
730 to.len = len;
731 self.free(from.id);
732 Some(to)
733 }
734
735 unsafe fn find_region<T>(&mut self, mut condition: T) -> Option<Region>
736 where
737 T: FnMut(&Header) -> bool,
738 {
739 let mut next = self.head;
740
741 while let Some(id) = next {
742 let ptr = self.header_mut(id);
743
744 if condition(&*ptr) {
745 return Some(Region { id, ptr });
746 }
747
748 next = (*ptr).next;
749 }
750
751 None
752 }
753
754 unsafe fn pop_free(&mut self) -> Option<Region> {
755 let id = self.free.take()?;
756 let ptr = self.header_mut(id);
757 self.free = (*ptr).next_free.take();
758 Some(Region { id, ptr })
759 }
760
761 #[inline]
762 fn region_to_addr(&self, at: HeaderId) -> usize {
763 region_to_addr(self.size, self.headers, at)
764 }
765}
766
767#[inline]
768fn region_to_addr(size: u32, headers: u8, at: HeaderId) -> usize {
769 (size + u32::from(headers - at.get()) * HEADER_U32) as usize
770}
771
772#[derive(Debug, Clone, Copy, PartialEq, Eq)]
774#[repr(u8)]
775enum State {
776 Free = 0,
784 Occupy,
792 Used,
794}
795
796#[derive(Debug, Clone, Copy, PartialEq, Eq)]
798#[repr(align(8))]
799struct Header {
800 start: u32,
802 len: u32,
804 cap: u32,
806 state: State,
808 next_free: Option<HeaderId>,
810 prev: Option<HeaderId>,
812 next: Option<HeaderId>,
814}