use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::slice;
use std::sync::Arc;
use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Alignment {
None,
Align16,
Align32,
Align64,
Custom(usize),
}
impl Alignment {
#[must_use]
pub const fn bytes(&self) -> usize {
match self {
Alignment::None => 1,
Alignment::Align16 => 16,
Alignment::Align32 => 32,
Alignment::Align64 => 64,
Alignment::Custom(n) => *n,
}
}
#[must_use]
pub fn is_aligned<T>(&self, ptr: *const T) -> bool {
(ptr as usize) % self.bytes() == 0
}
#[must_use]
pub const fn optimal_for_simd<T>() -> Self {
let size = mem::size_of::<T>();
if size >= 8 {
Alignment::Align64
} else if size >= 4 {
Alignment::Align32
} else {
Alignment::Align16
}
}
}
impl Default for Alignment {
fn default() -> Self {
Alignment::None
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct MemoryLayout {
pub shape: Vec<usize>,
pub strides: Vec<isize>,
pub len: usize,
pub element_size: usize,
pub alignment: Alignment,
pub is_contiguous: bool,
pub is_c_contiguous: bool,
pub is_f_contiguous: bool,
}
impl MemoryLayout {
#[must_use]
pub fn contiguous<T>(shape: &[usize]) -> Self {
let element_size = mem::size_of::<T>();
let len: usize = shape.iter().product();
let strides = Self::compute_c_strides(shape, element_size);
Self {
shape: shape.to_vec(),
strides,
len,
element_size,
alignment: Alignment::optimal_for_simd::<T>(),
is_contiguous: true,
is_c_contiguous: true,
is_f_contiguous: shape.len() <= 1,
}
}
#[must_use]
pub fn fortran_contiguous<T>(shape: &[usize]) -> Self {
let element_size = mem::size_of::<T>();
let len: usize = shape.iter().product();
let strides = Self::compute_f_strides(shape, element_size);
Self {
shape: shape.to_vec(),
strides,
len,
element_size,
alignment: Alignment::optimal_for_simd::<T>(),
is_contiguous: true,
is_c_contiguous: shape.len() <= 1,
is_f_contiguous: true,
}
}
fn compute_c_strides(shape: &[usize], element_size: usize) -> Vec<isize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![0isize; ndim];
strides[ndim - 1] = element_size as isize;
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * (shape[i + 1] as isize);
}
strides
}
fn compute_f_strides(shape: &[usize], element_size: usize) -> Vec<isize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![0isize; ndim];
strides[0] = element_size as isize;
for i in 1..ndim {
strides[i] = strides[i - 1] * (shape[i - 1] as isize);
}
strides
}
#[must_use]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[must_use]
pub fn is_compatible(&self, other: &Self) -> bool {
self.shape == other.shape
&& self.element_size == other.element_size
&& self.is_contiguous
&& other.is_contiguous
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len * self.element_size
}
}
pub trait ContiguousMemory {
fn as_ptr(&self) -> *const u8;
fn layout(&self) -> &MemoryLayout;
fn is_contiguous(&self) -> bool {
self.layout().is_contiguous
}
fn size_bytes(&self) -> usize {
self.layout().size_bytes()
}
}
pub trait ContiguousMemoryMut: ContiguousMemory {
fn as_mut_ptr(&mut self) -> *mut u8;
}
#[derive(Debug)]
pub struct SharedArrayView<'a, T> {
ptr: *const T,
len: usize,
layout: MemoryLayout,
_marker: PhantomData<&'a T>,
}
impl<'a, T> SharedArrayView<'a, T> {
#[must_use]
pub fn from_slice(data: &'a [T]) -> Self {
let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
Self {
ptr: data.as_ptr(),
len: data.len(),
layout,
_marker: PhantomData,
}
}
pub unsafe fn from_raw_parts(ptr: *const T, layout: MemoryLayout) -> Self {
Self {
ptr,
len: layout.len,
layout,
_marker: PhantomData,
}
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub const fn layout(&self) -> &MemoryLayout {
&self.layout
}
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.layout.shape
}
#[must_use]
pub unsafe fn get_unchecked(&self, index: usize) -> &T {
&*self.ptr.add(index)
}
pub fn get(&self, index: usize) -> Option<&T> {
if index < self.len {
Some(unsafe { self.get_unchecked(index) })
} else {
None
}
}
pub fn as_slice(&self) -> Option<&'a [T]> {
if self.layout.is_contiguous {
Some(unsafe { slice::from_raw_parts(self.ptr, self.len) })
} else {
None
}
}
pub fn slice(&self, start: usize, end: usize) -> CoreResult<SharedArrayView<'a, T>> {
if start > end || end > self.len {
return Err(CoreError::ValidationError(
ErrorContext::new(format!(
"Invalid slice range [{start}, {end}) for length {len}",
len = self.len
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
let new_len = end - start;
let new_layout = MemoryLayout::contiguous::<T>(&[new_len]);
Ok(SharedArrayView {
ptr: unsafe { self.ptr.add(start) },
len: new_len,
layout: new_layout,
_marker: PhantomData,
})
}
#[must_use]
pub fn is_simd_aligned(&self) -> bool {
self.layout.alignment.is_aligned(self.ptr)
}
}
impl<'a, T> ContiguousMemory for SharedArrayView<'a, T> {
fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
fn layout(&self) -> &MemoryLayout {
&self.layout
}
}
unsafe impl<T: Send + Sync> Send for SharedArrayView<'_, T> {}
unsafe impl<T: Sync> Sync for SharedArrayView<'_, T> {}
#[derive(Debug)]
pub struct SharedArrayViewMut<'a, T> {
ptr: *mut T,
len: usize,
layout: MemoryLayout,
_marker: PhantomData<&'a mut T>,
}
impl<'a, T> SharedArrayViewMut<'a, T> {
#[must_use]
pub fn from_slice(data: &'a mut [T]) -> Self {
let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
Self {
ptr: data.as_mut_ptr(),
len: data.len(),
layout,
_marker: PhantomData,
}
}
pub unsafe fn from_raw_parts(ptr: *mut T, layout: MemoryLayout) -> Self {
Self {
ptr,
len: layout.len,
layout,
_marker: PhantomData,
}
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub const fn layout(&self) -> &MemoryLayout {
&self.layout
}
pub fn get(&self, index: usize) -> Option<&T> {
if index < self.len {
Some(unsafe { &*self.ptr.add(index) })
} else {
None
}
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
if index < self.len {
Some(unsafe { &mut *self.ptr.add(index) })
} else {
None
}
}
#[must_use]
pub fn as_view(&self) -> SharedArrayView<'_, T> {
SharedArrayView {
ptr: self.ptr,
len: self.len,
layout: self.layout.clone(),
_marker: PhantomData,
}
}
pub fn as_slice(&self) -> Option<&[T]> {
if self.layout.is_contiguous {
Some(unsafe { slice::from_raw_parts(self.ptr, self.len) })
} else {
None
}
}
pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
if self.layout.is_contiguous {
Some(unsafe { slice::from_raw_parts_mut(self.ptr, self.len) })
} else {
None
}
}
}
impl<'a, T> ContiguousMemory for SharedArrayViewMut<'a, T> {
fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
fn layout(&self) -> &MemoryLayout {
&self.layout
}
}
impl<'a, T> ContiguousMemoryMut for SharedArrayViewMut<'a, T> {
fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
unsafe impl<T: Send> Send for SharedArrayViewMut<'_, T> {}
unsafe impl<T: Send + Sync> Sync for SharedArrayViewMut<'_, T> {}
#[derive(Debug)]
pub struct ZeroCopyBuffer {
data: Arc<[u8]>,
layout: MemoryLayout,
type_id: std::any::TypeId,
type_name: &'static str,
}
impl ZeroCopyBuffer {
pub fn from_vec<T: 'static + Clone>(data: Vec<T>) -> Self {
let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
let type_id = std::any::TypeId::of::<T>();
let type_name = std::any::type_name::<T>();
let byte_len = data.len() * mem::size_of::<T>();
let ptr = data.as_ptr() as *const u8;
let bytes = unsafe { slice::from_raw_parts(ptr, byte_len) };
let arc_bytes: Arc<[u8]> = bytes.into();
mem::forget(data);
Self {
data: arc_bytes,
layout,
type_id,
type_name,
}
}
pub fn as_typed<T: 'static>(&self) -> Option<&[T]> {
if std::any::TypeId::of::<T>() != self.type_id {
return None;
}
if !self.layout.is_contiguous {
return None;
}
Some(unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.layout.len) })
}
#[must_use]
pub const fn layout(&self) -> &MemoryLayout {
&self.layout
}
#[must_use]
pub const fn type_name(&self) -> &'static str {
self.type_name
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub const fn len(&self) -> usize {
self.layout.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.layout.len == 0
}
}
impl Clone for ZeroCopyBuffer {
fn clone(&self) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.clone(),
type_id: self.type_id,
type_name: self.type_name,
}
}
}
#[derive(Debug)]
pub struct ZeroCopySlice<'a, T> {
buffer: &'a ZeroCopyBuffer,
start: usize,
end: usize,
_marker: PhantomData<T>,
}
impl<'a, T: 'static> ZeroCopySlice<'a, T> {
pub fn new(buffer: &'a ZeroCopyBuffer, start: usize, end: usize) -> CoreResult<Self> {
if std::any::TypeId::of::<T>() != buffer.type_id {
return Err(CoreError::ValidationError(
ErrorContext::new(format!(
"Type mismatch: buffer is {buf_type}, requested {req_type}",
buf_type = buffer.type_name,
req_type = std::any::type_name::<T>()
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
if start > end || end > buffer.layout.len {
return Err(CoreError::ValidationError(
ErrorContext::new(format!(
"Invalid slice range [{start}, {end}) for length {len}",
len = buffer.layout.len
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
Ok(Self {
buffer,
start,
end,
_marker: PhantomData,
})
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
let full_slice: &[T] = self.buffer.as_typed().expect("Type already validated");
&full_slice[self.start..self.end]
}
#[must_use]
pub const fn len(&self) -> usize {
self.end - self.start
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.start == self.end
}
}
impl<'a, T: 'static> Deref for ZeroCopySlice<'a, T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
#[derive(Debug)]
pub struct ArrayBridge<T> {
data: Vec<T>,
layout: MemoryLayout,
}
impl<T: Clone> ArrayBridge<T> {
#[must_use]
pub fn from_vec(data: Vec<T>) -> Self {
let layout = MemoryLayout::contiguous::<T>(&[data.len()]);
Self { data, layout }
}
#[must_use]
pub fn from_slice(data: &[T]) -> Self {
Self::from_vec(data.to_vec())
}
pub fn with_shape(data: Vec<T>, shape: &[usize]) -> CoreResult<Self> {
let expected_len: usize = shape.iter().product();
if data.len() != expected_len {
return Err(CoreError::ValidationError(
ErrorContext::new(format!(
"Data length {actual} does not match shape {shape:?} (expected {expected})",
actual = data.len(),
expected = expected_len
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
let layout = MemoryLayout::contiguous::<T>(shape);
Ok(Self { data, layout })
}
#[must_use]
pub fn view(&self) -> SharedArrayView<'_, T> {
SharedArrayView::from_slice(&self.data)
}
#[must_use]
pub fn view_mut(&mut self) -> SharedArrayViewMut<'_, T> {
SharedArrayViewMut::from_slice(&mut self.data)
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
&self.data
}
#[must_use]
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
#[must_use]
pub const fn layout(&self) -> &MemoryLayout {
&self.layout
}
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.layout.shape
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn into_vec(self) -> Vec<T> {
self.data
}
pub fn reshape(&mut self, new_shape: &[usize]) -> CoreResult<()> {
let expected_len: usize = new_shape.iter().product();
if self.data.len() != expected_len {
return Err(CoreError::ValidationError(
ErrorContext::new(format!(
"Cannot reshape array of length {} to shape {new_shape:?}",
self.data.len()
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
self.layout = MemoryLayout::contiguous::<T>(new_shape);
Ok(())
}
}
impl<T: Clone> Clone for ArrayBridge<T> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
layout: self.layout.clone(),
}
}
}
impl<T> Deref for ArrayBridge<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<T> DerefMut for ArrayBridge<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
pub type TypedBuffer<T> = ArrayBridge<T>;
pub type BufferRef<'a, T> = SharedArrayView<'a, T>;
pub type BufferMut<'a, T> = SharedArrayViewMut<'a, T>;
pub type BorrowedArray<'a, T> = SharedArrayView<'a, T>;
pub type OwnedArray<T> = ArrayBridge<T>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shared_array_view() {
let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
let view = SharedArrayView::from_slice(&data);
assert_eq!(view.len(), 5);
assert!(!view.is_empty());
assert_eq!(view.get(0), Some(&1.0));
assert_eq!(view.get(4), Some(&5.0));
assert_eq!(view.get(5), None);
}
#[test]
fn test_shared_array_view_slice() {
let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
let view = SharedArrayView::from_slice(&data);
let subview = view.slice(1, 4).expect("Slice should succeed");
assert_eq!(subview.len(), 3);
assert_eq!(subview.get(0), Some(&2.0));
assert_eq!(subview.get(2), Some(&4.0));
}
#[test]
fn test_shared_array_view_mut() {
let mut data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
let mut view = SharedArrayViewMut::from_slice(&mut data);
if let Some(elem) = view.get_mut(2) {
*elem = 10.0;
}
assert_eq!(view.get(2), Some(&10.0));
assert_eq!(data[2], 10.0);
}
#[test]
fn test_memory_layout() {
let layout = MemoryLayout::contiguous::<f64>(&[3, 4]);
assert_eq!(layout.ndim(), 2);
assert_eq!(layout.len, 12);
assert_eq!(layout.element_size, 8);
assert!(layout.is_contiguous);
assert!(layout.is_c_contiguous);
}
#[test]
fn test_array_bridge() {
let data = vec![1, 2, 3, 4, 5, 6];
let mut bridge = ArrayBridge::with_shape(data, &[2, 3]).expect("Shape should be valid");
assert_eq!(bridge.shape(), &[2, 3]);
assert_eq!(bridge.len(), 6);
bridge.reshape(&[3, 2]).expect("Reshape should succeed");
assert_eq!(bridge.shape(), &[3, 2]);
}
#[test]
fn test_zero_copy_buffer() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let buffer = ZeroCopyBuffer::from_vec(data);
assert_eq!(buffer.len(), 4);
assert_eq!(buffer.type_name(), "f32");
let typed: &[f32] = buffer.as_typed().expect("Type should match");
assert_eq!(typed, &[1.0f32, 2.0, 3.0, 4.0]);
let wrong: Option<&[f64]> = buffer.as_typed();
assert!(wrong.is_none());
}
#[test]
fn test_zero_copy_slice() {
let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
let buffer = ZeroCopyBuffer::from_vec(data);
let slice: ZeroCopySlice<'_, f64> =
ZeroCopySlice::new(&buffer, 1, 4).expect("Slice should be valid");
assert_eq!(slice.len(), 3);
assert_eq!(slice.as_slice(), &[2.0, 3.0, 4.0]);
}
#[test]
fn test_alignment() {
assert_eq!(Alignment::None.bytes(), 1);
assert_eq!(Alignment::Align16.bytes(), 16);
assert_eq!(Alignment::Align32.bytes(), 32);
assert_eq!(Alignment::Align64.bytes(), 64);
assert_eq!(Alignment::Custom(128).bytes(), 128);
}
#[test]
fn test_contiguous_memory_trait() {
let data = vec![1.0f64, 2.0, 3.0];
let view = SharedArrayView::from_slice(&data);
assert!(view.is_contiguous());
assert_eq!(view.size_bytes(), 24); }
}