use crate::memory::MemoryError;
use alloc::{slice, vec::Vec};
use core::{iter, mem::ManuallyDrop};
#[derive(Debug)]
pub struct ByteBuffer {
pub(super) ptr: *mut u8,
pub(super) len: usize,
capacity: usize,
is_static: bool,
}
unsafe impl Send for ByteBuffer {}
unsafe impl Sync for ByteBuffer {}
fn vec_into_raw_parts(vec: Vec<u8>) -> (*mut u8, usize, usize) {
let mut vec = ManuallyDrop::new(vec);
(vec.as_mut_ptr(), vec.len(), vec.capacity())
}
impl ByteBuffer {
pub fn new(size: usize) -> Result<Self, MemoryError> {
let mut vec = Vec::new();
if vec.try_reserve(size).is_err() {
return Err(MemoryError::OutOfSystemMemory);
};
vec.extend(iter::repeat_n(0x00_u8, size));
let (ptr, len, capacity) = vec_into_raw_parts(vec);
Ok(Self {
ptr,
len,
capacity,
is_static: false,
})
}
pub fn new_static(buffer: &'static mut [u8], size: usize) -> Result<Self, MemoryError> {
let Some(bytes) = buffer.get_mut(..size) else {
return Err(MemoryError::InvalidStaticBufferSize);
};
bytes.fill(0x00_u8);
Ok(Self {
ptr: buffer.as_mut_ptr(),
len: size,
capacity: buffer.len(),
is_static: true,
})
}
pub fn grow(&mut self, new_size: usize) -> Result<(), MemoryError> {
assert!(self.len() <= new_size);
match self.get_vec() {
Some(vec) => self.grow_vec(vec, new_size),
None => self.grow_static(new_size),
}
}
fn grow_vec(
&mut self,
mut vec: ManuallyDrop<Vec<u8>>,
new_size: usize,
) -> Result<(), MemoryError> {
debug_assert!(vec.len() <= new_size);
let additional = new_size - vec.len();
if vec.try_reserve(additional).is_err() {
return Err(MemoryError::OutOfSystemMemory);
};
vec.resize(new_size, 0x00_u8);
(self.ptr, self.len, self.capacity) = vec_into_raw_parts(ManuallyDrop::into_inner(vec));
Ok(())
}
fn grow_static(&mut self, new_size: usize) -> Result<(), MemoryError> {
if self.capacity < new_size {
return Err(MemoryError::InvalidStaticBufferSize);
}
let len = self.len();
self.len = new_size;
self.data_mut()[len..new_size].fill(0x00_u8);
Ok(())
}
pub fn len(&self) -> usize {
self.len
}
pub fn data(&self) -> &[u8] {
unsafe { slice::from_raw_parts(self.ptr, self.len) }
}
pub fn data_mut(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }
}
fn get_vec(&mut self) -> Option<ManuallyDrop<Vec<u8>>> {
if self.is_static {
return None;
}
let vec = unsafe { Vec::from_raw_parts(self.ptr, self.len, self.capacity) };
Some(ManuallyDrop::new(vec))
}
}
impl Drop for ByteBuffer {
fn drop(&mut self) {
self.get_vec().map(ManuallyDrop::into_inner);
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_basic_allocation_deallocation() {
let buffer = ByteBuffer::new(10).unwrap();
assert_eq!(buffer.len(), 10);
}
#[test]
fn test_basic_data_manipulation() {
let mut buffer = ByteBuffer::new(10).unwrap();
assert_eq!(buffer.len(), 10);
let data = buffer.data(); assert_eq!(data, &[0; 10]);
let data = buffer.data_mut(); data[4] = 4; let data = buffer.data(); assert_eq!(data, &[0, 0, 0, 0, 4, 0, 0, 0, 0, 0]); }
#[test]
fn test_static_buffer_initialization() {
static mut BUF: [u8; 10] = [7; 10];
let buf = unsafe { &mut *core::ptr::addr_of_mut!(BUF) };
let mut buffer = ByteBuffer::new_static(buf, 5).unwrap();
assert_eq!(buffer.len(), 5);
let data = buffer.data_mut();
data[0] = 1;
unsafe {
assert_eq!(BUF[0], 1);
}
}
#[test]
fn test_growing_buffer() {
let mut buffer = ByteBuffer::new(5).unwrap();
buffer.grow(10).unwrap();
assert_eq!(buffer.len(), 10);
assert_eq!(buffer.data(), &[0; 10]);
}
#[test]
fn test_growing_static() {
static mut BUF: [u8; 10] = [7; 10];
let buf = unsafe { &mut *core::ptr::addr_of_mut!(BUF) };
let mut buffer = ByteBuffer::new_static(buf, 5).unwrap();
assert_eq!(buffer.len(), 5);
assert_eq!(buffer.data(), &[0; 5]);
buffer.grow(8).unwrap();
assert_eq!(buffer.len(), 8);
assert_eq!(buffer.data(), &[0; 8]);
buffer.grow(10).unwrap();
assert_eq!(buffer.len(), 10);
assert_eq!(buffer.data(), &[0; 10]);
}
#[test]
fn test_static_buffer_overflow() {
static mut BUF: [u8; 5] = [7; 5];
let buf = unsafe { &mut *core::ptr::addr_of_mut!(BUF) };
let mut buffer = ByteBuffer::new_static(buf, 5).unwrap();
assert!(buffer.grow(10).is_err());
}
#[test]
fn out_of_memory_works() {
let mut buffer = ByteBuffer::new(0).unwrap();
assert!(matches!(
buffer.grow(usize::MAX).unwrap_err(),
MemoryError::OutOfSystemMemory
));
assert_eq!(buffer.len(), 0);
assert_eq!(buffer.data().first(), None);
assert!(buffer.grow(1).is_ok());
assert!(matches!(
buffer.grow(usize::MAX).unwrap_err(),
MemoryError::OutOfSystemMemory
));
assert_eq!(buffer.len(), 1);
assert_eq!(buffer.data().first(), Some(&0x00_u8));
}
}