use super::{
AllocationController, AllocationError, AllocationProperty, SplitError,
default_controller::{MAX_ALIGN, NativeAllocationController},
};
use alloc::boxed::Box;
use core::cell::UnsafeCell;
use core::mem::MaybeUninit;
use core::ptr::NonNull;
use core::sync::atomic::{AtomicBool, Ordering};
pub struct SharedBytesAllocationController {
bytes: bytes::Bytes,
controller: UnsafeCell<Option<Box<dyn AllocationController>>>,
init: AtomicBool,
property: AllocationProperty,
}
impl SharedBytesAllocationController {
pub fn new(bytes: bytes::Bytes, property: AllocationProperty) -> Self {
Self {
bytes,
controller: UnsafeCell::new(None),
init: AtomicBool::new(false),
property,
}
}
fn init_mutable(&self) {
if self.init.load(Ordering::Relaxed) {
return;
}
let data: &[u8] = &self.bytes;
let controller = NativeAllocationController::alloc_with_data(data, MAX_ALIGN)
.unwrap_or_else(|e| {
panic!(
"failed to allocate MAX_ALIGN buffer for copy-on-write (len: {}, error: {:?})",
data.len(),
e
)
});
unsafe {
*self.controller.get() = Some(Box::new(controller));
}
self.init.store(true, Ordering::Relaxed);
}
}
impl AllocationController for SharedBytesAllocationController {
fn alloc_align(&self) -> usize {
MAX_ALIGN
}
fn property(&self) -> AllocationProperty {
self.property
}
fn memory(&self) -> &[MaybeUninit<u8>] {
if self.init.load(Ordering::Relaxed) {
unsafe {
(*self.controller.get())
.as_ref()
.expect("controller must be Some when init is true")
.memory()
}
} else {
let slice: &[u8] = &self.bytes;
unsafe { core::slice::from_raw_parts(slice.as_ptr().cast(), slice.len()) }
}
}
unsafe fn memory_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self.init_mutable();
unsafe {
(*self.controller.get())
.as_mut()
.expect("controller must be Some after init_mutable()")
.memory_mut()
}
}
fn split(
&mut self,
offset: usize,
) -> Result<(Box<dyn AllocationController>, Box<dyn AllocationController>), SplitError> {
if self.init.load(Ordering::Relaxed) {
return Err(SplitError::Unsupported);
}
if offset > self.bytes.len() {
return Err(SplitError::InvalidOffset);
}
let left = self.bytes.slice(..offset);
let right = self.bytes.slice(offset..);
Ok((
Box::new(SharedBytesAllocationController::new(left, self.property)),
Box::new(SharedBytesAllocationController::new(right, self.property)),
))
}
fn duplicate(&self) -> Option<Box<dyn AllocationController>> {
if self.init.load(Ordering::Relaxed) {
return None;
}
Some(Box::new(SharedBytesAllocationController::new(
self.bytes.clone(),
self.property,
)))
}
unsafe fn copy_into(&self, buf: &mut [u8]) {
if self.init.load(Ordering::Relaxed) {
let memory = self.memory();
let copy_len = buf.len().min(memory.len());
let memory_slice = &memory[..copy_len];
let data = unsafe {
core::slice::from_raw_parts(memory_slice.as_ptr().cast(), memory_slice.len())
};
buf[..copy_len].copy_from_slice(data);
} else {
let src: &[u8] = &self.bytes;
let copy_len = buf.len().min(src.len());
buf[..copy_len].copy_from_slice(&src[..copy_len]);
}
}
fn grow(&mut self, _size: usize, _align: usize) -> Result<(), AllocationError> {
Err(AllocationError::UnsupportedOperation)
}
fn try_detach(&mut self) -> Option<NonNull<u8>> {
self.init_mutable();
unsafe {
(*self.controller.get())
.as_mut()
.expect("controller must be Some after init_mutable()")
.try_detach()
}
}
}
#[cfg(test)]
mod tests {
use super::super::Bytes;
use super::*;
#[test_log::test]
fn test_from_static() {
static DATA: &[u8] = &[1, 2, 3, 4, 5];
let shared = bytes::Bytes::from_static(DATA);
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
assert_eq!(&bytes[..], &[1, 2, 3, 4, 5]);
assert_eq!(bytes.len(), 5);
}
#[test_log::test]
fn test_from_vec() {
let shared = bytes::Bytes::from(alloc::vec![10, 20, 30]);
let bytes = Bytes::from_shared(shared, AllocationProperty::Native);
assert_eq!(&bytes[..], &[10, 20, 30]);
assert_eq!(bytes.len(), 3);
}
#[test_log::test]
fn test_split() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4, 5, 6]);
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let (left, right) = bytes.split(3).unwrap();
assert_eq!(&left[..], &[1, 2, 3]);
assert_eq!(&right[..], &[4, 5, 6]);
}
#[test_log::test]
fn test_split_at_zero() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4]);
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let (left, right) = bytes.split(0).unwrap();
assert_eq!(left.len(), 0);
assert_eq!(&right[..], &[1, 2, 3, 4]);
}
#[test_log::test]
fn test_split_at_len() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4]);
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let len = bytes.len();
let (left, right) = bytes.split(len).unwrap();
assert_eq!(&left[..], &[1, 2, 3, 4]);
assert_eq!(right.len(), 0);
}
#[test_log::test]
fn test_duplicate() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let controller = SharedBytesAllocationController::new(shared, AllocationProperty::Other);
let dup = controller.duplicate().expect("duplicate should succeed");
assert_eq!(dup.memory().len(), 3);
}
#[test_log::test]
fn test_copy_into() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4, 5]);
let controller = SharedBytesAllocationController::new(shared, AllocationProperty::Other);
let mut buf = [0u8; 3];
unsafe { controller.copy_into(&mut buf) };
assert_eq!(buf, [1, 2, 3]);
}
#[test_log::test]
fn test_property_file() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let controller = SharedBytesAllocationController::new(shared, AllocationProperty::File);
assert!(matches!(controller.property(), AllocationProperty::File));
}
#[test_log::test]
fn test_bytes_from_shared_with_file_property() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4]);
let bytes = Bytes::from_shared(shared, AllocationProperty::File);
assert!(matches!(bytes.property(), AllocationProperty::File));
assert_eq!(&bytes[..], &[1, 2, 3, 4]);
}
#[test_log::test]
fn test_split_preserves_property() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4, 5, 6]);
let bytes = Bytes::from_shared(shared, AllocationProperty::File);
let (left, right) = bytes.split(3).unwrap();
assert!(matches!(left.property(), AllocationProperty::File));
assert!(matches!(right.property(), AllocationProperty::File));
}
#[test_log::test]
fn test_duplicate_preserves_property() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let bytes = Bytes::from_shared(shared, AllocationProperty::File);
let cloned = bytes.clone();
assert!(matches!(cloned.property(), AllocationProperty::File));
}
#[test_log::test]
fn test_alignment_reports_max_align() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let controller = SharedBytesAllocationController::new(shared, AllocationProperty::Other);
assert_eq!(controller.alloc_align(), MAX_ALIGN);
}
#[test_log::test]
fn test_grow_fails() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let mut controller =
SharedBytesAllocationController::new(shared, AllocationProperty::Other);
let result = controller.grow(100, 1);
assert!(matches!(result, Err(AllocationError::UnsupportedOperation)));
}
#[test_log::test]
fn test_try_detach_always_succeeds() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4]);
let mut controller =
SharedBytesAllocationController::new(shared, AllocationProperty::Other);
let ptr = controller.try_detach();
assert!(ptr.is_some(), "try_detach should always succeed");
if let Some(ptr) = ptr {
unsafe {
let capacity = 4usize.next_multiple_of(MAX_ALIGN);
let layout = core::alloc::Layout::from_size_align(capacity, MAX_ALIGN)
.expect("valid layout");
alloc::alloc::dealloc(ptr.as_ptr(), layout);
}
}
}
#[test_log::test]
fn test_try_into_vec_succeeds_for_u8() {
let bytes = Bytes::from_elems(alloc::vec![1u8, 2, 3, 4]);
let vec = bytes.try_into_vec::<u8>().expect("alignment matches");
assert_eq!(vec, alloc::vec![1, 2, 3, 4]);
}
#[test_log::test]
fn test_try_into_vec_succeeds_for_f32() {
let bytes = Bytes::from_elems(alloc::vec![1.0f32, 2.0, 3.0, 4.0]);
let vec = bytes.try_into_vec::<f32>().expect("alignment matches");
assert_eq!(vec, alloc::vec![1.0f32, 2.0, 3.0, 4.0]);
}
#[test_log::test]
fn test_try_into_vec_succeeds_for_f64() {
let bytes = Bytes::from_elems(alloc::vec![1.0f64, 2.0]);
let vec = bytes.try_into_vec::<f64>().expect("alignment matches");
assert_eq!(vec, alloc::vec![1.0f64, 2.0]);
}
#[test_log::test]
fn test_try_into_vec_fails_for_shared_bytes() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4]);
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
assert!(bytes.try_into_vec::<u8>().is_err());
}
#[test_log::test]
fn test_copy_on_write() {
let shared = bytes::Bytes::from_static(&[1, 2, 3, 4, 5]);
let mut bytes = Bytes::from_shared(shared, AllocationProperty::Other);
bytes[0] = 99;
assert_eq!(bytes[0], 99);
assert_eq!(&bytes[1..], &[2, 3, 4, 5]);
}
#[test_log::test]
fn test_clone_before_mutation_is_cheap() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let cloned = bytes.clone();
assert_eq!(&bytes[..], &cloned[..]);
}
#[test_log::test]
fn test_clone_after_mutation_copies() {
let shared = bytes::Bytes::from_static(&[1, 2, 3]);
let mut bytes = Bytes::from_shared(shared, AllocationProperty::Other);
bytes[0] = 99;
let cloned = bytes.clone();
assert_eq!(&bytes[..], &cloned[..]);
assert_eq!(cloned[0], 99);
}
#[test_log::test]
fn test_slices_from_static_region() {
static EMBEDDED_DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let shared = bytes::Bytes::from_static(EMBEDDED_DATA);
let first_4 = shared.slice(0..4);
let last_6 = shared.slice(4..10);
let bytes_first = Bytes::from_shared(first_4, AllocationProperty::Other);
let bytes_last = Bytes::from_shared(last_6, AllocationProperty::Other);
assert_eq!(&bytes_first[..], &[1, 2, 3, 4]);
assert_eq!(&bytes_last[..], &[5, 6, 7, 8, 9, 10]);
assert_eq!(bytes_first.len(), 4);
assert_eq!(bytes_last.len(), 6);
assert_eq!(bytes_first.align(), MAX_ALIGN);
assert_eq!(bytes_last.align(), MAX_ALIGN);
}
#[test_log::test]
fn test_multiple_slices_share_underlying_data() {
static DATA: &[u8] = &[0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE];
let shared = bytes::Bytes::from_static(DATA);
let slice_a = Bytes::from_shared(shared.slice(0..4), AllocationProperty::Other); let slice_b = Bytes::from_shared(shared.slice(4..8), AllocationProperty::Other); let slice_c = Bytes::from_shared(shared.slice(2..6), AllocationProperty::Other);
assert_eq!(&slice_a[..], &[0xDE, 0xAD, 0xBE, 0xEF]);
assert_eq!(&slice_b[..], &[0xCA, 0xFE, 0xBA, 0xBE]);
assert_eq!(&slice_c[..], &[0xBE, 0xEF, 0xCA, 0xFE]);
}
}