use std::{alloc::Layout, ptr::NonNull};
use crate::{Allocator, StackArena};
#[derive(Debug)]
pub struct ObjectStack {
arena: StackArena,
partial: bool,
}
impl ObjectStack {
#[inline]
pub fn new() -> Self {
Self {
arena: StackArena::new(),
partial: false,
}
}
#[inline]
pub fn len(&self) -> usize {
self.arena.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.arena.is_empty()
}
#[inline]
pub fn push_bytes<P: AsRef<[u8]>>(&mut self, object: P) -> NonNull<[u8]> {
let data = object.as_ref();
let layout = Layout::for_value(data);
let ptr = unsafe { self.arena.allocate(layout).unwrap() };
unsafe {
ptr.cast().copy_from_nonoverlapping(
NonNull::new_unchecked(data.as_ptr().cast_mut()),
data.len(),
)
};
ptr
}
#[inline]
pub fn push_copy<T: ?Sized>(&mut self, object: &T) -> NonNull<T>
where
T: Copy,
{
let layout = Layout::for_value(object);
let ptr = unsafe { self.arena.allocate(layout).unwrap() }.cast();
unsafe {
ptr.copy_from_nonoverlapping(NonNull::new_unchecked(object as *const T as *mut T), 1);
}
ptr
}
#[inline]
pub fn push<T>(&mut self, object: T) -> NonNull<T>
where
T: Sized,
{
let layout = Layout::new::<T>();
let ptr = unsafe { self.arena.allocate(layout).unwrap() }.cast();
unsafe {
ptr.write(object); }
ptr
}
#[inline]
pub unsafe fn drop_in_place<T>(&mut self, ptr: NonNull<T>) {
unsafe { ptr.drop_in_place() };
}
#[inline]
pub fn pop(&mut self) {
self.arena.pop();
self.partial = false;
}
#[inline]
pub fn extend<P: AsRef<[u8]>>(&mut self, value: P) {
let data = value.as_ref();
if self.partial {
let partial_object = self.arena.top().unwrap();
let old_layout = Layout::for_value(unsafe { partial_object.as_ref() });
let new_layout = unsafe {
Layout::from_size_align_unchecked(
old_layout.size() + data.len(),
old_layout.align(),
)
};
let store = unsafe {
self.arena
.grow(partial_object.cast(), old_layout, new_layout)
}
.unwrap();
unsafe {
store
.cast::<u8>()
.add(old_layout.size())
.copy_from_nonoverlapping(
NonNull::new_unchecked(data.as_ptr().cast_mut()),
data.len(),
);
}
} else {
let store = unsafe { self.arena.allocate(Layout::for_value(data)) }.unwrap();
unsafe {
store.cast::<u8>().copy_from_nonoverlapping(
NonNull::new_unchecked(data.as_ptr().cast_mut()),
data.len(),
)
};
}
self.partial = true;
}
#[inline]
pub fn finish(&mut self) -> NonNull<[u8]> {
debug_assert!(self.partial);
self.partial = false;
self.arena.top().unwrap()
}
#[inline]
pub fn rollback(&mut self, data: &[u8]) {
let data = data.as_ref();
unsafe {
self.arena.deallocate(
NonNull::new_unchecked(data.as_ptr().cast_mut()),
Layout::for_value(data),
)
};
}
}
impl std::fmt::Write for ObjectStack {
#[inline]
fn write_str(&mut self, s: &str) -> std::fmt::Result {
self.extend(s);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt::Write;
#[test]
fn test_lifecycle() {
let mut stack = ObjectStack::new();
write!(&mut stack, "ab").expect("write");
let s = "c";
stack.extend(s);
let p = unsafe { stack.finish().as_ref() };
assert_eq!(p, b"abc");
}
#[test]
fn test_push_pop() {
let mut stack = ObjectStack::new();
stack.push_bytes(b"hello");
assert_eq!(stack.len(), 1);
assert!(!stack.is_empty());
stack.push_bytes(b"world");
assert_eq!(stack.len(), 2);
stack.pop();
assert_eq!(stack.len(), 1);
stack.pop();
assert_eq!(stack.len(), 0);
assert!(stack.is_empty());
}
#[test]
fn test_extend_small_data() {
let mut stack = ObjectStack::new();
stack.extend(b"hello");
stack.extend(b" world");
let data = unsafe { stack.finish().as_ref() };
assert_eq!(data, b"hello world");
}
#[test]
fn test_extend_large_data() {
let mut stack = ObjectStack::new();
let large_data = vec![b'x'; 20];
stack.extend(&large_data);
let data = unsafe { stack.finish().as_ref() };
assert_eq!(data, &large_data[..]);
}
#[test]
fn test_extend_after_finish() {
let mut stack = ObjectStack::new();
stack.extend("first");
let first = unsafe { stack.finish().as_ref() };
assert_eq!(first, b"first");
stack.extend(b"second");
let second = unsafe { stack.finish().as_ref() };
assert_eq!(second, b"second");
assert_eq!(first, b"first");
assert_eq!(second, b"second");
}
#[test]
fn test_free() {
let mut stack = ObjectStack::new();
stack.push_bytes(b"first");
stack.extend(b"second");
let second = unsafe { stack.finish().as_ref() };
stack.extend(b"third");
let _third = unsafe { stack.finish().as_ref() };
stack.rollback(second);
assert_eq!(stack.len(), 1);
stack.extend(b"fourth");
let fourth = unsafe { stack.finish().as_ref() };
assert_eq!(fourth, b"fourth");
assert_eq!(stack.len(), 2); }
#[test]
fn test_write_trait() {
let mut stack = ObjectStack::new();
write!(&mut stack, "Hello, {}!", "world").unwrap();
let data = unsafe { stack.finish().as_ref() };
assert_eq!(data, b"Hello, world!");
}
#[test]
fn test_empty_data() {
let mut stack = ObjectStack::new();
stack.extend(b"");
let data = unsafe { stack.finish().as_ref() };
assert_eq!(data, b"");
stack.push_bytes(b"");
assert_eq!(stack.len(), 2);
}
#[test]
fn test_multiple_operations() {
let mut stack = ObjectStack::new();
stack.push_bytes(b"item1");
stack.extend(b"item2-part1");
stack.extend(b"-part2");
let item2 = unsafe { stack.finish().as_ref() };
write!(&mut stack, "item3").unwrap();
let item3 = unsafe { stack.finish().as_ref() };
assert_eq!(item2, b"item2-part1-part2");
assert_eq!(item3, b"item3");
assert_eq!(stack.len(), 3);
stack.pop();
assert_eq!(stack.len(), 2);
}
#[test]
fn test_extend_exact_capacity() {
let mut stack = ObjectStack::new();
let data = vec![b'x'; 10]; stack.extend(&data);
stack.extend(b"more");
let result = unsafe { stack.finish().as_ref() };
let mut expected = data.clone();
expected.extend_from_slice(b"more");
assert_eq!(result, expected.as_slice());
}
#[test]
fn test_free_all() {
let mut stack = ObjectStack::new();
let first = stack.push_bytes(b"first");
stack.extend(b"second");
let _second = unsafe { stack.finish().as_ref() };
stack.rollback(unsafe { first.as_ref() });
assert_eq!(stack.len(), 0);
assert!(stack.is_empty());
}
#[test]
#[should_panic]
fn test_free_nonexistent() {
let mut stack = ObjectStack::new();
stack.push_bytes(b"object");
assert_eq!(stack.len(), 1);
let dummy = b"nonexistent";
stack.rollback(dummy);
assert_eq!(stack.len(), 1);
}
#[test]
fn test_cross_chunk_allocation_deallocation() {
let mut stack = ObjectStack {
arena: StackArena::with_chunk_size(8),
partial: false,
};
let small1 = stack.push_bytes("small1");
stack.push("small2");
assert_eq!(stack.len(), 2);
stack.pop();
assert_eq!(stack.len(), 1);
let large = "start- middle- this-is-a-longer-string-to-trigger-new-chunk";
for part in large.split(' ') {
stack.extend(part);
}
assert_eq!(stack.len(), 2);
let large = stack.finish();
assert_eq!(stack.len(), 2);
unsafe {
assert_eq!(large.as_ref(), large.as_ref());
}
assert_eq!(unsafe { small1.as_ref() }, b"small1");
stack.pop(); assert_eq!(stack.len(), 1);
stack.pop(); assert!(stack.is_empty());
stack.extend("single-extend-object");
let single = stack.finish();
assert_eq!(unsafe { single.as_ref() }, b"single-extend-object");
stack.pop();
assert_eq!(stack.len(), 0);
}
#[test]
fn test_push_copy() {
let mut stack = ObjectStack::new();
let int_val = 42;
let int_ptr = stack.push_copy(&int_val);
unsafe {
let retrieved = *int_ptr.as_ref();
assert_eq!(retrieved, 42);
}
let float_val = 3.14;
let float_ptr = stack.push_copy(&float_val);
unsafe {
let retrieved = *float_ptr.as_ref();
assert_eq!(retrieved, 3.14);
}
let arr = [1, 2, 3, 4, 5];
let arr_ptr = stack.push_copy(&arr);
unsafe {
let retrieved = *arr_ptr.as_ref();
assert_eq!(retrieved, [1, 2, 3, 4, 5]);
}
#[derive(Debug, Copy, Clone, PartialEq)]
struct Point {
x: i32,
y: i32,
}
let point = Point { x: 10, y: 20 };
let point_ptr = stack.push_copy(&point);
unsafe {
let retrieved = *point_ptr.as_ref();
assert_eq!(retrieved, point);
}
assert_eq!(stack.len(), 4);
}
#[test]
fn test_push_move() {
let mut stack = ObjectStack::new();
let string = String::from("hello world");
let string_len = string.len();
let string_ptr = stack.push(string);
unsafe {
let retrieved = string_ptr.as_ref();
assert_eq!(retrieved, "hello world");
assert_eq!(retrieved.len(), string_len);
stack.drop_in_place(string_ptr);
}
struct Person {
name: String,
age: u32,
}
let person = Person {
name: String::from("Alice"),
age: 30,
};
let person_ptr = stack.push(person);
unsafe {
let retrieved: &Person = person_ptr.as_ref();
assert_eq!(retrieved.name, "Alice");
assert_eq!(retrieved.age, 30);
stack.drop_in_place(person_ptr);
}
let vec = vec![1, 2, 3, 4, 5];
let vec_len = vec.len();
let vec_ptr = stack.push(vec);
unsafe {
let retrieved: &Vec<i32> = vec_ptr.as_ref();
assert_eq!(retrieved.len(), vec_len);
assert_eq!(retrieved, &vec![1, 2, 3, 4, 5]);
stack.drop_in_place(vec_ptr);
}
assert_eq!(stack.len(), 3);
}
#[test]
fn test_drop_in_place() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
struct DropDetector {
_data: String,
drop_flag: Arc<AtomicBool>,
}
impl Drop for DropDetector {
fn drop(&mut self) {
self.drop_flag.store(true, Ordering::SeqCst);
}
}
let mut stack = ObjectStack::new();
let drop_flag = Arc::new(AtomicBool::new(false));
let detector = DropDetector {
_data: "test data".to_string(),
drop_flag: Arc::clone(&drop_flag),
};
let ptr = stack.push(detector);
assert_eq!(drop_flag.load(Ordering::SeqCst), false);
unsafe {
stack.drop_in_place(ptr);
}
assert_eq!(drop_flag.load(Ordering::SeqCst), true);
assert_eq!(stack.len(), 1);
stack.pop();
assert_eq!(stack.len(), 0);
}
}