use alloc::vec::Vec;
use burn_backend::Element;
use burn_std::Bytes;
use core::mem;
use ndarray::{ArcArray, ArrayView, IxDyn};
#[derive(Debug)]
pub enum NdArrayStorage<E: Element> {
Borrowed {
bytes: Bytes,
shape: Vec<usize>,
},
Owned(ArcArray<E, IxDyn>),
}
impl<E: Element> Clone for NdArrayStorage<E> {
fn clone(&self) -> Self {
match self {
Self::Borrowed { bytes, shape } => Self::Borrowed {
bytes: bytes.clone(),
shape: shape.clone(),
},
Self::Owned(arr) => Self::Owned(arr.clone()),
}
}
}
impl<E: Element> NdArrayStorage<E> {
pub fn from_borrowed(bytes: Bytes, shape: Vec<usize>) -> Result<Self, (Bytes, Vec<usize>)> {
let ptr = bytes.as_ptr();
if !(ptr as usize).is_multiple_of(mem::align_of::<E>()) {
return Err((bytes, shape));
}
let num_elements = match shape
.iter()
.try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
{
Some(n) => n,
None => return Err((bytes, shape)),
};
let expected_size = match num_elements.checked_mul(mem::size_of::<E>()) {
Some(s) => s,
None => return Err((bytes, shape)),
};
if bytes.len() < expected_size {
return Err((bytes, shape));
}
Ok(Self::Borrowed { bytes, shape })
}
#[inline]
pub fn from_owned(array: ArcArray<E, IxDyn>) -> Self {
Self::Owned(array)
}
#[inline]
pub fn is_unique(&self) -> bool {
match self {
Self::Borrowed { .. } => false, Self::Owned(arr) => arr.is_unique(),
}
}
#[inline]
pub fn view(&self) -> ArrayView<'_, E, IxDyn> {
match self {
Self::Borrowed { bytes, shape } => {
let ptr = bytes.as_ptr() as *const E;
let dim = IxDyn(shape);
unsafe { ArrayView::from_shape_ptr(dim, ptr) }
}
Self::Owned(arr) => arr.view(),
}
}
pub fn into_owned(self) -> ArcArray<E, IxDyn> {
match self {
Self::Borrowed { bytes, shape } => {
let ptr = bytes.as_ptr() as *const E;
let dim = IxDyn(&shape);
let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };
view.to_owned().into_shared()
}
Self::Owned(arr) => arr,
}
}
#[inline]
pub fn into_shared(self) -> ArcArray<E, IxDyn> {
self.into_owned()
}
pub fn shape(&self) -> &[usize] {
match self {
Self::Borrowed { shape, .. } => shape,
Self::Owned(arr) => arr.shape(),
}
}
#[inline]
pub fn ndim(&self) -> usize {
self.shape().len()
}
#[inline]
pub fn len(&self) -> usize {
self.shape().iter().product()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_borrowed(&self) -> bool {
matches!(self, Self::Borrowed { .. })
}
#[inline]
pub fn is_owned(&self) -> bool {
matches!(self, Self::Owned(_))
}
pub fn ensure_owned(&mut self) -> &mut ArcArray<E, IxDyn> {
if let Self::Borrowed { bytes, shape } = self {
let ptr = bytes.as_ptr() as *const E;
let dim = IxDyn(shape);
let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };
*self = Self::Owned(view.to_owned().into_shared());
}
match self {
Self::Owned(arr) => arr,
Self::Borrowed { .. } => unreachable!(),
}
}
}
impl<E: Element> From<ArcArray<E, IxDyn>> for NdArrayStorage<E> {
fn from(array: ArcArray<E, IxDyn>) -> Self {
Self::Owned(array)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use burn_std::Bytes;
#[test]
fn test_borrowed_is_not_unique() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
assert!(!storage.is_unique());
assert!(storage.is_borrowed());
}
#[test]
fn test_owned_unique_when_single_ref() {
let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();
let storage = NdArrayStorage::from_owned(array);
assert!(storage.is_unique());
assert!(storage.is_owned());
}
#[test]
fn test_owned_not_unique_when_cloned() {
let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();
let storage = NdArrayStorage::from_owned(array);
let _clone = storage.clone();
assert!(!storage.is_unique());
}
#[test]
fn test_view_zero_copy() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let view = storage.view();
assert_eq!(view[[0, 0]], 1.0);
assert_eq!(view[[1, 1]], 4.0);
}
#[test]
fn test_into_owned_copies_borrowed() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let owned = storage.into_owned();
assert_eq!(owned[[0, 0]], 1.0);
assert_eq!(owned[[1, 1]], 4.0);
}
#[test]
fn test_from_borrowed_validates_alignment() {
use burn_std::AllocationProperty;
let aligned_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let aligned_bytes = Bytes::from_elems(aligned_data);
assert_eq!(
(aligned_bytes.as_ptr() as usize) % core::mem::align_of::<f32>(),
0,
"Test setup: f32 data should be properly aligned"
);
let result = NdArrayStorage::<f32>::from_borrowed(aligned_bytes, vec![2, 2]);
assert!(
result.is_ok(),
"from_borrowed should succeed for properly aligned data"
);
let buffer: &[u8] = &[0u8; 32];
let shared = bytes::Bytes::from_static(buffer);
let base = shared.as_ptr() as usize;
let align = core::mem::align_of::<f32>();
let misalign_offset = (1..align)
.find(|&off| !(base + off).is_multiple_of(align))
.expect("Should find a misaligned offset");
let sliced = shared.slice(misalign_offset..(misalign_offset + 16));
let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other);
assert_ne!(
(misaligned_bytes.as_ptr() as usize) % align,
0,
"Test setup: sliced data should be misaligned for f32"
);
let result = NdArrayStorage::<f32>::from_borrowed(misaligned_bytes, vec![4]);
assert!(
result.is_err(),
"from_borrowed should return Err for misaligned data"
);
}
#[test]
fn test_insufficient_size_returns_err() {
let data: Vec<f32> = vec![1.0, 2.0]; let bytes = Bytes::from_elems(data);
let result = NdArrayStorage::<f32>::from_borrowed(bytes, vec![4]);
assert!(
result.is_err(),
"from_borrowed should return Err when bytes are too small"
);
}
#[test]
fn test_zero_copy_native_allocation() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let original_ptr = bytes.as_ptr();
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let view = storage.view();
let view_ptr = view.as_ptr() as *const u8;
assert_eq!(
original_ptr, view_ptr,
"ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes"
);
assert_eq!(view[[0, 0]], 1.0);
assert_eq!(view[[0, 1]], 2.0);
assert_eq!(view[[1, 0]], 3.0);
assert_eq!(view[[1, 1]], 4.0);
}
#[test]
fn test_zero_copy_shared_bytes_pointer_identity() {
use burn_std::AllocationProperty;
let data: &[u8] = &[
0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, ];
let shared = bytes::Bytes::from_static(data);
let original_ptr = shared.as_ptr();
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let view_ptr = storage.view().as_ptr() as *const u8;
assert_eq!(
original_ptr, view_ptr,
"ZERO-COPY REGRESSION: SharedBytes view must point to original static data"
);
let cloned = storage.clone();
let cloned_ptr = cloned.view().as_ptr() as *const u8;
assert_eq!(
original_ptr, cloned_ptr,
"ZERO-COPY REGRESSION: SharedBytes clone must share memory"
);
}
#[test]
fn test_clone_borrowed_stays_borrowed() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
let cloned = storage.clone();
assert!(
storage.is_borrowed(),
"ZERO-COPY REGRESSION: original should remain borrowed after clone"
);
assert!(
cloned.is_borrowed(),
"ZERO-COPY REGRESSION: clone should be borrowed type"
);
assert!(
!storage.is_unique(),
"ZERO-COPY REGRESSION: original should not be unique after clone"
);
assert!(
!cloned.is_unique(),
"ZERO-COPY REGRESSION: clone should not be unique"
);
assert_eq!(storage.view(), cloned.view(), "Clone should have same data");
}
#[test]
fn test_zero_copy_triggers_copy_on_mutation() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let original_ptr = bytes.as_ptr();
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
assert!(storage.is_borrowed(), "should start as borrowed");
let owned = storage.into_owned();
let owned_ptr = owned.as_ptr() as *const u8;
assert_ne!(
original_ptr, owned_ptr,
"into_owned() on borrowed data MUST allocate new memory (copy-on-write)"
);
}
#[test]
fn test_borrowed_reports_not_unique() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = Bytes::from_elems(data);
let storage =
NdArrayStorage::<f32>::from_borrowed(bytes, vec![2, 2]).expect("should create");
assert!(
!storage.is_unique(),
"ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \
to trigger copy-on-write. If this is true, mutations will corrupt shared data!"
);
}
}